diff --git a/.python-version b/.python-version new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ + diff --git a/CHANGELOG.md b/CHANGELOG.md index 49d2aae..e6a7cd1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,113 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.1.10] - 2025-07-20 + +### Added +- Ray distributed processing support for parallel symbol table generation (addresses [#16](https://github.com/codellm-devkit/codeanalyzer-python/issues/16)) +- `--ray/--no-ray` CLI flag to enable/disable Ray-based distributed analysis +- `--skip-tests/--include-tests` CLI flag to control whether test files are analyzed (improves analysis performance) +- `--file-name` CLI flag for single file analysis (addresses part of [#16](https://github.com/codellm-devkit/codeanalyzer-python/issues/16)) +- Incremental caching system with SHA256-based file change detection + - Automatic caching of analysis results to `analysis_cache.json` + - File-level caching with content hash validation to avoid re-analyzing unchanged files + - Significant performance improvements for subsequent analysis runs + - Cache reuse statistics logging +- Custom exception classes for better error handling in symbol table building: + - `SymbolTableBuilderException` (base exception) + - `SymbolTableBuilderFileNotFoundError` (file not found errors) + - `SymbolTableBuilderParsingError` (parsing errors) + - `SymbolTableBuilderRayError` (Ray processing errors) +- Enhanced PyModule schema with metadata fields for caching: + - `last_modified` timestamp tracking + - `content_hash` for precise change detection +- Progress bar support for both serial and parallel processing modes +- Enhanced test fixtures including xarray project for comprehensive testing +- Comprehensive `__init__.py` exports for syntactic analysis module +- Smart dependency installation with conditional logic: + - Only installs requirements files when they exist (requirements.txt, requirements-dev.txt, dev-requirements.txt, test-requirements.txt) + - Only performs editable installation when package definition files are present (pyproject.toml, setup.py, setup.cfg) + - Improved virtual environment setup with better dependency detection and installation logic + +### Changed +- **BREAKING CHANGE**: Updated Python version requirement from `>=3.10` to `>=3.9` for broader compatibility (closes [#17](https://github.com/codellm-devkit/codeanalyzer-python/issues/17)) +- **BREAKING CHANGE**: Updated dependency versions with more conservative constraints for better stability: + - `pydantic` downgraded from `>=2.11.7` to `>=1.8.0,<2.0.0` for stability + - `pandas` constrained to `>=1.3.0,<2.0.0` + - `numpy` constrained to `>=1.21.0,<1.24.0` + - `rich` constrained to `>=12.6.0,<14.0.0` + - `typer` constrained to `>=0.9.0,<1.0.0` + - Other dependencies updated with conservative version ranges for better compatibility +- Major Architecture Enhancement: Complete rewrite of analysis caching system + - `analyze()` method now implements intelligent caching with PyApplication serialization + - Symbol table building redesigned to support incremental updates and cache reuse + - File change detection using SHA256 content hashing for maximum accuracy +- Enhanced `Codeanalyzer` constructor signature to accept `file_name` parameter for single file analysis +- Refactored symbol table building from monolithic `build()` method to cache-aware file-level processing +- Enhanced `Codeanalyzer` constructor signature to accept `skip_tests` and `using_ray` parameters +- Improved error handling with proper context managers in core analyzer +- Updated CLI to use Pydantic v1 compatible JSON serialization methods +- Reorganized syntactic analysis module structure with proper exception handling and exports +- Enhanced virtual environment detection with better fallback mechanisms +- Symbol table builder now sets metadata fields (`last_modified`, `content_hash`) for all PyModule objects + +### Fixed +- Fixed critical symbol table bug for nested functions (closes [#15](https://github.com/codellm-devkit/codeanalyzer-python/issues/15)) + - Corrected `_callables()` method recursion logic to properly capture both outer and inner functions + - Previously, only inner/nested functions were being captured in the symbol table + - Now correctly processes module-level functions, class methods, and all nested function definitions +- Fixed nested method/function signature generation in symbol table builder + - Corrected `_callables()` method to properly build fully qualified signatures for nested structures + - Fixed issue where nested functions and methods were getting incorrect signatures (e.g., `main.__init__` instead of `main.outer_function.NestedClass.__init__`) + - Added `prefix` parameter to `_callables()` and `_add_class()` methods to maintain proper nesting context + - Signatures now correctly reflect the full nested hierarchy (e.g., `main.outer_function.NestedClass.nested_class_method.method_nested_function`) + - Updated class method processing to pass class signature as prefix to nested callable processing + - Improved path relativization to project directory for cleaner signature generation +- Fixed Pydantic v2 compatibility issues by reverting to v1 API (`json()` instead of `model_dump_json()`) +- Fixed missing import statements and type annotations throughout the codebase +- Fixed symbol table builder to support individual file processing for distributed execution +- Improved error handling in virtual environment detection and Python interpreter resolution +- Fixed schema type annotations to use proper string keys for better serialization +- Enhanced import ordering and removed unnecessary blank lines in CLI module +- Improved virtual environment setup reliability: + - Fixed unnecessary pip installs by adding conditional logic to only install when dependencies are available + - Only attempts to install requirements files if they actually exist in the project + - Only performs editable installation when package definition files are present + - Prevents errors and warnings from attempting to install non-existent dependencies + +### Technical Details +- Added Ray as a core dependency for distributed computing capabilities (addresses [#16](https://github.com/codellm-devkit/codeanalyzer-python/issues/16)) +- Implemented `@ray.remote` decorator for parallel file processing +- Comprehensive caching system implementation: + - `_load_pyapplication_from_cache()` and `_save_analysis_cache()` methods for PyApplication serialization + - `_file_unchanged()` method with SHA256 content hash validation + - Cache-aware symbol table building with selective file processing + - Automatic cache statistics and performance reporting +- Enhanced progress tracking for both serial and parallel execution modes with Rich progress bars +- Updated schema to use `Dict[str, PyModule]` instead of `dict[Path, PyModule]` for better serialization +- Extended PyModule schema with optional `last_modified` and `content_hash` fields for caching metadata +- Added comprehensive exception hierarchy for better error classification and handling +- Refactored symbol table building into modular, file-level processing suitable for distribution +- Enhanced Python interpreter detection with support for multiple version managers (pyenv, conda, asdf) +- Added `hashlib` integration for file content hashing throughout the codebase +- Enhanced virtual environment setup logic: + - Modified `_add_class()` method to accept `prefix` parameter and pass class signature to method processing + - Updated `_callables()` method signature to include `prefix` parameter for nested context tracking + - Enhanced signature building logic to use prefix when available, falling back to Jedi resolution for top-level definitions + - Fixed recursive calls to pass current signature as prefix for proper nesting hierarchy + - Implemented conditional dependency installation with existence checks for requirements files and package definition files + +### Notes +- This release significantly addresses the performance improvements requested in [#16](https://github.com/codellm-devkit/codeanalyzer-python/issues/16): + - ✅ Ray parallelization implemented + - ✅ Incremental caching with SHA256-based change detection implemented + - ✅ `--file-name` option for single-file analysis implemented + - ❌ `--nproc` options not yet included (still uses all available cores with Ray) +- ✅ Critical bug fix for nested function detection ([#15](https://github.com/codellm-devkit/codeanalyzer-python/issues/15)) is now included in this version +- Expected performance improvements: 2-10x faster on subsequent runs depending on code change frequency +- Enhanced symbol table accuracy ensures all function definitions are properly captured +- Virtual environment setup is now more robust and only installs dependencies when they are actually available + ## [0.1.9] - 2025-07-14 ### Fixed diff --git a/codeanalyzer/__main__.py b/codeanalyzer/__main__.py index ab400e3..5daf87b 100644 --- a/codeanalyzer/__main__.py +++ b/codeanalyzer/__main__.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Annotated, Optional +from typing import Optional, Annotated import typer @@ -7,7 +7,6 @@ from codeanalyzer.utils import _set_log_level, logger from codeanalyzer.config import OutputFormat - def main( input: Annotated[ Path, typer.Option("-i", "--input", help="Path to the project root directory.") @@ -32,6 +31,12 @@ def main( using_codeql: Annotated[ bool, typer.Option("--codeql/--no-codeql", help="Enable CodeQL-based analysis.") ] = False, + using_ray: Annotated[ + bool, + typer.Option( + "--ray/--no-ray", help="Enable Ray for distributed analysis." + ), + ] = False, rebuild_analysis: Annotated[ bool, typer.Option( @@ -39,18 +44,32 @@ def main( help="Enable eager or lazy analysis. Defaults to lazy.", ), ] = False, + skip_tests: Annotated[ + bool, + typer.Option( + "--skip-tests/--include-tests", + help="Skip test files in analysis.", + ), + ] = True, + file_name: Annotated[ + Optional[Path], + typer.Option( + "--file-name", + help="Analyze only the specified file (relative to input directory).", + ), + ] = None, cache_dir: Annotated[ Optional[Path], typer.Option( "-c", "--cache-dir", - help="Directory to store analysis cache.", + help="Directory to store analysis cache. Defaults to '.codeanalyzer' in the input directory.", ), ] = None, clear_cache: Annotated[ bool, - typer.Option("--clear-cache/--keep-cache", help="Clear cache after analysis."), - ] = True, + typer.Option("--clear-cache/--keep-cache", help="Clear cache after analysis. By default, cache is retained."), + ] = False, verbosity: Annotated[ int, typer.Option("-v", count=True, help="Increase verbosity: -v, -vv, -vvv") ] = 0, @@ -62,21 +81,28 @@ def main( logger.error(f"Input path '{input}' does not exist.") raise typer.Exit(code=1) + # Validate file_name if provided + if file_name is not None: + full_file_path = input / file_name + if not full_file_path.exists(): + logger.error(f"Specified file '{file_name}' does not exist in '{input}'.") + raise typer.Exit(code=1) + if not full_file_path.is_file(): + logger.error(f"Specified path '{file_name}' is not a file.") + raise typer.Exit(code=1) + if not str(file_name).endswith('.py'): + logger.error(f"Specified file '{file_name}' is not a Python file (.py).") + raise typer.Exit(code=1) + with Codeanalyzer( - input, analysis_level, using_codeql, rebuild_analysis, cache_dir, clear_cache + input, analysis_level, skip_tests, using_codeql, rebuild_analysis, cache_dir, clear_cache, using_ray, file_name ) as analyzer: artifacts = analyzer.analyze() # Handle output based on format if output is None: # Output to stdout (only for JSON) - if format == OutputFormat.JSON: - print(artifacts.model_dump_json(separators=(",", ":"))) - else: - logger.error( - f"Format '{format.value}' requires an output directory (use -o/--output)" - ) - raise typer.Exit(code=1) + print(artifacts.json(separators=(",", ":"))) else: # Output to file output.mkdir(parents=True, exist_ok=True) @@ -88,7 +114,7 @@ def _write_output(artifacts, output_dir: Path, format: OutputFormat): if format == OutputFormat.JSON: output_file = output_dir / "analysis.json" # Use Pydantic's json() with separators for compact output - json_str = artifacts.model_dump_json(indent=None) + json_str = artifacts.json(indent=None) with output_file.open("w") as f: f.write(json_str) logger.info(f"Analysis saved to {output_file}") diff --git a/codeanalyzer/core.py b/codeanalyzer/core.py index b135584..7d89126 100644 --- a/codeanalyzer/core.py +++ b/codeanalyzer/core.py @@ -4,13 +4,39 @@ import subprocess import sys from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Union, List -from codeanalyzer.schema.py_schema import PyApplication, PyModule +import ray +from codeanalyzer.utils import logger +from codeanalyzer.schema import PyApplication, PyModule from codeanalyzer.semantic_analysis.codeql import CodeQLLoader from codeanalyzer.semantic_analysis.codeql.codeql_exceptions import CodeQLExceptions +from codeanalyzer.syntactic_analysis.exceptions import SymbolTableBuilderRayError from codeanalyzer.syntactic_analysis.symbol_table_builder import SymbolTableBuilder -from codeanalyzer.utils import logger +from codeanalyzer.utils import ProgressBar + +@ray.remote +def _process_file_with_ray(py_file: Union[Path, str], project_dir: Union[Path, str], virtualenv: Union[Path, str, None]) -> Dict[str, PyModule]: + """Processes files in the project directory using Ray for distributed processing. + + Args: + py_file (Union[Path, str]): Path to the Python file to process. + project_dir (Union[Path, str]): Path to the project directory. + virtualenv (Union[Path, str, None]): Path to the virtual environment directory. + Returns: + Dict[str, PyModule]: A dictionary mapping file paths to PyModule objects. + """ + from rich.console import Console + console = Console() + module_map: Dict[str, PyModule] = {} + try: + py_file = Path(py_file) + symbol_table_builder = SymbolTableBuilder(project_dir, virtualenv) + module_map[str(py_file)] = symbol_table_builder.build_pymodule_from_file(py_file) + except Exception as e: + console.log(f"❌ Failed to process {py_file}: {e}") + raise SymbolTableBuilderRayError(f"Ray processing error for {py_file}: {e}") + return module_map class Codeanalyzer: @@ -28,14 +54,18 @@ class Codeanalyzer: def __init__( self, project_dir: Union[str, Path], - analysis_depth: int = 1, - using_codeql: bool = False, - rebuild_analysis: bool = False, - cache_dir: Optional[Path] = None, - clear_cache: bool = True, + analysis_depth: int, + skip_tests: bool, + using_codeql: bool, + rebuild_analysis: bool, + cache_dir: Optional[Path], + clear_cache: bool, + using_ray: bool, + file_name: Optional[Path] = None, ) -> None: self.analysis_depth = analysis_depth self.project_dir = Path(project_dir).resolve() + self.skip_tests = skip_tests self.using_codeql = using_codeql self.rebuild_analysis = rebuild_analysis self.cache_dir = ( @@ -45,10 +75,12 @@ def __init__( self.db_path: Optional[Path] = None self.codeql_bin: Optional[Path] = None self.virtualenv: Optional[Path] = None + self.using_ray: bool = using_ray + self.file_name: Optional[Path] = file_name @staticmethod def _cmd_exec_helper( - cmd: list[str], + cmd: List[str], cwd: Optional[Path] = None, capture_output: bool = True, check: bool = True, @@ -126,7 +158,8 @@ def _get_base_interpreter() -> Path: # We're inside a virtual environment; need to find the base interpreter # First, check if user explicitly set SYSTEM_PYTHON - if system_python := os.getenv("SYSTEM_PYTHON"): + system_python = os.getenv("SYSTEM_PYTHON") + if system_python: system_python_path = Path(system_python) if system_python_path.exists() and system_python_path.is_file(): return system_python_path @@ -142,14 +175,16 @@ def _get_base_interpreter() -> Path: # Use shutil.which to find python3 and python in PATH for python_name in ["python3", "python"]: - if python_path := shutil.which(python_name): + python_path = shutil.which(python_name) + if python_path: candidate = Path(python_path) # Skip if this is the current virtual environment's python if not str(candidate).startswith(sys.prefix): python_candidates.append(candidate) # Check pyenv installation - if pyenv_root := os.getenv("PYENV_ROOT"): + pyenv_root = os.getenv("PYENV_ROOT") + if pyenv_root: pyenv_python = Path(pyenv_root) / "shims" / "python" if pyenv_python.exists(): python_candidates.append(pyenv_python) @@ -160,15 +195,17 @@ def _get_base_interpreter() -> Path: python_candidates.append(home_pyenv) # Check conda base environment - if conda_prefix := os.getenv( - "CONDA_PREFIX_1" - ): # Original conda env before activation - conda_python = Path(conda_prefix) / "bin" / "python" + conda_base = os.getenv("CONDA_PREFIX") + if conda_base: + conda_python = Path(conda_base) / "bin" / "python" if conda_python.exists(): python_candidates.append(conda_python) # Check asdf - if asdf_dir := os.getenv("ASDF_DIR"): + asdf_dir = os.getenv("ASDF_DIR") + # If ASDF_DIR is set, use its shims directory + # Otherwise, check if asdf is installed in the default location + if asdf_dir: asdf_python = Path(asdf_dir) / "shims" / "python" if asdf_python.exists(): python_candidates.append(asdf_python) @@ -211,14 +248,61 @@ def __enter__(self) -> "Codeanalyzer": # Find python in the virtual environment venv_python = venv_path / "bin" / "python" - # Install the project itself (reads pyproject.toml) - self._cmd_exec_helper( - [str(venv_python), "-m", "pip", "install", "-U", f"{self.project_dir}"], - cwd=self.project_dir, - check=True, - ) - # Install the project dependencies - self.virtualenv = venv_path + # First, install dependencies from various dependency files + dependency_files = [ + ("requirements.txt", ["-r"]), + ("requirements-dev.txt", ["-r"]), + ("dev-requirements.txt", ["-r"]), + ("test-requirements.txt", ["-r"]), + ] + + for dep_file, pip_args in dependency_files: + if (self.project_dir / dep_file).exists(): + logger.info(f"Installing dependencies from {dep_file}") + self._cmd_exec_helper( + [str(venv_python), "-m", "pip", "install", "-U"] + pip_args + [str(self.project_dir / dep_file)], + cwd=self.project_dir, + check=True, + ) + + # Handle Pipenv files + if (self.project_dir / "Pipfile").exists(): + logger.info("Installing dependencies from Pipfile") + # Note: This would require pipenv to be installed + self._cmd_exec_helper( + [str(venv_python), "-m", "pip", "install", "pipenv"], + cwd=self.project_dir, + check=True, + ) + self._cmd_exec_helper( + ["pipenv", "install", "--dev"], + cwd=self.project_dir, + check=True, + ) + + # Handle conda environment files + conda_files = ["conda.yml", "environment.yml"] + for conda_file in conda_files: + if (self.project_dir / conda_file).exists(): + logger.info(f"Found {conda_file} - note that conda environments should be handled outside this tool") + break + + # Now install the project itself in editable mode (only if package definition exists) + package_definition_files = [ + "pyproject.toml", # Modern Python packaging (PEP 518/621) + "setup.py", # Traditional setuptools + "setup.cfg", # Setup configuration + ] + + if any((self.project_dir / file).exists() for file in package_definition_files): + logger.info("Installing project in editable mode") + self._cmd_exec_helper( + [str(venv_python), "-m", "pip", "install", "-e", str(self.project_dir)], + cwd=self.project_dir, + check=True, + ) + else: + logger.warning("No package definition files found, skipping editable installation") if self.using_codeql: logger.info(f"(Re-)initializing CodeQL analysis for {self.project_dir}") @@ -280,14 +364,95 @@ def is_cache_valid() -> bool: return self - def __exit__(self, exc_type, exc_val, exc_tb) -> None: + def __exit__(self, *args, **kwargs) -> None: if self.clear_cache and self.cache_dir.exists(): logger.info(f"Clearing cache directory: {self.cache_dir}") shutil.rmtree(self.cache_dir) def analyze(self) -> PyApplication: - """Return the path to the CodeQL database.""" - return PyApplication.builder().symbol_table(self._build_symbol_table()).build() + """Analyze the project and return a PyApplication with symbol table. + + Uses caching to avoid re-analyzing unchanged files. + """ + cache_file = self.cache_dir / "analysis_cache.json" + + # Try to load existing cached analysis + cached_pyapplication = None + if not self.rebuild_analysis and cache_file.exists(): + try: + cached_pyapplication = self._load_pyapplication_from_cache(cache_file) + logger.info("Loaded cached analysis") + except Exception as e: + logger.warning(f"Failed to load cache: {e}. Rebuilding analysis.") + cached_pyapplication = None + + # Build symbol table from cached application if available (if no available, the build a new one) + symbol_table = self._build_symbol_table(cached_pyapplication.symbol_table if cached_pyapplication else {}) + + # Recreate pyapplication + app = PyApplication.builder().symbol_table(symbol_table).build() + + # Save to cache + self._save_analysis_cache(app, cache_file) + + return app + + def _load_pyapplication_from_cache(self, cache_file: Path) -> PyApplication: + """Load cached analysis from file. + + Args: + cache_file: Path to the cache file + + Returns: + PyApplication: The cached application data + """ + with cache_file.open('r') as f: + data = f.read() + return PyApplication.parse_raw(data) + + def _save_analysis_cache(self, app: PyApplication, cache_file: Path) -> None: + """Save analysis to cache file. + + Args: + app: The PyApplication to cache + cache_file: Path to save the cache file + """ + # Ensure cache directory exists + cache_file.parent.mkdir(parents=True, exist_ok=True) + + with cache_file.open('w') as f: + f.write(app.json(indent=2)) + + logger.info(f"Analysis cached to {cache_file}") + + def _file_unchanged(self, file_path: Path, cached_module: PyModule) -> bool: + """Check if a file has changed since it was cached. + + Args: + file_path: Path to the file to check + cached_module: The cached PyModule for this file + + Returns: + bool: True if file is unchanged, False otherwise + """ + try: + # Check last modified time and file size + if (cached_module.last_modified is not None and + cached_module.file_size is not None and + cached_module.last_modified == file_path.stat().st_mtime and + cached_module.file_size == file_path.stat().st_size): + return True + # Also check content hash for extra safety + if cached_module.content_hash is not None: + content_hash = hashlib.sha256(file_path.read_bytes()).hexdigest() + return content_hash == cached_module.content_hash + + # No cached metadata mismatch, assume file changed + return False + + except Exception as e: + logger.debug(f"Error checking file {file_path}: {e}") + return False def _compute_checksum(self, root: Path) -> str: """Compute SHA256 checksum of all Python source files in a project directory. If somethings changes, the @@ -304,11 +469,131 @@ def _compute_checksum(self, root: Path) -> str: sha256.update(py_file.read_bytes()) return sha256.hexdigest() - def _build_symbol_table(self) -> Dict[str, PyModule]: - """Retrieve a symbol table of the whole project.""" - return SymbolTableBuilder(self.project_dir, self.virtualenv).build() + def _build_symbol_table(self, cached_symbol_table: Optional[Dict[str, PyModule]] = None) -> Dict[str, PyModule]: + """Builds the symbol table for the project. + + This method scans the project directory, identifies Python files, + and constructs a symbol table containing information about classes, + functions, and variables defined in those files. + + Args: + cached_app: Previously cached PyApplication to reuse unchanged files + + Returns: + Dict[str, PyModule]: A dictionary mapping file paths to PyModule objects. + """ + symbol_table: Dict[str, PyModule] = {} + + # Handle single file analysis + if self.file_name is not None: + single_file = self.project_dir / self.file_name + logger.info(f"Analyzing single file: {single_file}") + + # Check if file is in cache and unchanged + file_key = str(single_file) + if file_key in cached_symbol_table and not self.rebuild_analysis: + # Compute file checksum to see if it changed + if self._file_unchanged(single_file, cached_symbol_table[file_key]): + logger.info(f"Using cached analysis for {single_file}") + symbol_table[file_key] = cached_symbol_table[file_key] + return symbol_table + + # File is new or changed, analyze it + try: + symbol_table_builder = SymbolTableBuilder(self.project_dir, self.virtualenv) + py_module = symbol_table_builder.build_pymodule_from_file(single_file) + symbol_table[file_key] = py_module + logger.info("✅ Single file analysis complete.") + return symbol_table + except Exception as e: + logger.error(f"Failed to process {single_file}: {e}") + return symbol_table + + # Get all Python files first to show accurate progress + py_files = [] + for py_file in self.project_dir.rglob("*.py"): + rel_path = py_file.relative_to(self.project_dir) + path_parts = rel_path.parts + filename = py_file.name + + # Skip directories we don't care about + if ( + "site-packages" in path_parts + or ".venv" in path_parts + or ".codeanalyzer" in path_parts + ): + continue + + # Skip test files if enabled + if self.skip_tests and ( + "test" in path_parts + or "tests" in path_parts + or filename.startswith("test_") + or filename.endswith("_test.py") + ): + continue + + py_files.append(py_file) + + if self.using_ray: + logger.info("Using Ray for distributed symbol table generation.") + # Separate files into cached and new/changed + files_to_process = [] + for py_file in py_files: + file_key = str(py_file) + if file_key in cached_symbol_table and not self.rebuild_analysis: + if self._file_unchanged(py_file, cached_symbol_table[file_key]): + # Use cached version + symbol_table[file_key] = cached_symbol_table[file_key] + continue + files_to_process.append(py_file) + + # Process only new/changed files with Ray + if files_to_process: + futures = [_process_file_with_ray.remote(py_file, self.project_dir, str(self.virtualenv) if self.virtualenv else None) for py_file in files_to_process] + + with ProgressBar(len(futures), "Building symbol table (parallel)") as progress: + pending = futures[:] + while pending: + done, pending = ray.wait(pending, num_returns=1) + result = ray.get(done[0]) + if result: + symbol_table.update(result) + progress.advance() + else: + logger.info("Building symbol table serially.") + symbol_table_builder = SymbolTableBuilder(self.project_dir, self.virtualenv) + files_processed = 0 + files_from_cache = 0 + + with ProgressBar(len(py_files), "Building symbol table") as progress: + for py_file in py_files: + file_key = str(py_file) + + # Check if file is cached and unchanged + if file_key in cached_symbol_table and not self.rebuild_analysis: + if self._file_unchanged(py_file, cached_symbol_table[file_key]): + symbol_table[file_key] = cached_symbol_table[file_key] + files_from_cache += 1 + progress.advance() + continue + + # File is new or changed, analyze it + try: + py_module = symbol_table_builder.build_pymodule_from_file(py_file) + symbol_table[file_key] = py_module + files_processed += 1 + except Exception as e: + logger.error(f"Failed to process {py_file}: {e}") + progress.advance() + + if files_from_cache > 0: + logger.info(f"Reused {files_from_cache} files from cache, processed {files_processed} new/changed files") + + logger.info("✅ Symbol table generation complete.") + return symbol_table def _get_call_graph(self) -> Dict[str, Any]: """Retrieve call graph from CodeQL database.""" logger.warning("Call graph extraction not yet implemented.") - return {} + return {} \ No newline at end of file diff --git a/codeanalyzer/schema/__init__.py b/codeanalyzer/schema/__init__.py index 9ef2090..3756e9e 100644 --- a/codeanalyzer/schema/__init__.py +++ b/codeanalyzer/schema/__init__.py @@ -21,3 +21,13 @@ "PyClassAttribute", "PyCallableParameter", ] + +# Resolve forward references +PyCallable.update_forward_refs(PyClass=PyClass) +PyClass.update_forward_refs(PyCallable=PyCallable) +PyModule.update_forward_refs(PyCallable=PyCallable, PyClass=PyClass) +PyApplication.update_forward_refs( + PyCallable=PyCallable, + PyClass=PyClass, + PyModule=PyModule +) \ No newline at end of file diff --git a/codeanalyzer/schema/py_schema.py b/codeanalyzer/schema/py_schema.py index 19be7d2..62f3a8d 100644 --- a/codeanalyzer/schema/py_schema.py +++ b/codeanalyzer/schema/py_schema.py @@ -19,7 +19,7 @@ This module defines the data models used to represent Python code structures for static analysis purposes. """ - +from __future__ import annotations import inspect from pathlib import Path from typing import Any, Dict, List, Optional @@ -148,7 +148,8 @@ def method(self, value): method.__name__ = f"{f}" method.__annotations__ = {"value": t, "return": builder_name} - method.__doc__ = f"Set {f} ({t.__name__})" + # Check if 't' has '__name__' attribute, otherwise use a fallback + method.__doc__ = f"Set {f} ({getattr(t, '__name__', str(t))})" return method namespace[f"{field}"] = make_method() @@ -275,12 +276,16 @@ class PyCallable(BaseModel): code_start_line: int = -1 accessed_symbols: List[PySymbol] = [] call_sites: List[PyCallsite] = [] + inner_callables: Dict[str, "PyCallable"] = {} + inner_classes: Dict[str, "PyClass"] = {} local_variables: List[PyVariableDeclaration] = [] cyclomatic_complexity: int = 0 def __hash__(self) -> int: """Generate a hash based on the callable's signature.""" return hash(self.signature) + + @builder @@ -328,6 +333,10 @@ class PyModule(BaseModel): classes: Dict[str, PyClass] = {} functions: Dict[str, PyCallable] = {} variables: List[PyVariableDeclaration] = [] + # Metadata for caching + content_hash: Optional[str] = None + last_modified: Optional[float] = None + file_size: Optional[int] = None @builder @@ -335,4 +344,4 @@ class PyModule(BaseModel): class PyApplication(BaseModel): """Represents a Python application.""" - symbol_table: dict[Path, PyModule] + symbol_table: Dict[str, PyModule] diff --git a/codeanalyzer/syntactic_analysis/__init__.py b/codeanalyzer/syntactic_analysis/__init__.py index e69de29..3ade2d9 100644 --- a/codeanalyzer/syntactic_analysis/__init__.py +++ b/codeanalyzer/syntactic_analysis/__init__.py @@ -0,0 +1,16 @@ +from codeanalyzer.syntactic_analysis.exceptions import ( + SymbolTableBuilderException, + SymbolTableBuilderFileNotFoundError, + SymbolTableBuilderParsingError, + SymbolTableBuilderRayError, +) + +from codeanalyzer.syntactic_analysis.symbol_table_builder import SymbolTableBuilder + +__all__ = [ + "SymbolTableBuilder", + "SymbolTableBuilderException", + "SymbolTableBuilderFileNotFoundError", + "SymbolTableBuilderParsingError", + "SymbolTableBuilderRayError", +] \ No newline at end of file diff --git a/codeanalyzer/syntactic_analysis/exceptions.py b/codeanalyzer/syntactic_analysis/exceptions.py new file mode 100644 index 0000000..d06d44f --- /dev/null +++ b/codeanalyzer/syntactic_analysis/exceptions.py @@ -0,0 +1,15 @@ +class SymbolTableBuilderException(Exception): + """Base exception for symbol table builder errors.""" + pass + +class SymbolTableBuilderFileNotFoundError(SymbolTableBuilderException): + """Exception raised when a source file is not found.""" + pass + +class SymbolTableBuilderParsingError(SymbolTableBuilderException): + """Exception raised when a source file cannot be parsed.""" + pass + +class SymbolTableBuilderRayError(SymbolTableBuilderException): + """Exception raised when there is an error in Ray processing.""" + pass \ No newline at end of file diff --git a/codeanalyzer/syntactic_analysis/symbol_table_builder.py b/codeanalyzer/syntactic_analysis/symbol_table_builder.py index 135c206..c7b83b4 100644 --- a/codeanalyzer/syntactic_analysis/symbol_table_builder.py +++ b/codeanalyzer/syntactic_analysis/symbol_table_builder.py @@ -1,9 +1,10 @@ import ast +import hashlib import tokenize from ast import AST, ClassDef from io import StringIO from pathlib import Path -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union import jedi from jedi.api import Script @@ -21,14 +22,12 @@ PySymbol, PyVariableDeclaration, ) -from codeanalyzer.utils import logger -from codeanalyzer.utils.progress_bar import ProgressBar class SymbolTableBuilder: """A class for building a symbol table for a Python project.""" - def __init__(self, project_dir: Path | str, virtualenv: Path | str | None) -> None: + def __init__(self, project_dir: Union[Path, str], virtualenv: Union[Path, str, None]) -> None: self.project_dir = Path(project_dir) if virtualenv is None: # If no virtual environment is provided, create a jedi project without an environment. @@ -72,7 +71,7 @@ def _infer_qualified_name(script: Script, line: int, column: int) -> Optional[st pass return None - def _module(self, py_file: Path) -> PyModule: + def build_pymodule_from_file(self, py_file: Path) -> PyModule: """Builds a PyModule from a Python file. Args: @@ -83,18 +82,17 @@ def _module(self, py_file: Path) -> PyModule: """ # Get the raw source code from the file source = py_file.read_text(encoding="utf-8") + + # Get file metadata for caching + stat = py_file.stat() + file_size = stat.st_size + last_modified = stat.st_mtime + content_hash = hashlib.sha256(source.encode('utf-8')).hexdigest() + # Create a Jedi script for the file script: Script = Script(path=str(py_file), project=self.jedi_project) module = ast.parse(source, filename=str(py_file)) - - classes = {} - functions = {} - for node in ast.iter_child_nodes(module): - if isinstance(node, ClassDef): - classes.update(self._add_class(node, script)) - elif isinstance(node, ast.FunctionDef): - functions.update(self._callables(node, script)) - + return ( PyModule.builder() .file_path(str(py_file)) @@ -102,8 +100,11 @@ def _module(self, py_file: Path) -> PyModule: .comments(self._pycomments(module, source)) .imports(self._imports(module)) .variables(self._module_variables(module, script)) - .classes(classes) - .functions(functions) + .classes(self._add_class(module, script)) + .functions(self._callables(module, script)) + .content_hash(content_hash) + .last_modified(last_modified) + .file_size(file_size) .build() ) @@ -156,144 +157,112 @@ def _imports(self, module: ast.Module) -> List[PyImport]: return imports - def _add_class( - self, class_node: ast.ClassDef, script: Script - ) -> Dict[str, PyClass]: - """Builds a PyClass from a class definition node. - - Args: - class_node (ast.ClassDef): The AST node representing the class. - script (Script): The Jedi script object for the module. + def _add_class(self, node: AST, script: Script, prefix: str = "") -> Dict[str, PyClass]: + classes: Dict[str, PyClass] = {} - Returns: - Dict[str, PyClass]: Mapping of class signature to PyClass object. - """ - # Try resolving full signature with Jedi - try: - definitions = script.goto( - line=class_node.lineno, column=class_node.col_offset - ) - signature = next( - (d.full_name for d in definitions if d.type == "class"), - f"{script.path.__str__().replace('/', '.').replace('.py', '')}.{class_node.name}", - ) - except Exception: - signature = ( - f"{script.path.__str__().replace('/', '.').replace('.py', '')}.{class_node.name}", - ) + for child in ast.iter_child_nodes(node): + if not isinstance(child, ast.ClassDef): + continue - code: str = ast.unparse(class_node).strip() + class_name = child.name + start_line = child.lineno + end_line = getattr(child, "end_lineno", start_line + len(child.body)) + code = ast.unparse(child).strip() - py_class = ( - PyClass.builder() - .name(class_node.name) - .signature(signature) - .start_line(class_node.lineno) - .end_line( - getattr( - class_node, "end_lineno", class_node.lineno + len(class_node.body) - ) - ) - .comments(self._pycomments(class_node, code)) - .code(code) - .base_classes( - [ + # Try resolving full signature with Jedi + if prefix: + signature = f"{prefix}.{class_name}" + else: + try: + definitions = script.goto(line=start_line, column=child.col_offset) + signature = next( + (d.full_name for d in definitions if d.type == "class"), + f"{Path(script.path).relative_to(self.project_dir).__str__().replace('/', '.').replace('.py', '')}.{class_name}" + ) + except Exception: + signature = f"{Path(script.path).relative_to(self.project_dir).__str__().replace('/', '.').replace('.py', '')}.{class_name}" + py_class = ( + PyClass.builder() + .name(class_name) + .signature(signature) + .start_line(start_line) + .end_line(end_line) + .code(code) + .comments(self._pycomments(child, code)) + .base_classes([ ast.unparse(base) - for base in class_node.bases + for base in child.bases if isinstance(base, ast.expr) - ] - ) - .methods(self._callables(class_node, script)) - .attributes(self._class_attributes(class_node, script)) - .inner_classes( - { - k: v - for child in class_node.body - if isinstance(child, ast.ClassDef) - for k, v in self._add_class(child, script).items() - } + ]) + .methods(self._callables(child, script, prefix=signature)) # Pass class signature as prefix + .attributes(self._class_attributes(child, script)) + .inner_classes(self._add_class(child, script, prefix=signature)) # Pass class signature as prefix + .build() ) - .build() - ) - return {signature: py_class} + classes[signature] = py_class - def _callables(self, node: AST, script: Script) -> Dict[str, PyCallable]: - """ - Builds PyCallable objects from any AST node that may contain functions. + return classes - Args: - node (AST): The AST node to process (e.g., Module, ClassDef, FunctionDef). - script (Script): The Jedi script object for the module. - Returns: - Dict[str, PyCallable]: A dictionary mapping function/method names to PyCallable objects. - """ + def _callables(self, node: AST, script: Script, prefix: str = "") -> Dict[str, PyCallable]: callables: Dict[str, PyCallable] = {} - module_path: str = script.path or "" - module_name: str = Path(module_path).stem if module_path else "" - - def visit(n: AST, class_prefix: str = ""): - for child in ast.iter_child_nodes(n): - if isinstance(child, ast.FunctionDef): - method_name = child.name - start_line = child.lineno - end_line = getattr( - child, "end_lineno", start_line + len(child.body) - ) - code_start_line = child.body[0].lineno if child.body else start_line - code: str = ast.unparse(child).strip() - decorators = [ast.unparse(d) for d in child.decorator_list] + for child in ast.iter_child_nodes(node): + if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)): + method_name = child.name # Keep the actual method name unchanged + start_line = child.lineno + end_line = getattr(child, "end_lineno", start_line + len(child.body)) + code = ast.unparse(child).strip() + decorators = [ast.unparse(d) for d in child.decorator_list] + + if prefix: + # We're in a nested context - build signature with prefix + signature = f"{prefix}.{method_name}" + else: + # Top-level function - try Jedi first, fall back to relative path-based try: - definitions = script.goto( - line=start_line, column=child.col_offset + definitions = script.goto(line=start_line, column=child.col_offset) + signature = next( + (d.full_name for d in definitions if d.type == "function"), + None ) except Exception: - definitions = [] - - signature = next( - (d.full_name for d in definitions if d.type == "function"), - f"{module_name}.{class_prefix}{method_name}", + signature = None + + # If Jedi didn't provide a signature, build one relative to project_dir + if not signature: + relative_path = Path(script.path).relative_to(self.project_dir) + signature = f"{str(relative_path).replace('/', '.').replace('.py', '')}.{method_name}" + py_callable = ( + PyCallable.builder() + .name(method_name) # Use the actual method name, not the full signature + .path(str(script.path)) + .signature(signature) # Use the full signature here + .decorators(decorators) + .code(code) + .start_line(start_line) + .end_line(end_line) + .code_start_line(child.body[0].lineno if child.body else start_line) + .accessed_symbols(self._accessed_symbols(child, script)) + .call_sites(self._call_sites(child, script)) + .local_variables(self._local_variables(child, script)) + .cyclomatic_complexity(self._cyclomatic_complexity(child)) + .parameters(self._callable_parameters(child, script)) + .return_type( + ast.unparse(child.returns) + if child.returns else self._infer_type(script, child.lineno, child.col_offset) ) + .comments(self._pycomments(child, code)) + .inner_callables(self._callables(child, script, signature)) # Pass current signature as prefix + .inner_classes(self._add_class(child, script, signature)) # Pass current signature as prefix + .build() + ) - callables[method_name] = ( - PyCallable.builder() - .name(method_name) - .path(script.path.__str__()) - .signature(signature) - .decorators(decorators) - .code(code) - .start_line(start_line) - .end_line(end_line) - .code_start_line(code_start_line) - .accessed_symbols(self._accessed_symbols(child, script)) - .call_sites(self._call_sites(child, script)) - .local_variables(self._local_variables(child, script)) - .cyclomatic_complexity(self._cyclomatic_complexity(child)) - .parameters(self._callable_parameters(child, script)) - .return_type( - ast.unparse(child.returns) - if child.returns - else self._infer_type( - script, child.lineno, child.col_offset - ) - ) - .comments(self._pycomments(child, code)) - .build() - ) - - visit(child, class_prefix + method_name + ".") - - elif isinstance(child, ast.ClassDef): - visit(child, class_prefix + child.name + ".") - - elif hasattr(child, "body"): - visit(child, class_prefix) + callables[method_name] = py_callable # Key by method name, not full signature - visit(node) return callables - + def _pycomments(self, node: ast.AST, source: str) -> List[PyComment]: """ Extracts all PyComment instances (docstring and # comments) from within a specific AST node's body. @@ -868,35 +837,3 @@ def _symbol_from_name_node( .col_offset(col_offset) .build() ) - - def build(self) -> Dict[str, PyModule]: - """Builds the symbol table for the project. - - This method scans the project directory, identifies Python files, - and constructs a symbol table containing information about classes, - functions, and variables defined in those files. - """ - symbol_table: Dict[str, PyModule] = {} - # Get all Python files first to show accurate progress - py_files = [ - py_file - for py_file in self.project_dir.rglob("*.py") - if "site-packages" - not in py_file.resolve().__str__() # exclude site-packages - and ".venv" - not in py_file.resolve().__str__() # exclude virtual environments - and ".codeanalyzer" - not in py_file.resolve().__str__() # exclude internal cache directories - ] - - with ProgressBar(len(py_files), "Building symbol table") as progress: - for py_file in py_files: - try: - py_module = self._module(py_file) - symbol_table[str(py_file)] = py_module - except Exception as e: - logger.error(f"Failed to process {py_file}: {e}") - progress.advance() - progress.finish("✅ Symbol table generation complete.") - - return symbol_table diff --git a/pyproject.toml b/pyproject.toml index aacf61a..208d7f5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,35 +1,37 @@ [project] name = "codeanalyzer-python" -version = "0.1.9" +version = "0.1.10" description = "Static Analysis on Python source code using Jedi, CodeQL and Treesitter." readme = "README.md" authors = [ { name = "Rahul Krishna", email = "i.m.ralk@gmail.com" } ] -requires-python = ">=3.10" +requires-python = ">=3.9" dependencies = [ - "jedi>=0.19.2", - "loguru>=0.7.3", - "msgpack>=1.1.1", - "networkx>=3.4.2", - "pandas>=2.3.1", - "pydantic>=2.11.7", - "requests>=2.32.4", - "rich>=14.0.0", - "typer>=0.16.0", + "jedi>=0.18.0,<0.20.0", + "msgpack>=1.0.0,<1.0.7", + "networkx>=2.6.0,<3.2.0", + "pandas>=1.3.0,<2.0.0", + "numpy>=1.21.0,<1.24.0", + "pydantic>=1.8.0,<2.0.0", + "requests>=2.20.0,<3.0.0", + "rich>=12.6.0,<14.0.0", + "typer>=0.9.0,<1.0.0", + "ray>=2.0.0,<3.0.0", + "typing-extensions>=4.0.0" ] [dependency-groups] test = [ - "pytest>=8.4.1", - "pytest-asyncio>=1.0.0", - "pytest-cov>=6.2.1", - "pytest-pspec>=0.0.4", + "pytest>=7.0.0,<8.0.0", + "pytest-asyncio>=0.14.0,<0.15.0", + "pytest-cov>=2.10.0,<3.0.0", + "pytest-pspec>=0.0.3" ] dev = [ - "ipdb>=0.13.13", - "pre-commit>=4.2.0", + "ipdb>=0.13.0,<0.14.0", + "pre-commit>=2.9.0,<3.0.0" ] [project.scripts] @@ -58,7 +60,8 @@ addopts = [ "--cov=codeanalyzer", "--cov-report=html", "--cov-report=term-missing", - "--cov-fail-under=40" + "--cov-fail-under=40", + "--ignore=test/fixtures" ] testpaths = ["test"] diff --git a/test/conftest.py b/test/conftest.py index db6a66b..9af14d4 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -26,6 +26,11 @@ def cli_runner() -> CliRunner: @pytest.fixture -def project_root() -> Path: - """Returns the grandparent directory of this conftest file — typically the project root.""" - return Path(__file__).resolve().parents[1] +def whole_applications__xarray() -> Path: + """The xarray application directory.""" + return Path(__file__).parent.resolve().joinpath("fixtures", "whole_applications", "xarray") + +@pytest.fixture +def single_functionalities__stuff_nested_in_functions() -> Path: + """Returns the path to the 'single_functionalities/stuff_nested_in_functions' directory.""" + return Path(__file__).parent.resolve().joinpath("fixtures", "single_functionalities", "stuff_nested_in_functions_test") diff --git a/test/fixtures/single_functionalities/stuff_nested_in_functions_test/main.py b/test/fixtures/single_functionalities/stuff_nested_in_functions_test/main.py new file mode 100644 index 0000000..dee45c9 --- /dev/null +++ b/test/fixtures/single_functionalities/stuff_nested_in_functions_test/main.py @@ -0,0 +1,74 @@ +""" +Test file for nested structures: functions, classes, and complex nesting patterns. +This file tests the symbol table builder's ability to correctly identify and catalog +all nested definitions as reported in issue #15. +""" + +# Module-level imports +import os +from typing import List, Dict, Optional + + +# Module-level variable +MODULE_CONSTANT = "test_value" + + +def outer_function(): + """An outer function containing nested structures.""" + + # Local variable in outer function + outer_var = "outer" + + def nested_function(): + """A function nested inside another function.""" + nested_var = "nested" + + def deeply_nested_function(): + """A deeply nested function (3 levels deep).""" + return f"deeply nested: {nested_var}, {outer_var}" + + return deeply_nested_function() + + class NestedClass: + """A class defined inside a function.""" + + def __init__(self, name: str): + self.name = name + + def nested_class_method(self): + """A method inside a nested class.""" + + def method_nested_function(): + """A function nested inside a class method.""" + return f"method nested: {self.name}" + + class MethodNestedClass: + """A class nested inside a method.""" + + def __init__(self, value: int): + self.value = value + + def method_in_nested_class(self): + """Method in a class that's nested in a method.""" + + def function_in_method_in_nested_class(): + """Function inside method inside nested class.""" + return f"deep nesting: {self.value}" + + return function_in_method_in_nested_class() + + return method_nested_function(), MethodNestedClass(42) + + @staticmethod + def static_method_in_nested_class(): + """Static method in nested class.""" + return "static in nested" + + @classmethod + def class_method_in_nested_class(cls): + """Class method in nested class.""" + return f"class method in {cls.__name__}" + + # Create an instance and call methods + nested_instance = NestedClass("test") + return nested_function(), nested_instance \ No newline at end of file diff --git a/test/fixtures/whole_applications/xarray/.binder/environment.yml b/test/fixtures/whole_applications/xarray/.binder/environment.yml new file mode 100644 index 0000000..fee5ed0 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/.binder/environment.yml @@ -0,0 +1,39 @@ +name: xarray-examples +channels: + - conda-forge +dependencies: + - python=3.10 + - boto3 + - bottleneck + - cartopy + - cfgrib + - cftime + - coveralls + - dask + - distributed + - dask_labextension + - h5netcdf + - h5py + - hdf5 + - iris + - lxml # Optional dep of pydap + - matplotlib + - nc-time-axis + - netcdf4 + - numba + - numpy + - packaging + - pandas + - pint>=0.22 + - pip + - pooch + - pydap + - rasterio + - scipy + - seaborn + - setuptools + - sparse + - toolz + - xarray + - zarr + - numbagg diff --git a/test/fixtures/whole_applications/xarray/.codecov.yml b/test/fixtures/whole_applications/xarray/.codecov.yml new file mode 100644 index 0000000..d0bec95 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/.codecov.yml @@ -0,0 +1,38 @@ +codecov: + require_ci_to_pass: true + +coverage: + status: + project: + default: + # Require 1% coverage, i.e., always succeed + target: 1% + flags: + - unittests + paths: + - "!xarray/tests/" + unittests: + target: 90% + flags: + - unittests + paths: + - "!xarray/tests/" + mypy: + target: 20% + flags: + - mypy + patch: false + changes: false + +comment: false + +flags: + unittests: + paths: + - "xarray" + - "!xarray/tests" + carryforward: false + mypy: + paths: + - "xarray" + carryforward: false diff --git a/test/fixtures/whole_applications/xarray/.git-blame-ignore-revs b/test/fixtures/whole_applications/xarray/.git-blame-ignore-revs new file mode 100644 index 0000000..465572c --- /dev/null +++ b/test/fixtures/whole_applications/xarray/.git-blame-ignore-revs @@ -0,0 +1,5 @@ +# black PR 3142 +d089df385e737f71067309ff7abae15994d581ec + +# isort PR 1924 +0e73e240107caee3ffd1a1149f0150c390d43251 diff --git a/test/fixtures/whole_applications/xarray/.git_archival.txt b/test/fixtures/whole_applications/xarray/.git_archival.txt new file mode 100644 index 0000000..3c1479a --- /dev/null +++ b/test/fixtures/whole_applications/xarray/.git_archival.txt @@ -0,0 +1,4 @@ +node: bef04067dd87f9f0c1a3ae7840299e0bbdd595a8 +node-date: 2024-06-13T07:05:11-06:00 +describe-name: v2024.06.0 +ref-names: tag: v2024.06.0 diff --git a/test/fixtures/whole_applications/xarray/.gitattributes b/test/fixtures/whole_applications/xarray/.gitattributes new file mode 100644 index 0000000..7a79ddd --- /dev/null +++ b/test/fixtures/whole_applications/xarray/.gitattributes @@ -0,0 +1,4 @@ +# reduce the number of merge conflicts +doc/whats-new.rst merge=union +# allow installing from git archives +.git_archival.txt export-subst diff --git a/test/fixtures/whole_applications/xarray/.gitignore b/test/fixtures/whole_applications/xarray/.gitignore new file mode 100644 index 0000000..21011f0 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/.gitignore @@ -0,0 +1,82 @@ +*.py[cod] +__pycache__ +.env +.venv + +# example caches from Hypothesis +.hypothesis/ + +# temp files from docs build +doc/*.nc +doc/auto_gallery +doc/rasm.zarr +doc/savefig + +# C extensions +*.so + +# Packages +*.egg +*.egg-info +.eggs +dist +build +eggs +parts +bin +var +sdist +develop-eggs +.installed.cfg +lib +lib64 + +# Installer logs +pip-log.txt + +# Unit test / coverage reports +.coverage +.coverage.* +.tox +nosetests.xml +.cache +.dmypy.json +.mypy_cache +.ropeproject/ +.tags* +.testmon* +.tmontmp/ +.pytest_cache +dask-worker-space/ + +# asv environments +asv_bench/.asv +asv_bench/pkgs + +# Translations +*.mo + +# Mr Developer +.mr.developer.cfg +.project +.pydevproject + +# IDEs +.idea +*.swp +.DS_Store +.vscode/ + +# xarray specific +doc/_build +doc/generated/ +xarray/tests/data/*.grib.*.idx + +# Sync tools +Icon* + +.ipynb_checkpoints +doc/team-panel.txt +doc/external-examples-gallery.txt +doc/notebooks-examples-gallery.txt +doc/videos-gallery.txt diff --git a/test/fixtures/whole_applications/xarray/.pep8speaks.yml b/test/fixtures/whole_applications/xarray/.pep8speaks.yml new file mode 100644 index 0000000..8d87864 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/.pep8speaks.yml @@ -0,0 +1,6 @@ +# https://github.com/OrkoHunter/pep8speaks for more info +# pep8speaks will use the flake8 configs in `setup.cfg` + +scanner: + diff_only: False + linter: flake8 diff --git a/test/fixtures/whole_applications/xarray/.pre-commit-config.yaml b/test/fixtures/whole_applications/xarray/.pre-commit-config.yaml new file mode 100644 index 0000000..01a2b4a --- /dev/null +++ b/test/fixtures/whole_applications/xarray/.pre-commit-config.yaml @@ -0,0 +1,53 @@ +# https://pre-commit.com/ +ci: + autoupdate_schedule: monthly +exclude: 'xarray/datatree_.*' +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: debug-statements + - id: mixed-line-ending + - repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: 'v0.4.7' + hooks: + - id: ruff + args: ["--fix", "--show-fixes"] + # https://github.com/python/black#version-control-integration + - repo: https://github.com/psf/black-pre-commit-mirror + rev: 24.4.2 + hooks: + - id: black-jupyter + - repo: https://github.com/keewis/blackdoc + rev: v0.3.9 + hooks: + - id: blackdoc + exclude: "generate_aggregations.py" + additional_dependencies: ["black==24.4.2"] + - id: blackdoc-autoupdate-black + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.10.0 + hooks: + - id: mypy + # Copied from setup.cfg + exclude: "properties|asv_bench" + # This is slow and so we take it out of the fast-path; requires passing + # `--hook-stage manual` to pre-commit + stages: [manual] + additional_dependencies: [ + # Type stubs + types-python-dateutil, + types-pkg_resources, + types-PyYAML, + types-pytz, + typing-extensions>=4.1.0, + numpy, + ] + - repo: https://github.com/citation-file-format/cff-converter-python + rev: ebf0b5e44d67f8beaa1cd13a0d0393ea04c6058d + hooks: + - id: validate-cff diff --git a/test/fixtures/whole_applications/xarray/.readthedocs.yaml b/test/fixtures/whole_applications/xarray/.readthedocs.yaml new file mode 100644 index 0000000..55fea71 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/.readthedocs.yaml @@ -0,0 +1,20 @@ +version: 2 + +build: + os: ubuntu-22.04 + tools: + python: mambaforge-4.10 + jobs: + post_checkout: + - (git --no-pager log --pretty="tformat:%s" -1 | grep -vqF "[skip-rtd]") || exit 183 + - git fetch --unshallow || true + pre_install: + - git update-index --assume-unchanged doc/conf.py ci/requirements/doc.yml + +conda: + environment: ci/requirements/doc.yml + +sphinx: + fail_on_warning: true + +formats: [] diff --git a/test/fixtures/whole_applications/xarray/CITATION.cff b/test/fixtures/whole_applications/xarray/CITATION.cff new file mode 100644 index 0000000..2eee84b --- /dev/null +++ b/test/fixtures/whole_applications/xarray/CITATION.cff @@ -0,0 +1,113 @@ +cff-version: 1.2.0 +message: "If you use this software, please cite it as below." +authors: +- family-names: "Hoyer" + given-names: "Stephan" + orcid: "https://orcid.org/0000-0002-5207-0380" +- family-names: "Roos" + given-names: "Maximilian" +- family-names: "Joseph" + given-names: "Hamman" + orcid: "https://orcid.org/0000-0001-7479-8439" +- family-names: "Magin" + given-names: "Justus" + orcid: "https://orcid.org/0000-0002-4254-8002" +- family-names: "Cherian" + given-names: "Deepak" + orcid: "https://orcid.org/0000-0002-6861-8734" +- family-names: "Fitzgerald" + given-names: "Clark" + orcid: "https://orcid.org/0000-0003-3446-6389" +- family-names: "Hauser" + given-names: "Mathias" + orcid: "https://orcid.org/0000-0002-0057-4878" +- family-names: "Fujii" + given-names: "Keisuke" + orcid: "https://orcid.org/0000-0003-0390-9984" +- family-names: "Maussion" + given-names: "Fabien" + orcid: "https://orcid.org/0000-0002-3211-506X" +- family-names: "Imperiale" + given-names: "Guido" +- family-names: "Clark" + given-names: "Spencer" + orcid: "https://orcid.org/0000-0001-5595-7895" +- family-names: "Kleeman" + given-names: "Alex" +- family-names: "Nicholas" + given-names: "Thomas" + orcid: "https://orcid.org/0000-0002-2176-0530" +- family-names: "Kluyver" + given-names: "Thomas" + orcid: "https://orcid.org/0000-0003-4020-6364" +- family-names: "Westling" + given-names: "Jimmy" +- family-names: "Munroe" + given-names: "James" + orcid: "https://orcid.org/0000-0001-9098-6309" +- family-names: "Amici" + given-names: "Alessandro" + orcid: "https://orcid.org/0000-0002-1778-4505" +- family-names: "Barghini" + given-names: "Aureliana" +- family-names: "Banihirwe" + given-names: "Anderson" + orcid: "https://orcid.org/0000-0001-6583-571X" +- family-names: "Bell" + given-names: "Ray" + orcid: "https://orcid.org/0000-0003-2623-0587" +- family-names: "Hatfield-Dodds" + given-names: "Zac" + orcid: "https://orcid.org/0000-0002-8646-8362" +- family-names: "Abernathey" + given-names: "Ryan" + orcid: "https://orcid.org/0000-0001-5999-4917" +- family-names: "Bovy" + given-names: "Benoît" +- family-names: "Omotani" + given-names: "John" + orcid: "https://orcid.org/0000-0002-3156-8227" +- family-names: "Mühlbauer" + given-names: "Kai" + orcid: "https://orcid.org/0000-0001-6599-1034" +- family-names: "Roszko" + given-names: "Maximilian K." + orcid: "https://orcid.org/0000-0001-9424-2526" +- family-names: "Wolfram" + given-names: "Phillip J." + orcid: "https://orcid.org/0000-0001-5971-4241" +- family-names: "Henderson" + given-names: "Scott" + orcid: "https://orcid.org/0000-0003-0624-4965" +- family-names: "Awowale" + given-names: "Eniola Olufunke" +- family-names: "Scheick" + given-names: "Jessica" + orcid: "https://orcid.org/0000-0002-3421-4459" +- family-names: "Savoie" + given-names: "Matthew" + orcid: "https://orcid.org/0000-0002-8881-2550" +- family-names: "Littlejohns" + given-names: "Owen" +title: "xarray" +abstract: "N-D labeled arrays and datasets in Python." +license: Apache-2.0 +doi: 10.5281/zenodo.598201 +url: "https://xarray.dev/" +repository-code: "https://github.com/pydata/xarray" +preferred-citation: + type: article + authors: + - family-names: "Hoyer" + given-names: "Stephan" + orcid: "https://orcid.org/0000-0002-5207-0380" + - family-names: "Joseph" + given-names: "Hamman" + orcid: "https://orcid.org/0000-0001-7479-8439" + doi: "10.5334/jors.148" + journal: "Journal of Open Research Software" + month: 4 + title: "xarray: N-D labeled Arrays and Datasets in Python" + volume: 5 + issue: 1 + year: 2017 diff --git a/test/fixtures/whole_applications/xarray/CODE_OF_CONDUCT.md b/test/fixtures/whole_applications/xarray/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..d457a9e --- /dev/null +++ b/test/fixtures/whole_applications/xarray/CODE_OF_CONDUCT.md @@ -0,0 +1,46 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, gender identity and expression, level of experience, nationality, personal appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies both within project spaces and in public spaces when an individual is representing the project or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at xarray-core-team@googlegroups.com. The project team will review and investigate all complaints, and will respond in a way that it deems appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at [http://contributor-covenant.org/version/1/4][version] + +[homepage]: http://contributor-covenant.org +[version]: http://contributor-covenant.org/version/1/4/ diff --git a/test/fixtures/whole_applications/xarray/CONTRIBUTING.md b/test/fixtures/whole_applications/xarray/CONTRIBUTING.md new file mode 100644 index 0000000..dd9931f --- /dev/null +++ b/test/fixtures/whole_applications/xarray/CONTRIBUTING.md @@ -0,0 +1 @@ +Xarray's contributor guidelines [can be found in our online documentation](http://docs.xarray.dev/en/stable/contributing.html) diff --git a/test/fixtures/whole_applications/xarray/CORE_TEAM_GUIDE.md b/test/fixtures/whole_applications/xarray/CORE_TEAM_GUIDE.md new file mode 100644 index 0000000..9eb91f4 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/CORE_TEAM_GUIDE.md @@ -0,0 +1,322 @@ +> **_Note:_** This Core Team Member Guide was adapted from the [napari project's Core Developer Guide](https://napari.org/stable/developers/core_dev_guide.html) and the [Pandas maintainers guide](https://pandas.pydata.org/docs/development/maintaining.html). + +# Core Team Member Guide + +Welcome, new core team member! We appreciate the quality of your work, and enjoy working with you! +Thank you for your numerous contributions to the project so far. + +By accepting the invitation to become a core team member you are **not required to commit to doing any more work** - +xarray is a volunteer project, and we value the contributions you have made already. + +You can see a list of all the current core team members on our +[@pydata/xarray](https://github.com/orgs/pydata/teams/xarray) +GitHub team. Once accepted, you should now be on that list too. +This document offers guidelines for your new role. + +## Tasks + +Xarray values a wide range of contributions, only some of which involve writing code. +As such, we do not currently make a distinction between a "core team member", "core developer", "maintainer", +or "triage team member" as some projects do (e.g. [pandas](https://pandas.pydata.org/docs/development/maintaining.html)). +That said, if you prefer to refer to your role as one of the other titles above then that is fine by us! + +Xarray is mostly a volunteer project, so these tasks shouldn’t be read as “expectations”. +**There are no strict expectations**, other than to adhere to our [Code of Conduct](https://github.com/pydata/xarray/tree/main/CODE_OF_CONDUCT.md). +Rather, the tasks that follow are general descriptions of what it might mean to be a core team member: + +- Facilitate a welcoming environment for those who file issues, make pull requests, and open discussion topics, +- Triage newly filed issues, +- Review newly opened pull requests, +- Respond to updates on existing issues and pull requests, +- Drive discussion and decisions on stalled issues and pull requests, +- Provide experience / wisdom on API design questions to ensure consistency and maintainability, +- Project organization (run developer meetings, coordinate with sponsors), +- Project evangelism (advertise xarray to new users), +- Community contact (represent xarray in user communities such as [Pangeo](https://pangeo.io/)), +- Key project contact (represent xarray's perspective within key related projects like NumPy, Zarr or Dask), +- Project fundraising (help write and administrate grants that will support xarray), +- Improve documentation or tutorials (especially on [`tutorial.xarray.dev`](https://tutorial.xarray.dev/)), +- Presenting or running tutorials (such as those we have given at the SciPy conference), +- Help maintain the [`xarray.dev`](https://xarray.dev/) landing page and website, the [code for which is here](https://github.com/xarray-contrib/xarray.dev), +- Write blog posts on the [xarray blog](https://xarray.dev/blog), +- Help maintain xarray's various Continuous Integration Workflows, +- Help maintain a regular release schedule (we aim for one or more releases per month), +- Attend the bi-weekly community meeting ([issue](https://github.com/pydata/xarray/issues/4001)), +- Contribute to the xarray codebase. + +(Matt Rocklin's post on [the role of a maintainer](https://matthewrocklin.com/blog/2019/05/18/maintainer) may be +interesting background reading, but should not be taken to strictly apply to the Xarray project.) + +Obviously you are not expected to contribute in all (or even more than one) of these ways! +They are listed so as to indicate the many types of work that go into maintaining xarray. + +It is natural that your available time and enthusiasm for the project will wax and wane - this is fine and expected! +It is also common for core team members to have a "niche" - a particular part of the codebase they have specific expertise +with, or certain types of task above which they primarily perform. + +If however you feel that is unlikely you will be able to be actively contribute in the foreseeable future +(or especially if you won't be available to answer questions about pieces of code that you wrote previously) +then you may want to consider letting us know you would rather be listed as an "Emeritus Core Team Member", +as this would help us in evaluating the overall health of the project. + +## Issue triage + +One of the main ways you might spend your contribution time is by responding to or triaging new issues. +Here’s a typical workflow for triaging a newly opened issue or discussion: + +1. **Thank the reporter for opening an issue.** + + The issue tracker is many people’s first interaction with the xarray project itself, beyond just using the library. + It may also be their first open-source contribution of any kind. As such, we want it to be a welcoming, pleasant experience. + +2. **Is the necessary information provided?** + + Ideally reporters would fill out the issue template, but many don’t. If crucial information (like the version of xarray they used), + is missing feel free to ask for that and label the issue with “needs info”. + The report should follow the [guidelines for xarray discussions](https://github.com/pydata/xarray/discussions/5404). + You may want to link to that if they didn’t follow the template. + + Make sure that the title accurately reflects the issue. Edit it yourself if it’s not clear. + Remember also that issues can be converted to discussions and vice versa if appropriate. + +3. **Is this a duplicate issue?** + + We have many open issues. If a new issue is clearly a duplicate, label the new issue as “duplicate”, and close the issue with a link to the original issue. + Make sure to still thank the reporter, and encourage them to chime in on the original issue, and perhaps try to fix it. + + If the new issue provides relevant information, such as a better or slightly different example, add it to the original issue as a comment or an edit to the original post. + +4. **Is the issue minimal and reproducible?** + + For bug reports, we ask that the reporter provide a minimal reproducible example. + See [minimal-bug-reports](https://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports) for a good explanation. + If the example is not reproducible, or if it’s clearly not minimal, feel free to ask the reporter if they can provide and example or simplify the provided one. + Do acknowledge that writing minimal reproducible examples is hard work. If the reporter is struggling, you can try to write one yourself and we’ll edit the original post to include it. + + If a nice reproducible example has been provided, thank the reporter for that. + If a reproducible example can’t be provided, add the “needs mcve” label. + + If a reproducible example is provided, but you see a simplification, edit the original post with your simpler reproducible example. + +5. **Is this a clearly defined feature request?** + + Generally, xarray prefers to discuss and design new features in issues, before a pull request is made. + Encourage the submitter to include a proposed API for the new feature. Having them write a full docstring is a good way to pin down specifics. + + We may need a discussion from several xarray maintainers before deciding whether the proposal is in scope for xarray. + +6. **Is this a usage question?** + + We prefer that usage questions are asked on StackOverflow with the [`python-xarray` tag](https://stackoverflow.com/questions/tagged/python-xarray +) or as a [GitHub discussion topic](https://github.com/pydata/xarray/discussions). + + If it’s easy to answer, feel free to link to the relevant documentation section, let them know that in the future this kind of question should be on StackOverflow, and close the issue. + +7. **What labels and milestones should I add?** + + Apply the relevant labels. This is a bit of an art, and comes with experience. Look at similar issues to get a feel for how things are labeled. + Labels used for labelling issues that relate to particular features or parts of the codebase normally have the form `topic-`. + + If the issue is clearly defined and the fix seems relatively straightforward, label the issue as `contrib-good-first-issue`. + You can also remove the `needs triage` label that is automatically applied to all newly-opened issues. + +8. **Where should the poster look to fix the issue?** + + If you can, it is very helpful to point to the approximate location in the codebase where a contributor might begin to fix the issue. + This helps ease the way in for new contributors to the repository. + +## Code review and contributions + +As a core team member, you are a representative of the project, +and trusted to make decisions that will serve the long term interests +of all users. You also gain the responsibility of shepherding +other contributors through the review process; here are some +guidelines for how to do that. + +### All contributors are treated the same + +You should now have gained the ability to merge or approve +other contributors' pull requests. Merging contributions is a shared power: +only merge contributions you yourself have carefully reviewed, and that are +clear improvements for the project. When in doubt, and especially for more +complex changes, wait until at least one other core team member has approved. +(See [Reviewing](#reviewing) and especially +[Merge Only Changes You Understand](#merge-only-changes-you-understand) below.) + +It should also be considered best practice to leave a reasonable (24hr) time window +after approval before merge to ensure that other core team members have a reasonable +chance to weigh in. +Adding the `plan-to-merge` label notifies developers of the imminent merge. + +We are also an international community, with contributors from many different time zones, +some of whom will only contribute during their working hours, others who might only be able +to contribute during nights and weekends. It is important to be respectful of other peoples +schedules and working habits, even if it slows the project down slightly - we are in this +for the long run. In the same vein you also shouldn't feel pressured to be constantly +available or online, and users or contributors who are overly demanding and unreasonable +to the point of harassment will be directed to our [Code of Conduct](https://github.com/pydata/xarray/tree/main/CODE_OF_CONDUCT.md). +We value sustainable development practices over mad rushes. + +When merging, we automatically use GitHub's +[Squash and Merge](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/incorporating-changes-from-a-pull-request/merging-a-pull-request#merging-a-pull-request) +to ensure a clean git history. + +You should also continue to make your own pull requests as before and in accordance +with the [general contributing guide](https://docs.xarray.dev/en/stable/contributing.html). These pull requests still +require the approval of another core team member before they can be merged. + +### How to conduct a good review + +*Always* be kind to contributors. Contributors are often doing +volunteer work, for which we are tremendously grateful. Provide +constructive criticism on ideas and implementations, and remind +yourself of how it felt when your own work was being evaluated as a +novice. + +``xarray`` strongly values mentorship in code review. New users +often need more handholding, having little to no git +experience. Repeat yourself liberally, and, if you don’t recognize a +contributor, point them to our development guide, or other GitHub +workflow tutorials around the web. Do not assume that they know how +GitHub works (many don't realize that adding a commit +automatically updates a pull request, for example). Gentle, polite, kind +encouragement can make the difference between a new core team member and +an abandoned pull request. + +When reviewing, focus on the following: + +1. **Usability and generality:** `xarray` is a user-facing package that strives to be accessible +to both novice and advanced users, and new features should ultimately be +accessible to everyone using the package. `xarray` targets the scientific user +community broadly, and core features should be domain-agnostic and general purpose. +Custom functionality is meant to be provided through our various types of interoperability. + +2. **Performance and benchmarks:** As `xarray` targets scientific applications that often involve +large multidimensional datasets, high performance is a key value of `xarray`. While +every new feature won't scale equally to all sizes of data, keeping in mind performance +and our [benchmarks](https://github.com/pydata/xarray/tree/main/asv_bench) during a review may be important, and you may +need to ask for benchmarks to be run and reported or new benchmarks to be added. +You can run the CI benchmarking suite on any PR by tagging it with the ``run-benchmark`` label. + +3. **APIs and stability:** Coding users and developers will make +extensive use of our APIs. The foundation of a healthy ecosystem will be +a fully capable and stable set of APIs, so as `xarray` matures it will +very important to ensure our APIs are stable. Spending the extra time to consider names of public facing +variables and methods, alongside function signatures, could save us considerable +trouble in the future. We do our best to provide [deprecation cycles](https://docs.xarray.dev/en/stable/contributing.html#backwards-compatibility) +when making backwards-incompatible changes. + +4. **Documentation and tutorials:** All new methods should have appropriate doc +strings following [PEP257](https://peps.python.org/pep-0257/) and the +[NumPy documentation guide](https://numpy.org/devdocs/dev/howto-docs.html#documentation-style). +For any major new features, accompanying changes should be made to our +[tutorials](https://tutorial.xarray.dev). These should not only +illustrates the new feature, but explains it. + +5. **Implementations and algorithms:** You should understand the code being modified +or added before approving it. (See [Merge Only Changes You Understand](#merge-only-changes-you-understand) +below.) Implementations should do what they claim and be simple, readable, and efficient +in that order. + +6. **Tests:** All contributions *must* be tested, and each added line of code +should be covered by at least one test. Good tests not only execute the code, +but explore corner cases. It can be tempting not to review tests, but please +do so. + +Other changes may be *nitpicky*: spelling mistakes, formatting, +etc. Do not insist contributors make these changes, but instead you should offer +to make these changes by [pushing to their branch](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/committing-changes-to-a-pull-request-branch-created-from-a-fork), +or using GitHub’s [suggestion](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/reviewing-changes-in-pull-requests/commenting-on-a-pull-request) +[feature](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/reviewing-changes-in-pull-requests/incorporating-feedback-in-your-pull-request), and +be prepared to make them yourself if needed. Using the suggestion feature is preferred because +it gives the contributor a choice in whether to accept the changes. + +Unless you know that a contributor is experienced with git, don’t +ask for a rebase when merge conflicts arise. Instead, rebase the +branch yourself, force-push to their branch, and advise the contributor to force-pull. If the contributor is +no longer active, you may take over their branch by submitting a new pull +request and closing the original, including a reference to the original pull +request. In doing so, ensure you communicate that you are not throwing the +contributor's work away! If appropriate it is a good idea to acknowledge other contributions +to the pull request using the `Co-authored-by` +[syntax](https://docs.github.com/en/pull-requests/committing-changes-to-your-project/creating-and-editing-commits/creating-a-commit-with-multiple-authors) in the commit message. + +### Merge only changes you understand + +*Long-term maintainability* is an important concern. Code doesn't +merely have to *work*, but should be *understood* by multiple core +developers. Changes will have to be made in the future, and the +original contributor may have moved on. + +Therefore, *do not merge a code change unless you understand it*. Ask +for help freely: we can consult community members, or even external developers, +for added insight where needed, and see this as a great learning opportunity. + +While we collectively "own" any patches (and bugs!) that become part +of the code base, you are vouching for changes you merge. Please take +that responsibility seriously. + +Feel free to ping other active maintainers with any questions you may have. + +## Further resources + +As a core member, you should be familiar with community and developer +resources such as: + +- Our [contributor guide](https://docs.xarray.dev/en/stable/contributing.html). +- Our [code of conduct](https://github.com/pydata/xarray/tree/main/CODE_OF_CONDUCT.md). +- Our [philosophy and development roadmap](https://docs.xarray.dev/en/stable/roadmap.html). +- [PEP8](https://peps.python.org/pep-0008/) for Python style. +- [PEP257](https://peps.python.org/pep-0257/) and the + [NumPy documentation guide](https://numpy.org/devdocs/dev/howto-docs.html#documentation-style) + for docstring conventions. +- [`pre-commit`](https://pre-commit.com) hooks for autoformatting. +- [`black`](https://github.com/psf/black) autoformatting. +- [`flake8`](https://github.com/PyCQA/flake8) linting. +- [python-xarray](https://stackoverflow.com/questions/tagged/python-xarray) on Stack Overflow. +- [@xarray_dev](https://twitter.com/xarray_dev) on Twitter. +- [xarray-dev](https://discord.gg/bsSGdwBn) discord community (normally only used for remote synchronous chat during sprints). + +You are not required to monitor any of the social resources. + +Where possible we prefer to point people towards asynchronous forms of communication +like github issues instead of realtime chat options as they are far easier +for a global community to consume and refer back to. + +We hold a [bi-weekly developers meeting](https://docs.xarray.dev/en/stable/developers-meeting.html) via video call. +This is a great place to bring up any questions you have, raise visibility of an issue and/or gather more perspectives. +Attendance is absolutely optional, and we keep the meeting to 30 minutes in respect of your valuable time. +This meeting is public, so we occasionally have non-core team members join us. + +We also have a private mailing list for core team members +`xarray-core-team@googlegroups.com` which is sparingly used for discussions +that are required to be private, such as nominating new core members and discussing financial issues. + +## Inviting new core members + +Any core member may nominate other contributors to join the core team. +While there is no hard-and-fast rule about who can be nominated, ideally, +they should have: been part of the project for at least two months, contributed +significant changes of their own, contributed to the discussion and +review of others' work, and collaborated in a way befitting our +community values. **We strongly encourage nominating anyone who has made significant non-code contributions +to the Xarray community in any way**. After nomination voting will happen on a private mailing list. +While it is expected that most votes will be unanimous, a two-thirds majority of +the cast votes is enough. + +Core team members can choose to become emeritus core team members and suspend +their approval and voting rights until they become active again. + +## Contribute to this guide (!) + +This guide reflects the experience of the current core team members. We +may well have missed things that, by now, have become second +nature—things that you, as a new team member, will spot more easily. +Please ask the other core team members if you have any questions, and +submit a pull request with insights gained. + +## Conclusion + +We are excited to have you on board! We look forward to your +contributions to the code base and the community. Thank you in +advance! diff --git a/test/fixtures/whole_applications/xarray/HOW_TO_RELEASE.md b/test/fixtures/whole_applications/xarray/HOW_TO_RELEASE.md new file mode 100644 index 0000000..9d11645 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/HOW_TO_RELEASE.md @@ -0,0 +1,127 @@ +# How to issue an xarray release in 16 easy steps + +Time required: about an hour. + +These instructions assume that `upstream` refers to the main repository: + +```sh +$ git remote -v +{...} +upstream https://github.com/pydata/xarray (fetch) +upstream https://github.com/pydata/xarray (push) +``` + + + + 1. Ensure your main branch is synced to upstream: + ```sh + git switch main + git pull upstream main + ``` + 2. Add a list of contributors. + First fetch all previous release tags so we can see the version number of the last release was: + ```sh + git fetch upstream --tags + ``` + This will return a list of all the contributors since the last release: + ```sh + git log "$(git tag --sort=v:refname | tail -1).." --format=%aN | sort -u | perl -pe 's/\n/$1, /' + ``` + This will return the total number of contributors: + ```sh + git log "$(git tag --sort=v:refname | tail -1).." --format=%aN | sort -u | wc -l + ``` + 3. Write a release summary: ~50 words describing the high level features. This + will be used in the release emails, tweets, GitHub release notes, etc. + 4. Look over whats-new.rst and the docs. Make sure "What's New" is complete + (check the date!) and add the release summary at the top. + Things to watch out for: + - Important new features should be highlighted towards the top. + - Function/method references should include links to the API docs. + - Sometimes notes get added in the wrong section of whats-new, typically + due to a bad merge. Check for these before a release by using git diff, + e.g., `git diff v{YYYY.MM.X-1} whats-new.rst` where {YYYY.MM.X-1} is the previous + release. + 5. Open a PR with the release summary and whatsnew changes; in particular the + release headline should get feedback from the team on what's important to include. + 6. After merging, again ensure your main branch is synced to upstream: + ```sh + git pull upstream main + ``` + 7. If you have any doubts, run the full test suite one final time! + ```sh + pytest + ``` + 8. Check that the [ReadTheDocs build](https://readthedocs.org/projects/xray/) is passing on the `latest` build version (which is built from the `main` branch). + 9. Issue the release on GitHub. Click on "Draft a new release" at + . Type in the version number (with a "v") + and paste the release summary in the notes. + 10. This should automatically trigger an upload of the new build to PyPI via GitHub Actions. + Check this has run [here](https://github.com/pydata/xarray/actions/workflows/pypi-release.yaml), + and that the version number you expect is displayed [on PyPI](https://pypi.org/project/xarray/) +11. Add a section for the next release {YYYY.MM.X+1} to doc/whats-new.rst (we avoid doing this earlier so that it doesn't show up in the RTD build): + ```rst + .. _whats-new.YYYY.MM.X+1: + + vYYYY.MM.X+1 (unreleased) + ----------------------- + + New Features + ~~~~~~~~~~~~ + + + Breaking changes + ~~~~~~~~~~~~~~~~ + + + Deprecations + ~~~~~~~~~~~~ + + + Bug fixes + ~~~~~~~~~ + + + Documentation + ~~~~~~~~~~~~~ + + + Internal Changes + ~~~~~~~~~~~~~~~~ + + ``` +12. Commit your changes and push to main again: + ```sh + git commit -am 'New whatsnew section' + git push upstream main + ``` + You're done pushing to main! + +13. Update the version available on pyodide: + - Open the PyPI page for [Xarray downloads](https://pypi.org/project/xarray/#files) + - Edit [`pyodide/packages/xarray/meta.yaml`](https://github.com/pyodide/pyodide/blob/main/packages/xarray/meta.yaml) to update the + - version number + - link to the wheel (under "Built Distribution" on the PyPI page) + - SHA256 hash (Click "Show Hashes" next to the link to the wheel) + - Open a pull request to pyodide + +14. Issue the release announcement to mailing lists & Twitter. For bug fix releases, I + usually only email xarray@googlegroups.com. For major/feature releases, I will email a broader + list (no more than once every 3-6 months): + - pydata@googlegroups.com + - xarray@googlegroups.com + - numpy-discussion@scipy.org + - scipy-user@scipy.org + - pyaos@lists.johnny-lin.com + + Google search will turn up examples of prior release announcements (look for + "ANN xarray"). + Some of these groups require you to be subscribed in order to email them. + + + +## Note on version numbering + +As of 2022.03.0, we utilize the [CALVER](https://calver.org/) version system. +Specifically, we have adopted the pattern `YYYY.MM.X`, where `YYYY` is a 4-digit +year (e.g. `2022`), `0M` is a 2-digit zero-padded month (e.g. `01` for January), and `X` is the release number (starting at zero at the start of each month and incremented once for each additional release). diff --git a/test/fixtures/whole_applications/xarray/LICENSE b/test/fixtures/whole_applications/xarray/LICENSE new file mode 100644 index 0000000..37ec93a --- /dev/null +++ b/test/fixtures/whole_applications/xarray/LICENSE @@ -0,0 +1,191 @@ +Apache License +Version 2.0, January 2004 +http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + +"License" shall mean the terms and conditions for use, reproduction, and +distribution as defined by Sections 1 through 9 of this document. + +"Licensor" shall mean the copyright owner or entity authorized by the copyright +owner that is granting the License. + +"Legal Entity" shall mean the union of the acting entity and all other entities +that control, are controlled by, or are under common control with that entity. +For the purposes of this definition, "control" means (i) the power, direct or +indirect, to cause the direction or management of such entity, whether by +contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the +outstanding shares, or (iii) beneficial ownership of such entity. + +"You" (or "Your") shall mean an individual or Legal Entity exercising +permissions granted by this License. + +"Source" form shall mean the preferred form for making modifications, including +but not limited to software source code, documentation source, and configuration +files. + +"Object" form shall mean any form resulting from mechanical transformation or +translation of a Source form, including but not limited to compiled object code, +generated documentation, and conversions to other media types. + +"Work" shall mean the work of authorship, whether in Source or Object form, made +available under the License, as indicated by a copyright notice that is included +in or attached to the work (an example is provided in the Appendix below). + +"Derivative Works" shall mean any work, whether in Source or Object form, that +is based on (or derived from) the Work and for which the editorial revisions, +annotations, elaborations, or other modifications represent, as a whole, an +original work of authorship. For the purposes of this License, Derivative Works +shall not include works that remain separable from, or merely link (or bind by +name) to the interfaces of, the Work and Derivative Works thereof. + +"Contribution" shall mean any work of authorship, including the original version +of the Work and any modifications or additions to that Work or Derivative Works +thereof, that is intentionally submitted to Licensor for inclusion in the Work +by the copyright owner or by an individual or Legal Entity authorized to submit +on behalf of the copyright owner. For the purposes of this definition, +"submitted" means any form of electronic, verbal, or written communication sent +to the Licensor or its representatives, including but not limited to +communication on electronic mailing lists, source code control systems, and +issue tracking systems that are managed by, or on behalf of, the Licensor for +the purpose of discussing and improving the Work, but excluding communication +that is conspicuously marked or otherwise designated in writing by the copyright +owner as "Not a Contribution." + +"Contributor" shall mean Licensor and any individual or Legal Entity on behalf +of whom a Contribution has been received by Licensor and subsequently +incorporated within the Work. + +2. Grant of Copyright License. + +Subject to the terms and conditions of this License, each Contributor hereby +grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, +irrevocable copyright license to reproduce, prepare Derivative Works of, +publicly display, publicly perform, sublicense, and distribute the Work and such +Derivative Works in Source or Object form. + +3. Grant of Patent License. + +Subject to the terms and conditions of this License, each Contributor hereby +grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, +irrevocable (except as stated in this section) patent license to make, have +made, use, offer to sell, sell, import, and otherwise transfer the Work, where +such license applies only to those patent claims licensable by such Contributor +that are necessarily infringed by their Contribution(s) alone or by combination +of their Contribution(s) with the Work to which such Contribution(s) was +submitted. If You institute patent litigation against any entity (including a +cross-claim or counterclaim in a lawsuit) alleging that the Work or a +Contribution incorporated within the Work constitutes direct or contributory +patent infringement, then any patent licenses granted to You under this License +for that Work shall terminate as of the date such litigation is filed. + +4. Redistribution. + +You may reproduce and distribute copies of the Work or Derivative Works thereof +in any medium, with or without modifications, and in Source or Object form, +provided that You meet the following conditions: + +You must give any other recipients of the Work or Derivative Works a copy of +this License; and +You must cause any modified files to carry prominent notices stating that You +changed the files; and +You must retain, in the Source form of any Derivative Works that You distribute, +all copyright, patent, trademark, and attribution notices from the Source form +of the Work, excluding those notices that do not pertain to any part of the +Derivative Works; and +If the Work includes a "NOTICE" text file as part of its distribution, then any +Derivative Works that You distribute must include a readable copy of the +attribution notices contained within such NOTICE file, excluding those notices +that do not pertain to any part of the Derivative Works, in at least one of the +following places: within a NOTICE text file distributed as part of the +Derivative Works; within the Source form or documentation, if provided along +with the Derivative Works; or, within a display generated by the Derivative +Works, if and wherever such third-party notices normally appear. The contents of +the NOTICE file are for informational purposes only and do not modify the +License. You may add Your own attribution notices within Derivative Works that +You distribute, alongside or as an addendum to the NOTICE text from the Work, +provided that such additional attribution notices cannot be construed as +modifying the License. +You may add Your own copyright statement to Your modifications and may provide +additional or different license terms and conditions for use, reproduction, or +distribution of Your modifications, or for any such Derivative Works as a whole, +provided Your use, reproduction, and distribution of the Work otherwise complies +with the conditions stated in this License. + +5. Submission of Contributions. + +Unless You explicitly state otherwise, any Contribution intentionally submitted +for inclusion in the Work by You to the Licensor shall be under the terms and +conditions of this License, without any additional terms or conditions. +Notwithstanding the above, nothing herein shall supersede or modify the terms of +any separate license agreement you may have executed with Licensor regarding +such Contributions. + +6. Trademarks. + +This License does not grant permission to use the trade names, trademarks, +service marks, or product names of the Licensor, except as required for +reasonable and customary use in describing the origin of the Work and +reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. + +Unless required by applicable law or agreed to in writing, Licensor provides the +Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, +including, without limitation, any warranties or conditions of TITLE, +NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are +solely responsible for determining the appropriateness of using or +redistributing the Work and assume any risks associated with Your exercise of +permissions under this License. + +8. Limitation of Liability. + +In no event and under no legal theory, whether in tort (including negligence), +contract, or otherwise, unless required by applicable law (such as deliberate +and grossly negligent acts) or agreed to in writing, shall any Contributor be +liable to You for damages, including any direct, indirect, special, incidental, +or consequential damages of any character arising as a result of this License or +out of the use or inability to use the Work (including but not limited to +damages for loss of goodwill, work stoppage, computer failure or malfunction, or +any and all other commercial damages or losses), even if such Contributor has +been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. + +While redistributing the Work or Derivative Works thereof, You may choose to +offer, and charge a fee for, acceptance of support, warranty, indemnity, or +other liability obligations and/or rights consistent with this License. However, +in accepting such obligations, You may act only on Your own behalf and on Your +sole responsibility, not on behalf of any other Contributor, and only if You +agree to indemnify, defend, and hold each Contributor harmless for any liability +incurred by, or claims asserted against, such Contributor by reason of your +accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work + +To apply the Apache License to your work, attach the following boilerplate +notice, with the fields enclosed by brackets "[]" replaced with your own +identifying information. (Don't include the brackets!) The text should be +enclosed in the appropriate comment syntax for the file format. We also +recommend that a file or class name and description of purpose be included on +the same "printed page" as the copyright notice for easier identification within +third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/test/fixtures/whole_applications/xarray/MANIFEST.in b/test/fixtures/whole_applications/xarray/MANIFEST.in new file mode 100644 index 0000000..a119e7d --- /dev/null +++ b/test/fixtures/whole_applications/xarray/MANIFEST.in @@ -0,0 +1,2 @@ +prune xarray/datatree_* +recursive-include xarray/datatree_/datatree *.py diff --git a/test/fixtures/whole_applications/xarray/README.md b/test/fixtures/whole_applications/xarray/README.md new file mode 100644 index 0000000..432d535 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/README.md @@ -0,0 +1,138 @@ +# xarray: N-D labeled arrays and datasets + +[![CI](https://github.com/pydata/xarray/workflows/CI/badge.svg?branch=main)](https://github.com/pydata/xarray/actions?query=workflow%3ACI) +[![Code coverage](https://codecov.io/gh/pydata/xarray/branch/main/graph/badge.svg?flag=unittests)](https://codecov.io/gh/pydata/xarray) +[![Docs](https://readthedocs.org/projects/xray/badge/?version=latest)](https://docs.xarray.dev/) +[![Benchmarked with asv](https://img.shields.io/badge/benchmarked%20by-asv-green.svg?style=flat)](https://pandas.pydata.org/speed/xarray/) +[![Available on pypi](https://img.shields.io/pypi/v/xarray.svg)](https://pypi.python.org/pypi/xarray/) +[![Formatted with black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/python/black) +[![Checked with mypy](http://www.mypy-lang.org/static/mypy_badge.svg)](http://mypy-lang.org/) +[![Mirror on zendoo](https://zenodo.org/badge/DOI/10.5281/zenodo.598201.svg)](https://doi.org/10.5281/zenodo.598201) +[![Examples on binder](https://img.shields.io/badge/launch-binder-579ACA.svg?logo=)](https://mybinder.org/v2/gh/pydata/xarray/main?urlpath=lab/tree/doc/examples/weather-data.ipynb) +[![Twitter](https://img.shields.io/twitter/follow/xarray_dev?style=social)](https://twitter.com/xarray_dev) + +**xarray** (pronounced "ex-array", formerly known as **xray**) is an open source project and Python +package that makes working with labelled multi-dimensional arrays +simple, efficient, and fun! + +Xarray introduces labels in the form of dimensions, coordinates and +attributes on top of raw [NumPy](https://www.numpy.org)-like arrays, +which allows for a more intuitive, more concise, and less error-prone +developer experience. The package includes a large and growing library +of domain-agnostic functions for advanced analytics and visualization +with these data structures. + +Xarray was inspired by and borrows heavily from +[pandas](https://pandas.pydata.org), the popular data analysis package +focused on labelled tabular data. It is particularly tailored to working +with [netCDF](https://www.unidata.ucar.edu/software/netcdf) files, which +were the source of xarray\'s data model, and integrates tightly with +[dask](https://dask.org) for parallel computing. + +## Why xarray? + +Multi-dimensional (a.k.a. N-dimensional, ND) arrays (sometimes called +"tensors") are an essential part of computational science. They are +encountered in a wide range of fields, including physics, astronomy, +geoscience, bioinformatics, engineering, finance, and deep learning. In +Python, [NumPy](https://www.numpy.org) provides the fundamental data +structure and API for working with raw ND arrays. However, real-world +datasets are usually more than just raw numbers; they have labels which +encode information about how the array values map to locations in space, +time, etc. + +Xarray doesn\'t just keep track of labels on arrays \-- it uses them to +provide a powerful and concise interface. For example: + +- Apply operations over dimensions by name: `x.sum('time')`. +- Select values by label instead of integer location: + `x.loc['2014-01-01']` or `x.sel(time='2014-01-01')`. +- Mathematical operations (e.g., `x - y`) vectorize across multiple + dimensions (array broadcasting) based on dimension names, not shape. +- Flexible split-apply-combine operations with groupby: + `x.groupby('time.dayofyear').mean()`. +- Database like alignment based on coordinate labels that smoothly + handles missing values: `x, y = xr.align(x, y, join='outer')`. +- Keep track of arbitrary metadata in the form of a Python dictionary: + `x.attrs`. + +## Documentation + +Learn more about xarray in its official documentation at +. + +Try out an [interactive Jupyter +notebook](https://mybinder.org/v2/gh/pydata/xarray/main?urlpath=lab/tree/doc/examples/weather-data.ipynb). + +## Contributing + +You can find information about contributing to xarray at our +[Contributing +page](https://docs.xarray.dev/en/stable/contributing.html). + +## Get in touch + +- Ask usage questions ("How do I?") on + [GitHub Discussions](https://github.com/pydata/xarray/discussions). +- Report bugs, suggest features or view the source code [on + GitHub](https://github.com/pydata/xarray). +- For less well defined questions or ideas, or to announce other + projects of interest to xarray users, use the [mailing + list](https://groups.google.com/forum/#!forum/xarray). + +## NumFOCUS + + + +Xarray is a fiscally sponsored project of +[NumFOCUS](https://numfocus.org), a nonprofit dedicated to supporting +the open source scientific computing community. If you like Xarray and +want to support our mission, please consider making a +[donation](https://numfocus.salsalabs.org/donate-to-xarray/) to support +our efforts. + +## History + +Xarray is an evolution of an internal tool developed at [The Climate +Corporation](http://climate.com/). It was originally written by Climate +Corp researchers Stephan Hoyer, Alex Kleeman and Eugene Brevdo and was +released as open source in May 2014. The project was renamed from +"xray" in January 2016. Xarray became a fiscally sponsored project of +[NumFOCUS](https://numfocus.org) in August 2018. + +## Contributors + +Thanks to our many contributors! + +[![Contributors](https://contrib.rocks/image?repo=pydata/xarray)](https://github.com/pydata/xarray/graphs/contributors) + +## License + +Copyright 2014-2023, xarray Developers + +Licensed under the Apache License, Version 2.0 (the "License"); you +may not use this file except in compliance with the License. You may +obtain a copy of the License at + + + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Xarray bundles portions of pandas, NumPy and Seaborn, all of which are +available under a "3-clause BSD" license: + +- pandas: `setup.py`, `xarray/util/print_versions.py` +- NumPy: `xarray/core/npcompat.py` +- Seaborn: `_determine_cmap_params` in `xarray/core/plot/utils.py` + +Xarray also bundles portions of CPython, which is available under the +"Python Software Foundation License" in `xarray/core/pycompat.py`. + +Xarray uses icons from the icomoon package (free version), which is +available under the "CC BY 4.0" license. + +The full text of these licenses are included in the licenses directory. diff --git a/test/fixtures/whole_applications/xarray/asv_bench/asv.conf.json b/test/fixtures/whole_applications/xarray/asv_bench/asv.conf.json new file mode 100644 index 0000000..9dc86df --- /dev/null +++ b/test/fixtures/whole_applications/xarray/asv_bench/asv.conf.json @@ -0,0 +1,161 @@ +{ + // The version of the config file format. Do not change, unless + // you know what you are doing. + "version": 1, + + // The name of the project being benchmarked + "project": "xarray", + + // The project's homepage + "project_url": "http://docs.xarray.dev/", + + // The URL or local path of the source code repository for the + // project being benchmarked + "repo": "..", + + // List of branches to benchmark. If not provided, defaults to "master" + // (for git) or "default" (for mercurial). + "branches": ["main"], // for git + // "branches": ["default"], // for mercurial + + // The DVCS being used. If not set, it will be automatically + // determined from "repo" by looking at the protocol in the URL + // (if remote), or by looking for special directories, such as + // ".git" (if local). + "dvcs": "git", + + // The tool to use to create environments. May be "conda", + // "virtualenv" or other value depending on the plugins in use. + // If missing or the empty string, the tool will be automatically + // determined by looking for tools on the PATH environment + // variable. + "environment_type": "mamba", + "conda_channels": ["conda-forge"], + + // timeout in seconds for installing any dependencies in environment + // defaults to 10 min + "install_timeout": 600, + + // the base URL to show a commit for the project. + "show_commit_url": "https://github.com/pydata/xarray/commit/", + + // The Pythons you'd like to test against. If not provided, defaults + // to the current version of Python used to run `asv`. + "pythons": ["3.11"], + + // The matrix of dependencies to test. Each key is the name of a + // package (in PyPI) and the values are version numbers. An empty + // list or empty string indicates to just test against the default + // (latest) version. null indicates that the package is to not be + // installed. If the package to be tested is only available from + // PyPi, and the 'environment_type' is conda, then you can preface + // the package name by 'pip+', and the package will be installed via + // pip (with all the conda available packages installed first, + // followed by the pip installed packages). + // + // "matrix": { + // "numpy": ["1.6", "1.7"], + // "six": ["", null], // test with and without six installed + // "pip+emcee": [""], // emcee is only available for install with pip. + // }, + "matrix": { + "setuptools_scm": [""], // GH6609 + "numpy": [""], + "pandas": [""], + "netcdf4": [""], + "scipy": [""], + "bottleneck": [""], + "dask": [""], + "distributed": [""], + "flox": [""], + "numpy_groupies": [""], + "sparse": [""], + "cftime": [""] + }, + // fix for bad builds + // https://github.com/airspeed-velocity/asv/issues/1389#issuecomment-2076131185 + "build_command": [ + "python -m build", + "python -mpip wheel --no-deps --no-build-isolation --no-index -w {build_cache_dir} {build_dir}" + ], + // Combinations of libraries/python versions can be excluded/included + // from the set to test. Each entry is a dictionary containing additional + // key-value pairs to include/exclude. + // + // An exclude entry excludes entries where all values match. The + // values are regexps that should match the whole string. + // + // An include entry adds an environment. Only the packages listed + // are installed. The 'python' key is required. The exclude rules + // do not apply to includes. + // + // In addition to package names, the following keys are available: + // + // - python + // Python version, as in the *pythons* variable above. + // - environment_type + // Environment type, as above. + // - sys_platform + // Platform, as in sys.platform. Possible values for the common + // cases: 'linux2', 'win32', 'cygwin', 'darwin'. + // + // "exclude": [ + // {"python": "3.2", "sys_platform": "win32"}, // skip py3.2 on windows + // {"environment_type": "conda", "six": null}, // don't run without six on conda + // ], + // + // "include": [ + // // additional env for python2.7 + // {"python": "2.7", "numpy": "1.8"}, + // // additional env if run on windows+conda + // {"platform": "win32", "environment_type": "conda", "python": "2.7", "libpython": ""}, + // ], + + // The directory (relative to the current directory) that benchmarks are + // stored in. If not provided, defaults to "benchmarks" + "benchmark_dir": "benchmarks", + + // The directory (relative to the current directory) to cache the Python + // environments in. If not provided, defaults to "env" + "env_dir": ".asv/env", + + // The directory (relative to the current directory) that raw benchmark + // results are stored in. If not provided, defaults to "results". + "results_dir": ".asv/results", + + // The directory (relative to the current directory) that the html tree + // should be written to. If not provided, defaults to "html". + "html_dir": ".asv/html", + + // The number of characters to retain in the commit hashes. + // "hash_length": 8, + + // `asv` will cache wheels of the recent builds in each + // environment, making them faster to install next time. This is + // number of builds to keep, per environment. + // "wheel_cache_size": 0 + + // The commits after which the regression search in `asv publish` + // should start looking for regressions. Dictionary whose keys are + // regexps matching to benchmark names, and values corresponding to + // the commit (exclusive) after which to start looking for + // regressions. The default is to start from the first commit + // with results. If the commit is `null`, regression detection is + // skipped for the matching benchmark. + // + // "regressions_first_commits": { + // "some_benchmark": "352cdf", // Consider regressions only after this commit + // "another_benchmark": null, // Skip regression detection altogether + // } + + // The thresholds for relative change in results, after which `asv + // publish` starts reporting regressions. Dictionary of the same + // form as in ``regressions_first_commits``, with values + // indicating the thresholds. If multiple entries match, the + // maximum is taken. If no entry matches, the default is 5%. + // + // "regressions_thresholds": { + // "some_benchmark": 0.01, // Threshold of 1% + // "another_benchmark": 0.5, // Threshold of 50% + // } +} diff --git a/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/README_CI.md b/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/README_CI.md new file mode 100644 index 0000000..9d86cc2 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/README_CI.md @@ -0,0 +1,122 @@ +# Benchmark CI + + + + + +## How it works + +The `asv` suite can be run for any PR on GitHub Actions (check workflow `.github/workflows/benchmarks.yml`) by adding a `run-benchmark` label to said PR. This will trigger a job that will run the benchmarking suite for the current PR head (merged commit) against the PR base (usually `main`). + +We use `asv continuous` to run the job, which runs a relative performance measurement. This means that there's no state to be saved and that regressions are only caught in terms of performance ratio (absolute numbers are available but they are not useful since we do not use stable hardware over time). `asv continuous` will: + +* Compile `scikit-image` for _both_ commits. We use `ccache` to speed up the process, and `mamba` is used to create the build environments. +* Run the benchmark suite for both commits, _twice_ (since `processes=2` by default). +* Generate a report table with performance ratios: + * `ratio=1.0` -> performance didn't change. + * `ratio<1.0` -> PR made it slower. + * `ratio>1.0` -> PR made it faster. + +Due to the sensitivity of the test, we cannot guarantee that false positives are not produced. In practice, values between `(0.7, 1.5)` are to be considered part of the measurement noise. When in doubt, running the benchmark suite one more time will provide more information about the test being a false positive or not. + +## Running the benchmarks on GitHub Actions + +1. On a PR, add the label `run-benchmark`. +2. The CI job will be started. Checks will appear in the usual dashboard panel above the comment box. +3. If more commits are added, the label checks will be grouped with the last commit checks _before_ you added the label. +4. Alternatively, you can always go to the `Actions` tab in the repo and [filter for `workflow:Benchmark`](https://github.com/scikit-image/scikit-image/actions?query=workflow%3ABenchmark). Your username will be assigned to the `actor` field, so you can also filter the results with that if you need it. + +## The artifacts + +The CI job will also generate an artifact. This is the `.asv/results` directory compressed in a zip file. Its contents include: + +* `fv-xxxxx-xx/`. A directory for the machine that ran the suite. It contains three files: + * `.json`, `.json`: the benchmark results for each commit, with stats. + * `machine.json`: details about the hardware. +* `benchmarks.json`: metadata about the current benchmark suite. +* `benchmarks.log`: the CI logs for this run. +* This README. + +## Re-running the analysis + +Although the CI logs should be enough to get an idea of what happened (check the table at the end), one can use `asv` to run the analysis routines again. + +1. Uncompress the artifact contents in the repo, under `.asv/results`. This is, you should see `.asv/results/benchmarks.log`, not `.asv/results/something_else/benchmarks.log`. Write down the machine directory name for later. +2. Run `asv show` to see your available results. You will see something like this: + +``` +$> asv show + +Commits with results: + +Machine : Jaimes-MBP +Environment: conda-py3.9-cython-numpy1.20-scipy + + 00875e67 + +Machine : fv-az95-499 +Environment: conda-py3.7-cython-numpy1.17-pooch-scipy + + 8db28f02 + 3a305096 +``` + +3. We are interested in the commits for `fv-az95-499` (the CI machine for this run). We can compare them with `asv compare` and some extra options. `--sort ratio` will show largest ratios first, instead of alphabetical order. `--split` will produce three tables: improved, worsened, no changes. `--factor 1.5` tells `asv` to only complain if deviations are above a 1.5 ratio. `-m` is used to indicate the machine ID (use the one you wrote down in step 1). Finally, specify your commit hashes: baseline first, then contender! + +``` +$> asv compare --sort ratio --split --factor 1.5 -m fv-az95-499 8db28f02 3a305096 + +Benchmarks that have stayed the same: + + before after ratio + [8db28f02] [3a305096] + + n/a n/a n/a benchmark_restoration.RollingBall.time_rollingball_ndim + 1.23±0.04ms 1.37±0.1ms 1.12 benchmark_transform_warp.WarpSuite.time_to_float64(, 128, 3) + 5.07±0.1μs 5.59±0.4μs 1.10 benchmark_transform_warp.ResizeLocalMeanSuite.time_resize_local_mean(, (192, 192, 192), (192, 192, 192)) + 1.23±0.02ms 1.33±0.1ms 1.08 benchmark_transform_warp.WarpSuite.time_same_type(, 128, 3) + 9.45±0.2ms 10.1±0.5ms 1.07 benchmark_rank.Rank3DSuite.time_3d_filters('majority', (32, 32, 32)) + 23.0±0.9ms 24.6±1ms 1.07 benchmark_interpolation.InterpolationResize.time_resize((80, 80, 80), 0, 'symmetric', , True) + 38.7±1ms 41.1±1ms 1.06 benchmark_transform_warp.ResizeLocalMeanSuite.time_resize_local_mean(, (2048, 2048), (192, 192, 192)) + 4.97±0.2μs 5.24±0.2μs 1.05 benchmark_transform_warp.ResizeLocalMeanSuite.time_resize_local_mean(, (2048, 2048), (2048, 2048)) + 4.21±0.2ms 4.42±0.3ms 1.05 benchmark_rank.Rank3DSuite.time_3d_filters('gradient', (32, 32, 32)) + +... +``` + +If you want more details on a specific test, you can use `asv show`. Use `-b pattern` to filter which tests to show, and then specify a commit hash to inspect: + +``` +$> asv show -b time_to_float64 8db28f02 + +Commit: 8db28f02 + +benchmark_transform_warp.WarpSuite.time_to_float64 [fv-az95-499/conda-py3.7-cython-numpy1.17-pooch-scipy] + ok + =============== ============= ========== ============= ========== ============ ========== ============ ========== ============ + -- N / order + --------------- -------------------------------------------------------------------------------------------------------------- + dtype_in 128 / 0 128 / 1 128 / 3 1024 / 0 1024 / 1 1024 / 3 4096 / 0 4096 / 1 4096 / 3 + =============== ============= ========== ============= ========== ============ ========== ============ ========== ============ + numpy.uint8 2.56±0.09ms 523±30μs 1.28±0.05ms 130±3ms 28.7±2ms 81.9±3ms 2.42±0.01s 659±5ms 1.48±0.01s + numpy.uint16 2.48±0.03ms 530±10μs 1.28±0.02ms 130±1ms 30.4±0.7ms 81.1±2ms 2.44±0s 653±3ms 1.47±0.02s + numpy.float32 2.59±0.1ms 518±20μs 1.27±0.01ms 127±3ms 26.6±1ms 74.8±2ms 2.50±0.01s 546±10ms 1.33±0.02s + numpy.float64 2.48±0.04ms 513±50μs 1.23±0.04ms 134±3ms 30.7±2ms 85.4±2ms 2.55±0.01s 632±4ms 1.45±0.01s + =============== ============= ========== ============= ========== ============ ========== ============ ========== ============ + started: 2021-07-06 06:14:36, duration: 1.99m +``` + +## Other details + +### Skipping slow or demanding tests + +To minimize the time required to run the full suite, we trimmed the parameter matrix in some cases and, in others, directly skipped tests that ran for too long or require too much memory. Unlike `pytest`, `asv` does not have a notion of marks. However, you can `raise NotImplementedError` in the setup step to skip a test. In that vein, a new private function is defined at `benchmarks.__init__`: `_skip_slow`. This will check if the `ASV_SKIP_SLOW` environment variable has been defined. If set to `1`, it will raise `NotImplementedError` and skip the test. To implement this behavior in other tests, you can add the following attribute: + +```python +from . import _skip_slow # this function is defined in benchmarks.__init__ + +def time_something_slow(): + pass + +time_something.setup = _skip_slow +``` diff --git a/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/__init__.py b/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/__init__.py new file mode 100644 index 0000000..aa600c8 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/__init__.py @@ -0,0 +1,74 @@ +import itertools +import os + +import numpy as np + +_counter = itertools.count() + + +def parameterized(names, params): + def decorator(func): + func.param_names = names + func.params = params + return func + + return decorator + + +def requires_dask(): + try: + import dask # noqa: F401 + except ImportError: + raise NotImplementedError() + + +def requires_sparse(): + try: + import sparse # noqa: F401 + except ImportError: + raise NotImplementedError() + + +def randn(shape, frac_nan=None, chunks=None, seed=0): + rng = np.random.RandomState(seed) + if chunks is None: + x = rng.standard_normal(shape) + else: + import dask.array as da + + rng = da.random.RandomState(seed) + x = rng.standard_normal(shape, chunks=chunks) + + if frac_nan is not None: + inds = rng.choice(range(x.size), int(x.size * frac_nan)) + x.flat[inds] = np.nan + + return x + + +def randint(low, high=None, size=None, frac_minus=None, seed=0): + rng = np.random.RandomState(seed) + x = rng.randint(low, high, size) + if frac_minus is not None: + inds = rng.choice(range(x.size), int(x.size * frac_minus)) + x.flat[inds] = -1 + + return x + + +def _skip_slow(): + """ + Use this function to skip slow or highly demanding tests. + + Use it as a `Class.setup` method or a `function.setup` attribute. + + Examples + -------- + >>> from . import _skip_slow + >>> def time_something_slow(): + ... pass + ... + >>> time_something.setup = _skip_slow + """ + if os.environ.get("ASV_SKIP_SLOW", "0") == "1": + raise NotImplementedError("Skipping this test...") diff --git a/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/accessors.py b/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/accessors.py new file mode 100644 index 0000000..f9eb958 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/accessors.py @@ -0,0 +1,25 @@ +import numpy as np + +import xarray as xr + +from . import parameterized + +NTIME = 365 * 30 + + +@parameterized(["calendar"], [("standard", "noleap")]) +class DateTimeAccessor: + def setup(self, calendar): + np.random.randn(NTIME) + time = xr.date_range("2000", periods=30 * 365, calendar=calendar) + data = np.ones((NTIME,)) + self.da = xr.DataArray(data, dims="time", coords={"time": time}) + + def time_dayofyear(self, calendar): + self.da.time.dt.dayofyear + + def time_year(self, calendar): + self.da.time.dt.year + + def time_floor(self, calendar): + self.da.time.dt.floor("D") diff --git a/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/alignment.py b/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/alignment.py new file mode 100644 index 0000000..5a6ee3f --- /dev/null +++ b/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/alignment.py @@ -0,0 +1,54 @@ +import numpy as np + +import xarray as xr + +from . import parameterized, requires_dask + +ntime = 365 * 30 +nx = 50 +ny = 50 + +rng = np.random.default_rng(0) + + +class Align: + def setup(self, *args, **kwargs): + data = rng.standard_normal((ntime, nx, ny)) + self.ds = xr.Dataset( + {"temperature": (("time", "x", "y"), data)}, + coords={ + "time": xr.date_range("2000", periods=ntime), + "x": np.arange(nx), + "y": np.arange(ny), + }, + ) + self.year = self.ds.time.dt.year + self.idx = np.unique(rng.integers(low=0, high=ntime, size=ntime // 2)) + self.year_subset = self.year.isel(time=self.idx) + + @parameterized(["join"], [("outer", "inner", "left", "right", "exact", "override")]) + def time_already_aligned(self, join): + xr.align(self.ds, self.year, join=join) + + @parameterized(["join"], [("outer", "inner", "left", "right")]) + def time_not_aligned(self, join): + xr.align(self.ds, self.year[-100:], join=join) + + @parameterized(["join"], [("outer", "inner", "left", "right")]) + def time_not_aligned_random_integers(self, join): + xr.align(self.ds, self.year_subset, join=join) + + +class AlignCFTime(Align): + def setup(self, *args, **kwargs): + super().setup() + self.ds["time"] = xr.date_range("2000", periods=ntime, calendar="noleap") + self.year = self.ds.time.dt.year + self.year_subset = self.year.isel(time=self.idx) + + +class AlignDask(Align): + def setup(self, *args, **kwargs): + requires_dask() + super().setup() + self.ds = self.ds.chunk({"time": 100}) diff --git a/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/combine.py b/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/combine.py new file mode 100644 index 0000000..772d888 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/combine.py @@ -0,0 +1,79 @@ +import numpy as np + +import xarray as xr + +from . import requires_dask + + +class Combine1d: + """Benchmark concatenating and merging large datasets""" + + def setup(self) -> None: + """Create 2 datasets with two different variables""" + + t_size = 8000 + t = np.arange(t_size) + data = np.random.randn(t_size) + + self.dsA0 = xr.Dataset({"A": xr.DataArray(data, coords={"T": t}, dims=("T"))}) + self.dsA1 = xr.Dataset( + {"A": xr.DataArray(data, coords={"T": t + t_size}, dims=("T"))} + ) + + def time_combine_by_coords(self) -> None: + """Also has to load and arrange t coordinate""" + datasets = [self.dsA0, self.dsA1] + + xr.combine_by_coords(datasets) + + +class Combine1dDask(Combine1d): + """Benchmark concatenating and merging large datasets""" + + def setup(self) -> None: + """Create 2 datasets with two different variables""" + requires_dask() + + t_size = 8000 + t = np.arange(t_size) + var = xr.Variable(dims=("T",), data=np.random.randn(t_size)).chunk() + + data_vars = {f"long_name_{v}": ("T", var) for v in range(500)} + + self.dsA0 = xr.Dataset(data_vars, coords={"T": t}) + self.dsA1 = xr.Dataset(data_vars, coords={"T": t + t_size}) + + +class Combine3d: + """Benchmark concatenating and merging large datasets""" + + def setup(self): + """Create 4 datasets with two different variables""" + + t_size, x_size, y_size = 50, 450, 400 + t = np.arange(t_size) + data = np.random.randn(t_size, x_size, y_size) + + self.dsA0 = xr.Dataset( + {"A": xr.DataArray(data, coords={"T": t}, dims=("T", "X", "Y"))} + ) + self.dsA1 = xr.Dataset( + {"A": xr.DataArray(data, coords={"T": t + t_size}, dims=("T", "X", "Y"))} + ) + self.dsB0 = xr.Dataset( + {"B": xr.DataArray(data, coords={"T": t}, dims=("T", "X", "Y"))} + ) + self.dsB1 = xr.Dataset( + {"B": xr.DataArray(data, coords={"T": t + t_size}, dims=("T", "X", "Y"))} + ) + + def time_combine_nested(self): + datasets = [[self.dsA0, self.dsA1], [self.dsB0, self.dsB1]] + + xr.combine_nested(datasets, concat_dim=[None, "T"]) + + def time_combine_by_coords(self): + """Also has to load and arrange t coordinate""" + datasets = [self.dsA0, self.dsA1, self.dsB0, self.dsB1] + + xr.combine_by_coords(datasets) diff --git a/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/dataarray_missing.py b/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/dataarray_missing.py new file mode 100644 index 0000000..83de65b --- /dev/null +++ b/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/dataarray_missing.py @@ -0,0 +1,72 @@ +import pandas as pd + +import xarray as xr + +from . import parameterized, randn, requires_dask + + +def make_bench_data(shape, frac_nan, chunks): + vals = randn(shape, frac_nan) + coords = {"time": pd.date_range("2000-01-01", freq="D", periods=shape[0])} + da = xr.DataArray(vals, dims=("time", "x", "y"), coords=coords) + + if chunks is not None: + da = da.chunk(chunks) + + return da + + +class DataArrayMissingInterpolateNA: + def setup(self, shape, chunks, limit): + if chunks is not None: + requires_dask() + self.da = make_bench_data(shape, 0.1, chunks) + + @parameterized( + ["shape", "chunks", "limit"], + ( + [(365, 75, 75)], + [None, {"x": 25, "y": 25}], + [None, 3], + ), + ) + def time_interpolate_na(self, shape, chunks, limit): + actual = self.da.interpolate_na(dim="time", method="linear", limit=limit) + + if chunks is not None: + actual = actual.compute() + + +class DataArrayMissingBottleneck: + def setup(self, shape, chunks, limit): + if chunks is not None: + requires_dask() + self.da = make_bench_data(shape, 0.1, chunks) + + @parameterized( + ["shape", "chunks", "limit"], + ( + [(365, 75, 75)], + [None, {"x": 25, "y": 25}], + [None, 3], + ), + ) + def time_ffill(self, shape, chunks, limit): + actual = self.da.ffill(dim="time", limit=limit) + + if chunks is not None: + actual = actual.compute() + + @parameterized( + ["shape", "chunks", "limit"], + ( + [(365, 75, 75)], + [None, {"x": 25, "y": 25}], + [None, 3], + ), + ) + def time_bfill(self, shape, chunks, limit): + actual = self.da.bfill(dim="time", limit=limit) + + if chunks is not None: + actual = actual.compute() diff --git a/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/dataset.py b/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/dataset.py new file mode 100644 index 0000000..d8a6d6d --- /dev/null +++ b/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/dataset.py @@ -0,0 +1,32 @@ +import numpy as np + +from xarray import Dataset + +from . import requires_dask + + +class DatasetBinaryOp: + def setup(self): + self.ds = Dataset( + { + "a": (("x", "y"), np.ones((300, 400))), + "b": (("x", "y"), np.ones((300, 400))), + } + ) + self.mean = self.ds.mean() + self.std = self.ds.std() + + def time_normalize(self): + (self.ds - self.mean) / self.std + + +class DatasetChunk: + def setup(self): + requires_dask() + self.ds = Dataset() + array = np.ones(1000) + for i in range(250): + self.ds[f"var{i}"] = ("x", array) + + def time_chunk(self): + self.ds.chunk(x=(1,) * 1000) diff --git a/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/dataset_io.py b/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/dataset_io.py new file mode 100644 index 0000000..dcc2de0 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/dataset_io.py @@ -0,0 +1,652 @@ +from __future__ import annotations + +import os +from dataclasses import dataclass + +import numpy as np +import pandas as pd + +import xarray as xr + +from . import _skip_slow, parameterized, randint, randn, requires_dask + +try: + import dask + import dask.multiprocessing +except ImportError: + pass + + +os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE" + +_ENGINES = tuple(xr.backends.list_engines().keys() - {"store"}) + + +class IOSingleNetCDF: + """ + A few examples that benchmark reading/writing a single netCDF file with + xarray + """ + + timeout = 300.0 + repeat = 1 + number = 5 + + def make_ds(self): + # single Dataset + self.ds = xr.Dataset() + self.nt = 1000 + self.nx = 90 + self.ny = 45 + + self.block_chunks = { + "time": self.nt / 4, + "lon": self.nx / 3, + "lat": self.ny / 3, + } + + self.time_chunks = {"time": int(self.nt / 36)} + + times = pd.date_range("1970-01-01", periods=self.nt, freq="D") + lons = xr.DataArray( + np.linspace(0, 360, self.nx), + dims=("lon",), + attrs={"units": "degrees east", "long_name": "longitude"}, + ) + lats = xr.DataArray( + np.linspace(-90, 90, self.ny), + dims=("lat",), + attrs={"units": "degrees north", "long_name": "latitude"}, + ) + self.ds["foo"] = xr.DataArray( + randn((self.nt, self.nx, self.ny), frac_nan=0.2), + coords={"lon": lons, "lat": lats, "time": times}, + dims=("time", "lon", "lat"), + name="foo", + attrs={"units": "foo units", "description": "a description"}, + ) + self.ds["bar"] = xr.DataArray( + randn((self.nt, self.nx, self.ny), frac_nan=0.2), + coords={"lon": lons, "lat": lats, "time": times}, + dims=("time", "lon", "lat"), + name="bar", + attrs={"units": "bar units", "description": "a description"}, + ) + self.ds["baz"] = xr.DataArray( + randn((self.nx, self.ny), frac_nan=0.2).astype(np.float32), + coords={"lon": lons, "lat": lats}, + dims=("lon", "lat"), + name="baz", + attrs={"units": "baz units", "description": "a description"}, + ) + + self.ds.attrs = {"history": "created for xarray benchmarking"} + + self.oinds = { + "time": randint(0, self.nt, 120), + "lon": randint(0, self.nx, 20), + "lat": randint(0, self.ny, 10), + } + self.vinds = { + "time": xr.DataArray(randint(0, self.nt, 120), dims="x"), + "lon": xr.DataArray(randint(0, self.nx, 120), dims="x"), + "lat": slice(3, 20), + } + + +class IOWriteSingleNetCDF3(IOSingleNetCDF): + def setup(self): + # TODO: Lazily skipped in CI as it is very demanding and slow. + # Improve times and remove errors. + _skip_slow() + + self.format = "NETCDF3_64BIT" + self.make_ds() + + def time_write_dataset_netcdf4(self): + self.ds.to_netcdf("test_netcdf4_write.nc", engine="netcdf4", format=self.format) + + def time_write_dataset_scipy(self): + self.ds.to_netcdf("test_scipy_write.nc", engine="scipy", format=self.format) + + +class IOReadSingleNetCDF4(IOSingleNetCDF): + def setup(self): + # TODO: Lazily skipped in CI as it is very demanding and slow. + # Improve times and remove errors. + _skip_slow() + + self.make_ds() + + self.filepath = "test_single_file.nc4.nc" + self.format = "NETCDF4" + self.ds.to_netcdf(self.filepath, format=self.format) + + def time_load_dataset_netcdf4(self): + xr.open_dataset(self.filepath, engine="netcdf4").load() + + def time_orthogonal_indexing(self): + ds = xr.open_dataset(self.filepath, engine="netcdf4") + ds = ds.isel(**self.oinds).load() + + def time_vectorized_indexing(self): + ds = xr.open_dataset(self.filepath, engine="netcdf4") + ds = ds.isel(**self.vinds).load() + + +class IOReadSingleNetCDF3(IOReadSingleNetCDF4): + def setup(self): + # TODO: Lazily skipped in CI as it is very demanding and slow. + # Improve times and remove errors. + _skip_slow() + + self.make_ds() + + self.filepath = "test_single_file.nc3.nc" + self.format = "NETCDF3_64BIT" + self.ds.to_netcdf(self.filepath, format=self.format) + + def time_load_dataset_scipy(self): + xr.open_dataset(self.filepath, engine="scipy").load() + + def time_orthogonal_indexing(self): + ds = xr.open_dataset(self.filepath, engine="scipy") + ds = ds.isel(**self.oinds).load() + + def time_vectorized_indexing(self): + ds = xr.open_dataset(self.filepath, engine="scipy") + ds = ds.isel(**self.vinds).load() + + +class IOReadSingleNetCDF4Dask(IOSingleNetCDF): + def setup(self): + # TODO: Lazily skipped in CI as it is very demanding and slow. + # Improve times and remove errors. + _skip_slow() + + requires_dask() + + self.make_ds() + + self.filepath = "test_single_file.nc4.nc" + self.format = "NETCDF4" + self.ds.to_netcdf(self.filepath, format=self.format) + + def time_load_dataset_netcdf4_with_block_chunks(self): + xr.open_dataset( + self.filepath, engine="netcdf4", chunks=self.block_chunks + ).load() + + def time_load_dataset_netcdf4_with_block_chunks_oindexing(self): + ds = xr.open_dataset(self.filepath, engine="netcdf4", chunks=self.block_chunks) + ds = ds.isel(**self.oinds).load() + + def time_load_dataset_netcdf4_with_block_chunks_vindexing(self): + ds = xr.open_dataset(self.filepath, engine="netcdf4", chunks=self.block_chunks) + ds = ds.isel(**self.vinds).load() + + def time_load_dataset_netcdf4_with_block_chunks_multiprocessing(self): + with dask.config.set(scheduler="multiprocessing"): + xr.open_dataset( + self.filepath, engine="netcdf4", chunks=self.block_chunks + ).load() + + def time_load_dataset_netcdf4_with_time_chunks(self): + xr.open_dataset(self.filepath, engine="netcdf4", chunks=self.time_chunks).load() + + def time_load_dataset_netcdf4_with_time_chunks_multiprocessing(self): + with dask.config.set(scheduler="multiprocessing"): + xr.open_dataset( + self.filepath, engine="netcdf4", chunks=self.time_chunks + ).load() + + +class IOReadSingleNetCDF3Dask(IOReadSingleNetCDF4Dask): + def setup(self): + # TODO: Lazily skipped in CI as it is very demanding and slow. + # Improve times and remove errors. + _skip_slow() + + requires_dask() + + self.make_ds() + + self.filepath = "test_single_file.nc3.nc" + self.format = "NETCDF3_64BIT" + self.ds.to_netcdf(self.filepath, format=self.format) + + def time_load_dataset_scipy_with_block_chunks(self): + with dask.config.set(scheduler="multiprocessing"): + xr.open_dataset( + self.filepath, engine="scipy", chunks=self.block_chunks + ).load() + + def time_load_dataset_scipy_with_block_chunks_oindexing(self): + ds = xr.open_dataset(self.filepath, engine="scipy", chunks=self.block_chunks) + ds = ds.isel(**self.oinds).load() + + def time_load_dataset_scipy_with_block_chunks_vindexing(self): + ds = xr.open_dataset(self.filepath, engine="scipy", chunks=self.block_chunks) + ds = ds.isel(**self.vinds).load() + + def time_load_dataset_scipy_with_time_chunks(self): + with dask.config.set(scheduler="multiprocessing"): + xr.open_dataset( + self.filepath, engine="scipy", chunks=self.time_chunks + ).load() + + +class IOMultipleNetCDF: + """ + A few examples that benchmark reading/writing multiple netCDF files with + xarray + """ + + timeout = 300.0 + repeat = 1 + number = 5 + + def make_ds(self, nfiles=10): + # multiple Dataset + self.ds = xr.Dataset() + self.nt = 1000 + self.nx = 90 + self.ny = 45 + self.nfiles = nfiles + + self.block_chunks = { + "time": self.nt / 4, + "lon": self.nx / 3, + "lat": self.ny / 3, + } + + self.time_chunks = {"time": int(self.nt / 36)} + + self.time_vars = np.split( + pd.date_range("1970-01-01", periods=self.nt, freq="D"), self.nfiles + ) + + self.ds_list = [] + self.filenames_list = [] + for i, times in enumerate(self.time_vars): + ds = xr.Dataset() + nt = len(times) + lons = xr.DataArray( + np.linspace(0, 360, self.nx), + dims=("lon",), + attrs={"units": "degrees east", "long_name": "longitude"}, + ) + lats = xr.DataArray( + np.linspace(-90, 90, self.ny), + dims=("lat",), + attrs={"units": "degrees north", "long_name": "latitude"}, + ) + ds["foo"] = xr.DataArray( + randn((nt, self.nx, self.ny), frac_nan=0.2), + coords={"lon": lons, "lat": lats, "time": times}, + dims=("time", "lon", "lat"), + name="foo", + attrs={"units": "foo units", "description": "a description"}, + ) + ds["bar"] = xr.DataArray( + randn((nt, self.nx, self.ny), frac_nan=0.2), + coords={"lon": lons, "lat": lats, "time": times}, + dims=("time", "lon", "lat"), + name="bar", + attrs={"units": "bar units", "description": "a description"}, + ) + ds["baz"] = xr.DataArray( + randn((self.nx, self.ny), frac_nan=0.2).astype(np.float32), + coords={"lon": lons, "lat": lats}, + dims=("lon", "lat"), + name="baz", + attrs={"units": "baz units", "description": "a description"}, + ) + + ds.attrs = {"history": "created for xarray benchmarking"} + + self.ds_list.append(ds) + self.filenames_list.append("test_netcdf_%i.nc" % i) + + +class IOWriteMultipleNetCDF3(IOMultipleNetCDF): + def setup(self): + # TODO: Lazily skipped in CI as it is very demanding and slow. + # Improve times and remove errors. + _skip_slow() + + self.make_ds() + self.format = "NETCDF3_64BIT" + + def time_write_dataset_netcdf4(self): + xr.save_mfdataset( + self.ds_list, self.filenames_list, engine="netcdf4", format=self.format + ) + + def time_write_dataset_scipy(self): + xr.save_mfdataset( + self.ds_list, self.filenames_list, engine="scipy", format=self.format + ) + + +class IOReadMultipleNetCDF4(IOMultipleNetCDF): + def setup(self): + # TODO: Lazily skipped in CI as it is very demanding and slow. + # Improve times and remove errors. + _skip_slow() + + requires_dask() + + self.make_ds() + self.format = "NETCDF4" + xr.save_mfdataset(self.ds_list, self.filenames_list, format=self.format) + + def time_load_dataset_netcdf4(self): + xr.open_mfdataset(self.filenames_list, engine="netcdf4").load() + + def time_open_dataset_netcdf4(self): + xr.open_mfdataset(self.filenames_list, engine="netcdf4") + + +class IOReadMultipleNetCDF3(IOReadMultipleNetCDF4): + def setup(self): + # TODO: Lazily skipped in CI as it is very demanding and slow. + # Improve times and remove errors. + _skip_slow() + + requires_dask() + + self.make_ds() + self.format = "NETCDF3_64BIT" + xr.save_mfdataset(self.ds_list, self.filenames_list, format=self.format) + + def time_load_dataset_scipy(self): + xr.open_mfdataset(self.filenames_list, engine="scipy").load() + + def time_open_dataset_scipy(self): + xr.open_mfdataset(self.filenames_list, engine="scipy") + + +class IOReadMultipleNetCDF4Dask(IOMultipleNetCDF): + def setup(self): + # TODO: Lazily skipped in CI as it is very demanding and slow. + # Improve times and remove errors. + _skip_slow() + + requires_dask() + + self.make_ds() + self.format = "NETCDF4" + xr.save_mfdataset(self.ds_list, self.filenames_list, format=self.format) + + def time_load_dataset_netcdf4_with_block_chunks(self): + xr.open_mfdataset( + self.filenames_list, engine="netcdf4", chunks=self.block_chunks + ).load() + + def time_load_dataset_netcdf4_with_block_chunks_multiprocessing(self): + with dask.config.set(scheduler="multiprocessing"): + xr.open_mfdataset( + self.filenames_list, engine="netcdf4", chunks=self.block_chunks + ).load() + + def time_load_dataset_netcdf4_with_time_chunks(self): + xr.open_mfdataset( + self.filenames_list, engine="netcdf4", chunks=self.time_chunks + ).load() + + def time_load_dataset_netcdf4_with_time_chunks_multiprocessing(self): + with dask.config.set(scheduler="multiprocessing"): + xr.open_mfdataset( + self.filenames_list, engine="netcdf4", chunks=self.time_chunks + ).load() + + def time_open_dataset_netcdf4_with_block_chunks(self): + xr.open_mfdataset( + self.filenames_list, engine="netcdf4", chunks=self.block_chunks + ) + + def time_open_dataset_netcdf4_with_block_chunks_multiprocessing(self): + with dask.config.set(scheduler="multiprocessing"): + xr.open_mfdataset( + self.filenames_list, engine="netcdf4", chunks=self.block_chunks + ) + + def time_open_dataset_netcdf4_with_time_chunks(self): + xr.open_mfdataset( + self.filenames_list, engine="netcdf4", chunks=self.time_chunks + ) + + def time_open_dataset_netcdf4_with_time_chunks_multiprocessing(self): + with dask.config.set(scheduler="multiprocessing"): + xr.open_mfdataset( + self.filenames_list, engine="netcdf4", chunks=self.time_chunks + ) + + +class IOReadMultipleNetCDF3Dask(IOReadMultipleNetCDF4Dask): + def setup(self): + # TODO: Lazily skipped in CI as it is very demanding and slow. + # Improve times and remove errors. + _skip_slow() + + requires_dask() + + self.make_ds() + self.format = "NETCDF3_64BIT" + xr.save_mfdataset(self.ds_list, self.filenames_list, format=self.format) + + def time_load_dataset_scipy_with_block_chunks(self): + with dask.config.set(scheduler="multiprocessing"): + xr.open_mfdataset( + self.filenames_list, engine="scipy", chunks=self.block_chunks + ).load() + + def time_load_dataset_scipy_with_time_chunks(self): + with dask.config.set(scheduler="multiprocessing"): + xr.open_mfdataset( + self.filenames_list, engine="scipy", chunks=self.time_chunks + ).load() + + def time_open_dataset_scipy_with_block_chunks(self): + with dask.config.set(scheduler="multiprocessing"): + xr.open_mfdataset( + self.filenames_list, engine="scipy", chunks=self.block_chunks + ) + + def time_open_dataset_scipy_with_time_chunks(self): + with dask.config.set(scheduler="multiprocessing"): + xr.open_mfdataset( + self.filenames_list, engine="scipy", chunks=self.time_chunks + ) + + +def create_delayed_write(): + import dask.array as da + + vals = da.random.random(300, chunks=(1,)) + ds = xr.Dataset({"vals": (["a"], vals)}) + return ds.to_netcdf("file.nc", engine="netcdf4", compute=False) + + +class IOWriteNetCDFDask: + timeout = 60 + repeat = 1 + number = 5 + + def setup(self): + # TODO: Lazily skipped in CI as it is very demanding and slow. + # Improve times and remove errors. + _skip_slow() + + requires_dask() + + self.write = create_delayed_write() + + def time_write(self): + self.write.compute() + + +class IOWriteNetCDFDaskDistributed: + def setup(self): + # TODO: Lazily skipped in CI as it is very demanding and slow. + # Improve times and remove errors. + _skip_slow() + + requires_dask() + + try: + import distributed + except ImportError: + raise NotImplementedError() + + self.client = distributed.Client() + self.write = create_delayed_write() + + def cleanup(self): + self.client.shutdown() + + def time_write(self): + self.write.compute() + + +class IOReadSingleFile(IOSingleNetCDF): + def setup(self, *args, **kwargs): + self.make_ds() + + self.filepaths = {} + for engine in _ENGINES: + self.filepaths[engine] = f"test_single_file_with_{engine}.nc" + self.ds.to_netcdf(self.filepaths[engine], engine=engine) + + @parameterized(["engine", "chunks"], (_ENGINES, [None, {}])) + def time_read_dataset(self, engine, chunks): + xr.open_dataset(self.filepaths[engine], engine=engine, chunks=chunks) + + +class IOReadCustomEngine: + def setup(self, *args, **kwargs): + """ + The custom backend does the bare minimum to be considered a lazy backend. But + the data in it is still in memory so slow file reading shouldn't affect the + results. + """ + requires_dask() + + @dataclass + class PerformanceBackendArray(xr.backends.BackendArray): + filename_or_obj: str | os.PathLike | None + shape: tuple[int, ...] + dtype: np.dtype + lock: xr.backends.locks.SerializableLock + + def __getitem__(self, key: tuple): + return xr.core.indexing.explicit_indexing_adapter( + key, + self.shape, + xr.core.indexing.IndexingSupport.BASIC, + self._raw_indexing_method, + ) + + def _raw_indexing_method(self, key: tuple): + raise NotImplementedError + + @dataclass + class PerformanceStore(xr.backends.common.AbstractWritableDataStore): + manager: xr.backends.CachingFileManager + mode: str | None = None + lock: xr.backends.locks.SerializableLock | None = None + autoclose: bool = False + + def __post_init__(self): + self.filename = self.manager._args[0] + + @classmethod + def open( + cls, + filename: str | os.PathLike | None, + mode: str = "r", + lock: xr.backends.locks.SerializableLock | None = None, + autoclose: bool = False, + ): + if lock is None: + if mode == "r": + locker = xr.backends.locks.SerializableLock() + else: + locker = xr.backends.locks.SerializableLock() + else: + locker = lock + + manager = xr.backends.CachingFileManager( + xr.backends.DummyFileManager, + filename, + mode=mode, + ) + return cls(manager, mode=mode, lock=locker, autoclose=autoclose) + + def load(self) -> tuple: + """ + Load a bunch of test data quickly. + + Normally this method would've opened a file and parsed it. + """ + n_variables = 2000 + + # Important to have a shape and dtype for lazy loading. + shape = (1000,) + dtype = np.dtype(int) + variables = { + f"long_variable_name_{v}": xr.Variable( + data=PerformanceBackendArray( + self.filename, shape, dtype, self.lock + ), + dims=("time",), + fastpath=True, + ) + for v in range(0, n_variables) + } + attributes = {} + + return variables, attributes + + class PerformanceBackend(xr.backends.BackendEntrypoint): + def open_dataset( + self, + filename_or_obj: str | os.PathLike | None, + drop_variables: tuple[str] = None, + *, + mask_and_scale=True, + decode_times=True, + concat_characters=True, + decode_coords=True, + use_cftime=None, + decode_timedelta=None, + lock=None, + **kwargs, + ) -> xr.Dataset: + filename_or_obj = xr.backends.common._normalize_path(filename_or_obj) + store = PerformanceStore.open(filename_or_obj, lock=lock) + + store_entrypoint = xr.backends.store.StoreBackendEntrypoint() + + ds = store_entrypoint.open_dataset( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + return ds + + self.engine = PerformanceBackend + + @parameterized(["chunks"], ([None, {}, {"time": 10}])) + def time_open_dataset(self, chunks): + """ + Time how fast xr.open_dataset is without the slow data reading part. + Test with and without dask. + """ + xr.open_dataset(None, engine=self.engine, chunks=chunks) diff --git a/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/groupby.py b/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/groupby.py new file mode 100644 index 0000000..065c1b3 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/groupby.py @@ -0,0 +1,176 @@ +# import flox to avoid the cost of first import +import flox.xarray # noqa +import numpy as np +import pandas as pd + +import xarray as xr + +from . import _skip_slow, parameterized, requires_dask + + +class GroupBy: + def setup(self, *args, **kwargs): + self.n = 100 + self.ds1d = xr.Dataset( + { + "a": xr.DataArray(np.r_[np.repeat(1, self.n), np.repeat(2, self.n)]), + "b": xr.DataArray(np.arange(2 * self.n)), + "c": xr.DataArray(np.arange(2 * self.n)), + } + ) + self.ds2d = self.ds1d.expand_dims(z=10).copy() + self.ds1d_mean = self.ds1d.groupby("b").mean() + self.ds2d_mean = self.ds2d.groupby("b").mean() + + @parameterized(["ndim"], [(1, 2)]) + def time_init(self, ndim): + getattr(self, f"ds{ndim}d").groupby("b") + + @parameterized( + ["method", "ndim", "use_flox"], [("sum", "mean"), (1, 2), (True, False)] + ) + def time_agg_small_num_groups(self, method, ndim, use_flox): + ds = getattr(self, f"ds{ndim}d") + with xr.set_options(use_flox=use_flox): + getattr(ds.groupby("a"), method)().compute() + + @parameterized( + ["method", "ndim", "use_flox"], [("sum", "mean"), (1, 2), (True, False)] + ) + def time_agg_large_num_groups(self, method, ndim, use_flox): + ds = getattr(self, f"ds{ndim}d") + with xr.set_options(use_flox=use_flox): + getattr(ds.groupby("b"), method)().compute() + + def time_binary_op_1d(self): + (self.ds1d.groupby("b") - self.ds1d_mean).compute() + + def time_binary_op_2d(self): + (self.ds2d.groupby("b") - self.ds2d_mean).compute() + + def peakmem_binary_op_1d(self): + (self.ds1d.groupby("b") - self.ds1d_mean).compute() + + def peakmem_binary_op_2d(self): + (self.ds2d.groupby("b") - self.ds2d_mean).compute() + + +class GroupByDask(GroupBy): + def setup(self, *args, **kwargs): + requires_dask() + super().setup(**kwargs) + + self.ds1d = self.ds1d.sel(dim_0=slice(None, None, 2)) + self.ds1d["c"] = self.ds1d["c"].chunk({"dim_0": 50}) + self.ds2d = self.ds2d.sel(dim_0=slice(None, None, 2)) + self.ds2d["c"] = self.ds2d["c"].chunk({"dim_0": 50, "z": 5}) + self.ds1d_mean = self.ds1d.groupby("b").mean().compute() + self.ds2d_mean = self.ds2d.groupby("b").mean().compute() + + +# TODO: These don't work now because we are calling `.compute` explicitly. +class GroupByPandasDataFrame(GroupBy): + """Run groupby tests using pandas DataFrame.""" + + def setup(self, *args, **kwargs): + # Skip testing in CI as it won't ever change in a commit: + _skip_slow() + + super().setup(**kwargs) + self.ds1d = self.ds1d.to_dataframe() + self.ds1d_mean = self.ds1d.groupby("b").mean() + + def time_binary_op_2d(self): + raise NotImplementedError + + def peakmem_binary_op_2d(self): + raise NotImplementedError + + +class GroupByDaskDataFrame(GroupBy): + """Run groupby tests using dask DataFrame.""" + + def setup(self, *args, **kwargs): + # Skip testing in CI as it won't ever change in a commit: + _skip_slow() + + requires_dask() + super().setup(**kwargs) + self.ds1d = self.ds1d.chunk({"dim_0": 50}).to_dataframe() + self.ds1d_mean = self.ds1d.groupby("b").mean().compute() + + def time_binary_op_2d(self): + raise NotImplementedError + + def peakmem_binary_op_2d(self): + raise NotImplementedError + + +class Resample: + def setup(self, *args, **kwargs): + self.ds1d = xr.Dataset( + { + "b": ("time", np.arange(365.0 * 24)), + }, + coords={"time": pd.date_range("2001-01-01", freq="h", periods=365 * 24)}, + ) + self.ds2d = self.ds1d.expand_dims(z=10) + self.ds1d_mean = self.ds1d.resample(time="48h").mean() + self.ds2d_mean = self.ds2d.resample(time="48h").mean() + + @parameterized(["ndim"], [(1, 2)]) + def time_init(self, ndim): + getattr(self, f"ds{ndim}d").resample(time="D") + + @parameterized( + ["method", "ndim", "use_flox"], [("sum", "mean"), (1, 2), (True, False)] + ) + def time_agg_small_num_groups(self, method, ndim, use_flox): + ds = getattr(self, f"ds{ndim}d") + with xr.set_options(use_flox=use_flox): + getattr(ds.resample(time="3ME"), method)().compute() + + @parameterized( + ["method", "ndim", "use_flox"], [("sum", "mean"), (1, 2), (True, False)] + ) + def time_agg_large_num_groups(self, method, ndim, use_flox): + ds = getattr(self, f"ds{ndim}d") + with xr.set_options(use_flox=use_flox): + getattr(ds.resample(time="48h"), method)().compute() + + +class ResampleDask(Resample): + def setup(self, *args, **kwargs): + requires_dask() + super().setup(**kwargs) + self.ds1d = self.ds1d.chunk({"time": 50}) + self.ds2d = self.ds2d.chunk({"time": 50, "z": 4}) + + +class ResampleCFTime(Resample): + def setup(self, *args, **kwargs): + self.ds1d = xr.Dataset( + { + "b": ("time", np.arange(365.0 * 24)), + }, + coords={ + "time": xr.date_range( + "2001-01-01", freq="h", periods=365 * 24, calendar="noleap" + ) + }, + ) + self.ds2d = self.ds1d.expand_dims(z=10) + self.ds1d_mean = self.ds1d.resample(time="48h").mean() + self.ds2d_mean = self.ds2d.resample(time="48h").mean() + + +@parameterized(["use_cftime", "use_flox"], [[True, False], [True, False]]) +class GroupByLongTime: + def setup(self, use_cftime, use_flox): + arr = np.random.randn(10, 10, 365 * 30) + time = xr.date_range("2000", periods=30 * 365, use_cftime=use_cftime) + self.da = xr.DataArray(arr, dims=("y", "x", "time"), coords={"time": time}) + + def time_mean(self, use_cftime, use_flox): + with xr.set_options(use_flox=use_flox): + self.da.groupby("time.year").mean() diff --git a/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/import.py b/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/import.py new file mode 100644 index 0000000..f9d0bcc --- /dev/null +++ b/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/import.py @@ -0,0 +1,18 @@ +class Import: + """Benchmark importing xarray""" + + def timeraw_import_xarray(self): + return "import xarray" + + def timeraw_import_xarray_plot(self): + return "import xarray.plot" + + def timeraw_import_xarray_backends(self): + return """ + from xarray.backends import list_engines + list_engines() + """ + + def timeraw_import_xarray_only(self): + # import numpy and pandas in the setup stage + return "import xarray", "import numpy, pandas" diff --git a/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/indexing.py b/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/indexing.py new file mode 100644 index 0000000..529d023 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/indexing.py @@ -0,0 +1,175 @@ +import os + +import numpy as np +import pandas as pd + +import xarray as xr + +from . import parameterized, randint, randn, requires_dask + +nx = 2000 +ny = 1000 +nt = 500 + +basic_indexes = { + "1scalar": {"x": 0}, + "1slice": {"x": slice(0, 3)}, + "1slice-1scalar": {"x": 0, "y": slice(None, None, 3)}, + "2slicess-1scalar": {"x": slice(3, -3, 3), "y": 1, "t": slice(None, -3, 3)}, +} + +basic_assignment_values = { + "1scalar": 0, + "1slice": xr.DataArray(randn((3, ny), frac_nan=0.1), dims=["x", "y"]), + "1slice-1scalar": xr.DataArray(randn(int(ny / 3) + 1, frac_nan=0.1), dims=["y"]), + "2slicess-1scalar": xr.DataArray( + randn(np.empty(nx)[slice(3, -3, 3)].size, frac_nan=0.1), dims=["x"] + ), +} + +outer_indexes = { + "1d": {"x": randint(0, nx, 400)}, + "2d": {"x": randint(0, nx, 500), "y": randint(0, ny, 400)}, + "2d-1scalar": {"x": randint(0, nx, 100), "y": 1, "t": randint(0, nt, 400)}, +} + +outer_assignment_values = { + "1d": xr.DataArray(randn((400, ny), frac_nan=0.1), dims=["x", "y"]), + "2d": xr.DataArray(randn((500, 400), frac_nan=0.1), dims=["x", "y"]), + "2d-1scalar": xr.DataArray(randn(100, frac_nan=0.1), dims=["x"]), +} + +vectorized_indexes = { + "1-1d": {"x": xr.DataArray(randint(0, nx, 400), dims="a")}, + "2-1d": { + "x": xr.DataArray(randint(0, nx, 400), dims="a"), + "y": xr.DataArray(randint(0, ny, 400), dims="a"), + }, + "3-2d": { + "x": xr.DataArray(randint(0, nx, 400).reshape(4, 100), dims=["a", "b"]), + "y": xr.DataArray(randint(0, ny, 400).reshape(4, 100), dims=["a", "b"]), + "t": xr.DataArray(randint(0, nt, 400).reshape(4, 100), dims=["a", "b"]), + }, +} + +vectorized_assignment_values = { + "1-1d": xr.DataArray(randn((400, ny)), dims=["a", "y"], coords={"a": randn(400)}), + "2-1d": xr.DataArray(randn(400), dims=["a"], coords={"a": randn(400)}), + "3-2d": xr.DataArray( + randn((4, 100)), dims=["a", "b"], coords={"a": randn(4), "b": randn(100)} + ), +} + + +class Base: + def setup(self, key): + self.ds = xr.Dataset( + { + "var1": (("x", "y"), randn((nx, ny), frac_nan=0.1)), + "var2": (("x", "t"), randn((nx, nt))), + "var3": (("t",), randn(nt)), + }, + coords={ + "x": np.arange(nx), + "y": np.linspace(0, 1, ny), + "t": pd.date_range("1970-01-01", periods=nt, freq="D"), + "x_coords": ("x", np.linspace(1.1, 2.1, nx)), + }, + ) + # Benchmark how indexing is slowed down by adding many scalar variable + # to the dataset + # https://github.com/pydata/xarray/pull/9003 + self.ds_large = self.ds.merge({f"extra_var{i}": i for i in range(400)}) + + +class Indexing(Base): + @parameterized(["key"], [list(basic_indexes.keys())]) + def time_indexing_basic(self, key): + self.ds.isel(**basic_indexes[key]).load() + + @parameterized(["key"], [list(outer_indexes.keys())]) + def time_indexing_outer(self, key): + self.ds.isel(**outer_indexes[key]).load() + + @parameterized(["key"], [list(vectorized_indexes.keys())]) + def time_indexing_vectorized(self, key): + self.ds.isel(**vectorized_indexes[key]).load() + + @parameterized(["key"], [list(basic_indexes.keys())]) + def time_indexing_basic_ds_large(self, key): + # https://github.com/pydata/xarray/pull/9003 + self.ds_large.isel(**basic_indexes[key]).load() + + +class Assignment(Base): + @parameterized(["key"], [list(basic_indexes.keys())]) + def time_assignment_basic(self, key): + ind = basic_indexes[key] + val = basic_assignment_values[key] + self.ds["var1"][ind.get("x", slice(None)), ind.get("y", slice(None))] = val + + @parameterized(["key"], [list(outer_indexes.keys())]) + def time_assignment_outer(self, key): + ind = outer_indexes[key] + val = outer_assignment_values[key] + self.ds["var1"][ind.get("x", slice(None)), ind.get("y", slice(None))] = val + + @parameterized(["key"], [list(vectorized_indexes.keys())]) + def time_assignment_vectorized(self, key): + ind = vectorized_indexes[key] + val = vectorized_assignment_values[key] + self.ds["var1"][ind.get("x", slice(None)), ind.get("y", slice(None))] = val + + +class IndexingDask(Indexing): + def setup(self, key): + requires_dask() + super().setup(key) + self.ds = self.ds.chunk({"x": 100, "y": 50, "t": 50}) + + +class BooleanIndexing: + # https://github.com/pydata/xarray/issues/2227 + def setup(self): + self.ds = xr.Dataset( + {"a": ("time", np.arange(10_000_000))}, + coords={"time": np.arange(10_000_000)}, + ) + self.time_filter = self.ds.time > 50_000 + + def time_indexing(self): + self.ds.isel(time=self.time_filter) + + +class HugeAxisSmallSliceIndexing: + # https://github.com/pydata/xarray/pull/4560 + def setup(self): + self.filepath = "test_indexing_huge_axis_small_slice.nc" + if not os.path.isfile(self.filepath): + xr.Dataset( + {"a": ("x", np.arange(10_000_000))}, + coords={"x": np.arange(10_000_000)}, + ).to_netcdf(self.filepath, format="NETCDF4") + + self.ds = xr.open_dataset(self.filepath) + + def time_indexing(self): + self.ds.isel(x=slice(100)) + + def cleanup(self): + self.ds.close() + + +class AssignmentOptimized: + # https://github.com/pydata/xarray/pull/7382 + def setup(self): + self.ds = xr.Dataset(coords={"x": np.arange(500_000)}) + self.da = xr.DataArray(np.arange(500_000), dims="x") + + def time_assign_no_reindex(self): + # assign with non-indexed DataArray of same dimension size + self.ds.assign(foo=self.da) + + def time_assign_identical_indexes(self): + # fastpath index comparison (same index object) + self.ds.assign(foo=self.ds.x) diff --git a/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/interp.py b/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/interp.py new file mode 100644 index 0000000..4b6691b --- /dev/null +++ b/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/interp.py @@ -0,0 +1,51 @@ +import numpy as np +import pandas as pd + +import xarray as xr + +from . import parameterized, randn, requires_dask + +nx = 1500 +ny = 1000 +nt = 500 + +randn_xy = randn((nx, ny), frac_nan=0.1) +randn_xt = randn((nx, nt)) +randn_t = randn((nt,)) + +new_x_short = np.linspace(0.3 * nx, 0.7 * nx, 100) +new_x_long = np.linspace(0.3 * nx, 0.7 * nx, 500) +new_y_long = np.linspace(0.1, 0.9, 500) + + +class Interpolation: + def setup(self, *args, **kwargs): + self.ds = xr.Dataset( + { + "var1": (("x", "y"), randn_xy), + "var2": (("x", "t"), randn_xt), + "var3": (("t",), randn_t), + }, + coords={ + "x": np.arange(nx), + "y": np.linspace(0, 1, ny), + "t": pd.date_range("1970-01-01", periods=nt, freq="D"), + "x_coords": ("x", np.linspace(1.1, 2.1, nx)), + }, + ) + + @parameterized(["method", "is_short"], (["linear", "cubic"], [True, False])) + def time_interpolation(self, method, is_short): + new_x = new_x_short if is_short else new_x_long + self.ds.interp(x=new_x, method=method).load() + + @parameterized(["method"], (["linear", "nearest"])) + def time_interpolation_2d(self, method): + self.ds.interp(x=new_x_long, y=new_y_long, method=method).load() + + +class InterpolationDask(Interpolation): + def setup(self, *args, **kwargs): + requires_dask() + super().setup(**kwargs) + self.ds = self.ds.chunk({"t": 50}) diff --git a/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/merge.py b/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/merge.py new file mode 100644 index 0000000..6c8c1e9 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/merge.py @@ -0,0 +1,77 @@ +import numpy as np + +import xarray as xr + + +class DatasetAddVariable: + param_names = ["existing_elements"] + params = [[0, 10, 100, 1000]] + + def setup(self, existing_elements): + self.datasets = {} + # Dictionary insertion is fast(er) than xarray.Dataset insertion + d = {} + for i in range(existing_elements): + d[f"var{i}"] = i + self.dataset = xr.merge([d]) + + d = {f"set_2_{i}": i for i in range(existing_elements)} + self.dataset2 = xr.merge([d]) + + def time_variable_insertion(self, existing_elements): + dataset = self.dataset + dataset["new_var"] = 0 + + def time_merge_two_datasets(self, existing_elements): + xr.merge([self.dataset, self.dataset2]) + + +class DatasetCreation: + # The idea here is to time how long it takes to go from numpy + # and python data types, to a full dataset + # See discussion + # https://github.com/pydata/xarray/issues/7224#issuecomment-1292216344 + param_names = ["strategy", "count"] + params = [ + ["dict_of_DataArrays", "dict_of_Variables", "dict_of_Tuples"], + [0, 1, 10, 100, 1000], + ] + + def setup(self, strategy, count): + data = np.array(["0", "b"], dtype=str) + self.dataset_coords = dict(time=np.array([0, 1])) + self.dataset_attrs = dict(description="Test data") + attrs = dict(units="Celsius") + if strategy == "dict_of_DataArrays": + + def create_data_vars(): + return { + f"long_variable_name_{i}": xr.DataArray( + data=data, dims=("time"), attrs=attrs + ) + for i in range(count) + } + + elif strategy == "dict_of_Variables": + + def create_data_vars(): + return { + f"long_variable_name_{i}": xr.Variable("time", data, attrs=attrs) + for i in range(count) + } + + elif strategy == "dict_of_Tuples": + + def create_data_vars(): + return { + f"long_variable_name_{i}": ("time", data, attrs) + for i in range(count) + } + + self.create_data_vars = create_data_vars + + def time_dataset_creation(self, strategy, count): + data_vars = self.create_data_vars() + xr.Dataset( + data_vars=data_vars, coords=self.dataset_coords, attrs=self.dataset_attrs + ) diff --git a/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/pandas.py b/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/pandas.py new file mode 100644 index 0000000..ebe6108 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/pandas.py @@ -0,0 +1,64 @@ +import numpy as np +import pandas as pd + +import xarray as xr + +from . import parameterized, requires_dask + + +class MultiIndexSeries: + def setup(self, dtype, subset): + data = np.random.rand(100000).astype(dtype) + index = pd.MultiIndex.from_product( + [ + list("abcdefhijk"), + list("abcdefhijk"), + pd.date_range(start="2000-01-01", periods=1000, freq="D"), + ] + ) + series = pd.Series(data, index) + if subset: + series = series[::3] + self.series = series + + @parameterized(["dtype", "subset"], ([int, float], [True, False])) + def time_from_series(self, dtype, subset): + xr.DataArray.from_series(self.series) + + +class ToDataFrame: + def setup(self, *args, **kwargs): + xp = kwargs.get("xp", np) + nvars = kwargs.get("nvars", 1) + random_kws = kwargs.get("random_kws", {}) + method = kwargs.get("method", "to_dataframe") + + dim1 = 10_000 + dim2 = 10_000 + + var = xr.Variable( + dims=("dim1", "dim2"), data=xp.random.random((dim1, dim2), **random_kws) + ) + data_vars = {f"long_name_{v}": (("dim1", "dim2"), var) for v in range(nvars)} + + ds = xr.Dataset( + data_vars, coords={"dim1": np.arange(0, dim1), "dim2": np.arange(0, dim2)} + ) + self.to_frame = getattr(ds, method) + + def time_to_dataframe(self): + self.to_frame() + + def peakmem_to_dataframe(self): + self.to_frame() + + +class ToDataFrameDask(ToDataFrame): + def setup(self, *args, **kwargs): + requires_dask() + + import dask.array as da + + super().setup( + xp=da, random_kws=dict(chunks=5000), method="to_dask_dataframe", nvars=500 + ) diff --git a/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/polyfit.py b/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/polyfit.py new file mode 100644 index 0000000..429ffa1 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/polyfit.py @@ -0,0 +1,38 @@ +import numpy as np + +import xarray as xr + +from . import parameterized, randn, requires_dask + +NDEGS = (2, 5, 20) +NX = (10**2, 10**6) + + +class Polyval: + def setup(self, *args, **kwargs): + self.xs = {nx: xr.DataArray(randn((nx,)), dims="x", name="x") for nx in NX} + self.coeffs = { + ndeg: xr.DataArray( + randn((ndeg,)), dims="degree", coords={"degree": np.arange(ndeg)} + ) + for ndeg in NDEGS + } + + @parameterized(["nx", "ndeg"], [NX, NDEGS]) + def time_polyval(self, nx, ndeg): + x = self.xs[nx] + c = self.coeffs[ndeg] + xr.polyval(x, c).compute() + + @parameterized(["nx", "ndeg"], [NX, NDEGS]) + def peakmem_polyval(self, nx, ndeg): + x = self.xs[nx] + c = self.coeffs[ndeg] + xr.polyval(x, c).compute() + + +class PolyvalDask(Polyval): + def setup(self, *args, **kwargs): + requires_dask() + super().setup(*args, **kwargs) + self.xs = {k: v.chunk({"x": 10000}) for k, v in self.xs.items()} diff --git a/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/reindexing.py b/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/reindexing.py new file mode 100644 index 0000000..9d0767f --- /dev/null +++ b/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/reindexing.py @@ -0,0 +1,52 @@ +import numpy as np + +import xarray as xr + +from . import requires_dask + +ntime = 500 +nx = 50 +ny = 50 + + +class Reindex: + def setup(self): + data = np.random.RandomState(0).randn(ntime, nx, ny) + self.ds = xr.Dataset( + {"temperature": (("time", "x", "y"), data)}, + coords={"time": np.arange(ntime), "x": np.arange(nx), "y": np.arange(ny)}, + ) + + def time_1d_coarse(self): + self.ds.reindex(time=np.arange(0, ntime, 5)).load() + + def time_1d_fine_all_found(self): + self.ds.reindex(time=np.arange(0, ntime, 0.5), method="nearest").load() + + def time_1d_fine_some_missing(self): + self.ds.reindex( + time=np.arange(0, ntime, 0.5), method="nearest", tolerance=0.1 + ).load() + + def time_2d_coarse(self): + self.ds.reindex(x=np.arange(0, nx, 2), y=np.arange(0, ny, 2)).load() + + def time_2d_fine_all_found(self): + self.ds.reindex( + x=np.arange(0, nx, 0.5), y=np.arange(0, ny, 0.5), method="nearest" + ).load() + + def time_2d_fine_some_missing(self): + self.ds.reindex( + x=np.arange(0, nx, 0.5), + y=np.arange(0, ny, 0.5), + method="nearest", + tolerance=0.1, + ).load() + + +class ReindexDask(Reindex): + def setup(self): + requires_dask() + super().setup() + self.ds = self.ds.chunk({"time": 100}) diff --git a/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/renaming.py b/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/renaming.py new file mode 100644 index 0000000..3ade5d8 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/renaming.py @@ -0,0 +1,27 @@ +import numpy as np + +import xarray as xr + + +class SwapDims: + param_names = ["size"] + params = [[int(1e3), int(1e5), int(1e7)]] + + def setup(self, size: int) -> None: + self.ds = xr.Dataset( + {"a": (("x", "t"), np.ones((size, 2)))}, + coords={ + "x": np.arange(size), + "y": np.arange(size), + "z": np.arange(size), + "x2": ("x", np.arange(size)), + "y2": ("y", np.arange(size)), + "z2": ("z", np.arange(size)), + }, + ) + + def time_swap_dims(self, size: int) -> None: + self.ds.swap_dims({"x": "xn", "y": "yn", "z": "zn"}) + + def time_swap_dims_newindex(self, size: int) -> None: + self.ds.swap_dims({"x": "x2", "y": "y2", "z": "z2"}) diff --git a/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/repr.py b/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/repr.py new file mode 100644 index 0000000..4bf2ace --- /dev/null +++ b/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/repr.py @@ -0,0 +1,40 @@ +import numpy as np +import pandas as pd + +import xarray as xr + + +class Repr: + def setup(self): + a = np.arange(0, 100) + data_vars = dict() + for i in a: + data_vars[f"long_variable_name_{i}"] = xr.DataArray( + name=f"long_variable_name_{i}", + data=np.arange(0, 20), + dims=[f"long_coord_name_{i}_x"], + coords={f"long_coord_name_{i}_x": np.arange(0, 20) * 2}, + ) + self.ds = xr.Dataset(data_vars) + self.ds.attrs = {f"attr_{k}": 2 for k in a} + + def time_repr(self): + repr(self.ds) + + def time_repr_html(self): + self.ds._repr_html_() + + +class ReprMultiIndex: + def setup(self): + index = pd.MultiIndex.from_product( + [range(1000), range(1000)], names=("level_0", "level_1") + ) + series = pd.Series(range(1000 * 1000), index=index) + self.da = xr.DataArray(series) + + def time_repr(self): + repr(self.da) + + def time_repr_html(self): + self.da._repr_html_() diff --git a/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/rolling.py b/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/rolling.py new file mode 100644 index 0000000..579f4f0 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/rolling.py @@ -0,0 +1,139 @@ +import numpy as np +import pandas as pd + +import xarray as xr + +from . import parameterized, randn, requires_dask + +nx = 3000 +long_nx = 30000 +ny = 200 +nt = 1000 +window = 20 + +randn_xy = randn((nx, ny), frac_nan=0.1) +randn_xt = randn((nx, nt)) +randn_t = randn((nt,)) +randn_long = randn((long_nx,), frac_nan=0.1) + + +class Rolling: + def setup(self, *args, **kwargs): + self.ds = xr.Dataset( + { + "var1": (("x", "y"), randn_xy), + "var2": (("x", "t"), randn_xt), + "var3": (("t",), randn_t), + }, + coords={ + "x": np.arange(nx), + "y": np.linspace(0, 1, ny), + "t": pd.date_range("1970-01-01", periods=nt, freq="D"), + "x_coords": ("x", np.linspace(1.1, 2.1, nx)), + }, + ) + self.da_long = xr.DataArray( + randn_long, dims="x", coords={"x": np.arange(long_nx) * 0.1} + ) + + @parameterized( + ["func", "center", "use_bottleneck"], + (["mean", "count"], [True, False], [True, False]), + ) + def time_rolling(self, func, center, use_bottleneck): + with xr.set_options(use_bottleneck=use_bottleneck): + getattr(self.ds.rolling(x=window, center=center), func)().load() + + @parameterized( + ["func", "pandas", "use_bottleneck"], + (["mean", "count"], [True, False], [True, False]), + ) + def time_rolling_long(self, func, pandas, use_bottleneck): + if pandas: + se = self.da_long.to_series() + getattr(se.rolling(window=window, min_periods=window), func)() + else: + with xr.set_options(use_bottleneck=use_bottleneck): + getattr( + self.da_long.rolling(x=window, min_periods=window), func + )().load() + + @parameterized( + ["window_", "min_periods", "use_bottleneck"], ([20, 40], [5, 5], [True, False]) + ) + def time_rolling_np(self, window_, min_periods, use_bottleneck): + with xr.set_options(use_bottleneck=use_bottleneck): + self.ds.rolling(x=window_, center=False, min_periods=min_periods).reduce( + getattr(np, "nansum") + ).load() + + @parameterized( + ["center", "stride", "use_bottleneck"], ([True, False], [1, 1], [True, False]) + ) + def time_rolling_construct(self, center, stride, use_bottleneck): + with xr.set_options(use_bottleneck=use_bottleneck): + self.ds.rolling(x=window, center=center).construct( + "window_dim", stride=stride + ).sum(dim="window_dim").load() + + +class RollingDask(Rolling): + def setup(self, *args, **kwargs): + requires_dask() + super().setup(**kwargs) + self.ds = self.ds.chunk({"x": 100, "y": 50, "t": 50}) + self.da_long = self.da_long.chunk({"x": 10000}) + + +class RollingMemory: + def setup(self, *args, **kwargs): + self.ds = xr.Dataset( + { + "var1": (("x", "y"), randn_xy), + "var2": (("x", "t"), randn_xt), + "var3": (("t",), randn_t), + }, + coords={ + "x": np.arange(nx), + "y": np.linspace(0, 1, ny), + "t": pd.date_range("1970-01-01", periods=nt, freq="D"), + "x_coords": ("x", np.linspace(1.1, 2.1, nx)), + }, + ) + + +class DataArrayRollingMemory(RollingMemory): + @parameterized(["func", "use_bottleneck"], (["sum", "max", "mean"], [True, False])) + def peakmem_ndrolling_reduce(self, func, use_bottleneck): + with xr.set_options(use_bottleneck=use_bottleneck): + roll = self.ds.var1.rolling(x=10, y=4) + getattr(roll, func)() + + @parameterized(["func", "use_bottleneck"], (["sum", "max", "mean"], [True, False])) + def peakmem_1drolling_reduce(self, func, use_bottleneck): + with xr.set_options(use_bottleneck=use_bottleneck): + roll = self.ds.var3.rolling(t=100) + getattr(roll, func)() + + @parameterized(["stride"], ([None, 5, 50])) + def peakmem_1drolling_construct(self, stride): + self.ds.var2.rolling(t=100).construct("w", stride=stride) + self.ds.var3.rolling(t=100).construct("w", stride=stride) + + +class DatasetRollingMemory(RollingMemory): + @parameterized(["func", "use_bottleneck"], (["sum", "max", "mean"], [True, False])) + def peakmem_ndrolling_reduce(self, func, use_bottleneck): + with xr.set_options(use_bottleneck=use_bottleneck): + roll = self.ds.rolling(x=10, y=4) + getattr(roll, func)() + + @parameterized(["func", "use_bottleneck"], (["sum", "max", "mean"], [True, False])) + def peakmem_1drolling_reduce(self, func, use_bottleneck): + with xr.set_options(use_bottleneck=use_bottleneck): + roll = self.ds.rolling(t=100) + getattr(roll, func)() + + @parameterized(["stride"], ([None, 5, 50])) + def peakmem_1drolling_construct(self, stride): + self.ds.rolling(t=100).construct("w", stride=stride) diff --git a/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/unstacking.py b/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/unstacking.py new file mode 100644 index 0000000..dc8bc33 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/asv_bench/benchmarks/unstacking.py @@ -0,0 +1,64 @@ +import numpy as np +import pandas as pd + +import xarray as xr + +from . import requires_dask, requires_sparse + + +class Unstacking: + def setup(self): + data = np.random.RandomState(0).randn(250, 500) + self.da_full = xr.DataArray(data, dims=list("ab")).stack(flat_dim=[...]) + self.da_missing = self.da_full[:-1] + self.df_missing = self.da_missing.to_pandas() + + def time_unstack_fast(self): + self.da_full.unstack("flat_dim") + + def time_unstack_slow(self): + self.da_missing.unstack("flat_dim") + + def time_unstack_pandas_slow(self): + self.df_missing.unstack() + + +class UnstackingDask(Unstacking): + def setup(self, *args, **kwargs): + requires_dask() + super().setup(**kwargs) + self.da_full = self.da_full.chunk({"flat_dim": 25}) + + +class UnstackingSparse(Unstacking): + def setup(self, *args, **kwargs): + requires_sparse() + + import sparse + + data = sparse.random((500, 1000), random_state=0, fill_value=0) + self.da_full = xr.DataArray(data, dims=list("ab")).stack(flat_dim=[...]) + self.da_missing = self.da_full[:-1] + + mindex = pd.MultiIndex.from_arrays([np.arange(100), np.arange(100)]) + self.da_eye_2d = xr.DataArray(np.ones((100,)), dims="z", coords={"z": mindex}) + self.da_eye_3d = xr.DataArray( + np.ones((100, 50)), + dims=("z", "foo"), + coords={"z": mindex, "foo": np.arange(50)}, + ) + + def time_unstack_to_sparse_2d(self): + self.da_eye_2d.unstack(sparse=True) + + def time_unstack_to_sparse_3d(self): + self.da_eye_3d.unstack(sparse=True) + + def peakmem_unstack_to_sparse_2d(self): + self.da_eye_2d.unstack(sparse=True) + + def peakmem_unstack_to_sparse_3d(self): + self.da_eye_3d.unstack(sparse=True) + + def time_unstack_pandas_slow(self): + pass diff --git a/test/fixtures/whole_applications/xarray/ci/install-upstream-wheels.sh b/test/fixtures/whole_applications/xarray/ci/install-upstream-wheels.sh new file mode 100755 index 0000000..d728768 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/ci/install-upstream-wheels.sh @@ -0,0 +1,69 @@ +#!/usr/bin/env bash + +if which micromamba >/dev/null; then + conda=micromamba +elif which mamba >/dev/null; then + conda=mamba +else + conda=conda +fi + +# temporarily (?) remove numbagg and numba +$conda remove -y numba numbagg sparse +# temporarily remove numexpr +$conda remove -y numexpr +# temporarily remove backends +$conda remove -y cf_units hdf5 h5py netcdf4 pydap +# forcibly remove packages to avoid artifacts +$conda remove -y --force \ + numpy \ + scipy \ + pandas \ + distributed \ + fsspec \ + zarr \ + cftime \ + packaging \ + bottleneck \ + flox + # pint + +# to limit the runtime of Upstream CI +python -m pip install \ + -i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple \ + --no-deps \ + --pre \ + --upgrade \ + numpy \ + scipy \ + matplotlib \ + pandas \ + h5py +# for some reason pandas depends on pyarrow already. +# Remove once a `pyarrow` version compiled with `numpy>=2.0` is on `conda-forge` +python -m pip install \ + -i https://pypi.fury.io/arrow-nightlies/ \ + --prefer-binary \ + --no-deps \ + --pre \ + --upgrade \ + pyarrow +# manually install `pint` to pull in new dependencies +python -m pip install --upgrade pint +python -m pip install \ + --no-deps \ + --upgrade \ + git+https://github.com/dask/dask \ + git+https://github.com/dask/dask-expr \ + git+https://github.com/dask/distributed \ + git+https://github.com/zarr-developers/zarr.git@main \ + git+https://github.com/Unidata/cftime \ + git+https://github.com/pypa/packaging \ + git+https://github.com/hgrecco/pint \ + git+https://github.com/pydata/bottleneck \ + git+https://github.com/intake/filesystem_spec \ + git+https://github.com/SciTools/nc-time-axis \ + git+https://github.com/xarray-contrib/flox \ + git+https://github.com/h5netcdf/h5netcdf \ + git+https://github.com/dgasmith/opt_einsum + # git+https://github.com/pydata/sparse diff --git a/test/fixtures/whole_applications/xarray/ci/min_deps_check.py b/test/fixtures/whole_applications/xarray/ci/min_deps_check.py new file mode 100755 index 0000000..6981a69 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/ci/min_deps_check.py @@ -0,0 +1,218 @@ +#!/usr/bin/env python +"""Fetch from conda database all available versions of the xarray dependencies and their +publication date. Compare it against requirements/min-all-deps.yml to verify the +policy on obsolete dependencies is being followed. Print a pretty report :) +""" +from __future__ import annotations + +import itertools +import sys +from collections.abc import Iterator +from datetime import datetime + +import conda.api # type: ignore[import] +import yaml +from dateutil.relativedelta import relativedelta + +CHANNELS = ["conda-forge", "defaults"] +IGNORE_DEPS = { + "black", + "coveralls", + "flake8", + "hypothesis", + "isort", + "mypy", + "pip", + "pytest", + "pytest-cov", + "pytest-env", + "pytest-timeout", + "pytest-xdist", + "setuptools", +} + +POLICY_MONTHS = {"python": 30, "numpy": 18} +POLICY_MONTHS_DEFAULT = 12 +POLICY_OVERRIDE: dict[str, tuple[int, int]] = {} +errors = [] + + +def error(msg: str) -> None: + global errors + errors.append(msg) + print("ERROR:", msg) + + +def warning(msg: str) -> None: + print("WARNING:", msg) + + +def parse_requirements(fname) -> Iterator[tuple[str, int, int, int | None]]: + """Load requirements/min-all-deps.yml + + Yield (package name, major version, minor version, [patch version]) + """ + global errors + + with open(fname) as fh: + contents = yaml.safe_load(fh) + for row in contents["dependencies"]: + if isinstance(row, dict) and list(row) == ["pip"]: + continue + pkg, eq, version = row.partition("=") + if pkg.rstrip("<>") in IGNORE_DEPS: + continue + if pkg.endswith("<") or pkg.endswith(">") or eq != "=": + error("package should be pinned with exact version: " + row) + continue + + try: + version_tup = tuple(int(x) for x in version.split(".")) + except ValueError: + raise ValueError("non-numerical version: " + row) + + if len(version_tup) == 2: + yield (pkg, *version_tup, None) # type: ignore[misc] + elif len(version_tup) == 3: + yield (pkg, *version_tup) # type: ignore[misc] + else: + raise ValueError("expected major.minor or major.minor.patch: " + row) + + +def query_conda(pkg: str) -> dict[tuple[int, int], datetime]: + """Query the conda repository for a specific package + + Return map of {(major version, minor version): publication date} + """ + + def metadata(entry): + version = entry.version + + time = datetime.fromtimestamp(entry.timestamp) + major, minor = map(int, version.split(".")[:2]) + + return (major, minor), time + + raw_data = conda.api.SubdirData.query_all(pkg, channels=CHANNELS) + data = sorted(metadata(entry) for entry in raw_data if entry.timestamp != 0) + + release_dates = { + version: [time for _, time in group if time is not None] + for version, group in itertools.groupby(data, key=lambda x: x[0]) + } + out = {version: min(dates) for version, dates in release_dates.items() if dates} + + # Hardcoded fix to work around incorrect dates in conda + if pkg == "python": + out.update( + { + (2, 7): datetime(2010, 6, 3), + (3, 5): datetime(2015, 9, 13), + (3, 6): datetime(2016, 12, 23), + (3, 7): datetime(2018, 6, 27), + (3, 8): datetime(2019, 10, 14), + (3, 9): datetime(2020, 10, 5), + (3, 10): datetime(2021, 10, 4), + (3, 11): datetime(2022, 10, 24), + } + ) + + return out + + +def process_pkg( + pkg: str, req_major: int, req_minor: int, req_patch: int | None +) -> tuple[str, str, str, str, str, str]: + """Compare package version from requirements file to available versions in conda. + Return row to build pandas dataframe: + + - package name + - major.minor.[patch] version in requirements file + - publication date of version in requirements file (YYYY-MM-DD) + - major.minor version suggested by policy + - publication date of version suggested by policy (YYYY-MM-DD) + - status ("<", "=", "> (!)") + """ + print(f"Analyzing {pkg}...") + versions = query_conda(pkg) + + try: + req_published = versions[req_major, req_minor] + except KeyError: + error("not found in conda: " + pkg) + return pkg, fmt_version(req_major, req_minor, req_patch), "-", "-", "-", "(!)" + + policy_months = POLICY_MONTHS.get(pkg, POLICY_MONTHS_DEFAULT) + policy_published = datetime.now() - relativedelta(months=policy_months) + + filtered_versions = [ + version + for version, published in versions.items() + if published < policy_published + ] + policy_major, policy_minor = max(filtered_versions, default=(req_major, req_minor)) + + try: + policy_major, policy_minor = POLICY_OVERRIDE[pkg] + except KeyError: + pass + policy_published_actual = versions[policy_major, policy_minor] + + if (req_major, req_minor) < (policy_major, policy_minor): + status = "<" + elif (req_major, req_minor) > (policy_major, policy_minor): + status = "> (!)" + delta = relativedelta(datetime.now(), req_published).normalized() + n_months = delta.years * 12 + delta.months + warning( + f"Package is too new: {pkg}={req_major}.{req_minor} was " + f"published on {req_published:%Y-%m-%d} " + f"which was {n_months} months ago (policy is {policy_months} months)" + ) + else: + status = "=" + + if req_patch is not None: + warning("patch version should not appear in requirements file: " + pkg) + status += " (w)" + + return ( + pkg, + fmt_version(req_major, req_minor, req_patch), + req_published.strftime("%Y-%m-%d"), + fmt_version(policy_major, policy_minor), + policy_published_actual.strftime("%Y-%m-%d"), + status, + ) + + +def fmt_version(major: int, minor: int, patch: int = None) -> str: + if patch is None: + return f"{major}.{minor}" + else: + return f"{major}.{minor}.{patch}" + + +def main() -> None: + fname = sys.argv[1] + rows = [ + process_pkg(pkg, major, minor, patch) + for pkg, major, minor, patch in parse_requirements(fname) + ] + + print("\nPackage Required Policy Status") + print("----------------- -------------------- -------------------- ------") + fmt = "{:17} {:7} ({:10}) {:7} ({:10}) {}" + for row in rows: + print(fmt.format(*row)) + + if errors: + print("\nErrors:") + print("-------") + for i, e in enumerate(errors): + print(f"{i+1}. {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/test/fixtures/whole_applications/xarray/ci/requirements/all-but-dask.yml b/test/fixtures/whole_applications/xarray/ci/requirements/all-but-dask.yml new file mode 100644 index 0000000..2f47643 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/ci/requirements/all-but-dask.yml @@ -0,0 +1,42 @@ +name: xarray-tests +channels: + - conda-forge + - nodefaults +dependencies: + - black + - aiobotocore + - array-api-strict + - boto3 + - bottleneck + - cartopy + - cftime + - coveralls + - flox + - h5netcdf + - h5py + - hdf5 + - hypothesis + - lxml # Optional dep of pydap + - matplotlib-base + - nc-time-axis + - netcdf4 + - numba + - numbagg + - numpy + - packaging + - pandas + - pint>=0.22 + - pip + - pydap + - pytest + - pytest-cov + - pytest-env + - pytest-xdist + - pytest-timeout + - rasterio + - scipy + - seaborn + - sparse + - toolz + - typing_extensions + - zarr diff --git a/test/fixtures/whole_applications/xarray/ci/requirements/bare-minimum.yml b/test/fixtures/whole_applications/xarray/ci/requirements/bare-minimum.yml new file mode 100644 index 0000000..105e90c --- /dev/null +++ b/test/fixtures/whole_applications/xarray/ci/requirements/bare-minimum.yml @@ -0,0 +1,16 @@ +name: xarray-tests +channels: + - conda-forge + - nodefaults +dependencies: + - python=3.9 + - coveralls + - pip + - pytest + - pytest-cov + - pytest-env + - pytest-xdist + - pytest-timeout + - numpy=1.23 + - packaging=23.1 + - pandas=2.0 diff --git a/test/fixtures/whole_applications/xarray/ci/requirements/doc.yml b/test/fixtures/whole_applications/xarray/ci/requirements/doc.yml new file mode 100644 index 0000000..066d085 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/ci/requirements/doc.yml @@ -0,0 +1,46 @@ +name: xarray-docs +channels: + # Don't change to pkgs/main, as it causes random timeouts in readthedocs + - conda-forge + - nodefaults +dependencies: + - python=3.10 + - bottleneck + - cartopy + - cfgrib + - dask-core>=2022.1 + - dask-expr + - hypothesis>=6.75.8 + - h5netcdf>=0.13 + - ipykernel + - ipywidgets # silence nbsphinx warning + - ipython + - iris>=2.3 + - jupyter_client + - matplotlib-base + - nbsphinx + - netcdf4>=1.5 + - numba + - numpy>=1.21 + - packaging>=21.3 + - pandas>=1.4,!=2.1.0 + - pooch + - pip + - pre-commit + - pyproj + - scipy!=1.10.0 + - seaborn + - setuptools + - sparse + - sphinx-autosummary-accessors + - sphinx-book-theme<=1.0.1 + - sphinx-copybutton + - sphinx-design + - sphinx-inline-tabs + - sphinx>=5.0 + - sphinxext-opengraph + - sphinxext-rediraffe + - zarr>=2.10 + - pip: + # relative to this file. Needs to be editable to be accepted. + - -e ../.. diff --git a/test/fixtures/whole_applications/xarray/ci/requirements/environment-3.13.yml b/test/fixtures/whole_applications/xarray/ci/requirements/environment-3.13.yml new file mode 100644 index 0000000..dbb446f --- /dev/null +++ b/test/fixtures/whole_applications/xarray/ci/requirements/environment-3.13.yml @@ -0,0 +1,49 @@ +name: xarray-tests +channels: + - conda-forge + - nodefaults +dependencies: + - aiobotocore + - array-api-strict + - boto3 + - bottleneck + - cartopy + - cftime + - dask-core + - dask-expr + - distributed + - flox + - fsspec + - h5netcdf + - h5py + - hdf5 + - hypothesis + - iris + - lxml # Optional dep of pydap + - matplotlib-base + - nc-time-axis + - netcdf4 + # - numba + # - numbagg + - numexpr + - numpy + - opt_einsum + - packaging + - pandas + # - pint>=0.22 + - pip + - pooch + - pre-commit + - pydap + - pytest + - pytest-cov + - pytest-env + - pytest-xdist + - pytest-timeout + - rasterio + - scipy + - seaborn + # - sparse + - toolz + - typing_extensions + - zarr diff --git a/test/fixtures/whole_applications/xarray/ci/requirements/environment-windows-3.13.yml b/test/fixtures/whole_applications/xarray/ci/requirements/environment-windows-3.13.yml new file mode 100644 index 0000000..448e3f7 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/ci/requirements/environment-windows-3.13.yml @@ -0,0 +1,44 @@ +name: xarray-tests +channels: + - conda-forge +dependencies: + - array-api-strict + - boto3 + - bottleneck + - cartopy + - cftime + - dask-core + - dask-expr + - distributed + - flox + - fsspec + - h5netcdf + - h5py + - hdf5 + - hypothesis + - iris + - lxml # Optional dep of pydap + - matplotlib-base + - nc-time-axis + - netcdf4 + # - numba + # - numbagg + - numpy + - packaging + - pandas + # - pint>=0.22 + - pip + - pre-commit + - pydap + - pytest + - pytest-cov + - pytest-env + - pytest-xdist + - pytest-timeout + - rasterio + - scipy + - seaborn + # - sparse + - toolz + - typing_extensions + - zarr diff --git a/test/fixtures/whole_applications/xarray/ci/requirements/environment-windows.yml b/test/fixtures/whole_applications/xarray/ci/requirements/environment-windows.yml new file mode 100644 index 0000000..3b2e6dc --- /dev/null +++ b/test/fixtures/whole_applications/xarray/ci/requirements/environment-windows.yml @@ -0,0 +1,44 @@ +name: xarray-tests +channels: + - conda-forge +dependencies: + - array-api-strict + - boto3 + - bottleneck + - cartopy + - cftime + - dask-core + - dask-expr + - distributed + - flox + - fsspec + - h5netcdf + - h5py + - hdf5 + - hypothesis + - iris + - lxml # Optional dep of pydap + - matplotlib-base + - nc-time-axis + - netcdf4 + - numba + - numbagg + - numpy + - packaging + - pandas + # - pint>=0.22 + - pip + - pre-commit + - pydap + - pytest + - pytest-cov + - pytest-env + - pytest-xdist + - pytest-timeout + - rasterio + - scipy + - seaborn + - sparse + - toolz + - typing_extensions + - zarr diff --git a/test/fixtures/whole_applications/xarray/ci/requirements/environment.yml b/test/fixtures/whole_applications/xarray/ci/requirements/environment.yml new file mode 100644 index 0000000..01521e9 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/ci/requirements/environment.yml @@ -0,0 +1,50 @@ +name: xarray-tests +channels: + - conda-forge + - nodefaults +dependencies: + - aiobotocore + - array-api-strict + - boto3 + - bottleneck + - cartopy + - cftime + - dask-core + - dask-expr # dask raises a deprecation warning without this, breaking doctests + - distributed + - flox + - fsspec + - h5netcdf + - h5py + - hdf5 + - hypothesis + - iris + - lxml # Optional dep of pydap + - matplotlib-base + - nc-time-axis + - netcdf4 + - numba + - numbagg + - numexpr + - numpy + - opt_einsum + - packaging + - pandas + # - pint>=0.22 + - pip + - pooch + - pre-commit + - pyarrow # pandas raises a deprecation warning without this, breaking doctests + - pydap + - pytest + - pytest-cov + - pytest-env + - pytest-xdist + - pytest-timeout + - rasterio + - scipy + - seaborn + - sparse + - toolz + - typing_extensions + - zarr diff --git a/test/fixtures/whole_applications/xarray/ci/requirements/min-all-deps.yml b/test/fixtures/whole_applications/xarray/ci/requirements/min-all-deps.yml new file mode 100644 index 0000000..64f4327 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/ci/requirements/min-all-deps.yml @@ -0,0 +1,57 @@ +name: xarray-tests +channels: + - conda-forge + - nodefaults +dependencies: + # MINIMUM VERSIONS POLICY: see doc/user-guide/installing.rst + # Run ci/min_deps_check.py to verify that this file respects the policy. + # When upgrading python, numpy, or pandas, must also change + # doc/user-guide/installing.rst, doc/user-guide/plotting.rst and setup.py. + - python=3.9 + - array-api-strict=1.0 # dependency for testing the array api compat + - boto3=1.26 + - bottleneck=1.3 + - cartopy=0.21 + - cftime=1.6 + - coveralls + - dask-core=2023.4 + - distributed=2023.4 + # Flox > 0.8 has a bug with numbagg versions + # It will require numbagg > 0.6 + # so we should just skip that series eventually + # or keep flox pinned for longer than necessary + - flox=0.7 + - h5netcdf=1.1 + # h5py and hdf5 tend to cause conflicts + # for e.g. hdf5 1.12 conflicts with h5py=3.1 + # prioritize bumping other packages instead + - h5py=3.8 + - hdf5=1.12 + - hypothesis + - iris=3.4 + - lxml=4.9 # Optional dep of pydap + - matplotlib-base=3.7 + - nc-time-axis=1.4 + # netcdf follows a 1.major.minor[.patch] convention + # (see https://github.com/Unidata/netcdf4-python/issues/1090) + - netcdf4=1.6.0 + - numba=0.56 + - numbagg=0.2.1 + - numpy=1.23 + - packaging=23.1 + - pandas=2.0 + - pint=0.22 + - pip + - pydap=3.4 + - pytest + - pytest-cov + - pytest-env + - pytest-xdist + - pytest-timeout + - rasterio=1.3 + - scipy=1.10 + - seaborn=0.12 + - sparse=0.14 + - toolz=0.12 + - typing_extensions=4.5 + - zarr=2.14 diff --git a/test/fixtures/whole_applications/xarray/conftest.py b/test/fixtures/whole_applications/xarray/conftest.py new file mode 100644 index 0000000..24b7530 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/conftest.py @@ -0,0 +1,49 @@ +"""Configuration for pytest.""" + +import pytest + + +def pytest_addoption(parser): + """Add command-line flags for pytest.""" + parser.addoption("--run-flaky", action="store_true", help="runs flaky tests") + parser.addoption( + "--run-network-tests", + action="store_true", + help="runs tests requiring a network connection", + ) + + +def pytest_runtest_setup(item): + # based on https://stackoverflow.com/questions/47559524 + if "flaky" in item.keywords and not item.config.getoption("--run-flaky"): + pytest.skip("set --run-flaky option to run flaky tests") + if "network" in item.keywords and not item.config.getoption("--run-network-tests"): + pytest.skip( + "set --run-network-tests to run test requiring an internet connection" + ) + + +@pytest.fixture(autouse=True) +def add_standard_imports(doctest_namespace, tmpdir): + import numpy as np + import pandas as pd + + import xarray as xr + + doctest_namespace["np"] = np + doctest_namespace["pd"] = pd + doctest_namespace["xr"] = xr + + # always seed numpy.random to make the examples deterministic + np.random.seed(0) + + # always switch to the temporary directory, so files get written there + tmpdir.chdir() + + # Avoid the dask deprecation warning, can remove if CI passes without this. + try: + import dask + except ImportError: + pass + else: + dask.config.set({"dataframe.query-planning": True}) diff --git a/test/fixtures/whole_applications/xarray/design_notes/flexible_indexes_notes.md b/test/fixtures/whole_applications/xarray/design_notes/flexible_indexes_notes.md new file mode 100644 index 0000000..b36ce3e --- /dev/null +++ b/test/fixtures/whole_applications/xarray/design_notes/flexible_indexes_notes.md @@ -0,0 +1,398 @@ +# Proposal: Xarray flexible indexes refactoring + +Current status: https://github.com/pydata/xarray/projects/1 + +## 1. Data Model + +Indexes are used in Xarray to extract data from Xarray objects using coordinate labels instead of using integer array indices. Although the indexes used in an Xarray object can be accessed (or built on-the-fly) via public methods like `to_index()` or properties like `indexes`, those are mainly used internally. + +The goal of this project is to make those indexes 1st-class citizens of Xarray's data model. As such, indexes should clearly be separated from Xarray coordinates with the following relationships: + +- Index -> Coordinate: one-to-many +- Coordinate -> Index: one-to-zero-or-one + +An index may be built from one or more coordinates. However, each coordinate must relate to one index at most. Additionally, a coordinate may not be tied to any index. + +The order in which multiple coordinates relate to an index may matter. For example, Scikit-Learn's [`BallTree`](https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.BallTree.html#sklearn.neighbors.BallTree) index with the Haversine metric requires providing latitude and longitude values in that specific order. As another example, the order in which levels are defined in a `pandas.MultiIndex` may affect its lexsort depth (see [MultiIndex sorting](https://pandas.pydata.org/pandas-docs/stable/user_guide/advanced.html#sorting-a-multiindex)). + +Xarray's current data model has the same index-coordinate relationships than stated above, although this assumes that multi-index "virtual" coordinates are counted as coordinates (we can consider them as such, with some constraints). More importantly, This refactoring would turn the current one-to-one relationship between a dimension and an index into a many-to-many relationship, which would overcome some current limitations. + +For example, we might want to select data along a dimension which has several coordinates: + +```python +>>> da + +array([...]) +Coordinates: + * drainage_area (river_profile) float64 ... + * chi (river_profile) float64 ... +``` + +In this example, `chi` is a transformation of the `drainage_area` variable that is often used in geomorphology. We'd like to select data along the river profile using either `da.sel(drainage_area=...)` or `da.sel(chi=...)` but that's not currently possible. We could rename the `river_profile` dimension to one of the coordinates, then use `sel` with that coordinate, then call `swap_dims` if we want to use `sel` with the other coordinate, but that's not ideal. We could also build a `pandas.MultiIndex` from `drainage_area` and `chi`, but that's not optimal (there's no hierarchical relationship between these two coordinates). + +Let's take another example: + +```python +>>> da + +array([[...], [...]]) +Coordinates: + * lon (x, y) float64 ... + * lat (x, y) float64 ... + * x (x) float64 ... + * y (y) float64 ... +``` + +This refactoring would allow creating a geographic index for `lat` and `lon` and two simple indexes for `x` and `y` such that we could select data with either `da.sel(lon=..., lat=...)` or `da.sel(x=..., y=...)`. + +Refactoring the dimension -> index one-to-one relationship into many-to-many would also introduce some issues that we'll need to address, e.g., ambiguous cases like `da.sel(chi=..., drainage_area=...)` where multiple indexes may potentially return inconsistent positional indexers along a dimension. + +## 2. Proposed API changes + +### 2.1 Index wrapper classes + +Every index that is used to select data from Xarray objects should inherit from a base class, e.g., `XarrayIndex`, that provides some common API. `XarrayIndex` subclasses would generally consist of thin wrappers around existing index classes such as `pandas.Index`, `pandas.MultiIndex`, `scipy.spatial.KDTree`, etc. + +There is a variety of features that an xarray index wrapper may or may not support: + +- 1-dimensional vs. 2-dimensional vs. n-dimensional coordinate (e.g., `pandas.Index` only supports 1-dimensional coordinates while a geographic index could be built from n-dimensional coordinates) +- built from a single vs multiple coordinate(s) (e.g., `pandas.Index` is built from one coordinate, `pandas.MultiIndex` may be built from an arbitrary number of coordinates and a geographic index would typically require two latitude/longitude coordinates) +- in-memory vs. out-of-core (dask) index data/coordinates (vs. other array backends) +- range-based vs. point-wise selection +- exact vs. inexact lookups + +Whether or not a `XarrayIndex` subclass supports each of the features listed above should be either declared explicitly via a common API or left to the implementation. An `XarrayIndex` subclass may encapsulate more than one underlying object used to perform the actual indexing. Such "meta" index would typically support a range of features among those mentioned above and would automatically select the optimal index object for a given indexing operation. + +An `XarrayIndex` subclass must/should/may implement the following properties/methods: + +- a `from_coords` class method that creates a new index wrapper instance from one or more Dataset/DataArray coordinates (+ some options) +- a `query` method that takes label-based indexers as argument (+ some options) and that returns the corresponding position-based indexers +- an `indexes` property to access the underlying index object(s) wrapped by the `XarrayIndex` subclass +- a `data` property to access index's data and map it to coordinate data (see [Section 4](#4-indexvariable)) +- a `__getitem__()` implementation to propagate the index through DataArray/Dataset indexing operations +- `equals()`, `union()` and `intersection()` methods for data alignment (see [Section 2.6](#26-using-indexes-for-data-alignment)) +- Xarray coordinate getters (see [Section 2.2.4](#224-implicit-coodinates)) +- a method that may return a new index and that will be called when one of the corresponding coordinates is dropped from the Dataset/DataArray (multi-coordinate indexes) +- `encode()`/`decode()` methods that would allow storage-agnostic serialization and fast-path reconstruction of the underlying index object(s) (see [Section 2.8](#28-index-encoding)) +- one or more "non-standard" methods or properties that could be leveraged in Xarray 3rd-party extensions like Dataset/DataArray accessors (see [Section 2.7](#27-using-indexes-for-other-purposes)) + +The `XarrayIndex` API has still to be defined in detail. + +Xarray should provide a minimal set of built-in index wrappers (this could be reduced to the indexes currently supported in Xarray, i.e., `pandas.Index` and `pandas.MultiIndex`). Other index wrappers may be implemented in 3rd-party libraries (recommended). The `XarrayIndex` base class should be part of Xarray's public API. + +#### 2.1.1 Index discoverability + +For better discoverability of Xarray-compatible indexes, Xarray could provide some mechanism to register new index wrappers, e.g., something like [xoak's `IndexRegistry`](https://xoak.readthedocs.io/en/latest/_api_generated/xoak.IndexRegistry.html#xoak.IndexRegistry) or [numcodec's registry](https://numcodecs.readthedocs.io/en/stable/registry.html). + +Additionally (or alternatively), new index wrappers may be registered via entry points as is already the case for storage backends and maybe other backends (plotting) in the future. + +Registering new indexes either via a custom registry or via entry points should be optional. Xarray should also allow providing `XarrayIndex` subclasses in its API (Dataset/DataArray constructors, `set_index()`, etc.). + +### 2.2 Explicit vs. implicit index creation + +#### 2.2.1 Dataset/DataArray's `indexes` constructor argument + +The new `indexes` argument of Dataset/DataArray constructors may be used to specify which kind of index to bind to which coordinate(s). It would consist of a mapping where, for each item, the key is one coordinate name (or a sequence of coordinate names) that must be given in `coords` and the value is the type of the index to build from this (these) coordinate(s): + +```python +>>> da = xr.DataArray( +... data=[[275.2, 273.5], [270.8, 278.6]], +... dims=('x', 'y'), +... coords={ +... 'lat': (('x', 'y'), [[45.6, 46.5], [50.2, 51.6]]), +... 'lon': (('x', 'y'), [[5.7, 10.5], [6.2, 12.8]]), +... }, +... indexes={('lat', 'lon'): SpatialIndex}, +... ) + +array([[275.2, 273.5], + [270.8, 278.6]]) +Coordinates: + * lat (x, y) float64 45.6 46.5 50.2 51.6 + * lon (x, y) float64 5.7 10.5 6.2 12.8 +``` + +More formally, `indexes` would accept `Mapping[CoordinateNames, IndexSpec]` where: + +- `CoordinateNames = Union[CoordinateName, Tuple[CoordinateName, ...]]` and `CoordinateName = Hashable` +- `IndexSpec = Union[Type[XarrayIndex], Tuple[Type[XarrayIndex], Dict[str, Any]], XarrayIndex]`, so that index instances or index classes + build options could be also passed + +Currently index objects like `pandas.MultiIndex` can be passed directly to `coords`, which in this specific case results in the implicit creation of virtual coordinates. With the new `indexes` argument this behavior may become even more confusing than it currently is. For the sake of clarity, it would be appropriate to eventually drop support for this specific behavior and treat any given mapping value given in `coords` as an array that can be wrapped into an Xarray variable, i.e., in the case of a multi-index: + +```python +>>> xr.DataArray([1.0, 2.0], dims='x', coords={'x': midx}) + +array([1., 2.]) +Coordinates: + x (x) object ('a', 0) ('b', 1) +``` + +A possible, more explicit solution to reuse a `pandas.MultiIndex` in a DataArray/Dataset with levels exposed as coordinates is proposed in [Section 2.2.4](#224-implicit-coordinates). + +#### 2.2.2 Dataset/DataArray's `set_index` method + +New indexes may also be built from existing sets of coordinates or variables in a Dataset/DataArray using the `.set_index()` method. + +The [current signature](http://docs.xarray.dev/en/stable/generated/xarray.DataArray.set_index.html#xarray.DataArray.set_index) of `.set_index()` is tailored to `pandas.MultiIndex` and tied to the concept of a dimension-index. It is therefore hardly reusable as-is in the context of flexible indexes proposed here. + +The new signature may look like one of these: + +- A. `.set_index(coords: CoordinateNames, index: Union[XarrayIndex, Type[XarrayIndex]], **index_kwargs)`: one index is set at a time, index construction options may be passed as keyword arguments +- B. `.set_index(indexes: Mapping[CoordinateNames, Union[Type[XarrayIndex], Tuple[Type[XarrayIndex], Dict[str, Any]]]])`: multiple indexes may be set at a time from a mapping of coordinate or variable name(s) as keys and `XarrayIndex` subclasses (maybe with a dict of build options) as values. If variable names are given as keys of they will be promoted as coordinates + +Option A looks simple and elegant but significantly departs from the current signature. Option B is more consistent with the Dataset/DataArray constructor signature proposed in the previous section and would be easier to adopt in parallel with the current signature that we could still support through some depreciation cycle. + +The `append` parameter of the current `.set_index()` is specific to `pandas.MultiIndex`. With option B we could still support it, although we might want to either drop it or move it to the index construction options in the future. + +#### 2.2.3 Implicit default indexes + +In general explicit index creation should be preferred over implicit index creation. However, there is a majority of cases where basic `pandas.Index` objects could be built and used as indexes for 1-dimensional coordinates. For convenience, Xarray should automatically build such indexes for the coordinates where no index has been explicitly assigned in the Dataset/DataArray constructor or when indexes have been reset / dropped. + +For which coordinates? + +- A. only 1D coordinates with a name matching their dimension name +- B. all 1D coordinates + +When to create it? + +- A. each time when a new Dataset/DataArray is created +- B. only when we need it (i.e., when calling `.sel()` or `indexes`) + +Options A and A are what Xarray currently does and may be the best choice considering that indexes could possibly be invalidated by coordinate mutation. + +Besides `pandas.Index`, other indexes currently supported in Xarray like `CFTimeIndex` could be built depending on the coordinate data type. + +#### 2.2.4 Implicit coordinates + +Like for the indexes, explicit coordinate creation should be preferred over implicit coordinate creation. However, there may be some situations where we would like to keep creating coordinates implicitly for backwards compatibility. + +For example, it is currently possible to pass a `pandas.MulitIndex` object as a coordinate to the Dataset/DataArray constructor: + +```python +>>> midx = pd.MultiIndex.from_arrays([['a', 'b'], [0, 1]], names=['lvl1', 'lvl2']) +>>> da = xr.DataArray([1.0, 2.0], dims='x', coords={'x': midx}) +>>> da + +array([1., 2.]) +Coordinates: + * x (x) MultiIndex + - lvl1 (x) object 'a' 'b' + - lvl2 (x) int64 0 1 +``` + +In that case, virtual coordinates are created for each level of the multi-index. After the index refactoring, these coordinates would become real coordinates bound to the multi-index. + +In the example above a coordinate is also created for the `x` dimension: + +```python +>>> da.x + +array([('a', 0), ('b', 1)], dtype=object) +Coordinates: + * x (x) MultiIndex + - lvl1 (x) object 'a' 'b' + - lvl2 (x) int64 0 1 +``` + +With the new proposed data model, this wouldn't be a requirement anymore: there is no concept of a dimension-index. However, some users might still rely on the `x` coordinate so we could still (temporarily) support it for backwards compatibility. + +Besides `pandas.MultiIndex`, there may be other situations where we would like to reuse an existing index in a new Dataset/DataArray (e.g., when the index is very expensive to build), and which might require implicit creation of one or more coordinates. + +The example given here is quite confusing, though: this is not an easily predictable behavior. We could entirely avoid the implicit creation of coordinates, e.g., using a helper function that generates coordinate + index dictionaries that we could then pass directly to the DataArray/Dataset constructor: + +```python +>>> coords_dict, index_dict = create_coords_from_index(midx, dims='x', include_dim_coord=True) +>>> coords_dict +{'x': + array([('a', 0), ('b', 1)], dtype=object), + 'lvl1': + array(['a', 'b'], dtype=object), + 'lvl2': + array([0, 1])} +>>> index_dict +{('lvl1', 'lvl2'): midx} +>>> xr.DataArray([1.0, 2.0], dims='x', coords=coords_dict, indexes=index_dict) + +array([1., 2.]) +Coordinates: + x (x) object ('a', 0) ('b', 1) + * lvl1 (x) object 'a' 'b' + * lvl2 (x) int64 0 1 +``` + +### 2.2.5 Immutable indexes + +Some underlying indexes might be mutable (e.g., a tree-based index structure that allows dynamic addition of data points) while other indexes like `pandas.Index` aren't. To keep things simple, it is probably better to continue considering all indexes in Xarray as immutable (as well as their corresponding coordinates, see [Section 2.4.1](#241-mutable-coordinates)). + +### 2.3 Index access + +#### 2.3.1 Dataset/DataArray's `indexes` property + +The `indexes` property would allow easy access to all the indexes used in a Dataset/DataArray. It would return a `Dict[CoordinateName, XarrayIndex]` for easy index lookup from coordinate name. + +#### 2.3.2 Additional Dataset/DataArray properties or methods + +In some cases the format returned by the `indexes` property would not be the best (e.g, it may return duplicate index instances as values). For convenience, we could add one more property / method to get the indexes in the desired format if needed. + +### 2.4 Propagate indexes through operations + +#### 2.4.1 Mutable coordinates + +Dataset/DataArray coordinates may be replaced (`__setitem__`) or dropped (`__delitem__`) in-place, which may invalidate some of the indexes. A drastic though probably reasonable solution in this case would be to simply drop all indexes bound to those replaced/dropped coordinates. For the case where a 1D basic coordinate that corresponds to a dimension is added/replaced, we could automatically generate a new index (see [Section 2.2.4](#224-implicit-indexes)). + +We must also ensure that coordinates having a bound index are immutable, e.g., still wrap them into `IndexVariable` objects (even though the `IndexVariable` class might change substantially after this refactoring). + +#### 2.4.2 New Dataset/DataArray with updated coordinates + +Xarray provides a variety of Dataset/DataArray operations affecting the coordinates and where simply dropping the index(es) is not desirable. For example: + +- multi-coordinate indexes could be reduced to single coordinate indexes + - like in `.reset_index()` or `.sel()` applied on a subset of the levels of a `pandas.MultiIndex` and that internally call `MultiIndex.droplevel` and `MultiIndex.get_loc_level`, respectively +- indexes may be indexed themselves + - like `pandas.Index` implements `__getitem__()` + - when indexing their corresponding coordinate(s), e.g., via `.sel()` or `.isel()`, those indexes should be indexed too + - this might not be supported by all Xarray indexes, though +- some indexes that can't be indexed could still be automatically (re)built in the new Dataset/DataArray + - like for example building a new `KDTree` index from the selection of a subset of an initial collection of data points + - this is not always desirable, though, as indexes may be expensive to build + - a more reasonable option would be to explicitly re-build the index, e.g., using `.set_index()` +- Dataset/DataArray operations involving alignment (see [Section 2.6](#26-using-indexes-for-data-alignment)) + +### 2.5 Using indexes for data selection + +One main use of indexes is label-based data selection using the DataArray/Dataset `.sel()` method. This refactoring would introduce a number of API changes that could go through some depreciation cycles: + +- the keys of the mapping given to `indexers` (or the names of `indexer_kwargs`) would not correspond to only dimension names but could be the name of any coordinate that has an index +- for a `pandas.MultiIndex`, if no dimension-coordinate is created by default (see [Section 2.2.4](#224-implicit-coordinates)), providing dict-like objects as indexers should be depreciated +- there should be the possibility to provide additional options to the indexes that support specific selection features (e.g., Scikit-learn's `BallTree`'s `dualtree` query option to boost performance). + - the best API is not trivial here, since `.sel()` may accept indexers passed to several indexes (which should still be supported for convenience and compatibility), and indexes may have similar options with different semantics + - we could introduce a new parameter like `index_options: Dict[XarrayIndex, Dict[str, Any]]` to pass options grouped by index +- the `method` and `tolerance` parameters are specific to `pandas.Index` and would not be supported by all indexes: probably best is to eventually pass those arguments as `index_options` +- the list valid indexer types might be extended in order to support new ways of indexing data, e.g., unordered selection of all points within a given range + - alternatively, we could reuse existing indexer types with different semantics depending on the index, e.g., using `slice(min, max, None)` for unordered range selection + +With the new data model proposed here, an ambiguous situation may occur when indexers are given for several coordinates that share the same dimension but not the same index, e.g., from the example in [Section 1](#1-data-model): + +```python +da.sel(x=..., y=..., lat=..., lon=...) +``` + +The easiest solution for this situation would be to raise an error. Alternatively, we could introduce a new parameter to specify how to combine the resulting integer indexers (i.e., union vs intersection), although this could already be achieved by chaining `.sel()` calls or combining `.sel()` with `.merge()` (it may or may not be straightforward). + +### 2.6 Using indexes for data alignment + +Another main use if indexes is data alignment in various operations. Some considerations regarding alignment and flexible indexes: + +- support for alignment should probably be optional for an `XarrayIndex` subclass. + - like `pandas.Index`, the index wrapper classes that support it should implement `.equals()`, `.union()` and/or `.intersection()` + - support might be partial if that makes sense (outer, inner, left, right, exact...). + - index equality might involve more than just the labels: for example a spatial index might be used to check if the coordinate system (CRS) is identical for two sets of coordinates + - some indexes might implement inexact alignment, like in [#4489](https://github.com/pydata/xarray/pull/4489) or a `KDTree` index that selects nearest-neighbors within a given tolerance + - alignment may be "multi-dimensional", i.e., the `KDTree` example above vs. dimensions aligned independently of each other +- we need to decide what to do when one dimension has more than one index that supports alignment + - we should probably raise unless the user explicitly specify which index to use for the alignment +- we need to decide what to do when one dimension has one or more index(es) but none support alignment + - either we raise or we fail back (silently) to alignment based on dimension size +- for inexact alignment, the tolerance threshold might be given when building the index and/or when performing the alignment +- are there cases where we want a specific index to perform alignment and another index to perform selection? + - it would be tricky to support that unless we allow multiple indexes per coordinate + - alternatively, underlying indexes could be picked internally in a "meta" index for one operation or another, although the risk is to eventually have to deal with an explosion of index wrapper classes with different meta indexes for each combination that we'd like to use. + +### 2.7 Using indexes for other purposes + +Xarray also provides a number of Dataset/DataArray methods where indexes are used in various ways, e.g., + +- `resample` (`CFTimeIndex` and a `DatetimeIntervalIndex`) +- `DatetimeAccessor` & `TimedeltaAccessor` properties (`CFTimeIndex` and a `DatetimeIntervalIndex`) +- `interp` & `interpolate_na`, + - with `IntervalIndex`, these become regridding operations. Should we support hooks for these operations? +- `differentiate`, `integrate`, `polyfit` + - raise an error if not a "simple" 1D index? +- `pad` +- `coarsen` has to make choices about output index labels. +- `sortby` +- `stack`/`unstack` +- plotting + - `plot.pcolormesh` "infers" interval breaks along axes, which are really inferred `bounds` for the appropriate indexes. + - `plot.step` again uses `bounds`. In fact, we may even want `step` to be the default 1D plotting function if the axis has `bounds` attached. + +It would be reasonable to first restrict those methods to the indexes that are currently available in Xarray, and maybe extend the `XarrayIndex` API later upon request when the opportunity arises. + +Conversely, nothing should prevent implementing "non-standard" API in 3rd-party `XarrayIndex` subclasses that could be used in DataArray/Dataset extensions (accessors). For example, we might want to reuse a `KDTree` index to compute k-nearest neighbors (returning a DataArray/Dataset with a new dimension) and/or the distances to the nearest neighbors (returning a DataArray/Dataset with a new data variable). + +### 2.8 Index encoding + +Indexes don't need to be directly serializable since we could (re)build them from their corresponding coordinate(s). However, it would be useful if some indexes could be encoded/decoded to/from a set of arrays that would allow optimized reconstruction and/or storage, e.g., + +- `pandas.MultiIndex` -> `index.levels` and `index.codes` +- Scikit-learn's `KDTree` and `BallTree` that use an array-based representation of an immutable tree structure + +## 3. Index representation in DataArray/Dataset's `repr` + +Since indexes would become 1st class citizen of Xarray's data model, they deserve their own section in Dataset/DataArray `repr` that could look like: + +``` + +array([[5.4, 7.8], + [6.2, 4.7]]) +Coordinates: + * lon (x, y) float64 10.2 15.2 12.6 17.6 + * lat (x, y) float64 40.2 45.6 42.2 47.6 + * x (x) float64 200.0 400.0 + * y (y) float64 800.0 1e+03 +Indexes: + lat, lon + x + y +``` + +To keep the `repr` compact, we could: + +- consolidate entries that map to the same index object, and have an short inline repr for `XarrayIndex` object +- collapse the index section by default in the HTML `repr` +- maybe omit all trivial indexes for 1D coordinates that match the dimension name + +## 4. `IndexVariable` + +`IndexVariable` is currently used to wrap a `pandas.Index` as a variable, which would not be relevant after this refactoring since it is aimed at decoupling indexes and variables. + +We'll probably need to move elsewhere some of the features implemented in `IndexVariable` to: + +- ensure that all coordinates with an index are immutable (see [Section 2.4.1](#241-mutable-coordinates)) + - do not set values directly, do not (re)chunk (even though it may be already chunked), do not load, do not convert to sparse/dense, etc. +- directly reuse index's data when that's possible + - in the case of a `pandas.Index`, it makes little sense to have duplicate data (e.g., as a NumPy array) for its corresponding coordinate +- convert a variable into a `pandas.Index` using `.to_index()` (for backwards compatibility). + +Other `IndexVariable` API like `level_names` and `get_level_variable()` would not useful anymore: it is specific to how we currently deal with `pandas.MultiIndex` and virtual "level" coordinates in Xarray. + +## 5. Chunked coordinates and/or indexers + +We could take opportunity of this refactoring to better leverage chunked coordinates (and/or chunked indexers for data selection). There's two ways to enable it: + +A. support for chunked coordinates is left to the index +B. support for chunked coordinates is index agnostic and is implemented in Xarray + +As an example for B, [xoak](https://github.com/ESM-VFC/xoak) supports building an index for each chunk, which is coupled with a two-step data selection process (cross-index queries + brute force "reduction" look-up). There is an example [here](https://xoak.readthedocs.io/en/latest/examples/dask_support.html). This may be tedious to generalize this to other kinds of operations, though. Xoak's Dask support is rather experimental, not super stable (it's quite hard to control index replication and data transfer between Dask workers with the default settings), and depends on whether indexes are thread-safe and/or serializable. + +Option A may be more reasonable for now. + +## 6. Coordinate duck arrays + +Another opportunity of this refactoring is support for duck arrays as index coordinates. Decoupling coordinates and indexes would *de-facto* enable it. + +However, support for duck arrays in index-based operations such as data selection or alignment would probably require some protocol extension, e.g., + +```python +class MyDuckArray: + ... + + def _sel_(self, indexer): + """Prepare the label-based indexer to conform to this coordinate array.""" + ... + return new_indexer + + ... +``` + +For example, a `pint` array would implement `_sel_` to perform indexer unit conversion or raise, warn, or just pass the indexer through if it has no units. diff --git a/test/fixtures/whole_applications/xarray/design_notes/grouper_objects.md b/test/fixtures/whole_applications/xarray/design_notes/grouper_objects.md new file mode 100644 index 0000000..af42ef2 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/design_notes/grouper_objects.md @@ -0,0 +1,240 @@ +# Grouper Objects +**Author**: Deepak Cherian +**Created**: Nov 21, 2023 + +## Abstract + +I propose the addition of Grouper objects to Xarray's public API so that +```python +Dataset.groupby(x=BinGrouper(bins=np.arange(10, 2)))) +``` +is identical to today's syntax: +```python +Dataset.groupby_bins("x", bins=np.arange(10, 2)) +``` + +## Motivation and scope + +Xarray's GroupBy API implements the split-apply-combine pattern (Wickham, 2011)[^1], which applies to a very large number of problems: histogramming, compositing, climatological averaging, resampling to a different time frequency, etc. +The pattern abstracts the following pseudocode: +```python +results = [] +for element in unique_labels: + subset = ds.sel(x=(ds.x == element)) # split + # subset = ds.where(ds.x == element, drop=True) # alternative + result = subset.mean() # apply + results.append(result) + +xr.concat(results) # combine +``` + +to +```python +ds.groupby('x').mean() # splits, applies, and combines +``` + +Efficient vectorized implementations of this pattern are implemented in numpy's [`ufunc.at`](https://numpy.org/doc/stable/reference/generated/numpy.ufunc.at.html), [`ufunc.reduceat`](https://numpy.org/doc/stable/reference/generated/numpy.ufunc.reduceat.html), [`numbagg.grouped`](https://github.com/numbagg/numbagg/blob/main/numbagg/grouped.py), [`numpy_groupies`](https://github.com/ml31415/numpy-groupies), and probably more. +These vectorized implementations *all* require, as input, an array of integer codes or labels that identify unique elements in the array being grouped over (`'x'` in the example above). +```python +import numpy as np + +# array to reduce +a = np.array([1, 1, 1, 1, 2]) + +# initial value for result +out = np.zeros((3,), dtype=int) + +# integer codes +labels = np.array([0, 0, 1, 2, 1]) + +# groupby-reduction +np.add.at(out, labels, a) +out # array([2, 3, 1]) +``` + +One can 'factorize' or construct such an array of integer codes using `pandas.factorize` or `numpy.unique(..., return_inverse=True)` for categorical arrays; `pandas.cut`, `pandas.qcut`, or `np.digitize` for discretizing continuous variables. +In practice, since `GroupBy` objects exist, much of complexity in applying the groupby paradigm stems from appropriately factorizing or generating labels for the operation. +Consider these two examples: +1. [Bins that vary in a dimension](https://flox.readthedocs.io/en/latest/user-stories/nD-bins.html) +2. [Overlapping groups](https://flox.readthedocs.io/en/latest/user-stories/overlaps.html) +3. [Rolling resampling](https://github.com/pydata/xarray/discussions/8361) + +Anecdotally, less experienced users commonly resort to the for-loopy implementation illustrated by the pseudocode above when the analysis at hand is not easily expressed using the API presented by Xarray's GroupBy object. +Xarray's GroupBy API today abstracts away the split, apply, and combine stages but not the "factorize" stage. +Grouper objects will close the gap. + +## Usage and impact + + +Grouper objects +1. Will abstract useful factorization algorithms, and +2. Present a natural way to extend GroupBy to grouping by multiple variables: `ds.groupby(x=BinGrouper(...), t=Resampler(freq="M", ...)).mean()`. + +In addition, Grouper objects provide a nice interface to add often-requested grouping functionality +1. A new `SpaceResampler` would allow specifying resampling spatial dimensions. ([issue](https://github.com/pydata/xarray/issues/4008)) +2. `RollingTimeResampler` would allow rolling-like functionality that understands timestamps ([issue](https://github.com/pydata/xarray/issues/3216)) +3. A `QuantileBinGrouper` to abstract away `pd.cut` ([issue](https://github.com/pydata/xarray/discussions/7110)) +4. A `SeasonGrouper` and `SeasonResampler` would abstract away common annoyances with such calculations today + 1. Support seasons that span a year-end. + 2. Only include seasons with complete data coverage. + 3. Allow grouping over seasons of unequal length + 4. See [this xcdat discussion](https://github.com/xCDAT/xcdat/issues/416) for a `SeasonGrouper` like functionality: + 5. Return results with seasons in a sensible order +5. Weighted grouping ([issue](https://github.com/pydata/xarray/issues/3937)) + 1. Once `IntervalIndex` like objects are supported, `Resampler` groupers can account for interval lengths when resampling. + +## Backward Compatibility + +Xarray's existing grouping functionality will be exposed using two new Groupers: +1. `UniqueGrouper` which uses `pandas.factorize` +2. `BinGrouper` which uses `pandas.cut` +3. `TimeResampler` which mimics pandas' `.resample` + +Grouping by single variables will be unaffected so that `ds.groupby('x')` will be identical to `ds.groupby(x=UniqueGrouper())`. +Similarly, `ds.groupby_bins('x', bins=np.arange(10, 2))` will be unchanged and identical to `ds.groupby(x=BinGrouper(bins=np.arange(10, 2)))`. + +## Detailed description + +All Grouper objects will subclass from a Grouper object +```python +import abc + +class Grouper(abc.ABC): + @abc.abstractmethod + def factorize(self, by: DataArray): + raise NotImplementedError + +class CustomGrouper(Grouper): + def factorize(self, by: DataArray): + ... + return codes, group_indices, unique_coord, full_index + + def weights(self, by: DataArray) -> DataArray: + ... + return weights +``` + +### The `factorize` method +Today, the `factorize` method takes as input the group variable and returns 4 variables (I propose to clean this up below): +1. `codes`: An array of same shape as the `group` with int dtype. NaNs in `group` are coded by `-1` and ignored later. +2. `group_indices` is a list of index location of `group` elements that belong to a single group. +3. `unique_coord` is (usually) a `pandas.Index` object of all unique `group` members present in `group`. +4. `full_index` is a `pandas.Index` of all `group` members. This is different from `unique_coord` for binning and resampling, where not all groups in the output may be represented in the input `group`. For grouping by a categorical variable e.g. `['a', 'b', 'a', 'c']`, `full_index` and `unique_coord` are identical. +There is some redundancy here since `unique_coord` is always equal to or a subset of `full_index`. +We can clean this up (see Implementation below). + +### The `weights` method (?) + +The proposed `weights` method is optional and unimplemented today. +Groupers with `weights` will allow composing `weighted` and `groupby` ([issue](https://github.com/pydata/xarray/issues/3937)). +The `weights` method should return an appropriate array of weights such that the following property is satisfied +```python +gb_sum = ds.groupby(by).sum() + +weights = CustomGrouper.weights(by) +weighted_sum = xr.dot(ds, weights) + +assert_identical(gb_sum, weighted_sum) +``` +For example, the boolean weights for `group=np.array(['a', 'b', 'c', 'a', 'a'])` should be +``` +[[1, 0, 0, 1, 1], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 0]] +``` +This is the boolean "summarization matrix" referred to in the classic Iverson (1980, Section 4.3)[^2] and "nub sieve" in [various APLs](https://aplwiki.com/wiki/Nub_Sieve). + +> [!NOTE] +> We can always construct `weights` automatically using `group_indices` from `factorize`, so this is not a required method. + +For a rolling resampling, windowed weights are possible +``` +[[0.5, 1, 0.5, 0, 0], + [0, 0.25, 1, 1, 0], + [0, 0, 0, 1, 1]] +``` + +### The `preferred_chunks` method (?) + +Rechunking support is another optional extension point. +In `flox` I experimented some with automatically rechunking to make a groupby more parallel-friendly ([example 1](https://flox.readthedocs.io/en/latest/generated/flox.rechunk_for_blockwise.html), [example 2](https://flox.readthedocs.io/en/latest/generated/flox.rechunk_for_cohorts.html)). +A great example is for resampling-style groupby reductions, for which `codes` might look like +``` +0001|11122|3333 +``` +where `|` represents chunk boundaries. A simple rechunking to +``` +000|111122|3333 +``` +would make this resampling reduction an embarassingly parallel blockwise problem. + +Similarly consider monthly-mean climatologies for which the month numbers might be +``` +1 2 3 4 5 | 6 7 8 9 10 | 11 12 1 2 3 | 4 5 6 7 8 | 9 10 11 12 | +``` +A slight rechunking to +``` +1 2 3 4 | 5 6 7 8 | 9 10 11 12 | 1 2 3 4 | 5 6 7 8 | 9 10 11 12 | +``` +allows us to reduce `1, 2, 3, 4` separately from `5,6,7,8` and `9, 10, 11, 12` while still being parallel friendly (see the [flox documentation](https://flox.readthedocs.io/en/latest/implementation.html#method-cohorts) for more). + +We could attempt to detect these patterns, or we could just have the Grouper take as input `chunks` and return a tuple of "nice" chunk sizes to rechunk to. +```python +def preferred_chunks(self, chunks: ChunksTuple) -> ChunksTuple: + pass +``` +For monthly means, since the period of repetition of labels is 12, the Grouper might choose possible chunk sizes of `((2,),(3,),(4,),(6,))`. +For resampling, the Grouper could choose to resample to a multiple or an even fraction of the resampling frequency. + +## Related work + +Pandas has [Grouper objects](https://pandas.pydata.org/docs/reference/api/pandas.Grouper.html#pandas-grouper) that represent the GroupBy instruction. +However, these objects do not appear to be extension points, unlike the Grouper objects proposed here. +Instead, Pandas' `ExtensionArray` has a [`factorize`](https://pandas.pydata.org/docs/reference/api/pandas.api.extensions.ExtensionArray.factorize.html) method. + +Composing rolling with time resampling is a common workload: +1. Polars has [`group_by_dynamic`](https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.group_by_dynamic.html) which appears to be like the proposed `RollingResampler`. +2. scikit-downscale provides [`PaddedDOYGrouper`]( +https://github.com/pangeo-data/scikit-downscale/blob/e16944a32b44f774980fa953ea18e29a628c71b8/skdownscale/pointwise_models/groupers.py#L19) + +## Implementation Proposal + +1. Get rid of `squeeze` [issue](https://github.com/pydata/xarray/issues/2157): [PR](https://github.com/pydata/xarray/pull/8506) +2. Merge existing two class implementation to a single Grouper class + 1. This design was implemented in [this PR](https://github.com/pydata/xarray/pull/7206) to account for some annoying data dependencies. + 2. See [PR](https://github.com/pydata/xarray/pull/8509) +3. Clean up what's returned by `factorize` methods. + 1. A solution here might be to have `group_indices: Mapping[int, Sequence[int]]` be a mapping from group index in `full_index` to a sequence of integers. + 2. Return a `namedtuple` or `dataclass` from existing Grouper factorize methods to facilitate API changes in the future. +4. Figure out what to pass to `factorize` + 1. Xarray eagerly reshapes nD variables to 1D. This is an implementation detail we need not expose. + 2. When grouping by an unindexed variable Xarray passes a `_DummyGroup` object. This seems like something we don't want in the public interface. We could special case "internal" Groupers to preserve the optimizations in `UniqueGrouper`. +5. Grouper objects will exposed under the `xr.groupers` Namespace. At first these will include `UniqueGrouper`, `BinGrouper`, and `TimeResampler`. + +## Alternatives + +One major design choice made here was to adopt the syntax `ds.groupby(x=BinGrouper(...))` instead of `ds.groupby(BinGrouper('x', ...))`. +This allows reuse of Grouper objects, example +```python +grouper = BinGrouper(...) +ds.groupby(x=grouper, y=grouper) +``` +but requires that all variables being grouped by (`x` and `y` above) are present in Dataset `ds`. This does not seem like a bad requirement. +Importantly `Grouper` instances will be copied internally so that they can safely cache state that might be shared between `factorize` and `weights`. + +Today, it is possible to `ds.groupby(DataArray, ...)`. This syntax will still be supported. + +## Discussion + +This proposal builds on these discussions: +1. https://github.com/xarray-contrib/flox/issues/191#issuecomment-1328898836 +2. https://github.com/pydata/xarray/issues/6610 + +## Copyright + +This document has been placed in the public domain. + +## References and footnotes + +[^1]: Wickham, H. (2011). The split-apply-combine strategy for data analysis. https://vita.had.co.nz/papers/plyr.html +[^2]: Iverson, K.E. (1980). Notation as a tool of thought. Commun. ACM 23, 8 (Aug. 1980), 444–465. https://doi.org/10.1145/358896.358899 diff --git a/test/fixtures/whole_applications/xarray/design_notes/named_array_design_doc.md b/test/fixtures/whole_applications/xarray/design_notes/named_array_design_doc.md new file mode 100644 index 0000000..074f8cf --- /dev/null +++ b/test/fixtures/whole_applications/xarray/design_notes/named_array_design_doc.md @@ -0,0 +1,371 @@ +# named-array Design Document + +## Abstract + +Despite the wealth of scientific libraries in the Python ecosystem, there is a gap for a lightweight, efficient array structure with named dimensions that can provide convenient broadcasting and indexing. + +Existing solutions like Xarray's Variable, [Pytorch Named Tensor](https://github.com/pytorch/pytorch/issues/60832), [Levanter](https://crfm.stanford.edu/2023/06/16/levanter-1_0-release.html), and [Larray](https://larray.readthedocs.io/en/stable/tutorial/getting_started.html) have their own strengths and weaknesses. Xarray's Variable is an efficient data structure, but it depends on the relatively heavy-weight library Pandas, which limits its use in other projects. Pytorch Named Tensor offers named dimensions, but it lacks support for many operations, making it less user-friendly. Levanter is a powerful tool with a named tensor module (Haliax) that makes deep learning code easier to read, understand, and write, but it is not as lightweight or generic as desired. Larry offers labeled N-dimensional arrays, but it may not provide the level of seamless interoperability with other scientific Python libraries that some users need. + +named-array aims to solve these issues by exposing the core functionality of Xarray's Variable class as a standalone package. + +## Motivation and Scope + +The Python ecosystem boasts a wealth of scientific libraries that enable efficient computations on large, multi-dimensional arrays. Libraries like PyTorch, Xarray, and NumPy have revolutionized scientific computing by offering robust data structures for array manipulations. Despite this wealth of tools, a gap exists in the Python landscape for a lightweight, efficient array structure with named dimensions that can provide convenient broadcasting and indexing. + +Xarray internally maintains a data structure that meets this need, referred to as [`xarray.Variable`](https://docs.xarray.dev/en/latest/generated/xarray.Variable.html) . However, Xarray's dependency on Pandas, a relatively heavy-weight library, restricts other projects from leveraging this efficient data structure (, , ). + +We propose the creation of a standalone Python package, "named-array". This package is envisioned to be a version of the `xarray.Variable` data structure, cleanly separated from the heavier dependencies of Xarray. named-array will provide a lightweight, user-friendly array-like data structure with named dimensions, facilitating convenient indexing and broadcasting. The package will use existing scientific Python community standards such as established array protocols and the new [Python array API standard](https://data-apis.org/array-api/latest), allowing users to wrap multiple duck-array objects, including, but not limited to, NumPy, Dask, Sparse, Pint, CuPy, and Pytorch. + +The development of named-array is projected to meet a key community need and expected to broaden Xarray's user base. By making the core `xarray.Variable` more accessible, we anticipate an increase in contributors and a reduction in the developer burden on current Xarray maintainers. + +### Goals + +1. **Simple and minimal**: named-array will expose Xarray's [Variable class](https://docs.xarray.dev/en/stable/internals/variable-objects.html) as a standalone object (`NamedArray`) with named axes (dimensions) and arbitrary metadata (attributes) but without coordinate labels. This will make it a lightweight, efficient array data structure that allows convenient broadcasting and indexing. + +2. **Interoperability**: named-array will follow established scientific Python community standards and in doing so, will allow it to wrap multiple duck-array objects, including but not limited to, NumPy, Dask, Sparse, Pint, CuPy, and Pytorch. + +3. **Community Engagement**: By making the core `xarray.Variable` more accessible, we open the door to increased adoption of this fundamental data structure. As such, we hope to see an increase in contributors and reduction in the developer burden on current Xarray maintainers. + +### Non-Goals + +1. **Extensive Data Analysis**: named-array will not provide extensive data analysis features like statistical functions, data cleaning, or visualization. Its primary focus is on providing a data structure that allows users to use dimension names for descriptive array manipulations. + +2. **Support for I/O**: named-array will not bundle file reading functions. Instead users will be expected to handle I/O and then wrap those arrays with the new named-array data structure. + +## Backward Compatibility + +The creation of named-array is intended to separate the `xarray.Variable` from Xarray into a standalone package. This allows it to be used independently, without the need for Xarray's dependencies, like Pandas. This separation has implications for backward compatibility. + +Since the new named-array is envisioned to contain the core features of Xarray's variable, existing code using Variable from Xarray should be able to switch to named-array with minimal changes. However, there are several potential issues related to backward compatibility: + +* **API Changes**: as the Variable is decoupled from Xarray and moved into named-array, some changes to the API may be necessary. These changes might include differences in function signature, etc. These changes could break existing code that relies on the current API and associated utility functions (e.g. `as_variable()`). The `xarray.Variable` object will subclass `NamedArray`, and provide the existing interface for compatibility. + +## Detailed Description + +named-array aims to provide a lightweight, efficient array structure with named dimensions, or axes, that enables convenient broadcasting and indexing. The primary component of named-array is a standalone version of the xarray.Variable data structure, which was previously a part of the Xarray library. +The xarray.Variable data structure in named-array will maintain the core features of its counterpart in Xarray, including: + +* **Named Axes (Dimensions)**: Each axis of the array can be given a name, providing a descriptive and intuitive way to reference the dimensions of the array. + +* **Arbitrary Metadata (Attributes)**: named-array will support the attachment of arbitrary metadata to arrays as a dict, providing a mechanism to store additional information about the data that the array represents. + +* **Convenient Broadcasting and Indexing**: With named dimensions, broadcasting and indexing operations become more intuitive and less error-prone. + +The named-array package is designed to be interoperable with other scientific Python libraries. It will follow established scientific Python community standards and use standard array protocols, as well as the new data-apis standard. This allows named-array to wrap multiple duck-array objects, including, but not limited to, NumPy, Dask, Sparse, Pint, CuPy, and Pytorch. + +## Implementation + +* **Decoupling**: making `variable.py` agnostic to Xarray internals by decoupling it from the rest of the library. This will make the code more modular and easier to maintain. However, this will also make the code more complex, as we will need to define a clear interface for how the functionality in `variable.py` interacts with the rest of the library, particularly the ExplicitlyIndexed subclasses used to enable lazy indexing of data on disk. +* **Move Xarray's internal lazy indexing classes to follow standard Array Protocols**: moving the lazy indexing classes like `ExplicitlyIndexed` to use standard array protocols will be a key step in decoupling. It will also potentially improve interoperability with other libraries that use these protocols, and prepare these classes [for eventual movement out](https://github.com/pydata/xarray/issues/5081) of the Xarray code base. However, this will also require significant changes to the code, and we will need to ensure that all existing functionality is preserved. + * Use [https://data-apis.org/array-api-compat/](https://data-apis.org/array-api-compat/) to handle compatibility issues? +* **Leave lazy indexing classes in Xarray for now** +* **Preserve support for Dask collection protocols**: named-array will preserve existing support for the dask collections protocol namely the __dask_***__ methods +* **Preserve support for ChunkManagerEntrypoint?** Opening variables backed by dask vs cubed arrays currently is [handled within Variable.chunk](https://github.com/pydata/xarray/blob/92c8b33eb464b09d6f8277265b16cae039ab57ee/xarray/core/variable.py#L1272C15-L1272C15). If we are preserving dask support it would be nice to preserve general chunked array type support, but this currently requires an entrypoint. + +### Plan + +1. Create a new baseclass for `xarray.Variable` to its own module e.g. `xarray.core.base_variable` +2. Remove all imports of internal Xarray classes and utils from `base_variable.py`. `base_variable.Variable` should not depend on anything in xarray.core + * Will require moving the lazy indexing classes (subclasses of ExplicitlyIndexed) to be standards compliant containers.` + * an array-api compliant container that provides **array_namespace**` + * Support `.oindex` and `.vindex` for explicit indexing + * Potentially implement this by introducing a new compliant wrapper object? + * Delete the `NON_NUMPY_SUPPORTED_ARRAY_TYPES` variable which special-cases ExplicitlyIndexed and `pd.Index.` + * `ExplicitlyIndexed` class and subclasses should provide `.oindex` and `.vindex` for indexing by `Variable.__getitem__.`: `oindex` and `vindex` were proposed in [NEP21](https://numpy.org/neps/nep-0021-advanced-indexing.html), but have not been implemented yet + * Delete the ExplicitIndexer objects (`BasicIndexer`, `VectorizedIndexer`, `OuterIndexer`) + * Remove explicit support for `pd.Index`. When provided with a `pd.Index` object, Variable will coerce to an array using `np.array(pd.Index)`. For Xarray's purposes, Xarray can use `as_variable` to explicitly wrap these in PandasIndexingAdapter and pass them to `Variable.__init__`. +3. Define a minimal variable interface that the rest of Xarray can use: + 1. `dims`: tuple of dimension names + 2. `data`: numpy/dask/duck arrays` + 3. `attrs``: dictionary of attributes + +4. Implement basic functions & methods for manipulating these objects. These methods will be a cleaned-up subset (for now) of functionality on xarray.Variable, with adaptations inspired by the [Python array API](https://data-apis.org/array-api/2022.12/API_specification/index.html). +5. Existing Variable structures + 1. Keep Variable object which subclasses the new structure that adds the `.encoding` attribute and potentially other methods needed for easy refactoring. + 2. IndexVariable will remain in xarray.core.variable and subclass the new named-array data structure pending future deletion. +6. Docstrings and user-facing APIs will need to be updated to reflect the changed methods on Variable objects. + +Further implementation details are in Appendix: [Implementation Details](#appendix-implementation-details). + +## Plan for decoupling lazy indexing functionality from NamedArray + +Today's implementation Xarray's lazy indexing functionality uses three private objects: `*Indexer`, `*IndexingAdapter`, `*Array`. +These objects are needed for two reason: +1. We need to translate from Xarray (NamedArray) indexing rules to bare array indexing rules. + - `*Indexer` objects track the type of indexing - basic, orthogonal, vectorized +2. Not all arrays support the same indexing rules, so we need `*Indexing` adapters + 1. Indexing Adapters today implement `__getitem__` and use type of `*Indexer` object to do appropriate conversions. +3. We also want to support lazy indexing of on-disk arrays. + 1. These again support different types of indexing, so we have `explicit_indexing_adapter` that understands `*Indexer` objects. + +### Goals +1. We would like to keep the lazy indexing array objects, and backend array objects within Xarray. Thus NamedArray cannot treat these objects specially. +2. A key source of confusion (and coupling) is that both lazy indexing arrays and indexing adapters, both handle Indexer objects, and both subclass `ExplicitlyIndexedNDArrayMixin`. These are however conceptually different. + +### Proposal + +1. The `NumpyIndexingAdapter`, `DaskIndexingAdapter`, and `ArrayApiIndexingAdapter` classes will need to migrate to Named Array project since we will want to support indexing of numpy, dask, and array-API arrays appropriately. +2. The `as_indexable` function which wraps an array with the appropriate adapter will also migrate over to named array. +3. Lazy indexing arrays will implement `__getitem__` for basic indexing, `.oindex` for orthogonal indexing, and `.vindex` for vectorized indexing. +4. IndexingAdapter classes will similarly implement `__getitem__`, `oindex`, and `vindex`. +5. `NamedArray.__getitem__` (and `__setitem__`) will still use `*Indexer` objects internally (for e.g. in `NamedArray._broadcast_indexes`), but use `.oindex`, `.vindex` on the underlying indexing adapters. +6. We will move the `*Indexer` and `*IndexingAdapter` classes to Named Array. These will be considered private in the long-term. +7. `as_indexable` will no longer special case `ExplicitlyIndexed` objects (we can special case a new `IndexingAdapter` mixin class that will be private to NamedArray). To handle Xarray's lazy indexing arrays, we will introduce a new `ExplicitIndexingAdapter` which will wrap any array with any of `.oindex` of `.vindex` implemented. + 1. This will be the last case in the if-chain that is, we will try to wrap with all other `IndexingAdapter` objects before using `ExplicitIndexingAdapter` as a fallback. This Adapter will be used for the lazy indexing arrays, and backend arrays. + 2. As with other indexing adapters (point 4 above), this `ExplicitIndexingAdapter` will only implement `__getitem__` and will understand `*Indexer` objects. +8. For backwards compatibility with external backends, we will have to gracefully deprecate `indexing.explicit_indexing_adapter` which translates from Xarray's indexing rules to the indexing supported by the backend. + 1. We could split `explicit_indexing_adapter` in to 3: + - `basic_indexing_adapter`, `outer_indexing_adapter` and `vectorized_indexing_adapter` for public use. + 2. Implement fall back `.oindex`, `.vindex` properties on `BackendArray` base class. These will simply rewrap the `key` tuple with the appropriate `*Indexer` object, and pass it on to `__getitem__` or `__setitem__`. These methods will also raise DeprecationWarning so that external backends will know to migrate to `.oindex`, and `.vindex` over the next year. + +THe most uncertain piece here is maintaining backward compatibility with external backends. We should first migrate a single internal backend, and test out the proposed approach. + +## Project Timeline and Milestones + +We have identified the following milestones for the completion of this project: + +1. **Write and publish a design document**: this document will explain the purpose of named-array, the intended audience, and the features it will provide. It will also describe the architecture of named-array and how it will be implemented. This will ensure early community awareness and engagement in the project to promote subsequent uptake. +2. **Refactor `variable.py` to `base_variable.py`** and remove internal Xarray imports. +3. **Break out the package and create continuous integration infrastructure**: this will entail breaking out the named-array project into a Python package and creating a continuous integration (CI) system. This will help to modularize the code and make it easier to manage. Building a CI system will help ensure that codebase changes do not break existing functionality. +4. Incrementally add new functions & methods to the new package, ported from xarray. This will start to make named-array useful on its own. +5. Refactor the existing Xarray codebase to rely on the newly created package (named-array): This will help to demonstrate the usefulness of the new package, and also provide an example for others who may want to use it. +6. Expand tests, add documentation, and write a blog post: expanding the test suite will help to ensure that the code is reliable and that changes do not introduce bugs. Adding documentation will make it easier for others to understand and use the project. +7. Finally, we will write a series of blog posts on [xarray.dev](https://xarray.dev/) to promote the project and attract more contributors. + * Toward the end of the process, write a few blog posts that demonstrate the use of the newly available data structure + * pick the same example applications used by other implementations/applications (e.g. Pytorch, sklearn, and Levanter) to show how it can work. + +## Related Work + +1. [GitHub - deepmind/graphcast](https://github.com/deepmind/graphcast) +2. [Getting Started — LArray 0.34 documentation](https://larray.readthedocs.io/en/stable/tutorial/getting_started.html) +3. [Levanter — Legible, Scalable, Reproducible Foundation Models with JAX](https://crfm.stanford.edu/2023/06/16/levanter-1_0-release.html) +4. [google/xarray-tensorstore](https://github.com/google/xarray-tensorstore) +5. [State of Torch Named Tensors · Issue #60832 · pytorch/pytorch · GitHub](https://github.com/pytorch/pytorch/issues/60832) + * Incomplete support: Many primitive operations result in errors, making it difficult to use NamedTensors in Practice. Users often have to resort to removing the names from tensors to avoid these errors. + * Lack of active development: the development of the NamedTensor feature in PyTorch is not currently active due a lack of bandwidth for resolving ambiguities in the design. + * Usability issues: the current form of NamedTensor is not user-friendly and sometimes raises errors, making it difficult for users to incorporate NamedTensors into their workflows. +6. [Scikit-learn Enhancement Proposals (SLEPs) 8, 12, 14](https://github.com/scikit-learn/enhancement_proposals/pull/18) + * Some of the key points and limitations discussed in these proposals are: + * Inconsistency in feature name handling: Scikit-learn currently lacks a consistent and comprehensive way to handle and propagate feature names through its pipelines and estimators ([SLEP 8](https://github.com/scikit-learn/enhancement_proposals/pull/18),[SLEP 12](https://scikit-learn-enhancement-proposals.readthedocs.io/en/latest/slep012/proposal.html)). + * Memory intensive for large feature sets: storing and propagating feature names can be memory intensive, particularly in cases where the entire "dictionary" becomes the features, such as in NLP use cases ([SLEP 8](https://github.com/scikit-learn/enhancement_proposals/pull/18),[GitHub issue #35](https://github.com/scikit-learn/enhancement_proposals/issues/35)) + * Sparse matrices: sparse data structures present a challenge for feature name propagation. For instance, the sparse data structure functionality in Pandas 1.0 only supports converting directly to the coordinate format (COO), which can be an issue with transformers such as the OneHotEncoder.transform that has been optimized to construct a CSR matrix ([SLEP 14](https://scikit-learn-enhancement-proposals.readthedocs.io/en/latest/slep014/proposal.html)) + * New Data structures: the introduction of new data structures, such as "InputArray" or "DataArray" could lead to more burden for third-party estimator maintainers and increase the learning curve for users. Xarray's "DataArray" is mentioned as a potential alternative, but the proposal mentions that the conversion from a Pandas dataframe to a Dataset is not lossless ([SLEP 12](https://scikit-learn-enhancement-proposals.readthedocs.io/en/latest/slep012/proposal.html),[SLEP 14](https://scikit-learn-enhancement-proposals.readthedocs.io/en/latest/slep014/proposal.html),[GitHub issue #35](https://github.com/scikit-learn/enhancement_proposals/issues/35)). + * Dependency on other libraries: solutions that involve using Xarray and/or Pandas to handle feature names come with the challenge of managing dependencies. While a soft dependency approach is suggested, this means users would be able to have/enable the feature only if they have the dependency installed. Xarra-lite's integration with other scientific Python libraries could potentially help with this issue ([GitHub issue #35](https://github.com/scikit-learn/enhancement_proposals/issues/35)). + +## References and Previous Discussion + +* [[Proposal] Expose Variable without Pandas dependency · Issue #3981 · pydata/xarray · GitHub](https://github.com/pydata/xarray/issues/3981) +* [https://github.com/pydata/xarray/issues/3981#issuecomment-985051449](https://github.com/pydata/xarray/issues/3981#issuecomment-985051449) +* [Lazy indexing arrays as a stand-alone package · Issue #5081 · pydata/xarray · GitHub](https://github.com/pydata/xarray/issues/5081) + +### Appendix: Engagement with the Community + +We plan to publicize this document on : + +* [x] `Xarray dev call` +* [ ] `Scientific Python discourse` +* [ ] `Xarray Github` +* [ ] `Twitter` +* [ ] `Respond to NamedTensor and Scikit-Learn issues?` +* [ ] `Pangeo Discourse` +* [ ] `Numpy, SciPy email lists?` +* [ ] `Xarray blog` + +Additionally, We plan on writing a series of blog posts to effectively showcase the implementation and potential of the newly available functionality. To illustrate this, we will use the same example applications as other established libraries (such as Pytorch, sklearn), providing practical demonstrations of how these new data structures can be leveraged. + +### Appendix: API Surface + +Questions: + +1. Document Xarray indexing rules +2. Document use of .oindex and .vindex protocols +3. Do we use `.mean` and `.nanmean` or `.mean(skipna=...)`? + * Default behavior in named-array should mirror NumPy / the array API standard, not pandas. + * nanmean is not (yet) in the [array API](https://github.com/pydata/xarray/pull/7424#issuecomment-1373979208). There are a handful of other key functions (e.g., median) that are are also missing. I think that should be OK, as long as what we support is a strict superset of the array API. +4. What methods need to be exposed on Variable? + * `Variable.concat` classmethod: create two functions, one as the equivalent of `np.stack` and other for `np.concat` + * `.rolling_window` and `.coarsen_reshape` ? + * `named-array.apply_ufunc`: used in astype, clip, quantile, isnull, notnull` + +#### methods to be preserved from xarray.Variable + +```python +# Sorting + Variable.argsort + Variable.searchsorted + +# NaN handling + Variable.fillna + Variable.isnull + Variable.notnull + +# Lazy data handling + Variable.chunk # Could instead have accessor interface and recommend users use `Variable.dask.chunk` and `Variable.cubed.chunk`? + Variable.to_numpy() + Variable.as_numpy() + +# Xarray-specific + Variable.get_axis_num + Variable.isel + Variable.to_dict + +# Reductions + Variable.reduce + Variable.all + Variable.any + Variable.argmax + Variable.argmin + Variable.count + Variable.max + Variable.mean + Variable.median + Variable.min + Variable.prod + Variable.quantile + Variable.std + Variable.sum + Variable.var + +# Accumulate + Variable.cumprod + Variable.cumsum + +# numpy-like Methods + Variable.astype + Variable.copy + Variable.clip + Variable.round + Variable.item + Variable.where + +# Reordering/Reshaping + Variable.squeeze + Variable.pad + Variable.roll + Variable.shift + +``` + +#### methods to be renamed from xarray.Variable + +```python +# Xarray-specific + Variable.concat # create two functions, one as the equivalent of `np.stack` and other for `np.concat` + + # Given how niche these are, these would be better as functions than methods. + # We could also keep these in Xarray, at least for now. If we don't think people will use functionality outside of Xarray it probably is not worth the trouble of porting it (including documentation, etc). + Variable.coarsen # This should probably be called something like coarsen_reduce. + Variable.coarsen_reshape + Variable.rolling_window + + Variable.set_dims # split this into broadcas_to and expand_dims + + +# Reordering/Reshaping + Variable.stack # To avoid confusion with np.stack, let's call this stack_dims. + Variable.transpose # Could consider calling this permute_dims, like the [array API standard](https://data-apis.org/array-api/2022.12/API_specification/manipulation_functions.html#objects-in-api) + Variable.unstack # Likewise, maybe call this unstack_dims? +``` + +#### methods to be removed from xarray.Variable + +```python +# Testing + Variable.broadcast_equals + Variable.equals + Variable.identical + Variable.no_conflicts + +# Lazy data handling + Variable.compute # We can probably omit this method for now, too, given that dask.compute() uses a protocol. The other concern is that different array libraries have different notions of "compute" and this one is rather Dask specific, including conversion from Dask to NumPy arrays. For example, in JAX every operation executes eagerly, but in a non-blocking fashion, and you need to call jax.block_until_ready() to ensure computation is finished. + Variable.load # Could remove? compute vs load is a common source of confusion. + +# Xarray-specific + Variable.to_index + Variable.to_index_variable + Variable.to_variable + Variable.to_base_variable + Variable.to_coord + + Variable.rank # Uses bottleneck. Delete? Could use https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.rankdata.html instead + + +# numpy-like Methods + Variable.conjugate # .conj is enough + Variable.__array_wrap__ # This is a very old NumPy protocol for duck arrays. We don't need it now that we have `__array_ufunc__` and `__array_function__` + +# Encoding + Variable.reset_encoding + +``` + +#### Attributes to be preserved from xarray.Variable + +```python +# Properties + Variable.attrs + Variable.chunks + Variable.data + Variable.dims + Variable.dtype + + Variable.nbytes + Variable.ndim + Variable.shape + Variable.size + Variable.sizes + + Variable.T + Variable.real + Variable.imag + Variable.conj +``` + +#### Attributes to be renamed from xarray.Variable + +```python +``` + +#### Attributes to be removed from xarray.Variable + +```python + + Variable.values # Probably also remove -- this is a legacy from before Xarray supported dask arrays. ".data" is enough. + +# Encoding + Variable.encoding + +``` + +### Appendix: Implementation Details + +* Merge in VariableArithmetic's parent classes: AbstractArray, NdimSizeLenMixin with the new data structure.. + +```python +class VariableArithmetic( + ImplementsArrayReduce, + IncludeReduceMethods, + IncludeCumMethods, + IncludeNumpySameMethods, + SupportsArithmetic, + VariableOpsMixin, +): + __slots__ = () + # prioritize our operations over those of numpy.ndarray (priority=0) + __array_priority__ = 50 + +``` + +* Move over `_typed_ops.VariableOpsMixin` +* Build a list of utility functions used elsewhere : Which of these should become public API? + * `broadcast_variables`: `dataset.py`, `dataarray.py`,`missing.py` + * This could be just called "broadcast" in named-array. + * `Variable._getitem_with_mask` : `alignment.py` + * keep this method/function as private and inside Xarray. +* The Variable constructor will need to be rewritten to no longer accept tuples, encodings, etc. These details should be handled at the Xarray data structure level. +* What happens to `duck_array_ops?` +* What about Variable.chunk and "chunk managers"? + * Could this functionality be left in Xarray proper for now? Alternative array types like JAX also have some notion of "chunks" for parallel arrays, but the details differ in a number of ways from the Dask/Cubed. + * Perhaps variable.chunk/load methods should become functions defined in xarray that convert Variable objects. This is easy so long as xarray can reach in and replace .data + +* Utility functions like `as_variable` should be moved out of `base_variable.py` so they can convert BaseVariable objects to/from DataArray or Dataset containing explicitly indexed arrays. diff --git a/test/fixtures/whole_applications/xarray/doc/Makefile b/test/fixtures/whole_applications/xarray/doc/Makefile new file mode 100644 index 0000000..8b08d3a --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/Makefile @@ -0,0 +1,248 @@ +# Makefile for Sphinx documentation +# + +# You can set these variables from the command line. +SPHINXOPTS = +SPHINXBUILD = sphinx-build +SPHINXATUOBUILD = sphinx-autobuild +PAPER = +BUILDDIR = _build + +# Internal variables. +PAPEROPT_a4 = -D latex_paper_size=a4 +PAPEROPT_letter = -D latex_paper_size=letter +ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . +# the i18n builder cannot share the environment and doctrees with the others +I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . + +.PHONY: help +help: + @echo "Please use \`make ' where is one of" + @echo " html to make standalone HTML files" + @echo " rtdhtml Build html using same settings used on ReadtheDocs" + @echo " livehtml Make standalone HTML files and rebuild the documentation when a change is detected. Also includes a livereload enabled web server" + @echo " dirhtml to make HTML files named index.html in directories" + @echo " singlehtml to make a single large HTML file" + @echo " pickle to make pickle files" + @echo " json to make JSON files" + @echo " htmlhelp to make HTML files and a HTML help project" + @echo " qthelp to make HTML files and a qthelp project" + @echo " applehelp to make an Apple Help Book" + @echo " devhelp to make HTML files and a Devhelp project" + @echo " epub to make an epub" + @echo " epub3 to make an epub3" + @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" + @echo " latexpdf to make LaTeX files and run them through pdflatex" + @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" + @echo " text to make text files" + @echo " man to make manual pages" + @echo " texinfo to make Texinfo files" + @echo " info to make Texinfo files and run them through makeinfo" + @echo " gettext to make PO message catalogs" + @echo " changes to make an overview of all changed/added/deprecated items" + @echo " xml to make Docutils-native XML files" + @echo " pseudoxml to make pseudoxml-XML files for display purposes" + @echo " linkcheck to check all external links for integrity" + @echo " doctest to run all doctests embedded in the documentation (if enabled)" + @echo " coverage to run coverage check of the documentation (if enabled)" + @echo " dummy to check syntax errors of document sources" + +.PHONY: clean +clean: + rm -rf $(BUILDDIR)/* + rm -rf generated/* + rm -rf auto_gallery/ + +.PHONY: html +html: + $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html + @echo + @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." + +.PHONY: rtdhtml +rtdhtml: + $(SPHINXBUILD) -T -j auto -E -W --keep-going -b html -d $(BUILDDIR)/doctrees -D language=en . $(BUILDDIR)/html + @echo + @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." + + +.PHONY: livehtml +livehtml: + # @echo "$(SPHINXATUOBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html" + $(SPHINXATUOBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html + +.PHONY: dirhtml +dirhtml: + $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml + @echo + @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." + +.PHONY: singlehtml +singlehtml: + $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml + @echo + @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." + +.PHONY: html-noplot +html-noplot: + $(SPHINXBUILD) -D plot_gallery=0 -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html + @echo + @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." + +.PHONY: pickle +pickle: + $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle + @echo + @echo "Build finished; now you can process the pickle files." + +.PHONY: json +json: + $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json + @echo + @echo "Build finished; now you can process the JSON files." + +.PHONY: htmlhelp +htmlhelp: + $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp + @echo + @echo "Build finished; now you can run HTML Help Workshop with the" \ + ".hhp project file in $(BUILDDIR)/htmlhelp." + +.PHONY: qthelp +qthelp: + $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp + @echo + @echo "Build finished; now you can run "qcollectiongenerator" with the" \ + ".qhcp project file in $(BUILDDIR)/qthelp, like this:" + @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/xarray.qhcp" + @echo "To view the help file:" + @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/xarray.qhc" + +.PHONY: applehelp +applehelp: + $(SPHINXBUILD) -b applehelp $(ALLSPHINXOPTS) $(BUILDDIR)/applehelp + @echo + @echo "Build finished. The help book is in $(BUILDDIR)/applehelp." + @echo "N.B. You won't be able to view it unless you put it in" \ + "~/Library/Documentation/Help or install it in your application" \ + "bundle." + +.PHONY: devhelp +devhelp: + $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp + @echo + @echo "Build finished." + @echo "To view the help file:" + @echo "# mkdir -p $$HOME/.local/share/devhelp/xarray" + @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/xarray" + @echo "# devhelp" + +.PHONY: epub +epub: + $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub + @echo + @echo "Build finished. The epub file is in $(BUILDDIR)/epub." + +.PHONY: epub3 +epub3: + $(SPHINXBUILD) -b epub3 $(ALLSPHINXOPTS) $(BUILDDIR)/epub3 + @echo + @echo "Build finished. The epub3 file is in $(BUILDDIR)/epub3." + +.PHONY: latex +latex: + $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex + @echo + @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." + @echo "Run \`make' in that directory to run these through (pdf)latex" \ + "(use \`make latexpdf' here to do that automatically)." + +.PHONY: latexpdf +latexpdf: + $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex + @echo "Running LaTeX files through pdflatex..." + $(MAKE) -C $(BUILDDIR)/latex all-pdf + @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." + +.PHONY: latexpdfja +latexpdfja: + $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex + @echo "Running LaTeX files through platex and dvipdfmx..." + $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja + @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." + +.PHONY: text +text: + $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text + @echo + @echo "Build finished. The text files are in $(BUILDDIR)/text." + +.PHONY: man +man: + $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man + @echo + @echo "Build finished. The manual pages are in $(BUILDDIR)/man." + +.PHONY: texinfo +texinfo: + $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo + @echo + @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." + @echo "Run \`make' in that directory to run these through makeinfo" \ + "(use \`make info' here to do that automatically)." + +.PHONY: info +info: + $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo + @echo "Running Texinfo files through makeinfo..." + make -C $(BUILDDIR)/texinfo info + @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." + +.PHONY: gettext +gettext: + $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale + @echo + @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." + +.PHONY: changes +changes: + $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes + @echo + @echo "The overview file is in $(BUILDDIR)/changes." + +.PHONY: linkcheck +linkcheck: + $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck + @echo + @echo "Link check complete; look for any errors in the above output " \ + "or in $(BUILDDIR)/linkcheck/output.txt." + +.PHONY: doctest +doctest: + $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest + @echo "Testing of doctests in the sources finished, look at the " \ + "results in $(BUILDDIR)/doctest/output.txt." + +.PHONY: coverage +coverage: + $(SPHINXBUILD) -b coverage $(ALLSPHINXOPTS) $(BUILDDIR)/coverage + @echo "Testing of coverage in the sources finished, look at the " \ + "results in $(BUILDDIR)/coverage/python.txt." + +.PHONY: xml +xml: + $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml + @echo + @echo "Build finished. The XML files are in $(BUILDDIR)/xml." + +.PHONY: pseudoxml +pseudoxml: + $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml + @echo + @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." + +.PHONY: dummy +dummy: + $(SPHINXBUILD) -b dummy $(ALLSPHINXOPTS) $(BUILDDIR)/dummy + @echo + @echo "Build finished. Dummy builder generates no files." diff --git a/test/fixtures/whole_applications/xarray/doc/README.rst b/test/fixtures/whole_applications/xarray/doc/README.rst new file mode 100644 index 0000000..c1b6c63 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/README.rst @@ -0,0 +1,6 @@ +:orphan: + +xarray +------ + +You can find information about building the docs at our `Contributing page `_. diff --git a/test/fixtures/whole_applications/xarray/doc/_static/.gitignore b/test/fixtures/whole_applications/xarray/doc/_static/.gitignore new file mode 100644 index 0000000..5ea6e27 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/_static/.gitignore @@ -0,0 +1,5 @@ +examples*.png +*.log +*.pdf +*.fbd_latexmk +*.aux diff --git a/test/fixtures/whole_applications/xarray/doc/_static/advanced_selection_interpolation.svg b/test/fixtures/whole_applications/xarray/doc/_static/advanced_selection_interpolation.svg new file mode 100644 index 0000000..096563a --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/_static/advanced_selection_interpolation.svg @@ -0,0 +1,731 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + image/svg+xml + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + y + x + + + + + z + + + + + + + + + + + + + + + + + + + + + + + + + + + + y + x + + + + + z + + + + + + + + + Advanced indexing + Advanced interpolation + + + + diff --git a/test/fixtures/whole_applications/xarray/doc/_static/ci.png b/test/fixtures/whole_applications/xarray/doc/_static/ci.png new file mode 100644 index 0000000..aec900b Binary files /dev/null and b/test/fixtures/whole_applications/xarray/doc/_static/ci.png differ diff --git a/test/fixtures/whole_applications/xarray/doc/_static/dask_array.png b/test/fixtures/whole_applications/xarray/doc/_static/dask_array.png new file mode 100644 index 0000000..7ddb6e4 Binary files /dev/null and b/test/fixtures/whole_applications/xarray/doc/_static/dask_array.png differ diff --git a/test/fixtures/whole_applications/xarray/doc/_static/dataset-diagram.png b/test/fixtures/whole_applications/xarray/doc/_static/dataset-diagram.png new file mode 100644 index 0000000..be9aa8d Binary files /dev/null and b/test/fixtures/whole_applications/xarray/doc/_static/dataset-diagram.png differ diff --git a/test/fixtures/whole_applications/xarray/doc/_static/index_api.svg b/test/fixtures/whole_applications/xarray/doc/_static/index_api.svg new file mode 100644 index 0000000..87013d2 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/_static/index_api.svg @@ -0,0 +1,97 @@ + + + + + + + + + + image/svg+xml + + + + + + + + + + + + + + + + + diff --git a/test/fixtures/whole_applications/xarray/doc/_static/index_contribute.svg b/test/fixtures/whole_applications/xarray/doc/_static/index_contribute.svg new file mode 100644 index 0000000..399f1d7 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/_static/index_contribute.svg @@ -0,0 +1,76 @@ + + + + + + + + + + image/svg+xml + + + + + + + + + + + + diff --git a/test/fixtures/whole_applications/xarray/doc/_static/index_getting_started.svg b/test/fixtures/whole_applications/xarray/doc/_static/index_getting_started.svg new file mode 100644 index 0000000..d1c7b08 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/_static/index_getting_started.svg @@ -0,0 +1,66 @@ + + + + + + + + + + image/svg+xml + + + + + + + + + diff --git a/test/fixtures/whole_applications/xarray/doc/_static/index_user_guide.svg b/test/fixtures/whole_applications/xarray/doc/_static/index_user_guide.svg new file mode 100644 index 0000000..bff2482 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/_static/index_user_guide.svg @@ -0,0 +1,67 @@ + + + + + + + + + + image/svg+xml + + + + + + + + + diff --git a/test/fixtures/whole_applications/xarray/doc/_static/logos/Xarray_Icon_Final.png b/test/fixtures/whole_applications/xarray/doc/_static/logos/Xarray_Icon_Final.png new file mode 100644 index 0000000..6c0bae4 Binary files /dev/null and b/test/fixtures/whole_applications/xarray/doc/_static/logos/Xarray_Icon_Final.png differ diff --git a/test/fixtures/whole_applications/xarray/doc/_static/logos/Xarray_Icon_Final.svg b/test/fixtures/whole_applications/xarray/doc/_static/logos/Xarray_Icon_Final.svg new file mode 100644 index 0000000..689b207 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/_static/logos/Xarray_Icon_Final.svg @@ -0,0 +1,29 @@ + + + + + + + + + + + + + + + + + + + diff --git a/test/fixtures/whole_applications/xarray/doc/_static/logos/Xarray_Logo_FullColor_InverseRGB_Final.png b/test/fixtures/whole_applications/xarray/doc/_static/logos/Xarray_Logo_FullColor_InverseRGB_Final.png new file mode 100644 index 0000000..68701ee Binary files /dev/null and b/test/fixtures/whole_applications/xarray/doc/_static/logos/Xarray_Logo_FullColor_InverseRGB_Final.png differ diff --git a/test/fixtures/whole_applications/xarray/doc/_static/logos/Xarray_Logo_FullColor_InverseRGB_Final.svg b/test/fixtures/whole_applications/xarray/doc/_static/logos/Xarray_Logo_FullColor_InverseRGB_Final.svg new file mode 100644 index 0000000..a803e93 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/_static/logos/Xarray_Logo_FullColor_InverseRGB_Final.svg @@ -0,0 +1,54 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/test/fixtures/whole_applications/xarray/doc/_static/logos/Xarray_Logo_RGB_Final.png b/test/fixtures/whole_applications/xarray/doc/_static/logos/Xarray_Logo_RGB_Final.png new file mode 100644 index 0000000..823ff8d Binary files /dev/null and b/test/fixtures/whole_applications/xarray/doc/_static/logos/Xarray_Logo_RGB_Final.png differ diff --git a/test/fixtures/whole_applications/xarray/doc/_static/logos/Xarray_Logo_RGB_Final.svg b/test/fixtures/whole_applications/xarray/doc/_static/logos/Xarray_Logo_RGB_Final.svg new file mode 100644 index 0000000..86e1b48 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/_static/logos/Xarray_Logo_RGB_Final.svg @@ -0,0 +1,54 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/test/fixtures/whole_applications/xarray/doc/_static/numfocus_logo.png b/test/fixtures/whole_applications/xarray/doc/_static/numfocus_logo.png new file mode 100644 index 0000000..af3c842 Binary files /dev/null and b/test/fixtures/whole_applications/xarray/doc/_static/numfocus_logo.png differ diff --git a/test/fixtures/whole_applications/xarray/doc/_static/opendap-prism-tmax.png b/test/fixtures/whole_applications/xarray/doc/_static/opendap-prism-tmax.png new file mode 100644 index 0000000..7ff778a Binary files /dev/null and b/test/fixtures/whole_applications/xarray/doc/_static/opendap-prism-tmax.png differ diff --git a/test/fixtures/whole_applications/xarray/doc/_static/style.css b/test/fixtures/whole_applications/xarray/doc/_static/style.css new file mode 100644 index 0000000..bd0b13c --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/_static/style.css @@ -0,0 +1,270 @@ +table.colwidths-given { + table-layout: fixed; + width: 100%; +} +table.docutils td { + white-space: unset; + word-wrap: break-word; +} + +.bd-header-announcement { + background-color: var(--pst-color-info-bg); +} + +/* Reduce left and right margins */ + +.container, .container-lg, .container-md, .container-sm, .container-xl { + max-width: 1350px !important; +} + + +/* Copied from +https://github.com/bokeh/bokeh/blob/branch-2.4/sphinx/source/bokeh/static/custom.css +*/ + +:root { + /* Logo image height + all the paddings/margins make the navbar height. */ + --navbar-height: calc(30px + 0.3125rem * 2 + 0.5rem * 2); +} + +.bd-search { + position: relative; + padding-bottom: 20px; +} + +@media (min-width: 768px) { + .search-front-page { + width: 50%; + } +} + +/* minimal copy paste from bootstrap docs css to get sidebars working */ + +.bd-toc { + -ms-flex-order: 2; + order: 2; + padding-top: 1.5rem; + padding-bottom: 1.5rem; + /* font-size: 0.875rem; */ + /* add scrolling sidebar */ + height: calc(100vh - 2rem); + overflow-y: auto; +} + +@supports ((position: -webkit-sticky) or (position: sticky)) { + .bd-toc { + position: -webkit-sticky; + position: sticky; + top: 4rem; + height: calc(100vh - 4rem); + overflow-y: auto; + } +} + +.section-nav { + padding-left: 0; + border-left: 1px solid #eee; + border-bottom: none; +} + +.section-nav ul { + padding-left: 1rem; +} + +.toc-entry { + display: block; +} + +.toc-entry a { + display: block; + padding: 0.125rem 1.5rem; + color: #77757a; +} + +.toc-entry a:hover { + color: rgba(0, 0, 0, 0.85); + text-decoration: none; +} + +.bd-sidebar { + -ms-flex-order: 0; + order: 0; + border-bottom: 1px solid rgba(0, 0, 0, 0.1); +} + +@media (min-width: 768px) { + .bd-sidebar { + border-right: 1px solid rgba(0, 0, 0, 0.1); + } + @supports ((position: -webkit-sticky) or (position: sticky)) { + .bd-sidebar { + position: -webkit-sticky; + position: sticky; + top: var(--navbar-height); + z-index: 1000; + height: calc(100vh - var(--navbar-height)); + } + } +} + +@media (min-width: 1200px) { + .bd-sidebar { + -ms-flex: 0 1 320px; + flex: 0 1 320px; + } +} + +.bd-links { + padding-top: 1rem; + padding-bottom: 1rem; + margin-right: -15px; + margin-left: -15px; +} + +@media (min-width: 768px) { + @supports ((position: -webkit-sticky) or (position: sticky)) { + .bd-links { + max-height: calc(100vh - 9rem); + overflow-y: auto; + } + } +} + +@media (min-width: 768px) { + .bd-links { + display: block !important; + } +} + +.bd-sidenav { + display: none; +} + +.bd-toc-link { + display: block; + padding: 0.25rem 1.5rem; + font-weight: 400; + color: rgba(0, 0, 0, 0.65); +} + +.bd-toc-link:hover { + color: rgba(0, 0, 0, 0.85); + text-decoration: none; +} + +.bd-toc-item.active { + margin-bottom: 1rem; +} + +.bd-toc-item.active:not(:first-child) { + margin-top: 1rem; +} + +.bd-toc-item.active > .bd-toc-link { + color: rgba(0, 0, 0, 0.85); +} + +.bd-toc-item.active > .bd-toc-link:hover { + background-color: transparent; +} + +.bd-toc-item.active > .bd-sidenav { + display: block; +} + +.bd-sidebar .nav > li > a { + display: block; + padding: 0.25rem 1.5rem; + font-size: 90%; +} + +.bd-sidebar .nav > li > a:hover { + text-decoration: none; + background-color: transparent; +} + +.bd-sidebar .nav > .active > a, +.bd-sidebar .nav > .active:hover > a { + font-weight: 400; + /* adjusted from original + color: rgba(0, 0, 0, 0.85); + background-color: transparent; */ +} + +.bd-sidebar .nav > li > ul { + list-style: none; + padding: 0.25rem 1.5rem; +} + +.bd-sidebar .nav > li > ul > li > a { + display: block; + padding: 0.25rem 1.5rem; + font-size: 90%; +} + +.bd-sidebar .nav > li > ul > .active > a, +.bd-sidebar .nav > li > ul > .active:hover > a { + font-weight: 400; +} + +dt:target { + background-color: initial; +} + +/* Offsetting anchored elements within the main content to adjust for fixed header + https://github.com/pandas-dev/pandas-sphinx-theme/issues/6 */ +main *:target::before { + display: block; + content: ''; + height: var(--navbar-height); + margin-top: calc(-1 * var(--navbar-height)); +} + +body { + width: 100%; +} + +/* adjust toc font sizes to improve overview */ +.toc-h2 { + font-size: 0.85rem; +} + +.toc-h3 { + font-size: 0.75rem; +} + +.toc-h4 { + font-size: 0.65rem; +} + +.toc-entry > .nav-link.active { + font-weight: 400; + color: #542437; + background-color: transparent; + border-left: 2px solid #563d7c; +} + +.nav-link:hover { + border-style: none; +} + +/* Collapsing of the TOC sidebar while scrolling */ + +/* Nav: hide second level (shown on .active) */ +.bd-toc .nav .nav { + display: none; +} + +.bd-toc .nav > .active > ul { + display: block; +} + +/* Main index page overview cards */ + +.sd-card-img-top { + width: 33% !important; + display: block; + margin-left: auto; + margin-right: auto; + margin-top: 10px; +} diff --git a/test/fixtures/whole_applications/xarray/doc/_static/thumbnails/ERA5-GRIB-example.png b/test/fixtures/whole_applications/xarray/doc/_static/thumbnails/ERA5-GRIB-example.png new file mode 100644 index 0000000..412dd28 Binary files /dev/null and b/test/fixtures/whole_applications/xarray/doc/_static/thumbnails/ERA5-GRIB-example.png differ diff --git a/test/fixtures/whole_applications/xarray/doc/_static/thumbnails/ROMS_ocean_model.png b/test/fixtures/whole_applications/xarray/doc/_static/thumbnails/ROMS_ocean_model.png new file mode 100644 index 0000000..9333335 Binary files /dev/null and b/test/fixtures/whole_applications/xarray/doc/_static/thumbnails/ROMS_ocean_model.png differ diff --git a/test/fixtures/whole_applications/xarray/doc/_static/thumbnails/area_weighted_temperature.png b/test/fixtures/whole_applications/xarray/doc/_static/thumbnails/area_weighted_temperature.png new file mode 100644 index 0000000..7d3604d Binary files /dev/null and b/test/fixtures/whole_applications/xarray/doc/_static/thumbnails/area_weighted_temperature.png differ diff --git a/test/fixtures/whole_applications/xarray/doc/_static/thumbnails/monthly-means.png b/test/fixtures/whole_applications/xarray/doc/_static/thumbnails/monthly-means.png new file mode 100644 index 0000000..da56918 Binary files /dev/null and b/test/fixtures/whole_applications/xarray/doc/_static/thumbnails/monthly-means.png differ diff --git a/test/fixtures/whole_applications/xarray/doc/_static/thumbnails/multidimensional-coords.png b/test/fixtures/whole_applications/xarray/doc/_static/thumbnails/multidimensional-coords.png new file mode 100644 index 0000000..b0d893d Binary files /dev/null and b/test/fixtures/whole_applications/xarray/doc/_static/thumbnails/multidimensional-coords.png differ diff --git a/test/fixtures/whole_applications/xarray/doc/_static/thumbnails/toy-weather-data.png b/test/fixtures/whole_applications/xarray/doc/_static/thumbnails/toy-weather-data.png new file mode 100644 index 0000000..64ac0a4 Binary files /dev/null and b/test/fixtures/whole_applications/xarray/doc/_static/thumbnails/toy-weather-data.png differ diff --git a/test/fixtures/whole_applications/xarray/doc/_static/thumbnails/visualization_gallery.png b/test/fixtures/whole_applications/xarray/doc/_static/thumbnails/visualization_gallery.png new file mode 100644 index 0000000..9e6c243 Binary files /dev/null and b/test/fixtures/whole_applications/xarray/doc/_static/thumbnails/visualization_gallery.png differ diff --git a/test/fixtures/whole_applications/xarray/doc/_static/view-docs.png b/test/fixtures/whole_applications/xarray/doc/_static/view-docs.png new file mode 100644 index 0000000..2e79ff6 Binary files /dev/null and b/test/fixtures/whole_applications/xarray/doc/_static/view-docs.png differ diff --git a/test/fixtures/whole_applications/xarray/doc/_templates/autosummary/accessor.rst b/test/fixtures/whole_applications/xarray/doc/_templates/autosummary/accessor.rst new file mode 100644 index 0000000..4ba745c --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/_templates/autosummary/accessor.rst @@ -0,0 +1,6 @@ +{{ fullname }} +{{ underline }} + +.. currentmodule:: {{ module.split('.')[0] }} + +.. autoaccessor:: {{ (module.split('.')[1:] + [objname]) | join('.') }} diff --git a/test/fixtures/whole_applications/xarray/doc/_templates/autosummary/accessor_attribute.rst b/test/fixtures/whole_applications/xarray/doc/_templates/autosummary/accessor_attribute.rst new file mode 100644 index 0000000..b5ad65d --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/_templates/autosummary/accessor_attribute.rst @@ -0,0 +1,6 @@ +{{ fullname }} +{{ underline }} + +.. currentmodule:: {{ module.split('.')[0] }} + +.. autoaccessorattribute:: {{ (module.split('.')[1:] + [objname]) | join('.') }} diff --git a/test/fixtures/whole_applications/xarray/doc/_templates/autosummary/accessor_callable.rst b/test/fixtures/whole_applications/xarray/doc/_templates/autosummary/accessor_callable.rst new file mode 100644 index 0000000..7a33018 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/_templates/autosummary/accessor_callable.rst @@ -0,0 +1,6 @@ +{{ fullname }} +{{ underline }} + +.. currentmodule:: {{ module.split('.')[0] }} + +.. autoaccessorcallable:: {{ (module.split('.')[1:] + [objname]) | join('.') }}.__call__ diff --git a/test/fixtures/whole_applications/xarray/doc/_templates/autosummary/accessor_method.rst b/test/fixtures/whole_applications/xarray/doc/_templates/autosummary/accessor_method.rst new file mode 100644 index 0000000..aefbba6 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/_templates/autosummary/accessor_method.rst @@ -0,0 +1,6 @@ +{{ fullname }} +{{ underline }} + +.. currentmodule:: {{ module.split('.')[0] }} + +.. autoaccessormethod:: {{ (module.split('.')[1:] + [objname]) | join('.') }} diff --git a/test/fixtures/whole_applications/xarray/doc/api-hidden.rst b/test/fixtures/whole_applications/xarray/doc/api-hidden.rst new file mode 100644 index 0000000..d9c8964 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/api-hidden.rst @@ -0,0 +1,695 @@ +.. Generate API reference pages, but don't display these in tables. +.. This extra page is a work around for sphinx not having any support for +.. hiding an autosummary table. + +:orphan: + +.. currentmodule:: xarray + +.. autosummary:: + :toctree: generated/ + + Coordinates.from_pandas_multiindex + Coordinates.get + Coordinates.items + Coordinates.keys + Coordinates.values + Coordinates.dims + Coordinates.dtypes + Coordinates.variables + Coordinates.xindexes + Coordinates.indexes + Coordinates.to_dataset + Coordinates.to_index + Coordinates.update + Coordinates.assign + Coordinates.merge + Coordinates.copy + Coordinates.equals + Coordinates.identical + + core.coordinates.DatasetCoordinates.get + core.coordinates.DatasetCoordinates.items + core.coordinates.DatasetCoordinates.keys + core.coordinates.DatasetCoordinates.values + core.coordinates.DatasetCoordinates.dims + core.coordinates.DatasetCoordinates.dtypes + core.coordinates.DatasetCoordinates.variables + core.coordinates.DatasetCoordinates.xindexes + core.coordinates.DatasetCoordinates.indexes + core.coordinates.DatasetCoordinates.to_dataset + core.coordinates.DatasetCoordinates.to_index + core.coordinates.DatasetCoordinates.update + core.coordinates.DatasetCoordinates.assign + core.coordinates.DatasetCoordinates.merge + core.coordinates.DatasetCoordinates.copy + core.coordinates.DatasetCoordinates.equals + core.coordinates.DatasetCoordinates.identical + + core.rolling.DatasetCoarsen.boundary + core.rolling.DatasetCoarsen.coord_func + core.rolling.DatasetCoarsen.obj + core.rolling.DatasetCoarsen.side + core.rolling.DatasetCoarsen.trim_excess + core.rolling.DatasetCoarsen.windows + + core.rolling.DatasetRolling.center + core.rolling.DatasetRolling.dim + core.rolling.DatasetRolling.min_periods + core.rolling.DatasetRolling.obj + core.rolling.DatasetRolling.rollings + core.rolling.DatasetRolling.window + + core.weighted.DatasetWeighted.obj + core.weighted.DatasetWeighted.weights + + Dataset.load_store + Dataset.dump_to_store + + DataArray.astype + DataArray.item + + core.coordinates.DataArrayCoordinates.get + core.coordinates.DataArrayCoordinates.items + core.coordinates.DataArrayCoordinates.keys + core.coordinates.DataArrayCoordinates.values + core.coordinates.DataArrayCoordinates.dims + core.coordinates.DataArrayCoordinates.dtypes + core.coordinates.DataArrayCoordinates.variables + core.coordinates.DataArrayCoordinates.xindexes + core.coordinates.DataArrayCoordinates.indexes + core.coordinates.DataArrayCoordinates.to_dataset + core.coordinates.DataArrayCoordinates.to_index + core.coordinates.DataArrayCoordinates.update + core.coordinates.DataArrayCoordinates.assign + core.coordinates.DataArrayCoordinates.merge + core.coordinates.DataArrayCoordinates.copy + core.coordinates.DataArrayCoordinates.equals + core.coordinates.DataArrayCoordinates.identical + + core.rolling.DataArrayCoarsen.boundary + core.rolling.DataArrayCoarsen.coord_func + core.rolling.DataArrayCoarsen.obj + core.rolling.DataArrayCoarsen.side + core.rolling.DataArrayCoarsen.trim_excess + core.rolling.DataArrayCoarsen.windows + + core.rolling.DataArrayRolling.center + core.rolling.DataArrayRolling.dim + core.rolling.DataArrayRolling.min_periods + core.rolling.DataArrayRolling.obj + core.rolling.DataArrayRolling.window + core.rolling.DataArrayRolling.window_labels + + core.weighted.DataArrayWeighted.obj + core.weighted.DataArrayWeighted.weights + + core.accessor_dt.DatetimeAccessor.ceil + core.accessor_dt.DatetimeAccessor.floor + core.accessor_dt.DatetimeAccessor.round + core.accessor_dt.DatetimeAccessor.strftime + core.accessor_dt.DatetimeAccessor.calendar + core.accessor_dt.DatetimeAccessor.date + core.accessor_dt.DatetimeAccessor.day + core.accessor_dt.DatetimeAccessor.dayofweek + core.accessor_dt.DatetimeAccessor.dayofyear + core.accessor_dt.DatetimeAccessor.days_in_month + core.accessor_dt.DatetimeAccessor.daysinmonth + core.accessor_dt.DatetimeAccessor.hour + core.accessor_dt.DatetimeAccessor.is_leap_year + core.accessor_dt.DatetimeAccessor.is_month_end + core.accessor_dt.DatetimeAccessor.is_month_start + core.accessor_dt.DatetimeAccessor.is_quarter_end + core.accessor_dt.DatetimeAccessor.is_quarter_start + core.accessor_dt.DatetimeAccessor.is_year_end + core.accessor_dt.DatetimeAccessor.is_year_start + core.accessor_dt.DatetimeAccessor.isocalendar + core.accessor_dt.DatetimeAccessor.microsecond + core.accessor_dt.DatetimeAccessor.minute + core.accessor_dt.DatetimeAccessor.month + core.accessor_dt.DatetimeAccessor.nanosecond + core.accessor_dt.DatetimeAccessor.quarter + core.accessor_dt.DatetimeAccessor.season + core.accessor_dt.DatetimeAccessor.second + core.accessor_dt.DatetimeAccessor.time + core.accessor_dt.DatetimeAccessor.week + core.accessor_dt.DatetimeAccessor.weekday + core.accessor_dt.DatetimeAccessor.weekofyear + core.accessor_dt.DatetimeAccessor.year + + core.accessor_dt.TimedeltaAccessor.ceil + core.accessor_dt.TimedeltaAccessor.floor + core.accessor_dt.TimedeltaAccessor.round + core.accessor_dt.TimedeltaAccessor.days + core.accessor_dt.TimedeltaAccessor.microseconds + core.accessor_dt.TimedeltaAccessor.nanoseconds + core.accessor_dt.TimedeltaAccessor.seconds + + core.accessor_str.StringAccessor.capitalize + core.accessor_str.StringAccessor.casefold + core.accessor_str.StringAccessor.cat + core.accessor_str.StringAccessor.center + core.accessor_str.StringAccessor.contains + core.accessor_str.StringAccessor.count + core.accessor_str.StringAccessor.decode + core.accessor_str.StringAccessor.encode + core.accessor_str.StringAccessor.endswith + core.accessor_str.StringAccessor.extract + core.accessor_str.StringAccessor.extractall + core.accessor_str.StringAccessor.find + core.accessor_str.StringAccessor.findall + core.accessor_str.StringAccessor.format + core.accessor_str.StringAccessor.get + core.accessor_str.StringAccessor.get_dummies + core.accessor_str.StringAccessor.index + core.accessor_str.StringAccessor.isalnum + core.accessor_str.StringAccessor.isalpha + core.accessor_str.StringAccessor.isdecimal + core.accessor_str.StringAccessor.isdigit + core.accessor_str.StringAccessor.islower + core.accessor_str.StringAccessor.isnumeric + core.accessor_str.StringAccessor.isspace + core.accessor_str.StringAccessor.istitle + core.accessor_str.StringAccessor.isupper + core.accessor_str.StringAccessor.join + core.accessor_str.StringAccessor.len + core.accessor_str.StringAccessor.ljust + core.accessor_str.StringAccessor.lower + core.accessor_str.StringAccessor.lstrip + core.accessor_str.StringAccessor.match + core.accessor_str.StringAccessor.normalize + core.accessor_str.StringAccessor.pad + core.accessor_str.StringAccessor.partition + core.accessor_str.StringAccessor.repeat + core.accessor_str.StringAccessor.replace + core.accessor_str.StringAccessor.rfind + core.accessor_str.StringAccessor.rindex + core.accessor_str.StringAccessor.rjust + core.accessor_str.StringAccessor.rpartition + core.accessor_str.StringAccessor.rsplit + core.accessor_str.StringAccessor.rstrip + core.accessor_str.StringAccessor.slice + core.accessor_str.StringAccessor.slice_replace + core.accessor_str.StringAccessor.split + core.accessor_str.StringAccessor.startswith + core.accessor_str.StringAccessor.strip + core.accessor_str.StringAccessor.swapcase + core.accessor_str.StringAccessor.title + core.accessor_str.StringAccessor.translate + core.accessor_str.StringAccessor.upper + core.accessor_str.StringAccessor.wrap + core.accessor_str.StringAccessor.zfill + + Variable.all + Variable.any + Variable.argmax + Variable.argmin + Variable.argsort + Variable.astype + Variable.broadcast_equals + Variable.chunk + Variable.clip + Variable.coarsen + Variable.compute + Variable.concat + Variable.conj + Variable.conjugate + Variable.copy + Variable.count + Variable.cumprod + Variable.cumsum + Variable.equals + Variable.fillna + Variable.get_axis_num + Variable.identical + Variable.isel + Variable.isnull + Variable.item + Variable.load + Variable.max + Variable.mean + Variable.median + Variable.min + Variable.no_conflicts + Variable.notnull + Variable.pad + Variable.prod + Variable.quantile + Variable.rank + Variable.reduce + Variable.roll + Variable.rolling_window + Variable.round + Variable.searchsorted + Variable.set_dims + Variable.shift + Variable.squeeze + Variable.stack + Variable.std + Variable.sum + Variable.to_base_variable + Variable.to_coord + Variable.to_dict + Variable.to_index + Variable.to_index_variable + Variable.to_variable + Variable.transpose + Variable.unstack + Variable.var + Variable.where + Variable.T + Variable.attrs + Variable.chunks + Variable.data + Variable.dims + Variable.dtype + Variable.encoding + Variable.drop_encoding + Variable.imag + Variable.nbytes + Variable.ndim + Variable.real + Variable.shape + Variable.size + Variable.sizes + Variable.values + + IndexVariable.all + IndexVariable.any + IndexVariable.argmax + IndexVariable.argmin + IndexVariable.argsort + IndexVariable.astype + IndexVariable.broadcast_equals + IndexVariable.chunk + IndexVariable.clip + IndexVariable.coarsen + IndexVariable.compute + IndexVariable.concat + IndexVariable.conj + IndexVariable.conjugate + IndexVariable.copy + IndexVariable.count + IndexVariable.cumprod + IndexVariable.cumsum + IndexVariable.equals + IndexVariable.fillna + IndexVariable.get_axis_num + IndexVariable.get_level_variable + IndexVariable.identical + IndexVariable.isel + IndexVariable.isnull + IndexVariable.item + IndexVariable.load + IndexVariable.max + IndexVariable.mean + IndexVariable.median + IndexVariable.min + IndexVariable.no_conflicts + IndexVariable.notnull + IndexVariable.pad + IndexVariable.prod + IndexVariable.quantile + IndexVariable.rank + IndexVariable.reduce + IndexVariable.roll + IndexVariable.rolling_window + IndexVariable.round + IndexVariable.searchsorted + IndexVariable.set_dims + IndexVariable.shift + IndexVariable.squeeze + IndexVariable.stack + IndexVariable.std + IndexVariable.sum + IndexVariable.to_base_variable + IndexVariable.to_coord + IndexVariable.to_dict + IndexVariable.to_index + IndexVariable.to_index_variable + IndexVariable.to_variable + IndexVariable.transpose + IndexVariable.unstack + IndexVariable.var + IndexVariable.where + IndexVariable.T + IndexVariable.attrs + IndexVariable.chunks + IndexVariable.data + IndexVariable.dims + IndexVariable.dtype + IndexVariable.encoding + IndexVariable.imag + IndexVariable.level_names + IndexVariable.name + IndexVariable.nbytes + IndexVariable.ndim + IndexVariable.real + IndexVariable.shape + IndexVariable.size + IndexVariable.sizes + IndexVariable.values + + + NamedArray.all + NamedArray.any + NamedArray.attrs + NamedArray.broadcast_to + NamedArray.chunks + NamedArray.chunksizes + NamedArray.copy + NamedArray.count + NamedArray.cumprod + NamedArray.cumsum + NamedArray.data + NamedArray.dims + NamedArray.dtype + NamedArray.expand_dims + NamedArray.get_axis_num + NamedArray.max + NamedArray.mean + NamedArray.median + NamedArray.min + NamedArray.nbytes + NamedArray.ndim + NamedArray.prod + NamedArray.reduce + NamedArray.shape + NamedArray.size + NamedArray.sizes + NamedArray.std + NamedArray.sum + NamedArray.var + + + plot.plot + plot.line + plot.step + plot.hist + plot.contour + plot.contourf + plot.imshow + plot.pcolormesh + plot.scatter + plot.surface + + CFTimeIndex.all + CFTimeIndex.any + CFTimeIndex.append + CFTimeIndex.argsort + CFTimeIndex.argmax + CFTimeIndex.argmin + CFTimeIndex.asof + CFTimeIndex.asof_locs + CFTimeIndex.astype + CFTimeIndex.calendar + CFTimeIndex.ceil + CFTimeIndex.contains + CFTimeIndex.copy + CFTimeIndex.days_in_month + CFTimeIndex.delete + CFTimeIndex.difference + CFTimeIndex.drop + CFTimeIndex.drop_duplicates + CFTimeIndex.droplevel + CFTimeIndex.dropna + CFTimeIndex.duplicated + CFTimeIndex.equals + CFTimeIndex.factorize + CFTimeIndex.fillna + CFTimeIndex.floor + CFTimeIndex.format + CFTimeIndex.get_indexer + CFTimeIndex.get_indexer_for + CFTimeIndex.get_indexer_non_unique + CFTimeIndex.get_level_values + CFTimeIndex.get_loc + CFTimeIndex.get_slice_bound + CFTimeIndex.get_value + CFTimeIndex.groupby + CFTimeIndex.holds_integer + CFTimeIndex.identical + CFTimeIndex.insert + CFTimeIndex.intersection + CFTimeIndex.is_ + CFTimeIndex.is_boolean + CFTimeIndex.is_categorical + CFTimeIndex.is_floating + CFTimeIndex.is_integer + CFTimeIndex.is_interval + CFTimeIndex.is_numeric + CFTimeIndex.is_object + CFTimeIndex.isin + CFTimeIndex.isna + CFTimeIndex.isnull + CFTimeIndex.item + CFTimeIndex.join + CFTimeIndex.map + CFTimeIndex.max + CFTimeIndex.memory_usage + CFTimeIndex.min + CFTimeIndex.notna + CFTimeIndex.notnull + CFTimeIndex.nunique + CFTimeIndex.putmask + CFTimeIndex.ravel + CFTimeIndex.reindex + CFTimeIndex.rename + CFTimeIndex.repeat + CFTimeIndex.round + CFTimeIndex.searchsorted + CFTimeIndex.set_names + CFTimeIndex.shift + CFTimeIndex.slice_indexer + CFTimeIndex.slice_locs + CFTimeIndex.sort + CFTimeIndex.sort_values + CFTimeIndex.sortlevel + CFTimeIndex.strftime + CFTimeIndex.symmetric_difference + CFTimeIndex.take + CFTimeIndex.to_datetimeindex + CFTimeIndex.to_flat_index + CFTimeIndex.to_frame + CFTimeIndex.to_list + CFTimeIndex.to_numpy + CFTimeIndex.to_series + CFTimeIndex.tolist + CFTimeIndex.transpose + CFTimeIndex.union + CFTimeIndex.unique + CFTimeIndex.value_counts + CFTimeIndex.view + CFTimeIndex.where + + CFTimeIndex.T + CFTimeIndex.array + CFTimeIndex.asi8 + CFTimeIndex.date_type + CFTimeIndex.day + CFTimeIndex.dayofweek + CFTimeIndex.dayofyear + CFTimeIndex.dtype + CFTimeIndex.empty + CFTimeIndex.freq + CFTimeIndex.has_duplicates + CFTimeIndex.hasnans + CFTimeIndex.hour + CFTimeIndex.inferred_type + CFTimeIndex.is_monotonic_increasing + CFTimeIndex.is_monotonic_decreasing + CFTimeIndex.is_unique + CFTimeIndex.microsecond + CFTimeIndex.minute + CFTimeIndex.month + CFTimeIndex.name + CFTimeIndex.names + CFTimeIndex.nbytes + CFTimeIndex.ndim + CFTimeIndex.nlevels + CFTimeIndex.second + CFTimeIndex.shape + CFTimeIndex.size + CFTimeIndex.values + CFTimeIndex.year + + Index.from_variables + Index.concat + Index.stack + Index.unstack + Index.create_variables + Index.to_pandas_index + Index.isel + Index.sel + Index.join + Index.reindex_like + Index.equals + Index.roll + Index.rename + Index.copy + + backends.NetCDF4DataStore.close + backends.NetCDF4DataStore.encode + backends.NetCDF4DataStore.encode_attribute + backends.NetCDF4DataStore.encode_variable + backends.NetCDF4DataStore.get_attrs + backends.NetCDF4DataStore.get_dimensions + backends.NetCDF4DataStore.get_encoding + backends.NetCDF4DataStore.get_variables + backends.NetCDF4DataStore.load + backends.NetCDF4DataStore.open + backends.NetCDF4DataStore.open_store_variable + backends.NetCDF4DataStore.prepare_variable + backends.NetCDF4DataStore.set_attribute + backends.NetCDF4DataStore.set_attributes + backends.NetCDF4DataStore.set_dimension + backends.NetCDF4DataStore.set_dimensions + backends.NetCDF4DataStore.set_variable + backends.NetCDF4DataStore.set_variables + backends.NetCDF4DataStore.store + backends.NetCDF4DataStore.store_dataset + backends.NetCDF4DataStore.sync + backends.NetCDF4DataStore.autoclose + backends.NetCDF4DataStore.ds + backends.NetCDF4DataStore.format + backends.NetCDF4DataStore.is_remote + backends.NetCDF4DataStore.lock + + backends.NetCDF4BackendEntrypoint.description + backends.NetCDF4BackendEntrypoint.url + backends.NetCDF4BackendEntrypoint.guess_can_open + backends.NetCDF4BackendEntrypoint.open_dataset + + backends.H5NetCDFStore.autoclose + backends.H5NetCDFStore.close + backends.H5NetCDFStore.encode + backends.H5NetCDFStore.encode_attribute + backends.H5NetCDFStore.encode_variable + backends.H5NetCDFStore.format + backends.H5NetCDFStore.get_attrs + backends.H5NetCDFStore.get_dimensions + backends.H5NetCDFStore.get_encoding + backends.H5NetCDFStore.get_variables + backends.H5NetCDFStore.is_remote + backends.H5NetCDFStore.load + backends.H5NetCDFStore.lock + backends.H5NetCDFStore.open + backends.H5NetCDFStore.open_store_variable + backends.H5NetCDFStore.prepare_variable + backends.H5NetCDFStore.set_attribute + backends.H5NetCDFStore.set_attributes + backends.H5NetCDFStore.set_dimension + backends.H5NetCDFStore.set_dimensions + backends.H5NetCDFStore.set_variable + backends.H5NetCDFStore.set_variables + backends.H5NetCDFStore.store + backends.H5NetCDFStore.store_dataset + backends.H5NetCDFStore.sync + backends.H5NetCDFStore.ds + + backends.H5netcdfBackendEntrypoint.description + backends.H5netcdfBackendEntrypoint.url + backends.H5netcdfBackendEntrypoint.guess_can_open + backends.H5netcdfBackendEntrypoint.open_dataset + + backends.PydapDataStore.close + backends.PydapDataStore.get_attrs + backends.PydapDataStore.get_dimensions + backends.PydapDataStore.get_encoding + backends.PydapDataStore.get_variables + backends.PydapDataStore.load + backends.PydapDataStore.open + backends.PydapDataStore.open_store_variable + + backends.PydapBackendEntrypoint.description + backends.PydapBackendEntrypoint.url + backends.PydapBackendEntrypoint.guess_can_open + backends.PydapBackendEntrypoint.open_dataset + + backends.ScipyDataStore.close + backends.ScipyDataStore.encode + backends.ScipyDataStore.encode_attribute + backends.ScipyDataStore.encode_variable + backends.ScipyDataStore.get_attrs + backends.ScipyDataStore.get_dimensions + backends.ScipyDataStore.get_encoding + backends.ScipyDataStore.get_variables + backends.ScipyDataStore.load + backends.ScipyDataStore.open_store_variable + backends.ScipyDataStore.prepare_variable + backends.ScipyDataStore.set_attribute + backends.ScipyDataStore.set_attributes + backends.ScipyDataStore.set_dimension + backends.ScipyDataStore.set_dimensions + backends.ScipyDataStore.set_variable + backends.ScipyDataStore.set_variables + backends.ScipyDataStore.store + backends.ScipyDataStore.store_dataset + backends.ScipyDataStore.sync + backends.ScipyDataStore.ds + + backends.ScipyBackendEntrypoint.description + backends.ScipyBackendEntrypoint.url + backends.ScipyBackendEntrypoint.guess_can_open + backends.ScipyBackendEntrypoint.open_dataset + + backends.ZarrStore.close + backends.ZarrStore.encode_attribute + backends.ZarrStore.encode_variable + backends.ZarrStore.get_attrs + backends.ZarrStore.get_dimensions + backends.ZarrStore.get_variables + backends.ZarrStore.open_group + backends.ZarrStore.open_store_variable + backends.ZarrStore.set_attributes + backends.ZarrStore.set_dimensions + backends.ZarrStore.set_variables + backends.ZarrStore.store + backends.ZarrStore.sync + backends.ZarrStore.ds + + backends.ZarrBackendEntrypoint.description + backends.ZarrBackendEntrypoint.url + backends.ZarrBackendEntrypoint.guess_can_open + backends.ZarrBackendEntrypoint.open_dataset + + backends.StoreBackendEntrypoint.description + backends.StoreBackendEntrypoint.url + backends.StoreBackendEntrypoint.guess_can_open + backends.StoreBackendEntrypoint.open_dataset + + backends.FileManager.acquire + backends.FileManager.acquire_context + backends.FileManager.close + + backends.CachingFileManager.acquire + backends.CachingFileManager.acquire_context + backends.CachingFileManager.close + + backends.DummyFileManager.acquire + backends.DummyFileManager.acquire_context + backends.DummyFileManager.close + + backends.BackendArray + backends.BackendEntrypoint.guess_can_open + backends.BackendEntrypoint.open_dataset + + core.indexing.IndexingSupport + core.indexing.explicit_indexing_adapter + core.indexing.BasicIndexer + core.indexing.OuterIndexer + core.indexing.VectorizedIndexer + core.indexing.LazilyIndexedArray + core.indexing.LazilyVectorizedIndexedArray + + conventions.decode_cf_variables + + coding.variables.UnsignedIntegerCoder + coding.variables.CFMaskCoder + coding.variables.CFScaleOffsetCoder + + coding.strings.CharacterArrayCoder + coding.strings.EncodedStringCoder + + coding.times.CFTimedeltaCoder + coding.times.CFDatetimeCoder diff --git a/test/fixtures/whole_applications/xarray/doc/api.rst b/test/fixtures/whole_applications/xarray/doc/api.rst new file mode 100644 index 0000000..a8f8ea7 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/api.rst @@ -0,0 +1,1178 @@ +.. currentmodule:: xarray + +.. _api: + +############# +API reference +############# + +This page provides an auto-generated summary of xarray's API. For more details +and examples, refer to the relevant chapters in the main part of the +documentation. + +See also: :ref:`public api` + +Top-level functions +=================== + +.. autosummary:: + :toctree: generated/ + + apply_ufunc + align + broadcast + concat + merge + combine_by_coords + combine_nested + where + infer_freq + full_like + zeros_like + ones_like + cov + corr + cross + dot + polyval + map_blocks + show_versions + set_options + get_options + unify_chunks + +Dataset +======= + +Creating a dataset +------------------ + +.. autosummary:: + :toctree: generated/ + + Dataset + decode_cf + +Attributes +---------- + +.. autosummary:: + :toctree: generated/ + + Dataset.dims + Dataset.sizes + Dataset.dtypes + Dataset.data_vars + Dataset.coords + Dataset.attrs + Dataset.encoding + Dataset.indexes + Dataset.chunks + Dataset.chunksizes + Dataset.nbytes + +Dictionary interface +-------------------- + +Datasets implement the mapping interface with keys given by variable names +and values given by ``DataArray`` objects. + +.. autosummary:: + :toctree: generated/ + + Dataset.__getitem__ + Dataset.__setitem__ + Dataset.__delitem__ + Dataset.update + Dataset.get + Dataset.items + Dataset.keys + Dataset.values + +Dataset contents +---------------- + +.. autosummary:: + :toctree: generated/ + + Dataset.copy + Dataset.assign + Dataset.assign_coords + Dataset.assign_attrs + Dataset.pipe + Dataset.merge + Dataset.rename + Dataset.rename_vars + Dataset.rename_dims + Dataset.swap_dims + Dataset.expand_dims + Dataset.drop_vars + Dataset.drop_indexes + Dataset.drop_duplicates + Dataset.drop_dims + Dataset.drop_encoding + Dataset.set_coords + Dataset.reset_coords + Dataset.convert_calendar + Dataset.interp_calendar + Dataset.get_index + +Comparisons +----------- + +.. autosummary:: + :toctree: generated/ + + Dataset.equals + Dataset.identical + Dataset.broadcast_equals + +Indexing +-------- + +.. autosummary:: + :toctree: generated/ + + Dataset.loc + Dataset.isel + Dataset.sel + Dataset.drop_sel + Dataset.drop_isel + Dataset.head + Dataset.tail + Dataset.thin + Dataset.squeeze + Dataset.interp + Dataset.interp_like + Dataset.reindex + Dataset.reindex_like + Dataset.set_index + Dataset.reset_index + Dataset.set_xindex + Dataset.reorder_levels + Dataset.query + +Missing value handling +---------------------- + +.. autosummary:: + :toctree: generated/ + + Dataset.isnull + Dataset.notnull + Dataset.combine_first + Dataset.count + Dataset.dropna + Dataset.fillna + Dataset.ffill + Dataset.bfill + Dataset.interpolate_na + Dataset.where + Dataset.isin + +Computation +----------- + +.. autosummary:: + :toctree: generated/ + + Dataset.map + Dataset.reduce + Dataset.groupby + Dataset.groupby_bins + Dataset.rolling + Dataset.rolling_exp + Dataset.cumulative + Dataset.weighted + Dataset.coarsen + Dataset.resample + Dataset.diff + Dataset.quantile + Dataset.differentiate + Dataset.integrate + Dataset.map_blocks + Dataset.polyfit + Dataset.curvefit + Dataset.eval + +Aggregation +----------- + +.. autosummary:: + :toctree: generated/ + + Dataset.all + Dataset.any + Dataset.argmax + Dataset.argmin + Dataset.count + Dataset.idxmax + Dataset.idxmin + Dataset.max + Dataset.min + Dataset.mean + Dataset.median + Dataset.prod + Dataset.sum + Dataset.std + Dataset.var + Dataset.cumsum + Dataset.cumprod + +ndarray methods +--------------- + +.. autosummary:: + :toctree: generated/ + + Dataset.argsort + Dataset.astype + Dataset.clip + Dataset.conj + Dataset.conjugate + Dataset.imag + Dataset.round + Dataset.real + Dataset.rank + +Reshaping and reorganizing +-------------------------- + +.. autosummary:: + :toctree: generated/ + + Dataset.transpose + Dataset.stack + Dataset.unstack + Dataset.to_stacked_array + Dataset.shift + Dataset.roll + Dataset.pad + Dataset.sortby + Dataset.broadcast_like + +DataArray +========= + +.. autosummary:: + :toctree: generated/ + + DataArray + +Attributes +---------- + +.. autosummary:: + :toctree: generated/ + + DataArray.values + DataArray.data + DataArray.coords + DataArray.dims + DataArray.sizes + DataArray.name + DataArray.attrs + DataArray.encoding + DataArray.indexes + DataArray.chunksizes + +ndarray attributes +------------------ + +.. autosummary:: + :toctree: generated/ + + DataArray.ndim + DataArray.nbytes + DataArray.shape + DataArray.size + DataArray.dtype + DataArray.chunks + + +DataArray contents +------------------ + +.. autosummary:: + :toctree: generated/ + + DataArray.assign_coords + DataArray.assign_attrs + DataArray.pipe + DataArray.rename + DataArray.swap_dims + DataArray.expand_dims + DataArray.drop_vars + DataArray.drop_indexes + DataArray.drop_duplicates + DataArray.drop_encoding + DataArray.reset_coords + DataArray.copy + DataArray.convert_calendar + DataArray.interp_calendar + DataArray.get_index + DataArray.astype + DataArray.item + +Indexing +-------- + +.. autosummary:: + :toctree: generated/ + + DataArray.__getitem__ + DataArray.__setitem__ + DataArray.loc + DataArray.isel + DataArray.sel + DataArray.drop_sel + DataArray.drop_isel + DataArray.head + DataArray.tail + DataArray.thin + DataArray.squeeze + DataArray.interp + DataArray.interp_like + DataArray.reindex + DataArray.reindex_like + DataArray.set_index + DataArray.reset_index + DataArray.set_xindex + DataArray.reorder_levels + DataArray.query + +Missing value handling +---------------------- + +.. autosummary:: + :toctree: generated/ + + DataArray.isnull + DataArray.notnull + DataArray.combine_first + DataArray.count + DataArray.dropna + DataArray.fillna + DataArray.ffill + DataArray.bfill + DataArray.interpolate_na + DataArray.where + DataArray.isin + +Comparisons +----------- + +.. autosummary:: + :toctree: generated/ + + DataArray.equals + DataArray.identical + DataArray.broadcast_equals + +Computation +----------- + +.. autosummary:: + :toctree: generated/ + + DataArray.reduce + DataArray.groupby + DataArray.groupby_bins + DataArray.rolling + DataArray.rolling_exp + DataArray.cumulative + DataArray.weighted + DataArray.coarsen + DataArray.resample + DataArray.get_axis_num + DataArray.diff + DataArray.dot + DataArray.quantile + DataArray.differentiate + DataArray.integrate + DataArray.polyfit + DataArray.map_blocks + DataArray.curvefit + +Aggregation +----------- + +.. autosummary:: + :toctree: generated/ + + DataArray.all + DataArray.any + DataArray.argmax + DataArray.argmin + DataArray.count + DataArray.idxmax + DataArray.idxmin + DataArray.max + DataArray.min + DataArray.mean + DataArray.median + DataArray.prod + DataArray.sum + DataArray.std + DataArray.var + DataArray.cumsum + DataArray.cumprod + +ndarray methods +--------------- + +.. autosummary:: + :toctree: generated/ + + DataArray.argsort + DataArray.clip + DataArray.conj + DataArray.conjugate + DataArray.imag + DataArray.searchsorted + DataArray.round + DataArray.real + DataArray.T + DataArray.rank + + +String manipulation +------------------- + +.. autosummary:: + :toctree: generated/ + :template: autosummary/accessor.rst + + DataArray.str + +.. autosummary:: + :toctree: generated/ + :template: autosummary/accessor_method.rst + + DataArray.str.capitalize + DataArray.str.casefold + DataArray.str.cat + DataArray.str.center + DataArray.str.contains + DataArray.str.count + DataArray.str.decode + DataArray.str.encode + DataArray.str.endswith + DataArray.str.extract + DataArray.str.extractall + DataArray.str.find + DataArray.str.findall + DataArray.str.format + DataArray.str.get + DataArray.str.get_dummies + DataArray.str.index + DataArray.str.isalnum + DataArray.str.isalpha + DataArray.str.isdecimal + DataArray.str.isdigit + DataArray.str.islower + DataArray.str.isnumeric + DataArray.str.isspace + DataArray.str.istitle + DataArray.str.isupper + DataArray.str.join + DataArray.str.len + DataArray.str.ljust + DataArray.str.lower + DataArray.str.lstrip + DataArray.str.match + DataArray.str.normalize + DataArray.str.pad + DataArray.str.partition + DataArray.str.repeat + DataArray.str.replace + DataArray.str.rfind + DataArray.str.rindex + DataArray.str.rjust + DataArray.str.rpartition + DataArray.str.rsplit + DataArray.str.rstrip + DataArray.str.slice + DataArray.str.slice_replace + DataArray.str.split + DataArray.str.startswith + DataArray.str.strip + DataArray.str.swapcase + DataArray.str.title + DataArray.str.translate + DataArray.str.upper + DataArray.str.wrap + DataArray.str.zfill + +Datetimelike properties +----------------------- + +**Datetime properties**: + +.. autosummary:: + :toctree: generated/ + :template: autosummary/accessor_attribute.rst + + DataArray.dt.year + DataArray.dt.month + DataArray.dt.day + DataArray.dt.hour + DataArray.dt.minute + DataArray.dt.second + DataArray.dt.microsecond + DataArray.dt.nanosecond + DataArray.dt.dayofweek + DataArray.dt.weekday + DataArray.dt.dayofyear + DataArray.dt.quarter + DataArray.dt.days_in_month + DataArray.dt.daysinmonth + DataArray.dt.season + DataArray.dt.time + DataArray.dt.date + DataArray.dt.calendar + DataArray.dt.is_month_start + DataArray.dt.is_month_end + DataArray.dt.is_quarter_end + DataArray.dt.is_year_start + DataArray.dt.is_leap_year + +**Datetime methods**: + +.. autosummary:: + :toctree: generated/ + :template: autosummary/accessor_method.rst + + DataArray.dt.floor + DataArray.dt.ceil + DataArray.dt.isocalendar + DataArray.dt.round + DataArray.dt.strftime + +**Timedelta properties**: + +.. autosummary:: + :toctree: generated/ + :template: autosummary/accessor_attribute.rst + + DataArray.dt.days + DataArray.dt.seconds + DataArray.dt.microseconds + DataArray.dt.nanoseconds + DataArray.dt.total_seconds + +**Timedelta methods**: + +.. autosummary:: + :toctree: generated/ + :template: autosummary/accessor_method.rst + + DataArray.dt.floor + DataArray.dt.ceil + DataArray.dt.round + + +Reshaping and reorganizing +-------------------------- + +.. autosummary:: + :toctree: generated/ + + DataArray.transpose + DataArray.stack + DataArray.unstack + DataArray.to_unstacked_dataset + DataArray.shift + DataArray.roll + DataArray.pad + DataArray.sortby + DataArray.broadcast_like + +IO / Conversion +=============== + +Dataset methods +--------------- + +.. autosummary:: + :toctree: generated/ + + load_dataset + open_dataset + open_mfdataset + open_zarr + save_mfdataset + Dataset.as_numpy + Dataset.from_dataframe + Dataset.from_dict + Dataset.to_dataarray + Dataset.to_dataframe + Dataset.to_dask_dataframe + Dataset.to_dict + Dataset.to_netcdf + Dataset.to_pandas + Dataset.to_zarr + Dataset.chunk + Dataset.close + Dataset.compute + Dataset.filter_by_attrs + Dataset.info + Dataset.load + Dataset.persist + Dataset.unify_chunks + +DataArray methods +----------------- + +.. autosummary:: + :toctree: generated/ + + load_dataarray + open_dataarray + DataArray.as_numpy + DataArray.from_dict + DataArray.from_iris + DataArray.from_series + DataArray.to_dask_dataframe + DataArray.to_dataframe + DataArray.to_dataset + DataArray.to_dict + DataArray.to_index + DataArray.to_iris + DataArray.to_masked_array + DataArray.to_netcdf + DataArray.to_numpy + DataArray.to_pandas + DataArray.to_series + DataArray.to_zarr + DataArray.chunk + DataArray.close + DataArray.compute + DataArray.persist + DataArray.load + DataArray.unify_chunks + +Coordinates objects +=================== + +Dataset +------- + +.. autosummary:: + :toctree: generated/ + + core.coordinates.DatasetCoordinates + core.coordinates.DatasetCoordinates.dtypes + +DataArray +--------- + +.. autosummary:: + :toctree: generated/ + + core.coordinates.DataArrayCoordinates + core.coordinates.DataArrayCoordinates.dtypes + +Plotting +======== + +Dataset +------- + +.. autosummary:: + :toctree: generated/ + :template: autosummary/accessor_method.rst + + Dataset.plot.scatter + Dataset.plot.quiver + Dataset.plot.streamplot + +DataArray +--------- + +.. autosummary:: + :toctree: generated/ + :template: autosummary/accessor_callable.rst + + DataArray.plot + +.. autosummary:: + :toctree: generated/ + :template: autosummary/accessor_method.rst + + DataArray.plot.contourf + DataArray.plot.contour + DataArray.plot.hist + DataArray.plot.imshow + DataArray.plot.line + DataArray.plot.pcolormesh + DataArray.plot.step + DataArray.plot.scatter + DataArray.plot.surface + + +Faceting +-------- +.. autosummary:: + :toctree: generated/ + + plot.FacetGrid + plot.FacetGrid.add_colorbar + plot.FacetGrid.add_legend + plot.FacetGrid.add_quiverkey + plot.FacetGrid.map + plot.FacetGrid.map_dataarray + plot.FacetGrid.map_dataarray_line + plot.FacetGrid.map_dataset + plot.FacetGrid.map_plot1d + plot.FacetGrid.set_axis_labels + plot.FacetGrid.set_ticks + plot.FacetGrid.set_titles + plot.FacetGrid.set_xlabels + plot.FacetGrid.set_ylabels + + + +GroupBy objects +=============== + +.. currentmodule:: xarray.core.groupby + +Dataset +------- + +.. autosummary:: + :toctree: generated/ + + DatasetGroupBy + DatasetGroupBy.map + DatasetGroupBy.reduce + DatasetGroupBy.assign + DatasetGroupBy.assign_coords + DatasetGroupBy.first + DatasetGroupBy.last + DatasetGroupBy.fillna + DatasetGroupBy.quantile + DatasetGroupBy.where + DatasetGroupBy.all + DatasetGroupBy.any + DatasetGroupBy.count + DatasetGroupBy.cumsum + DatasetGroupBy.cumprod + DatasetGroupBy.max + DatasetGroupBy.mean + DatasetGroupBy.median + DatasetGroupBy.min + DatasetGroupBy.prod + DatasetGroupBy.std + DatasetGroupBy.sum + DatasetGroupBy.var + DatasetGroupBy.dims + DatasetGroupBy.groups + +DataArray +--------- + +.. autosummary:: + :toctree: generated/ + + DataArrayGroupBy + DataArrayGroupBy.map + DataArrayGroupBy.reduce + DataArrayGroupBy.assign_coords + DataArrayGroupBy.first + DataArrayGroupBy.last + DataArrayGroupBy.fillna + DataArrayGroupBy.quantile + DataArrayGroupBy.where + DataArrayGroupBy.all + DataArrayGroupBy.any + DataArrayGroupBy.count + DataArrayGroupBy.cumsum + DataArrayGroupBy.cumprod + DataArrayGroupBy.max + DataArrayGroupBy.mean + DataArrayGroupBy.median + DataArrayGroupBy.min + DataArrayGroupBy.prod + DataArrayGroupBy.std + DataArrayGroupBy.sum + DataArrayGroupBy.var + DataArrayGroupBy.dims + DataArrayGroupBy.groups + + +Rolling objects +=============== + +.. currentmodule:: xarray.core.rolling + +Dataset +------- + +.. autosummary:: + :toctree: generated/ + + DatasetRolling + DatasetRolling.construct + DatasetRolling.reduce + DatasetRolling.argmax + DatasetRolling.argmin + DatasetRolling.count + DatasetRolling.max + DatasetRolling.mean + DatasetRolling.median + DatasetRolling.min + DatasetRolling.prod + DatasetRolling.std + DatasetRolling.sum + DatasetRolling.var + +DataArray +--------- + +.. autosummary:: + :toctree: generated/ + + DataArrayRolling + DataArrayRolling.__iter__ + DataArrayRolling.construct + DataArrayRolling.reduce + DataArrayRolling.argmax + DataArrayRolling.argmin + DataArrayRolling.count + DataArrayRolling.max + DataArrayRolling.mean + DataArrayRolling.median + DataArrayRolling.min + DataArrayRolling.prod + DataArrayRolling.std + DataArrayRolling.sum + DataArrayRolling.var + +Coarsen objects +=============== + +Dataset +------- + +.. autosummary:: + :toctree: generated/ + + DatasetCoarsen + DatasetCoarsen.all + DatasetCoarsen.any + DatasetCoarsen.construct + DatasetCoarsen.count + DatasetCoarsen.max + DatasetCoarsen.mean + DatasetCoarsen.median + DatasetCoarsen.min + DatasetCoarsen.prod + DatasetCoarsen.reduce + DatasetCoarsen.std + DatasetCoarsen.sum + DatasetCoarsen.var + +DataArray +--------- + +.. autosummary:: + :toctree: generated/ + + DataArrayCoarsen + DataArrayCoarsen.all + DataArrayCoarsen.any + DataArrayCoarsen.construct + DataArrayCoarsen.count + DataArrayCoarsen.max + DataArrayCoarsen.mean + DataArrayCoarsen.median + DataArrayCoarsen.min + DataArrayCoarsen.prod + DataArrayCoarsen.reduce + DataArrayCoarsen.std + DataArrayCoarsen.sum + DataArrayCoarsen.var + +Exponential rolling objects +=========================== + +.. currentmodule:: xarray.core.rolling_exp + +.. autosummary:: + :toctree: generated/ + + RollingExp + RollingExp.mean + RollingExp.sum + +Weighted objects +================ + +.. currentmodule:: xarray.core.weighted + +Dataset +------- + +.. autosummary:: + :toctree: generated/ + + DatasetWeighted + DatasetWeighted.mean + DatasetWeighted.quantile + DatasetWeighted.sum + DatasetWeighted.std + DatasetWeighted.var + DatasetWeighted.sum_of_weights + DatasetWeighted.sum_of_squares + +DataArray +--------- + +.. autosummary:: + :toctree: generated/ + + DataArrayWeighted + DataArrayWeighted.mean + DataArrayWeighted.quantile + DataArrayWeighted.sum + DataArrayWeighted.std + DataArrayWeighted.var + DataArrayWeighted.sum_of_weights + DataArrayWeighted.sum_of_squares + +Resample objects +================ + +.. currentmodule:: xarray.core.resample + +Dataset +------- + +.. autosummary:: + :toctree: generated/ + + DatasetResample + DatasetResample.asfreq + DatasetResample.backfill + DatasetResample.interpolate + DatasetResample.nearest + DatasetResample.pad + DatasetResample.all + DatasetResample.any + DatasetResample.apply + DatasetResample.assign + DatasetResample.assign_coords + DatasetResample.bfill + DatasetResample.count + DatasetResample.ffill + DatasetResample.fillna + DatasetResample.first + DatasetResample.last + DatasetResample.map + DatasetResample.max + DatasetResample.mean + DatasetResample.median + DatasetResample.min + DatasetResample.prod + DatasetResample.quantile + DatasetResample.reduce + DatasetResample.std + DatasetResample.sum + DatasetResample.var + DatasetResample.where + DatasetResample.dims + DatasetResample.groups + + +DataArray +--------- + +.. autosummary:: + :toctree: generated/ + + DataArrayResample + DataArrayResample.asfreq + DataArrayResample.backfill + DataArrayResample.interpolate + DataArrayResample.nearest + DataArrayResample.pad + DataArrayResample.all + DataArrayResample.any + DataArrayResample.apply + DataArrayResample.assign_coords + DataArrayResample.bfill + DataArrayResample.count + DataArrayResample.ffill + DataArrayResample.fillna + DataArrayResample.first + DataArrayResample.last + DataArrayResample.map + DataArrayResample.max + DataArrayResample.mean + DataArrayResample.median + DataArrayResample.min + DataArrayResample.prod + DataArrayResample.quantile + DataArrayResample.reduce + DataArrayResample.std + DataArrayResample.sum + DataArrayResample.var + DataArrayResample.where + DataArrayResample.dims + DataArrayResample.groups + +Accessors +========= + +.. currentmodule:: xarray + +.. autosummary:: + :toctree: generated/ + + core.accessor_dt.DatetimeAccessor + core.accessor_dt.TimedeltaAccessor + core.accessor_str.StringAccessor + +Custom Indexes +============== +.. autosummary:: + :toctree: generated/ + + CFTimeIndex + +Creating custom indexes +----------------------- +.. autosummary:: + :toctree: generated/ + + cftime_range + date_range + date_range_like + +Tutorial +======== + +.. autosummary:: + :toctree: generated/ + + tutorial.open_dataset + tutorial.load_dataset + +Testing +======= + +.. autosummary:: + :toctree: generated/ + + testing.assert_equal + testing.assert_identical + testing.assert_allclose + testing.assert_chunks_equal + +Hypothesis Testing Strategies +============================= + +.. currentmodule:: xarray + +See the :ref:`documentation page on testing ` for a guide on how to use these strategies. + +.. warning:: + These strategies should be considered highly experimental, and liable to change at any time. + +.. autosummary:: + :toctree: generated/ + + testing.strategies.supported_dtypes + testing.strategies.names + testing.strategies.dimension_names + testing.strategies.dimension_sizes + testing.strategies.attrs + testing.strategies.variables + testing.strategies.unique_subset_of + +Exceptions +========== + +.. autosummary:: + :toctree: generated/ + + MergeError + SerializationWarning + +Advanced API +============ + +.. autosummary:: + :toctree: generated/ + + Coordinates + Dataset.variables + DataArray.variable + Variable + IndexVariable + as_variable + Index + IndexSelResult + Context + register_dataset_accessor + register_dataarray_accessor + Dataset.set_close + backends.BackendArray + backends.BackendEntrypoint + backends.list_engines + backends.refresh_engines + +Default, pandas-backed indexes built-in Xarray: + + indexes.PandasIndex + indexes.PandasMultiIndex + +These backends provide a low-level interface for lazily loading data from +external file-formats or protocols, and can be manually invoked to create +arguments for the ``load_store`` and ``dump_to_store`` Dataset methods: + +.. autosummary:: + :toctree: generated/ + + backends.NetCDF4DataStore + backends.H5NetCDFStore + backends.PydapDataStore + backends.ScipyDataStore + backends.ZarrStore + backends.FileManager + backends.CachingFileManager + backends.DummyFileManager + +These BackendEntrypoints provide a basic interface to the most commonly +used filetypes in the xarray universe. + +.. autosummary:: + :toctree: generated/ + + backends.NetCDF4BackendEntrypoint + backends.H5netcdfBackendEntrypoint + backends.PydapBackendEntrypoint + backends.ScipyBackendEntrypoint + backends.StoreBackendEntrypoint + backends.ZarrBackendEntrypoint + +Deprecated / Pending Deprecation +================================ + +.. autosummary:: + :toctree: generated/ + + Dataset.drop + DataArray.drop + Dataset.apply + core.groupby.DataArrayGroupBy.apply + core.groupby.DatasetGroupBy.apply + +.. autosummary:: + :toctree: generated/ + :template: autosummary/accessor_attribute.rst + + DataArray.dt.weekofyear + DataArray.dt.week diff --git a/test/fixtures/whole_applications/xarray/doc/conf.py b/test/fixtures/whole_applications/xarray/doc/conf.py new file mode 100644 index 0000000..80b2444 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/conf.py @@ -0,0 +1,464 @@ +# +# xarray documentation build configuration file, created by +# sphinx-quickstart on Thu Feb 6 18:57:54 2014. +# +# This file is execfile()d with the current directory set to its +# containing dir. +# +# Note that not all possible configuration values are present in this +# autogenerated file. +# +# All configuration values have a default; values that are commented out +# serve to show the default. + + +import datetime +import inspect +import os +import pathlib +import subprocess +import sys +from contextlib import suppress +from textwrap import dedent, indent + +import sphinx_autosummary_accessors +import yaml +from sphinx.application import Sphinx +from sphinx.util import logging + +import xarray + +LOGGER = logging.getLogger("conf") + +allowed_failures = set() + +print("python exec:", sys.executable) +print("sys.path:", sys.path) + +if "CONDA_DEFAULT_ENV" in os.environ or "conda" in sys.executable: + print("conda environment:") + subprocess.run([os.environ.get("CONDA_EXE", "conda"), "list"]) +else: + print("pip environment:") + subprocess.run([sys.executable, "-m", "pip", "list"]) + +print(f"xarray: {xarray.__version__}, {xarray.__file__}") + +with suppress(ImportError): + import matplotlib + + matplotlib.use("Agg") + +try: + import cartopy # noqa: F401 +except ImportError: + allowed_failures.update( + [ + "gallery/plot_cartopy_facetgrid.py", + ] + ) + +nbsphinx_allow_errors = False + +# -- General configuration ------------------------------------------------ + +# If your documentation needs a minimal Sphinx version, state it here. +# needs_sphinx = '1.0' + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.intersphinx", + "sphinx.ext.extlinks", + "sphinx.ext.mathjax", + "sphinx.ext.napoleon", + "IPython.sphinxext.ipython_directive", + "IPython.sphinxext.ipython_console_highlighting", + "nbsphinx", + "sphinx_autosummary_accessors", + "sphinx.ext.linkcode", + "sphinxext.opengraph", + "sphinx_copybutton", + "sphinxext.rediraffe", + "sphinx_design", + "sphinx_inline_tabs", +] + + +extlinks = { + "issue": ("https://github.com/pydata/xarray/issues/%s", "GH%s"), + "pull": ("https://github.com/pydata/xarray/pull/%s", "PR%s"), +} + +# sphinx-copybutton configurations +copybutton_prompt_text = r">>> |\.\.\. |\$ |In \[\d*\]: | {2,5}\.\.\.: | {5,8}: " +copybutton_prompt_is_regexp = True + +# nbsphinx configurations + +nbsphinx_timeout = 600 +nbsphinx_execute = "always" +nbsphinx_prolog = """ +{% set docname = env.doc2path(env.docname, base=None) %} + +You can run this notebook in a `live session `_ |Binder| or view it `on Github `_. + +.. |Binder| image:: https://mybinder.org/badge.svg + :target: https://mybinder.org/v2/gh/pydata/xarray/main?urlpath=lab/tree/doc/{{ docname }} +""" + +autosummary_generate = True +autodoc_typehints = "none" + +# Napoleon configurations + +napoleon_google_docstring = False +napoleon_numpy_docstring = True +napoleon_use_param = False +napoleon_use_rtype = False +napoleon_preprocess_types = True +napoleon_type_aliases = { + # general terms + "sequence": ":term:`sequence`", + "iterable": ":term:`iterable`", + "callable": ":py:func:`callable`", + "dict_like": ":term:`dict-like `", + "dict-like": ":term:`dict-like `", + "path-like": ":term:`path-like `", + "mapping": ":term:`mapping`", + "file-like": ":term:`file-like `", + # special terms + # "same type as caller": "*same type as caller*", # does not work, yet + # "same type as values": "*same type as values*", # does not work, yet + # stdlib type aliases + "MutableMapping": "~collections.abc.MutableMapping", + "sys.stdout": ":obj:`sys.stdout`", + "timedelta": "~datetime.timedelta", + "string": ":class:`string `", + # numpy terms + "array_like": ":term:`array_like`", + "array-like": ":term:`array-like `", + "scalar": ":term:`scalar`", + "array": ":term:`array`", + "hashable": ":term:`hashable `", + # matplotlib terms + "color-like": ":py:func:`color-like `", + "matplotlib colormap name": ":doc:`matplotlib colormap name `", + "matplotlib axes object": ":py:class:`matplotlib axes object `", + "colormap": ":py:class:`colormap `", + # objects without namespace: xarray + "DataArray": "~xarray.DataArray", + "Dataset": "~xarray.Dataset", + "Variable": "~xarray.Variable", + "DatasetGroupBy": "~xarray.core.groupby.DatasetGroupBy", + "DataArrayGroupBy": "~xarray.core.groupby.DataArrayGroupBy", + # objects without namespace: numpy + "ndarray": "~numpy.ndarray", + "MaskedArray": "~numpy.ma.MaskedArray", + "dtype": "~numpy.dtype", + "ComplexWarning": "~numpy.ComplexWarning", + # objects without namespace: pandas + "Index": "~pandas.Index", + "MultiIndex": "~pandas.MultiIndex", + "CategoricalIndex": "~pandas.CategoricalIndex", + "TimedeltaIndex": "~pandas.TimedeltaIndex", + "DatetimeIndex": "~pandas.DatetimeIndex", + "Series": "~pandas.Series", + "DataFrame": "~pandas.DataFrame", + "Categorical": "~pandas.Categorical", + "Path": "~~pathlib.Path", + # objects with abbreviated namespace (from pandas) + "pd.Index": "~pandas.Index", + "pd.NaT": "~pandas.NaT", +} + + +# Add any paths that contain templates here, relative to this directory. +templates_path = ["_templates", sphinx_autosummary_accessors.templates_path] + +# The suffix of source filenames. +# source_suffix = ".rst" + + +# The master toctree document. +master_doc = "index" + +# General information about the project. +project = "xarray" +copyright = f"2014-{datetime.datetime.now().year}, xarray Developers" + +# The short X.Y version. +version = xarray.__version__.split("+")[0] +# The full version, including alpha/beta/rc tags. +release = xarray.__version__ + +# There are two options for replacing |today|: either, you set today to some +# non-false value, then it is used: +# today = '' +# Else, today_fmt is used as the format for a strftime call. +today_fmt = "%Y-%m-%d" + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +exclude_patterns = ["_build", "**.ipynb_checkpoints"] + + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = "sphinx" + + +# -- Options for HTML output ---------------------------------------------- +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +html_theme = "sphinx_book_theme" +html_title = "" + +html_context = { + "github_user": "pydata", + "github_repo": "xarray", + "github_version": "main", + "doc_path": "doc", +} + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +html_theme_options = dict( + # analytics_id='' this is configured in rtfd.io + # canonical_url="", + repository_url="https://github.com/pydata/xarray", + repository_branch="main", + navigation_with_keys=False, # pydata/pydata-sphinx-theme#1492 + path_to_docs="doc", + use_edit_page_button=True, + use_repository_button=True, + use_issues_button=True, + home_page_in_toc=False, + extra_footer="""

Xarray is a fiscally sponsored project of NumFOCUS, + a nonprofit dedicated to supporting the open-source scientific computing community.
+ Theme by the Executable Book Project

""", + twitter_url="https://twitter.com/xarray_dev", + icon_links=[], # workaround for pydata/pydata-sphinx-theme#1220 + announcement="Xarray's 2024 User Survey is live now. Please take ~5 minutes to fill it out and help us improve Xarray.", +) + + +# The name of an image file (relative to this directory) to place at the top +# of the sidebar. +html_logo = "_static/logos/Xarray_Logo_RGB_Final.svg" + +# The name of an image file (within the static path) to use as favicon of the +# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 +# pixels large. +html_favicon = "_static/logos/Xarray_Icon_Final.svg" + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ["_static"] +html_css_files = ["style.css"] + + +# configuration for sphinxext.opengraph +ogp_site_url = "https://docs.xarray.dev/en/latest/" +ogp_image = "https://docs.xarray.dev/en/stable/_static/logos/Xarray_Logo_RGB_Final.png" +ogp_custom_meta_tags = [ + '', + '', + '', +] + +# Redirects for pages that were moved to new locations + +rediraffe_redirects = { + "terminology.rst": "user-guide/terminology.rst", + "data-structures.rst": "user-guide/data-structures.rst", + "indexing.rst": "user-guide/indexing.rst", + "interpolation.rst": "user-guide/interpolation.rst", + "computation.rst": "user-guide/computation.rst", + "groupby.rst": "user-guide/groupby.rst", + "reshaping.rst": "user-guide/reshaping.rst", + "combining.rst": "user-guide/combining.rst", + "time-series.rst": "user-guide/time-series.rst", + "weather-climate.rst": "user-guide/weather-climate.rst", + "pandas.rst": "user-guide/pandas.rst", + "io.rst": "user-guide/io.rst", + "dask.rst": "user-guide/dask.rst", + "plotting.rst": "user-guide/plotting.rst", + "duckarrays.rst": "user-guide/duckarrays.rst", + "related-projects.rst": "ecosystem.rst", + "faq.rst": "getting-started-guide/faq.rst", + "why-xarray.rst": "getting-started-guide/why-xarray.rst", + "installing.rst": "getting-started-guide/installing.rst", + "quick-overview.rst": "getting-started-guide/quick-overview.rst", +} + +# Sometimes the savefig directory doesn't exist and needs to be created +# https://github.com/ipython/ipython/issues/8733 +# becomes obsolete when we can pin ipython>=5.2; see ci/requirements/doc.yml +ipython_savefig_dir = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "_build", "html", "_static" +) +if not os.path.exists(ipython_savefig_dir): + os.makedirs(ipython_savefig_dir) + + +# If not '', a 'Last updated on:' timestamp is inserted at every page bottom, +# using the given strftime format. +html_last_updated_fmt = today_fmt + +# Output file base name for HTML help builder. +htmlhelp_basename = "xarraydoc" + + +# Example configuration for intersphinx: refer to the Python standard library. +intersphinx_mapping = { + "cftime": ("https://unidata.github.io/cftime", None), + "cubed": ("https://cubed-dev.github.io/cubed/", None), + "dask": ("https://docs.dask.org/en/latest", None), + "datatree": ("https://xarray-datatree.readthedocs.io/en/latest/", None), + "flox": ("https://flox.readthedocs.io/en/latest/", None), + "hypothesis": ("https://hypothesis.readthedocs.io/en/latest/", None), + "iris": ("https://scitools-iris.readthedocs.io/en/latest", None), + "matplotlib": ("https://matplotlib.org/stable/", None), + "numba": ("https://numba.readthedocs.io/en/stable/", None), + "numpy": ("https://numpy.org/doc/stable", None), + "pandas": ("https://pandas.pydata.org/pandas-docs/stable", None), + "python": ("https://docs.python.org/3/", None), + "scipy": ("https://docs.scipy.org/doc/scipy", None), + "sparse": ("https://sparse.pydata.org/en/latest/", None), + "xarray-tutorial": ("https://tutorial.xarray.dev/", None), + "zarr": ("https://zarr.readthedocs.io/en/latest/", None), +} + + +# based on numpy doc/source/conf.py +def linkcode_resolve(domain, info): + """ + Determine the URL corresponding to Python object + """ + if domain != "py": + return None + + modname = info["module"] + fullname = info["fullname"] + + submod = sys.modules.get(modname) + if submod is None: + return None + + obj = submod + for part in fullname.split("."): + try: + obj = getattr(obj, part) + except AttributeError: + return None + + try: + fn = inspect.getsourcefile(inspect.unwrap(obj)) + except TypeError: + fn = None + if not fn: + return None + + try: + source, lineno = inspect.getsourcelines(obj) + except OSError: + lineno = None + + if lineno: + linespec = f"#L{lineno}-L{lineno + len(source) - 1}" + else: + linespec = "" + + fn = os.path.relpath(fn, start=os.path.dirname(xarray.__file__)) + + if "+" in xarray.__version__: + return f"https://github.com/pydata/xarray/blob/main/xarray/{fn}{linespec}" + else: + return ( + f"https://github.com/pydata/xarray/blob/" + f"v{xarray.__version__}/xarray/{fn}{linespec}" + ) + + +def html_page_context(app, pagename, templatename, context, doctree): + # Disable edit button for docstring generated pages + if "generated" in pagename: + context["theme_use_edit_page_button"] = False + + +def update_gallery(app: Sphinx): + """Update the gallery page.""" + + LOGGER.info("Updating gallery page...") + + gallery = yaml.safe_load(pathlib.Path(app.srcdir, "gallery.yml").read_bytes()) + + for key in gallery: + items = [ + f""" + .. grid-item-card:: + :text-align: center + :link: {item['path']} + + .. image:: {item['thumbnail']} + :alt: {item['title']} + +++ + {item['title']} + """ + for item in gallery[key] + ] + + items_md = indent(dedent("\n".join(items)), prefix=" ") + markdown = f""" +.. grid:: 1 2 2 2 + :gutter: 2 + + {items_md} + """ + pathlib.Path(app.srcdir, f"{key}-gallery.txt").write_text(markdown) + LOGGER.info(f"{key} gallery page updated.") + LOGGER.info("Gallery page updated.") + + +def update_videos(app: Sphinx): + """Update the videos page.""" + + LOGGER.info("Updating videos page...") + + videos = yaml.safe_load(pathlib.Path(app.srcdir, "videos.yml").read_bytes()) + + items = [] + for video in videos: + authors = " | ".join(video["authors"]) + item = f""" +.. grid-item-card:: {" ".join(video["title"].split())} + :text-align: center + + .. raw:: html + + {video['src']} + +++ + {authors} + """ + items.append(item) + + items_md = indent(dedent("\n".join(items)), prefix=" ") + markdown = f""" +.. grid:: 1 2 2 2 + :gutter: 2 + + {items_md} + """ + pathlib.Path(app.srcdir, "videos-gallery.txt").write_text(markdown) + LOGGER.info("Videos page updated.") + + +def setup(app: Sphinx): + app.connect("html-page-context", html_page_context) + app.connect("builder-inited", update_gallery) + app.connect("builder-inited", update_videos) diff --git a/test/fixtures/whole_applications/xarray/doc/contributing.rst b/test/fixtures/whole_applications/xarray/doc/contributing.rst new file mode 100644 index 0000000..c3dc484 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/contributing.rst @@ -0,0 +1,1077 @@ +.. _contributing: + +********************** +Contributing to xarray +********************** + +.. note:: + + Large parts of this document came from the `Pandas Contributing + Guide `_. + +Overview +======== + +We welcome your skills and enthusiasm at the xarray project!. There are numerous opportunities to +contribute beyond just writing code. +All contributions, including bug reports, bug fixes, documentation improvements, enhancement suggestions, +and other ideas are welcome. + +If you have any questions on the process or how to fix something feel free to ask us! +The recommended place to ask a question is on `GitHub Discussions `_ +, but we also have a `Discord `_ and a +`mailing list `_. There is also a +`"python-xarray" tag on Stack Overflow `_ which we monitor for questions. + +We also have a biweekly community call, details of which are announced on the +`Developers meeting `_. +You are very welcome to join! Though we would love to hear from you, there is no expectation to +contribute during the meeting either - you are always welcome to just sit in and listen. + +This project is a community effort, and everyone is welcome to contribute. Everyone within the community +is expected to abide by our `code of conduct `_. + +Where to start? +=============== + +If you are brand new to *xarray* or open-source development, we recommend going +through the `GitHub "issues" tab `_ +to find issues that interest you. +Some issues are particularly suited for new contributors by the label `Documentation `_ +and `good first issue +`_ where you could start out. +These are well documented issues, that do not require a deep understanding of the internals of xarray. + +Once you've found an interesting issue, you can return here to get your development environment setup. +The xarray project does not assign issues. Issues are "assigned" by opening a Pull Request(PR). + +.. _contributing.bug_reports: + +Bug reports and enhancement requests +==================================== + +Bug reports are an important part of making *xarray* more stable. Having a complete bug +report will allow others to reproduce the bug and provide insight into fixing. + +Trying out the bug-producing code on the *main* branch is often a worthwhile exercise +to confirm that the bug still exists. It is also worth searching existing bug reports and +pull requests to see if the issue has already been reported and/or fixed. + +Submitting a bug report +----------------------- + +If you find a bug in the code or documentation, do not hesitate to submit a ticket to the +`Issue Tracker `_. +You are also welcome to post feature requests or pull requests. + +If you are reporting a bug, please use the provided template which includes the following: + +#. Include a short, self-contained Python snippet reproducing the problem. + You can format the code nicely by using `GitHub Flavored Markdown + `_:: + + ```python + import xarray as xr + ds = xr.Dataset(...) + + ... + ``` + +#. Include the full version string of *xarray* and its dependencies. You can use the + built in function:: + + ```python + import xarray as xr + xr.show_versions() + + ... + ``` + +#. Explain why the current behavior is wrong/not desired and what you expect instead. + +The issue will then show up to the *xarray* community and be open to comments/ideas from others. + +See this `stackoverflow article for tips on writing a good bug report `_ . + + +.. _contributing.github: + +Now that you have an issue you want to fix, enhancement to add, or documentation +to improve, you need to learn how to work with GitHub and the *xarray* code base. + +.. _contributing.version_control: + +Version control, Git, and GitHub +================================ + +The code is hosted on `GitHub `_. To +contribute you will need to sign up for a `free GitHub account +`_. We use `Git `_ for +version control to allow many people to work together on the project. + +Some great resources for learning Git: + +* the `GitHub help pages `_. +* the `NumPy's documentation `_. +* Matthew Brett's `Pydagogue `_. + +Getting started with Git +------------------------ + +`GitHub has instructions for setting up Git `__ including installing git, +setting up your SSH key, and configuring git. All these steps need to be completed before +you can work seamlessly between your local repository and GitHub. + +.. note:: + + The following instructions assume you want to learn how to interact with github via the git command-line utility, + but contributors who are new to git may find it easier to use other tools instead such as + `Github Desktop `_. + +Development workflow +==================== + +To keep your work well organized, with readable history, and in turn make it easier for project +maintainers to see what you've done, and why you did it, we recommend you to follow workflow: + +1. `Create an account `_ on GitHub if you do not already have one. + +2. You will need your own fork to work on the code. Go to the `xarray project + page `_ and hit the ``Fork`` button near the top of the page. + This creates a copy of the code under your account on the GitHub server. + +3. Clone your fork to your machine:: + + git clone https://github.com/your-user-name/xarray.git + cd xarray + git remote add upstream https://github.com/pydata/xarray.git + + This creates the directory `xarray` and connects your repository to + the upstream (main project) *xarray* repository. + +Creating a development environment +---------------------------------- + +To test out code changes locally, you'll need to build *xarray* from source, which requires you to +`create a local development environment `_. + +Update the ``main`` branch +-------------------------- + +First make sure you have followed `Setting up xarray for development +`_ + +Before starting a new set of changes, fetch all changes from ``upstream/main``, and start a new +feature branch from that. From time to time you should fetch the upstream changes from GitHub: :: + + git fetch upstream + git merge upstream/main + +This will combine your commits with the latest *xarray* git ``main``. If this +leads to merge conflicts, you must resolve these before submitting your pull +request. If you have uncommitted changes, you will need to ``git stash`` them +prior to updating. This will effectively store your changes, which can be +reapplied after updating. + +Create a new feature branch +--------------------------- + +Create a branch to save your changes, even before you start making changes. You want your +``main branch`` to contain only production-ready code:: + + git checkout -b shiny-new-feature + +This changes your working directory to the ``shiny-new-feature`` branch. Keep any changes in this +branch specific to one bug or feature so it is clear what the branch brings to *xarray*. You can have +many "shiny-new-features" and switch in between them using the ``git checkout`` command. + +Generally, you will want to keep your feature branches on your public GitHub fork of xarray. To do this, +you ``git push`` this new branch up to your GitHub repo. Generally (if you followed the instructions in +these pages, and by default), git will have a link to your fork of the GitHub repo, called ``origin``. +You push up to your own fork with: :: + + git push origin shiny-new-feature + +In git >= 1.7 you can ensure that the link is correctly set by using the ``--set-upstream`` option: :: + + git push --set-upstream origin shiny-new-feature + +From now on git will know that ``shiny-new-feature`` is related to the ``shiny-new-feature branch`` in the GitHub repo. + +The editing workflow +-------------------- + +1. Make some changes + +2. See which files have changed with ``git status``. You'll see a listing like this one: :: + + # On branch shiny-new-feature + # Changed but not updated: + # (use "git add ..." to update what will be committed) + # (use "git checkout -- ..." to discard changes in working directory) + # + # modified: README + +3. Check what the actual changes are with ``git diff``. + +4. Build the `documentation run `_ +for the documentation changes. + +`Run the test suite `_ for code changes. + +Commit and push your changes +---------------------------- + +1. To commit all modified files into the local copy of your repo, do ``git commit -am 'A commit message'``. + +2. To push the changes up to your forked repo on GitHub, do a ``git push``. + +Open a pull request +------------------- + +When you're ready or need feedback on your code, open a Pull Request (PR) so that the xarray developers can +give feedback and eventually include your suggested code into the ``main`` branch. +`Pull requests (PRs) on GitHub `_ +are the mechanism for contributing to xarray's code and documentation. + +Enter a title for the set of changes with some explanation of what you've done. +Follow the PR template, which looks like this. :: + + [ ]Closes #xxxx + [ ]Tests added + [ ]User visible changes (including notable bug fixes) are documented in whats-new.rst + [ ]New functions/methods are listed in api.rst + +Mention anything you'd like particular attention for - such as a complicated change or some code you are not happy with. +If you don't think your request is ready to be merged, just say so in your pull request message and use +the "Draft PR" feature of GitHub. This is a good way of getting some preliminary code review. + +.. _contributing.dev_env: + +Creating a development environment +================================== + +To test out code changes locally, you'll need to build *xarray* from source, which +requires a Python environment. If you're making documentation changes, you can +skip to :ref:`contributing.documentation` but you won't be able to build the +documentation locally before pushing your changes. + +.. note:: + + For small changes, such as fixing a typo, you don't necessarily need to build and test xarray locally. + If you make your changes then :ref:`commit and push them to a new branch `, + xarray's automated :ref:`continuous integration tests ` will run and check your code in various ways. + You can then try to fix these problems by committing and pushing more commits to the same branch. + + You can also avoid building the documentation locally by instead :ref:`viewing the updated documentation via the CI `. + + To speed up this feedback loop or for more complex development tasks you should build and test xarray locally. + + +.. _contributing.dev_python: + +Creating a Python Environment +----------------------------- + +Before starting any development, you'll need to create an isolated xarray +development environment: + +- Install either `Anaconda `_ or `miniconda + `_ +- Make sure your conda is up to date (``conda update conda``) +- Make sure that you have :ref:`cloned the repository ` +- ``cd`` to the *xarray* source directory + +We'll now kick off a two-step process: + +1. Install the build dependencies +2. Build and install xarray + +.. code-block:: sh + + # Create and activate the build environment + conda create -c conda-forge -n xarray-tests python=3.10 + + # This is for Linux and MacOS + conda env update -f ci/requirements/environment.yml + + # On windows, use environment-windows.yml instead + conda env update -f ci/requirements/environment-windows.yml + + conda activate xarray-tests + + # or with older versions of Anaconda: + source activate xarray-tests + + # Build and install xarray + pip install -e . + +At this point you should be able to import *xarray* from your locally +built version: + +.. code-block:: sh + + $ python # start an interpreter + >>> import xarray + >>> xarray.__version__ + '0.10.0+dev46.g015daca' + +This will create the new environment, and not touch any of your existing environments, +nor any existing Python installation. + +To view your environments:: + + conda info -e + +To return to your root environment:: + + conda deactivate + +See the full `conda docs here `__. + +Install pre-commit hooks +------------------------ + +We highly recommend that you setup `pre-commit `_ hooks to automatically +run all the above tools every time you make a git commit. To install the hooks:: + + python -m pip install pre-commit + pre-commit install + +This can be done by running: :: + + pre-commit run + +from the root of the xarray repository. You can skip the pre-commit checks with +``git commit --no-verify``. + +.. _contributing.documentation: + +Contributing to the documentation +================================= + +If you're not the developer type, contributing to the documentation is still of +huge value. You don't even have to be an expert on *xarray* to do so! In fact, +there are sections of the docs that are worse off after being written by +experts. If something in the docs doesn't make sense to you, updating the +relevant section after you figure it out is a great way to ensure it will help +the next person. + +.. contents:: Documentation: + :local: + + +About the *xarray* documentation +-------------------------------- + +The documentation is written in **reStructuredText**, which is almost like writing +in plain English, and built using `Sphinx `__. The +Sphinx Documentation has an excellent `introduction to reST +`__. Review the Sphinx docs to perform more +complex changes to the documentation as well. + +Some other important things to know about the docs: + +- The *xarray* documentation consists of two parts: the docstrings in the code + itself and the docs in this folder ``xarray/doc/``. + + The docstrings are meant to provide a clear explanation of the usage of the + individual functions, while the documentation in this folder consists of + tutorial-like overviews per topic together with some other information + (what's new, installation, etc). + +- The docstrings follow the **NumPy Docstring Standard**, which is used widely + in the Scientific Python community. This standard specifies the format of + the different sections of the docstring. Refer to the `documentation for the Numpy docstring format + `_ + for a detailed explanation, or look at some of the existing functions to + extend it in a similar manner. + +- The tutorials make heavy use of the `ipython directive + `_ sphinx extension. + This directive lets you put code in the documentation which will be run + during the doc build. For example: + + .. code:: rst + + .. ipython:: python + + x = 2 + x**3 + + will be rendered as:: + + In [1]: x = 2 + + In [2]: x**3 + Out[2]: 8 + + Almost all code examples in the docs are run (and the output saved) during the + doc build. This approach means that code examples will always be up to date, + but it does make building the docs a bit more complex. + +- Our API documentation in ``doc/api.rst`` houses the auto-generated + documentation from the docstrings. For classes, there are a few subtleties + around controlling which methods and attributes have pages auto-generated. + + Every method should be included in a ``toctree`` in ``api.rst``, else Sphinx + will emit a warning. + + +How to build the *xarray* documentation +--------------------------------------- + +Requirements +~~~~~~~~~~~~ +Make sure to follow the instructions on :ref:`creating a development environment above `, but +to build the docs you need to use the environment file ``ci/requirements/doc.yml``. +You should also use this environment and these steps if you want to view changes you've made to the docstrings. + +.. code-block:: sh + + # Create and activate the docs environment + conda env create -f ci/requirements/doc.yml + conda activate xarray-docs + + # or with older versions of Anaconda: + source activate xarray-docs + + # Build and install a local, editable version of xarray + pip install -e . + +Building the documentation +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To build the documentation run:: + + cd doc/ + make html + +Then you can find the HTML output files in the folder ``xarray/doc/_build/html/``. + +To see what the documentation now looks like with your changes, you can view the HTML build locally by opening the files in your local browser. +For example, if you normally use Google Chrome as your browser, you could enter:: + + google-chrome _build/html/quick-overview.html + +in the terminal, running from within the ``doc/`` folder. +You should now see a new tab pop open in your local browser showing the ``quick-overview`` page of the documentation. +The different pages of this local build of the documentation are linked together, +so you can browse the whole documentation by following links the same way you would on the officially-hosted xarray docs site. + +The first time you build the docs, it will take quite a while because it has to run +all the code examples and build all the generated docstring pages. In subsequent +evocations, Sphinx will try to only build the pages that have been modified. + +If you want to do a full clean build, do:: + + make clean + make html + +Writing ReST pages +------------------ + +Most documentation is either in the docstrings of individual classes and methods, in explicit +``.rst`` files, or in examples and tutorials. All of these use the +`ReST `_ syntax and are processed by +`Sphinx `_. + +This section contains additional information and conventions how ReST is used in the +xarray documentation. + +Section formatting +~~~~~~~~~~~~~~~~~~ + +We aim to follow the recommendations from the +`Python documentation `_ +and the `Sphinx reStructuredText documentation `_ +for section markup characters, + +- ``*`` with overline, for chapters + +- ``=``, for heading + +- ``-``, for sections + +- ``~``, for subsections + +- ``**`` text ``**``, for **bold** text + +Referring to other documents and sections +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +`Sphinx `_ allows internal +`references `_ between documents. + +Documents can be linked with the ``:doc:`` directive: + +:: + + See the :doc:`/getting-started-guide/installing` + + See the :doc:`/getting-started-guide/quick-overview` + +will render as: + +See the `Installation `_ + +See the `Quick Overview `_ + +Including figures and files +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Image files can be directly included in pages with the ``image::`` directive. + +.. _contributing.code: + +Contributing to the code base +============================= + +.. contents:: Code Base: + :local: + +Code standards +-------------- + +Writing good code is not just about what you write. It is also about *how* you +write it. During :ref:`Continuous Integration ` testing, several +tools will be run to check your code for stylistic errors. +Generating any warnings will cause the test to fail. +Thus, good style is a requirement for submitting code to *xarray*. + +In addition, because a lot of people use our library, it is important that we +do not make sudden changes to the code that could have the potential to break +a lot of user code as a result, that is, we need it to be as *backwards compatible* +as possible to avoid mass breakages. + +Code Formatting +~~~~~~~~~~~~~~~ + +xarray uses several tools to ensure a consistent code format throughout the project: + +- `Black `_ for standardized + code formatting, +- `blackdoc `_ for + standardized code formatting in documentation, +- `ruff `_ for code quality checks and standardized order in imports +- `absolufy-imports `_ for absolute instead of relative imports from different files, +- `mypy `_ for static type checking on `type hints + `_. + +We highly recommend that you setup `pre-commit hooks `_ +to automatically run all the above tools every time you make a git commit. This +can be done by running:: + + pre-commit install + +from the root of the xarray repository. You can skip the pre-commit checks +with ``git commit --no-verify``. + + +Backwards Compatibility +~~~~~~~~~~~~~~~~~~~~~~~ + +Please try to maintain backwards compatibility. *xarray* has a growing number of users with +lots of existing code, so don't break it if at all possible. If you think breakage is +required, clearly state why as part of the pull request. + +Be especially careful when changing function and method signatures, because any change +may require a deprecation warning. For example, if your pull request means that the +argument ``old_arg`` to ``func`` is no longer valid, instead of simply raising an error if +a user passes ``old_arg``, we would instead catch it: + +.. code-block:: python + + def func(new_arg, old_arg=None): + if old_arg is not None: + from warnings import warn + + warn( + "`old_arg` has been deprecated, and in the future will raise an error." + "Please use `new_arg` from now on.", + DeprecationWarning, + ) + + # Still do what the user intended here + +This temporary check would then be removed in a subsequent version of xarray. +This process of first warning users before actually breaking their code is known as a +"deprecation cycle", and makes changes significantly easier to handle both for users +of xarray, and for developers of other libraries that depend on xarray. + + +.. _contributing.ci: + +Testing With Continuous Integration +----------------------------------- + +The *xarray* test suite runs automatically via the +`GitHub Actions `__, +continuous integration service, once your pull request is submitted. + +A pull-request will be considered for merging when you have an all 'green' build. If any +tests are failing, then you will get a red 'X', where you can click through to see the +individual failed tests. This is an example of a green build. + +.. image:: _static/ci.png + +.. note:: + + Each time you push to your PR branch, a new run of the tests will be + triggered on the CI. If they haven't already finished, tests for any older + commits on the same branch will be automatically cancelled. + +.. _contributing.tdd: + + +Test-driven development/code writing +------------------------------------ + +*xarray* is serious about testing and strongly encourages contributors to embrace +`test-driven development (TDD) `_. +This development process "relies on the repetition of a very short development cycle: +first the developer writes an (initially failing) automated test case that defines a desired +improvement or new function, then produces the minimum amount of code to pass that test." +So, before actually writing any code, you should write your tests. Often the test can be +taken from the original GitHub issue. However, it is always worth considering additional +use cases and writing corresponding tests. + +Adding tests is one of the most common requests after code is pushed to *xarray*. Therefore, +it is worth getting in the habit of writing tests ahead of time so that this is never an issue. + +Like many packages, *xarray* uses `pytest +`_ and the convenient +extensions in `numpy.testing +`_. + +Writing tests +~~~~~~~~~~~~~ + +All tests should go into the ``tests`` subdirectory of the specific package. +This folder contains many current examples of tests, and we suggest looking to these for +inspiration. + +The ``xarray.testing`` module has many special ``assert`` functions that +make it easier to make statements about whether DataArray or Dataset objects are +equivalent. The easiest way to verify that your code is correct is to +explicitly construct the result you expect, then compare the actual result to +the expected correct result:: + + def test_constructor_from_0d(): + expected = Dataset({None: ([], 0)})[None] + actual = DataArray(0) + assert_identical(expected, actual) + +Transitioning to ``pytest`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +*xarray* existing test structure is *mostly* class-based, meaning that you will +typically find tests wrapped in a class. + +.. code-block:: python + + class TestReallyCoolFeature: ... + +Going forward, we are moving to a more *functional* style using the +`pytest `__ framework, which offers a richer +testing framework that will facilitate testing and developing. Thus, instead of +writing test classes, we will write test functions like this: + +.. code-block:: python + + def test_really_cool_feature(): ... + +Using ``pytest`` +~~~~~~~~~~~~~~~~ + +Here is an example of a self-contained set of tests that illustrate multiple +features that we like to use. + +- functional style: tests are like ``test_*`` and *only* take arguments that are either + fixtures or parameters +- ``pytest.mark`` can be used to set metadata on test functions, e.g. ``skip`` or ``xfail``. +- using ``parametrize``: allow testing of multiple cases +- to set a mark on a parameter, ``pytest.param(..., marks=...)`` syntax should be used +- ``fixture``, code for object construction, on a per-test basis +- using bare ``assert`` for scalars and truth-testing +- ``assert_equal`` and ``assert_identical`` from the ``xarray.testing`` module for xarray object comparisons. +- the typical pattern of constructing an ``expected`` and comparing versus the ``result`` + +We would name this file ``test_cool_feature.py`` and put in an appropriate place in the +``xarray/tests/`` structure. + +.. code-block:: python + + import pytest + import numpy as np + import xarray as xr + from xarray.testing import assert_equal + + + @pytest.mark.parametrize("dtype", ["int8", "int16", "int32", "int64"]) + def test_dtypes(dtype): + assert str(np.dtype(dtype)) == dtype + + + @pytest.mark.parametrize( + "dtype", + [ + "float32", + pytest.param("int16", marks=pytest.mark.skip), + pytest.param( + "int32", marks=pytest.mark.xfail(reason="to show how it works") + ), + ], + ) + def test_mark(dtype): + assert str(np.dtype(dtype)) == "float32" + + + @pytest.fixture + def dataarray(): + return xr.DataArray([1, 2, 3]) + + + @pytest.fixture(params=["int8", "int16", "int32", "int64"]) + def dtype(request): + return request.param + + + def test_series(dataarray, dtype): + result = dataarray.astype(dtype) + assert result.dtype == dtype + + expected = xr.DataArray(np.array([1, 2, 3], dtype=dtype)) + assert_equal(result, expected) + + + +A test run of this yields + +.. code-block:: shell + + ((xarray) $ pytest test_cool_feature.py -v + ================================= test session starts ================================== + platform darwin -- Python 3.10.6, pytest-7.2.0, pluggy-1.0.0 -- + cachedir: .pytest_cache + plugins: hypothesis-6.56.3, cov-4.0.0 + collected 11 items + + xarray/tests/test_cool_feature.py::test_dtypes[int8] PASSED [ 9%] + xarray/tests/test_cool_feature.py::test_dtypes[int16] PASSED [ 18%] + xarray/tests/test_cool_feature.py::test_dtypes[int32] PASSED [ 27%] + xarray/tests/test_cool_feature.py::test_dtypes[int64] PASSED [ 36%] + xarray/tests/test_cool_feature.py::test_mark[float32] PASSED [ 45%] + xarray/tests/test_cool_feature.py::test_mark[int16] SKIPPED (unconditional skip) [ 54%] + xarray/tests/test_cool_feature.py::test_mark[int32] XFAIL (to show how it works) [ 63%] + xarray/tests/test_cool_feature.py::test_series[int8] PASSED [ 72%] + xarray/tests/test_cool_feature.py::test_series[int16] PASSED [ 81%] + xarray/tests/test_cool_feature.py::test_series[int32] PASSED [ 90%] + xarray/tests/test_cool_feature.py::test_series[int64] PASSED [100%] + + + ==================== 9 passed, 1 skipped, 1 xfailed in 1.83 seconds ==================== + +Tests that we have ``parametrized`` are now accessible via the test name, for +example we could run these with ``-k int8`` to sub-select *only* those tests +which match ``int8``. + + +.. code-block:: shell + + ((xarray) bash-3.2$ pytest test_cool_feature.py -v -k int8 + ================================== test session starts ================================== + platform darwin -- Python 3.10.6, pytest-7.2.0, pluggy-1.0.0 -- + cachedir: .pytest_cache + plugins: hypothesis-6.56.3, cov-4.0.0 + collected 11 items + + test_cool_feature.py::test_dtypes[int8] PASSED + test_cool_feature.py::test_series[int8] PASSED + + +Running the test suite +---------------------- + +The tests can then be run directly inside your Git clone (without having to +install *xarray*) by typing:: + + pytest xarray + +The tests suite is exhaustive and takes a few minutes. Often it is +worth running only a subset of tests first around your changes before running the +entire suite. + +The easiest way to do this is with:: + + pytest xarray/path/to/test.py -k regex_matching_test_name + +Or with one of the following constructs:: + + pytest xarray/tests/[test-module].py + pytest xarray/tests/[test-module].py::[TestClass] + pytest xarray/tests/[test-module].py::[TestClass]::[test_method] + +Using `pytest-xdist `_, one can +speed up local testing on multicore machines, by running pytest with the optional -n argument:: + + pytest xarray -n 4 + +This can significantly reduce the time it takes to locally run tests before +submitting a pull request. + +For more, see the `pytest `_ documentation. + +Running the performance test suite +---------------------------------- + +Performance matters and it is worth considering whether your code has introduced +performance regressions. *xarray* is starting to write a suite of benchmarking tests +using `asv `__ +to enable easy monitoring of the performance of critical *xarray* operations. +These benchmarks are all found in the ``xarray/asv_bench`` directory. + +To use all features of asv, you will need either ``conda`` or +``virtualenv``. For more details please check the `asv installation +webpage `_. + +To install asv:: + + python -m pip install asv + +If you need to run a benchmark, change your directory to ``asv_bench/`` and run:: + + asv continuous -f 1.1 upstream/main HEAD + +You can replace ``HEAD`` with the name of the branch you are working on, +and report benchmarks that changed by more than 10%. +The command uses ``conda`` by default for creating the benchmark +environments. If you want to use virtualenv instead, write:: + + asv continuous -f 1.1 -E virtualenv upstream/main HEAD + +The ``-E virtualenv`` option should be added to all ``asv`` commands +that run benchmarks. The default value is defined in ``asv.conf.json``. + +Running the full benchmark suite can take up to one hour and use up a few GBs of RAM. +Usually it is sufficient to paste only a subset of the results into the pull +request to show that the committed changes do not cause unexpected performance +regressions. You can run specific benchmarks using the ``-b`` flag, which +takes a regular expression. For example, this will only run tests from a +``xarray/asv_bench/benchmarks/groupby.py`` file:: + + asv continuous -f 1.1 upstream/main HEAD -b ^groupby + +If you want to only run a specific group of tests from a file, you can do it +using ``.`` as a separator. For example:: + + asv continuous -f 1.1 upstream/main HEAD -b groupby.GroupByMethods + +will only run the ``GroupByMethods`` benchmark defined in ``groupby.py``. + +You can also run the benchmark suite using the version of *xarray* +already installed in your current Python environment. This can be +useful if you do not have ``virtualenv`` or ``conda``, or are using the +``setup.py develop`` approach discussed above; for the in-place build +you need to set ``PYTHONPATH``, e.g. +``PYTHONPATH="$PWD/.." asv [remaining arguments]``. +You can run benchmarks using an existing Python +environment by:: + + asv run -e -E existing + +or, to use a specific Python interpreter,:: + + asv run -e -E existing:python3.10 + +This will display stderr from the benchmarks, and use your local +``python`` that comes from your ``$PATH``. + +Learn `how to write a benchmark and how to use asv from the documentation `_ . + + +.. + TODO: uncomment once we have a working setup + see https://github.com/pydata/xarray/pull/5066 + + The *xarray* benchmarking suite is run remotely and the results are + available `here `_. + +Documenting your code +--------------------- + +Changes should be reflected in the release notes located in ``doc/whats-new.rst``. +This file contains an ongoing change log for each release. Add an entry to this file to +document your fix, enhancement or (unavoidable) breaking change. Make sure to include the +GitHub issue number when adding your entry (using ``:issue:`1234```, where ``1234`` is the +issue/pull request number). + +If your code is an enhancement, it is most likely necessary to add usage +examples to the existing documentation. This can be done by following the :ref:`guidelines for contributing to the documentation `. + +.. _contributing.changes: + +Contributing your changes to *xarray* +===================================== + +.. _contributing.committing: + +Committing your code +-------------------- + +Keep style fixes to a separate commit to make your pull request more readable. + +Once you've made changes, you can see them by typing:: + + git status + +If you have created a new file, it is not being tracked by git. Add it by typing:: + + git add path/to/file-to-be-added.py + +Doing 'git status' again should give something like:: + + # On branch shiny-new-feature + # + # modified: /relative/path/to/file-you-added.py + # + +The following defines how a commit message should ideally be structured: + +* A subject line with `< 72` chars. +* One blank line. +* Optionally, a commit message body. + +Please reference the relevant GitHub issues in your commit message using ``GH1234`` or +``#1234``. Either style is fine, but the former is generally preferred. + +Now you can commit your changes in your local repository:: + + git commit -m + + +.. _contributing.pushing: + +Pushing your changes +-------------------- + +When you want your changes to appear publicly on your GitHub page, push your +forked feature branch's commits:: + + git push origin shiny-new-feature + +Here ``origin`` is the default name given to your remote repository on GitHub. +You can see the remote repositories:: + + git remote -v + +If you added the upstream repository as described above you will see something +like:: + + origin git@github.com:yourname/xarray.git (fetch) + origin git@github.com:yourname/xarray.git (push) + upstream git://github.com/pydata/xarray.git (fetch) + upstream git://github.com/pydata/xarray.git (push) + +Now your code is on GitHub, but it is not yet a part of the *xarray* project. For that to +happen, a pull request needs to be submitted on GitHub. + +.. _contributing.review: + +Review your code +---------------- + +When you're ready to ask for a code review, file a pull request. Before you do, once +again make sure that you have followed all the guidelines outlined in this document +regarding code style, tests, performance tests, and documentation. You should also +double check your branch changes against the branch it was based on: + +#. Navigate to your repository on GitHub -- https://github.com/your-user-name/xarray +#. Click on ``Branches`` +#. Click on the ``Compare`` button for your feature branch +#. Select the ``base`` and ``compare`` branches, if necessary. This will be ``main`` and + ``shiny-new-feature``, respectively. + +.. _contributing.pr: + +Finally, make the pull request +------------------------------ + +If everything looks good, you are ready to make a pull request. A pull request is how +code from a local repository becomes available to the GitHub community and can be looked +at and eventually merged into the ``main`` version. This pull request and its associated +changes will eventually be committed to the ``main`` branch and available in the next +release. To submit a pull request: + +#. Navigate to your repository on GitHub +#. Click on the ``Pull Request`` button +#. You can then click on ``Commits`` and ``Files Changed`` to make sure everything looks + okay one last time +#. Write a description of your changes in the ``Preview Discussion`` tab +#. Click ``Send Pull Request``. + +This request then goes to the repository maintainers, and they will review +the code. + +If you have made updates to the documentation, you can now see a preview of the updated docs by clicking on "Details" under +the ``docs/readthedocs.org`` check near the bottom of the list of checks that run automatically when submitting a PR, +then clicking on the "View Docs" button on the right (not the big green button, the small black one further down). + +.. image:: _static/view-docs.png + + +If you need to make more changes, you can make them in +your branch, add them to a new commit, push them to GitHub, and the pull request +will automatically be updated. Pushing them to GitHub again is done by:: + + git push origin shiny-new-feature + +This will automatically update your pull request with the latest code and restart the +:ref:`Continuous Integration ` tests. + + +.. _contributing.delete: + +Delete your merged branch (optional) +------------------------------------ + +Once your feature branch is accepted into upstream, you'll probably want to get rid of +the branch. First, update your ``main`` branch to check that the merge was successful:: + + git fetch upstream + git checkout main + git merge upstream/main + +Then you can do:: + + git branch -D shiny-new-feature + +You need to use a upper-case ``-D`` because the branch was squashed into a +single commit before merging. Be careful with this because ``git`` won't warn +you if you accidentally delete an unmerged branch. + +If you didn't delete your branch using GitHub's interface, then it will still exist on +GitHub. To delete it there do:: + + git push origin --delete shiny-new-feature + + +.. _contributing.checklist: + +PR checklist +------------ + +- **Properly comment and document your code.** See `"Documenting your code" `_. +- **Test that the documentation builds correctly** by typing ``make html`` in the ``doc`` directory. This is not strictly necessary, but this may be easier than waiting for CI to catch a mistake. See `"Contributing to the documentation" `_. +- **Test your code**. + + - Write new tests if needed. See `"Test-driven development/code writing" `_. + - Test the code using `Pytest `_. Running all tests (type ``pytest`` in the root directory) takes a while, so feel free to only run the tests you think are needed based on your PR (example: ``pytest xarray/tests/test_dataarray.py``). CI will catch any failing tests. + - By default, the upstream dev CI is disabled on pull request and push events. You can override this behavior per commit by adding a ``[test-upstream]`` tag to the first line of the commit message. For documentation-only commits, you can skip the CI per commit by adding a ``[skip-ci]`` tag to the first line of the commit message. + +- **Properly format your code** and verify that it passes the formatting guidelines set by `Black `_ and `Flake8 `_. See `"Code formatting" `_. You can use `pre-commit `_ to run these automatically on each commit. + + - Run ``pre-commit run --all-files`` in the root directory. This may modify some files. Confirm and commit any formatting changes. + +- **Push your code** and `create a PR on GitHub `_. +- **Use a helpful title for your pull request** by summarizing the main contributions rather than using the latest commit message. If the PR addresses an `issue `_, please `reference it `_. diff --git a/test/fixtures/whole_applications/xarray/doc/developers-meeting.rst b/test/fixtures/whole_applications/xarray/doc/developers-meeting.rst new file mode 100644 index 0000000..153f352 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/developers-meeting.rst @@ -0,0 +1,20 @@ +Developers meeting +------------------ + +Xarray developers meet bi-weekly every other Wednesday. + +The meeting occurs on `Zoom `__. + +Find the `notes for the meeting here `__. + +There is a :issue:`GitHub issue for changes to the meeting<4001>`. + +You can subscribe to this calendar to be notified of changes: + +* `Google Calendar `__ +* `iCal `__ + +.. raw:: html + + + diff --git a/test/fixtures/whole_applications/xarray/doc/ecosystem.rst b/test/fixtures/whole_applications/xarray/doc/ecosystem.rst new file mode 100644 index 0000000..63f60cd --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/ecosystem.rst @@ -0,0 +1,105 @@ +.. _ecosystem: + +Xarray related projects +----------------------- + +Below is a list of existing open source projects that build +functionality upon xarray. See also section :ref:`internals` for more +details on how to build xarray extensions. We also maintain the +`xarray-contrib `_ GitHub organization +as a place to curate projects that build upon xarray. + +Geosciences +~~~~~~~~~~~ + +- `aospy `_: Automated analysis and management of gridded climate data. +- `argopy `_: xarray-based Argo data access, manipulation and visualisation for standard users as well as Argo experts. +- `climpred `_: Analysis of ensemble forecast models for climate prediction. +- `geocube `_: Tool to convert geopandas vector data into rasterized xarray data. +- `GeoWombat `_: Utilities for analysis of remotely sensed and gridded raster data at scale (easily tame Landsat, Sentinel, Quickbird, and PlanetScope). +- `gsw-xarray `_: a wrapper around `gsw `_ that adds CF compliant attributes when possible, units, name. +- `infinite-diff `_: xarray-based finite-differencing, focused on gridded climate/meteorology data +- `marc_analysis `_: Analysis package for CESM/MARC experiments and output. +- `MetPy `_: A collection of tools in Python for reading, visualizing, and performing calculations with weather data. +- `MPAS-Analysis `_: Analysis for simulations produced with Model for Prediction Across Scales (MPAS) components and the Accelerated Climate Model for Energy (ACME). +- `OGGM `_: Open Global Glacier Model +- `Oocgcm `_: Analysis of large gridded geophysical datasets +- `Open Data Cube `_: Analysis toolkit of continental scale Earth Observation data from satellites. +- `Pangaea: `_: xarray extension for gridded land surface & weather model output). +- `Pangeo `_: A community effort for big data geoscience in the cloud. +- `PyGDX `_: Python 3 package for + accessing data stored in GAMS Data eXchange (GDX) files. Also uses a custom + subclass. +- `pyinterp `_: Python 3 package for interpolating geo-referenced data used in the field of geosciences. +- `pyXpcm `_: xarray-based Profile Classification Modelling (PCM), mostly for ocean data. +- `Regionmask `_: plotting and creation of masks of spatial regions +- `rioxarray `_: geospatial xarray extension powered by rasterio +- `salem `_: Adds geolocalised subsetting, masking, and plotting operations to xarray's data structures via accessors. +- `SatPy `_ : Library for reading and manipulating meteorological remote sensing data and writing it to various image and data file formats. +- `SARXarray `_: xarray extension for reading and processing large Synthetic Aperture Radar (SAR) data stacks. +- `Spyfit `_: FTIR spectroscopy of the atmosphere +- `windspharm `_: Spherical + harmonic wind analysis in Python. +- `wradlib `_: An Open Source Library for Weather Radar Data Processing. +- `wrf-python `_: A collection of diagnostic and interpolation routines for use with output of the Weather Research and Forecasting (WRF-ARW) Model. +- `xarray-regrid `_: xarray extension for regridding rectilinear data. +- `xarray-simlab `_: xarray extension for computer model simulations. +- `xarray-spatial `_: Numba-accelerated raster-based spatial processing tools (NDVI, curvature, zonal-statistics, proximity, hillshading, viewshed, etc.) +- `xarray-topo `_: xarray extension for topographic analysis and modelling. +- `xbpch `_: xarray interface for bpch files. +- `xCDAT `_: An extension of xarray for climate data analysis on structured grids. +- `xclim `_: A library for calculating climate science indices with unit handling built from xarray and dask. +- `xESMF `_: Universal regridder for geospatial data. +- `xgcm `_: Extends the xarray data model to understand finite volume grid cells (common in General Circulation Models) and provides interpolation and difference operations for such grids. +- `xmitgcm `_: a python package for reading `MITgcm `_ binary MDS files into xarray data structures. +- `xnemogcm `_: a package to read `NEMO `_ output files and add attributes to interface with xgcm. + +Machine Learning +~~~~~~~~~~~~~~~~ +- `ArviZ `_: Exploratory analysis of Bayesian models, built on top of xarray. +- `Darts `_: User-friendly modern machine learning for time series in Python. +- `Elm `_: Parallel machine learning on xarray data structures +- `sklearn-xarray (1) `_: Combines scikit-learn and xarray (1). +- `sklearn-xarray (2) `_: Combines scikit-learn and xarray (2). +- `xbatcher `_: Batch Generation from Xarray Datasets. + +Other domains +~~~~~~~~~~~~~ +- `ptsa `_: EEG Time Series Analysis +- `pycalphad `_: Computational Thermodynamics in Python +- `pyomeca `_: Python framework for biomechanical analysis + +Extend xarray capabilities +~~~~~~~~~~~~~~~~~~~~~~~~~~ +- `Collocate `_: Collocate xarray trajectories in arbitrary physical dimensions +- `eofs `_: EOF analysis in Python. +- `hypothesis-gufunc `_: Extension to hypothesis. Makes it easy to write unit tests with xarray objects as input. +- `ntv-pandas `_ : A tabular analyzer and a semantic, compact and reversible converter for multidimensional and tabular data +- `nxarray `_: NeXus input/output capability for xarray. +- `xarray-compare `_: xarray extension for data comparison. +- `xarray-dataclasses `_: xarray extension for typed DataArray and Dataset creation. +- `xarray_einstats `_: Statistics, linear algebra and einops for xarray +- `xarray_extras `_: Advanced algorithms for xarray objects (e.g. integrations/interpolations). +- `xeofs `_: PCA/EOF analysis and related techniques, integrated with xarray and Dask for efficient handling of large-scale data. +- `xpublish `_: Publish Xarray Datasets via a Zarr compatible REST API. +- `xrft `_: Fourier transforms for xarray data. +- `xr-scipy `_: A lightweight scipy wrapper for xarray. +- `X-regression `_: Multiple linear regression from Statsmodels library coupled with Xarray library. +- `xskillscore `_: Metrics for verifying forecasts. +- `xyzpy `_: Easily generate high dimensional data, including parallelization. + +Visualization +~~~~~~~~~~~~~ +- `datashader `_, `geoviews `_, `holoviews `_, : visualization packages for large data. +- `hvplot `_ : A high-level plotting API for the PyData ecosystem built on HoloViews. +- `psyplot `_: Interactive data visualization with python. +- `xarray-leaflet `_: An xarray extension for tiled map plotting based on ipyleaflet. +- `xtrude `_: An xarray extension for 3D terrain visualization based on pydeck. +- `pyvista-xarray `_: xarray DataArray accessor for 3D visualization with `PyVista `_ and DataSet engines for reading VTK data formats. + +Non-Python projects +~~~~~~~~~~~~~~~~~~~ +- `xframe `_: C++ data structures inspired by xarray. +- `AxisArrays `_, `NamedArrays `_ and `YAXArrays.jl `_: similar data structures for Julia. + +More projects can be found at the `"xarray" Github topic `_. diff --git a/test/fixtures/whole_applications/xarray/doc/examples/ERA5-GRIB-example.ipynb b/test/fixtures/whole_applications/xarray/doc/examples/ERA5-GRIB-example.ipynb new file mode 100644 index 0000000..5d09f1a --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/examples/ERA5-GRIB-example.ipynb @@ -0,0 +1,124 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# GRIB Data Example " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "GRIB format is commonly used to disseminate atmospheric model data. With xarray and the cfgrib engine, GRIB data can easily be analyzed and visualized." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import xarray as xr\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To read GRIB data, you can use `xarray.load_dataset`. The only extra code you need is to specify the engine as `cfgrib`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds = xr.tutorial.load_dataset(\"era5-2mt-2019-03-uk.grib\", engine=\"cfgrib\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's create a simple plot of 2-m air temperature in degrees Celsius:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds = ds - 273.15\n", + "ds.t2m[0].plot(cmap=plt.cm.coolwarm)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "With CartoPy, we can create a more detailed plot, using built-in shapefiles to help provide geographic context:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import cartopy.crs as ccrs\n", + "import cartopy\n", + "\n", + "fig = plt.figure(figsize=(10, 10))\n", + "ax = plt.axes(projection=ccrs.Robinson())\n", + "ax.coastlines(resolution=\"10m\")\n", + "plot = ds.t2m[0].plot(\n", + " cmap=plt.cm.coolwarm, transform=ccrs.PlateCarree(), cbar_kwargs={\"shrink\": 0.6}\n", + ")\n", + "plt.title(\"ERA5 - 2m temperature British Isles March 2019\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, we can also pull out a time series for a given location easily:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds.t2m.sel(longitude=0, latitude=51.5).plot()\n", + "plt.title(\"ERA5 - London 2m temperature March 2019\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/test/fixtures/whole_applications/xarray/doc/examples/ROMS_ocean_model.ipynb b/test/fixtures/whole_applications/xarray/doc/examples/ROMS_ocean_model.ipynb new file mode 100644 index 0000000..d5c7638 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/examples/ROMS_ocean_model.ipynb @@ -0,0 +1,229 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# ROMS Ocean Model Example" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The Regional Ocean Modeling System ([ROMS](http://myroms.org)) is an open source hydrodynamic model that is used for simulating currents and water properties in coastal and estuarine regions. ROMS is one of a few standard ocean models, and it has an active user community.\n", + "\n", + "ROMS uses a regular C-Grid in the horizontal, similar to other structured grid ocean and atmospheric models, and a stretched vertical coordinate (see [the ROMS documentation](https://www.myroms.org/wiki/Vertical_S-coordinate) for more details). Both of these require special treatment when using `xarray` to analyze ROMS ocean model output. This example notebook shows how to create a lazily evaluated vertical coordinate, and make some basic plots. The `xgcm` package is required to do analysis that is aware of the horizontal C-Grid." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import cartopy.crs as ccrs\n", + "import cartopy.feature as cfeature\n", + "import matplotlib.pyplot as plt\n", + "\n", + "%matplotlib inline\n", + "\n", + "import xarray as xr" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load a sample ROMS file. This is a subset of a full model available at \n", + "\n", + " http://barataria.tamu.edu/thredds/catalog.html?dataset=txla_hindcast_agg\n", + " \n", + "The subsetting was done using the following command on one of the output files:\n", + "\n", + " #open dataset\n", + " ds = xr.open_dataset('/d2/shared/TXLA_ROMS/output_20yr_obc/2001/ocean_his_0015.nc')\n", + " \n", + " # Turn on chunking to activate dask and parallelize read/write.\n", + " ds = ds.chunk({'ocean_time': 1})\n", + " \n", + " # Pick out some of the variables that will be included as coordinates\n", + " ds = ds.set_coords(['Cs_r', 'Cs_w', 'hc', 'h', 'Vtransform'])\n", + " \n", + " # Select a a subset of variables. Salt will be visualized, zeta is used to \n", + " # calculate the vertical coordinate\n", + " variables = ['salt', 'zeta']\n", + " ds[variables].isel(ocean_time=slice(47, None, 7*24), \n", + " xi_rho=slice(300, None)).to_netcdf('ROMS_example.nc', mode='w')\n", + "\n", + "So, the `ROMS_example.nc` file contains a subset of the grid, one 3D variable, and two time steps." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load in ROMS dataset as an xarray object" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# load in the file\n", + "ds = xr.tutorial.open_dataset(\"ROMS_example.nc\", chunks={\"ocean_time\": 1})\n", + "\n", + "# This is a way to turn on chunking and lazy evaluation. Opening with mfdataset, or\n", + "# setting the chunking in the open_dataset would also achieve this.\n", + "ds" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Add a lazilly calculated vertical coordinates\n", + "\n", + "Write equations to calculate the vertical coordinate. These will be only evaluated when data is requested. Information about the ROMS vertical coordinate can be found (here)[https://www.myroms.org/wiki/Vertical_S-coordinate]\n", + "\n", + "In short, for `Vtransform==2` as used in this example, \n", + "\n", + "$Z_0 = (h_c \\, S + h \\,C) / (h_c + h)$\n", + "\n", + "$z = Z_0 (\\zeta + h) + \\zeta$\n", + "\n", + "where the variables are defined as in the link above." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if ds.Vtransform == 1:\n", + " Zo_rho = ds.hc * (ds.s_rho - ds.Cs_r) + ds.Cs_r * ds.h\n", + " z_rho = Zo_rho + ds.zeta * (1 + Zo_rho / ds.h)\n", + "elif ds.Vtransform == 2:\n", + " Zo_rho = (ds.hc * ds.s_rho + ds.Cs_r * ds.h) / (ds.hc + ds.h)\n", + " z_rho = ds.zeta + (ds.zeta + ds.h) * Zo_rho\n", + "\n", + "ds.coords[\"z_rho\"] = z_rho.transpose() # needing transpose seems to be an xarray bug\n", + "ds.salt" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### A naive vertical slice\n", + "\n", + "Creating a slice using the s-coordinate as the vertical dimension is typically not very informative." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "ds.salt.isel(xi_rho=50, ocean_time=0).plot()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can feed coordinate information to the plot method to give a more informative cross-section that uses the depths. Note that we did not need to slice the depth or longitude information separately, this was done automatically as the variable was sliced." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "section = ds.salt.isel(xi_rho=50, eta_rho=slice(0, 167), ocean_time=0)\n", + "section.plot(x=\"lon_rho\", y=\"z_rho\", figsize=(15, 6), clim=(25, 35))\n", + "plt.ylim([-100, 1]);" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### A plan view\n", + "\n", + "Now make a naive plan view, without any projection information, just using lon/lat as x/y. This looks OK, but will appear compressed because lon and lat do not have an aspect constrained by the projection." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds.salt.isel(s_rho=-1, ocean_time=0).plot(x=\"lon_rho\", y=\"lat_rho\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And let's use a projection to make it nicer, and add a coast." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "proj = ccrs.LambertConformal(central_longitude=-92, central_latitude=29)\n", + "fig = plt.figure(figsize=(15, 5))\n", + "ax = plt.axes(projection=proj)\n", + "ds.salt.isel(s_rho=-1, ocean_time=0).plot(\n", + " x=\"lon_rho\", y=\"lat_rho\", transform=ccrs.PlateCarree()\n", + ")\n", + "\n", + "coast_10m = cfeature.NaturalEarthFeature(\n", + " \"physical\", \"land\", \"10m\", edgecolor=\"k\", facecolor=\"0.8\"\n", + ")\n", + "ax.add_feature(coast_10m)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/test/fixtures/whole_applications/xarray/doc/examples/_code/accessor_example.py b/test/fixtures/whole_applications/xarray/doc/examples/_code/accessor_example.py new file mode 100644 index 0000000..ffbacb4 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/examples/_code/accessor_example.py @@ -0,0 +1,23 @@ +import xarray as xr + + +@xr.register_dataset_accessor("geo") +class GeoAccessor: + def __init__(self, xarray_obj): + self._obj = xarray_obj + self._center = None + + @property + def center(self): + """Return the geographic center point of this dataset.""" + if self._center is None: + # we can use a cache on our accessor objects, because accessors + # themselves are cached on instances that access them. + lon = self._obj.latitude + lat = self._obj.longitude + self._center = (float(lon.mean()), float(lat.mean())) + return self._center + + def plot(self): + """Plot data on a map.""" + return "plotting!" diff --git a/test/fixtures/whole_applications/xarray/doc/examples/apply_ufunc_vectorize_1d.ipynb b/test/fixtures/whole_applications/xarray/doc/examples/apply_ufunc_vectorize_1d.ipynb new file mode 100644 index 0000000..c2ab727 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/examples/apply_ufunc_vectorize_1d.ipynb @@ -0,0 +1,737 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Applying unvectorized functions with `apply_ufunc`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This example will illustrate how to conveniently apply an unvectorized function `func` to xarray objects using `apply_ufunc`. `func` expects 1D numpy arrays and returns a 1D numpy array. Our goal is to conveniently apply this function along a dimension of xarray objects that may or may not wrap dask arrays with a signature.\n", + "\n", + "We will illustrate this using `np.interp`: \n", + "\n", + " Signature: np.interp(x, xp, fp, left=None, right=None, period=None)\n", + " Docstring:\n", + " One-dimensional linear interpolation.\n", + "\n", + " Returns the one-dimensional piecewise linear interpolant to a function\n", + " with given discrete data points (`xp`, `fp`), evaluated at `x`.\n", + "\n", + "and write an `xr_interp` function with signature\n", + "\n", + " xr_interp(xarray_object, dimension_name, new_coordinate_to_interpolate_to)\n", + "\n", + "### Load data\n", + "\n", + "First lets load an example dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:45:51.659160Z", + "start_time": "2020-01-15T14:45:50.528742Z" + } + }, + "outputs": [], + "source": [ + "import xarray as xr\n", + "import numpy as np\n", + "\n", + "xr.set_options(display_style=\"html\") # fancy HTML repr\n", + "\n", + "air = (\n", + " xr.tutorial.load_dataset(\"air_temperature\")\n", + " .air.sortby(\"lat\") # np.interp needs coordinate in ascending order\n", + " .isel(time=slice(4), lon=slice(3))\n", + ") # choose a small subset for convenience\n", + "air" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The function we will apply is `np.interp` which expects 1D numpy arrays. This functionality is already implemented in xarray so we use that capability to make sure we are not making mistakes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:45:55.431708Z", + "start_time": "2020-01-15T14:45:55.104701Z" + } + }, + "outputs": [], + "source": [ + "newlat = np.linspace(15, 75, 100)\n", + "air.interp(lat=newlat)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's define a function that works with one vector of data along `lat` at a time." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:45:57.889496Z", + "start_time": "2020-01-15T14:45:57.792269Z" + } + }, + "outputs": [], + "source": [ + "def interp1d_np(data, x, xi):\n", + " return np.interp(xi, x, data)\n", + "\n", + "\n", + "interped = interp1d_np(air.isel(time=0, lon=0), air.lat, newlat)\n", + "expected = air.interp(lat=newlat)\n", + "\n", + "# no errors are raised if values are equal to within floating point precision\n", + "np.testing.assert_allclose(expected.isel(time=0, lon=0).values, interped)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### No errors are raised so our interpolation is working.\n", + "\n", + "This function consumes and returns numpy arrays, which means we need to do a lot of work to convert the result back to an xarray object with meaningful metadata. This is where `apply_ufunc` is very useful.\n", + "\n", + "### `apply_ufunc`\n", + "\n", + " Apply a vectorized function for unlabeled arrays on xarray objects.\n", + "\n", + " The function will be mapped over the data variable(s) of the input arguments using \n", + " xarray’s standard rules for labeled computation, including alignment, broadcasting, \n", + " looping over GroupBy/Dataset variables, and merging of coordinates.\n", + " \n", + "`apply_ufunc` has many capabilities but for simplicity this example will focus on the common task of vectorizing 1D functions over nD xarray objects. We will iteratively build up the right set of arguments to `apply_ufunc` and read through many error messages in doing so." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:45:59.768626Z", + "start_time": "2020-01-15T14:45:59.543808Z" + } + }, + "outputs": [], + "source": [ + "xr.apply_ufunc(\n", + " interp1d_np, # first the function\n", + " air.isel(time=0, lon=0), # now arguments in the order expected by 'interp1_np'\n", + " air.lat,\n", + " newlat,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`apply_ufunc` needs to know a lot of information about what our function does so that it can reconstruct the outputs. In this case, the size of dimension lat has changed and we need to explicitly specify that this will happen. xarray helpfully tells us that we need to specify the kwarg `exclude_dims`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### `exclude_dims`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "```\n", + "exclude_dims : set, optional\n", + " Core dimensions on the inputs to exclude from alignment and\n", + " broadcasting entirely. Any input coordinates along these dimensions\n", + " will be dropped. Each excluded dimension must also appear in\n", + " ``input_core_dims`` for at least one argument. Only dimensions listed\n", + " here are allowed to change size between input and output objects.\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:46:02.187012Z", + "start_time": "2020-01-15T14:46:02.105563Z" + } + }, + "outputs": [], + "source": [ + "xr.apply_ufunc(\n", + " interp1d_np, # first the function\n", + " air.isel(time=0, lon=0), # now arguments in the order expected by 'interp1_np'\n", + " air.lat,\n", + " newlat,\n", + " exclude_dims=set((\"lat\",)), # dimensions allowed to change size. Must be set!\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Core dimensions\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Core dimensions are central to using `apply_ufunc`. In our case, our function expects to receive a 1D vector along `lat` — this is the dimension that is \"core\" to the function's functionality. Multiple core dimensions are possible. `apply_ufunc` needs to know which dimensions of each variable are core dimensions.\n", + "\n", + " input_core_dims : Sequence[Sequence], optional\n", + " List of the same length as ``args`` giving the list of core dimensions\n", + " on each input argument that should not be broadcast. By default, we\n", + " assume there are no core dimensions on any input arguments.\n", + "\n", + " For example, ``input_core_dims=[[], ['time']]`` indicates that all\n", + " dimensions on the first argument and all dimensions other than 'time'\n", + " on the second argument should be broadcast.\n", + "\n", + " Core dimensions are automatically moved to the last axes of input\n", + " variables before applying ``func``, which facilitates using NumPy style\n", + " generalized ufuncs [2]_.\n", + " \n", + " output_core_dims : List[tuple], optional\n", + " List of the same length as the number of output arguments from\n", + " ``func``, giving the list of core dimensions on each output that were\n", + " not broadcast on the inputs. By default, we assume that ``func``\n", + " outputs exactly one array, with axes corresponding to each broadcast\n", + " dimension.\n", + "\n", + " Core dimensions are assumed to appear as the last dimensions of each\n", + " output in the provided order.\n", + " \n", + "Next we specify `\"lat\"` as `input_core_dims` on both `air` and `air.lat`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:46:05.031672Z", + "start_time": "2020-01-15T14:46:04.947588Z" + } + }, + "outputs": [], + "source": [ + "xr.apply_ufunc(\n", + " interp1d_np, # first the function\n", + " air.isel(time=0, lon=0), # now arguments in the order expected by 'interp1_np'\n", + " air.lat,\n", + " newlat,\n", + " input_core_dims=[[\"lat\"], [\"lat\"], []],\n", + " exclude_dims=set((\"lat\",)), # dimensions allowed to change size. Must be set!\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "xarray is telling us that it expected to receive back a numpy array with 0 dimensions but instead received an array with 1 dimension corresponding to `newlat`. We can fix this by specifying `output_core_dims`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:46:09.325218Z", + "start_time": "2020-01-15T14:46:09.303020Z" + } + }, + "outputs": [], + "source": [ + "xr.apply_ufunc(\n", + " interp1d_np, # first the function\n", + " air.isel(time=0, lon=0), # now arguments in the order expected by 'interp1_np'\n", + " air.lat,\n", + " newlat,\n", + " input_core_dims=[[\"lat\"], [\"lat\"], []], # list with one entry per arg\n", + " output_core_dims=[[\"lat\"]],\n", + " exclude_dims=set((\"lat\",)), # dimensions allowed to change size. Must be set!\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally we get some output! Let's check that this is right\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:46:11.295440Z", + "start_time": "2020-01-15T14:46:11.226553Z" + } + }, + "outputs": [], + "source": [ + "interped = xr.apply_ufunc(\n", + " interp1d_np, # first the function\n", + " air.isel(time=0, lon=0), # now arguments in the order expected by 'interp1_np'\n", + " air.lat,\n", + " newlat,\n", + " input_core_dims=[[\"lat\"], [\"lat\"], []], # list with one entry per arg\n", + " output_core_dims=[[\"lat\"]],\n", + " exclude_dims=set((\"lat\",)), # dimensions allowed to change size. Must be set!\n", + ")\n", + "interped[\"lat\"] = newlat # need to add this manually\n", + "xr.testing.assert_allclose(expected.isel(time=0, lon=0), interped)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "No errors are raised so it is right!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Vectorization with `np.vectorize`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now our function currently only works on one vector of data which is not so useful given our 3D dataset.\n", + "Let's try passing the whole dataset. We add a `print` statement so we can see what our function receives." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:46:13.808646Z", + "start_time": "2020-01-15T14:46:13.680098Z" + } + }, + "outputs": [], + "source": [ + "def interp1d_np(data, x, xi):\n", + " print(f\"data: {data.shape} | x: {x.shape} | xi: {xi.shape}\")\n", + " return np.interp(xi, x, data)\n", + "\n", + "\n", + "interped = xr.apply_ufunc(\n", + " interp1d_np, # first the function\n", + " air.isel(\n", + " lon=slice(3), time=slice(4)\n", + " ), # now arguments in the order expected by 'interp1_np'\n", + " air.lat,\n", + " newlat,\n", + " input_core_dims=[[\"lat\"], [\"lat\"], []], # list with one entry per arg\n", + " output_core_dims=[[\"lat\"]],\n", + " exclude_dims=set((\"lat\",)), # dimensions allowed to change size. Must be set!\n", + ")\n", + "interped[\"lat\"] = newlat # need to add this manually\n", + "xr.testing.assert_allclose(expected.isel(time=0, lon=0), interped)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "That's a hard-to-interpret error but our `print` call helpfully printed the shapes of the input data: \n", + "\n", + " data: (10, 53, 25) | x: (25,) | xi: (100,)\n", + "\n", + "We see that `air` has been passed as a 3D numpy array which is not what `np.interp` expects. Instead we want loop over all combinations of `lon` and `time`; and apply our function to each corresponding vector of data along `lat`.\n", + "`apply_ufunc` makes this easy by specifying `vectorize=True`:\n", + "\n", + " vectorize : bool, optional\n", + " If True, then assume ``func`` only takes arrays defined over core\n", + " dimensions as input and vectorize it automatically with\n", + " :py:func:`numpy.vectorize`. This option exists for convenience, but is\n", + " almost always slower than supplying a pre-vectorized function.\n", + " Using this option requires NumPy version 1.12 or newer.\n", + " \n", + "Also see the documentation for `np.vectorize`: https://docs.scipy.org/doc/numpy/reference/generated/numpy.vectorize.html. Most importantly\n", + "\n", + " The vectorize function is provided primarily for convenience, not for performance. \n", + " The implementation is essentially a for loop." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:46:26.633233Z", + "start_time": "2020-01-15T14:46:26.515209Z" + } + }, + "outputs": [], + "source": [ + "def interp1d_np(data, x, xi):\n", + " print(f\"data: {data.shape} | x: {x.shape} | xi: {xi.shape}\")\n", + " return np.interp(xi, x, data)\n", + "\n", + "\n", + "interped = xr.apply_ufunc(\n", + " interp1d_np, # first the function\n", + " air, # now arguments in the order expected by 'interp1_np'\n", + " air.lat, # as above\n", + " newlat, # as above\n", + " input_core_dims=[[\"lat\"], [\"lat\"], []], # list with one entry per arg\n", + " output_core_dims=[[\"lat\"]], # returned data has one dimension\n", + " exclude_dims=set((\"lat\",)), # dimensions allowed to change size. Must be set!\n", + " vectorize=True, # loop over non-core dims\n", + ")\n", + "interped[\"lat\"] = newlat # need to add this manually\n", + "xr.testing.assert_allclose(expected, interped)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This unfortunately is another cryptic error from numpy. \n", + "\n", + "Notice that `newlat` is not an xarray object. Let's add a dimension name `new_lat` and modify the call. Note this cannot be `lat` because xarray expects dimensions to be the same size (or broadcastable) among all inputs. `output_core_dims` needs to be modified appropriately. We'll manually rename `new_lat` back to `lat` for easy checking." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:46:30.026663Z", + "start_time": "2020-01-15T14:46:29.893267Z" + } + }, + "outputs": [], + "source": [ + "def interp1d_np(data, x, xi):\n", + " print(f\"data: {data.shape} | x: {x.shape} | xi: {xi.shape}\")\n", + " return np.interp(xi, x, data)\n", + "\n", + "\n", + "interped = xr.apply_ufunc(\n", + " interp1d_np, # first the function\n", + " air, # now arguments in the order expected by 'interp1_np'\n", + " air.lat, # as above\n", + " newlat, # as above\n", + " input_core_dims=[[\"lat\"], [\"lat\"], [\"new_lat\"]], # list with one entry per arg\n", + " output_core_dims=[[\"new_lat\"]], # returned data has one dimension\n", + " exclude_dims=set((\"lat\",)), # dimensions allowed to change size. Must be a set!\n", + " vectorize=True, # loop over non-core dims\n", + ")\n", + "interped = interped.rename({\"new_lat\": \"lat\"})\n", + "interped[\"lat\"] = newlat # need to add this manually\n", + "xr.testing.assert_allclose(\n", + " expected.transpose(*interped.dims), interped # order of dims is different\n", + ")\n", + "interped" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice that the printed input shapes are all 1D and correspond to one vector along the `lat` dimension.\n", + "\n", + "The result is now an xarray object with coordinate values copied over from `data`. This is why `apply_ufunc` is so convenient; it takes care of a lot of boilerplate necessary to apply functions that consume and produce numpy arrays to xarray objects.\n", + "\n", + "One final point: `lat` is now the *last* dimension in `interped`. This is a \"property\" of core dimensions: they are moved to the end before being sent to `interp1d_np` as was noted in the docstring for `input_core_dims`\n", + "\n", + " Core dimensions are automatically moved to the last axes of input\n", + " variables before applying ``func``, which facilitates using NumPy style\n", + " generalized ufuncs [2]_." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Parallelization with dask\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "So far our function can only handle numpy arrays. A real benefit of `apply_ufunc` is the ability to easily parallelize over dask chunks _when needed_. \n", + "\n", + "We want to apply this function in a vectorized fashion over each chunk of the dask array. This is possible using dask's `blockwise`, `map_blocks`, or `apply_gufunc`. Xarray's `apply_ufunc` wraps dask's `apply_gufunc` and asking it to map the function over chunks using `apply_gufunc` is as simple as specifying `dask=\"parallelized\"`. With this level of flexibility we need to provide dask with some extra information: \n", + " 1. `output_dtypes`: dtypes of all returned objects, and \n", + " 2. `output_sizes`: lengths of any new dimensions. \n", + " \n", + "Here we need to specify `output_dtypes` since `apply_ufunc` can infer the size of the new dimension `new_lat` from the argument corresponding to the third element in `input_core_dims`. Here I choose the chunk sizes to illustrate that `np.vectorize` is still applied so that our function receives 1D vectors even though the blocks are 3D." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:48:42.469341Z", + "start_time": "2020-01-15T14:48:42.344209Z" + } + }, + "outputs": [], + "source": [ + "def interp1d_np(data, x, xi):\n", + " print(f\"data: {data.shape} | x: {x.shape} | xi: {xi.shape}\")\n", + " return np.interp(xi, x, data)\n", + "\n", + "\n", + "interped = xr.apply_ufunc(\n", + " interp1d_np, # first the function\n", + " air.chunk(\n", + " {\"time\": 2, \"lon\": 2}\n", + " ), # now arguments in the order expected by 'interp1_np'\n", + " air.lat, # as above\n", + " newlat, # as above\n", + " input_core_dims=[[\"lat\"], [\"lat\"], [\"new_lat\"]], # list with one entry per arg\n", + " output_core_dims=[[\"new_lat\"]], # returned data has one dimension\n", + " exclude_dims=set((\"lat\",)), # dimensions allowed to change size. Must be a set!\n", + " vectorize=True, # loop over non-core dims\n", + " dask=\"parallelized\",\n", + " output_dtypes=[air.dtype], # one per output\n", + ").rename({\"new_lat\": \"lat\"})\n", + "interped[\"lat\"] = newlat # need to add this manually\n", + "xr.testing.assert_allclose(expected.transpose(*interped.dims), interped)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Yay! our function is receiving 1D vectors, so we've successfully parallelized applying a 1D function over a block. If you have a distributed dashboard up, you should see computes happening as equality is checked.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### High performance vectorization: gufuncs, numba & guvectorize\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`np.vectorize` is a very convenient function but is unfortunately slow. It is only marginally faster than writing a for loop in Python and looping. A common way to get around this is to write a base interpolation function that can handle nD arrays in a compiled language like Fortran and then pass that to `apply_ufunc`.\n", + "\n", + "Another option is to use the numba package which provides a very convenient `guvectorize` decorator: https://numba.pydata.org/numba-doc/latest/user/vectorize.html#the-guvectorize-decorator\n", + "\n", + "Any decorated function gets compiled and will loop over any non-core dimension in parallel when necessary. We need to specify some extra information:\n", + "\n", + " 1. Our function cannot return a variable any more. Instead it must receive a variable (the last argument) whose contents the function will modify. So we change from `def interp1d_np(data, x, xi)` to `def interp1d_np_gufunc(data, x, xi, out)`. Our computed results must be assigned to `out`. All values of `out` must be assigned explicitly.\n", + " \n", + " 2. `guvectorize` needs to know the dtypes of the input and output. This is specified in string form as the first argument. Each element of the tuple corresponds to each argument of the function. In this case, we specify `float64` for all inputs and outputs: `\"(float64[:], float64[:], float64[:], float64[:])\"` corresponding to `data, x, xi, out`\n", + " \n", + " 3. Now we need to tell numba the size of the dimensions the function takes as inputs and returns as output i.e. core dimensions. This is done in symbolic form i.e. `data` and `x` are vectors of the same length, say `n`; `xi` and the output `out` have a different length, say `m`. So the second argument is (again as a string)\n", + " `\"(n), (n), (m) -> (m).\"` corresponding again to `data, x, xi, out`\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:48:45.267633Z", + "start_time": "2020-01-15T14:48:44.943939Z" + } + }, + "outputs": [], + "source": [ + "from numba import float64, guvectorize\n", + "\n", + "\n", + "@guvectorize(\"(float64[:], float64[:], float64[:], float64[:])\", \"(n), (n), (m) -> (m)\")\n", + "def interp1d_np_gufunc(data, x, xi, out):\n", + " # numba doesn't really like this.\n", + " # seem to support fstrings so do it the old way\n", + " print(\n", + " \"data: \" + str(data.shape) + \" | x:\" + str(x.shape) + \" | xi: \" + str(xi.shape)\n", + " )\n", + " out[:] = np.interp(xi, x, data)\n", + " # gufuncs don't return data\n", + " # instead you assign to a the last arg\n", + " # return np.interp(xi, x, data)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The warnings are about object-mode compilation relating to the `print` statement. This means we don't get much speed up: https://numba.pydata.org/numba-doc/latest/user/performance-tips.html#no-python-mode-vs-object-mode. We'll keep the `print` statement temporarily to make sure that `guvectorize` acts like we want it to." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:48:54.755405Z", + "start_time": "2020-01-15T14:48:54.634724Z" + } + }, + "outputs": [], + "source": [ + "interped = xr.apply_ufunc(\n", + " interp1d_np_gufunc, # first the function\n", + " air.chunk(\n", + " {\"time\": 2, \"lon\": 2}\n", + " ), # now arguments in the order expected by 'interp1_np'\n", + " air.lat, # as above\n", + " newlat, # as above\n", + " input_core_dims=[[\"lat\"], [\"lat\"], [\"new_lat\"]], # list with one entry per arg\n", + " output_core_dims=[[\"new_lat\"]], # returned data has one dimension\n", + " exclude_dims=set((\"lat\",)), # dimensions allowed to change size. Must be a set!\n", + " # vectorize=True, # not needed since numba takes care of vectorizing\n", + " dask=\"parallelized\",\n", + " output_dtypes=[air.dtype], # one per output\n", + ").rename({\"new_lat\": \"lat\"})\n", + "interped[\"lat\"] = newlat # need to add this manually\n", + "xr.testing.assert_allclose(expected.transpose(*interped.dims), interped)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Yay! Our function is receiving 1D vectors and is working automatically with dask arrays. Finally let's comment out the print line and wrap everything up in a nice reusable function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:49:28.667528Z", + "start_time": "2020-01-15T14:49:28.103914Z" + } + }, + "outputs": [], + "source": [ + "from numba import float64, guvectorize\n", + "\n", + "\n", + "@guvectorize(\n", + " \"(float64[:], float64[:], float64[:], float64[:])\",\n", + " \"(n), (n), (m) -> (m)\",\n", + " nopython=True,\n", + ")\n", + "def interp1d_np_gufunc(data, x, xi, out):\n", + " out[:] = np.interp(xi, x, data)\n", + "\n", + "\n", + "def xr_interp(data, dim, newdim):\n", + " interped = xr.apply_ufunc(\n", + " interp1d_np_gufunc, # first the function\n", + " data, # now arguments in the order expected by 'interp1_np'\n", + " data[dim], # as above\n", + " newdim, # as above\n", + " input_core_dims=[[dim], [dim], [\"__newdim__\"]], # list with one entry per arg\n", + " output_core_dims=[[\"__newdim__\"]], # returned data has one dimension\n", + " exclude_dims=set((dim,)), # dimensions allowed to change size. Must be a set!\n", + " # vectorize=True, # not needed since numba takes care of vectorizing\n", + " dask=\"parallelized\",\n", + " output_dtypes=[\n", + " data.dtype\n", + " ], # one per output; could also be float or np.dtype(\"float64\")\n", + " ).rename({\"__newdim__\": dim})\n", + " interped[dim] = newdim # need to add this manually\n", + "\n", + " return interped\n", + "\n", + "\n", + "xr.testing.assert_allclose(\n", + " expected.transpose(*interped.dims),\n", + " xr_interp(air.chunk({\"time\": 2, \"lon\": 2}), \"lat\", newlat),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This technique is generalizable to any 1D function." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + }, + "nbsphinx": { + "allow_errors": true + }, + "org": null, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": false, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": true + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/test/fixtures/whole_applications/xarray/doc/examples/area_weighted_temperature.ipynb b/test/fixtures/whole_applications/xarray/doc/examples/area_weighted_temperature.ipynb new file mode 100644 index 0000000..7299b50 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/examples/area_weighted_temperature.ipynb @@ -0,0 +1,226 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "toc": true + }, + "source": [ + "

Table of Contents

\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Compare weighted and unweighted mean temperature\n", + "\n", + "\n", + "Author: [Mathias Hauser](https://github.com/mathause/)\n", + "\n", + "\n", + "We use the `air_temperature` example dataset to calculate the area-weighted temperature over its domain. This dataset has a regular latitude/ longitude grid, thus the grid cell area decreases towards the pole. For this grid we can use the cosine of the latitude as proxy for the grid cell area.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-03-17T14:43:57.222351Z", + "start_time": "2020-03-17T14:43:56.147541Z" + } + }, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "\n", + "import cartopy.crs as ccrs\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "import xarray as xr" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Data\n", + "\n", + "Load the data, convert to celsius, and resample to daily values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-03-17T14:43:57.831734Z", + "start_time": "2020-03-17T14:43:57.651845Z" + } + }, + "outputs": [], + "source": [ + "ds = xr.tutorial.load_dataset(\"air_temperature\")\n", + "\n", + "# to celsius\n", + "air = ds.air - 273.15\n", + "\n", + "# resample from 6-hourly to daily values\n", + "air = air.resample(time=\"D\").mean()\n", + "\n", + "air" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Plot the first timestep:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-03-17T14:43:59.887120Z", + "start_time": "2020-03-17T14:43:59.582894Z" + } + }, + "outputs": [], + "source": [ + "projection = ccrs.LambertConformal(central_longitude=-95, central_latitude=45)\n", + "\n", + "f, ax = plt.subplots(subplot_kw=dict(projection=projection))\n", + "\n", + "air.isel(time=0).plot(transform=ccrs.PlateCarree(), cbar_kwargs=dict(shrink=0.7))\n", + "ax.coastlines()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Creating weights\n", + "\n", + "For a rectangular grid the cosine of the latitude is proportional to the grid cell area." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-03-17T14:44:18.777092Z", + "start_time": "2020-03-17T14:44:18.736587Z" + } + }, + "outputs": [], + "source": [ + "weights = np.cos(np.deg2rad(air.lat))\n", + "weights.name = \"weights\"\n", + "weights" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Weighted mean" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-03-17T14:44:52.607120Z", + "start_time": "2020-03-17T14:44:52.564674Z" + } + }, + "outputs": [], + "source": [ + "air_weighted = air.weighted(weights)\n", + "air_weighted" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-03-17T14:44:54.334279Z", + "start_time": "2020-03-17T14:44:54.280022Z" + } + }, + "outputs": [], + "source": [ + "weighted_mean = air_weighted.mean((\"lon\", \"lat\"))\n", + "weighted_mean" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Plot: comparison with unweighted mean\n", + "\n", + "Note how the weighted mean temperature is higher than the unweighted." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-03-17T14:45:08.877307Z", + "start_time": "2020-03-17T14:45:08.673383Z" + } + }, + "outputs": [], + "source": [ + "weighted_mean.plot(label=\"weighted\")\n", + "air.mean((\"lon\", \"lat\")).plot(label=\"unweighted\")\n", + "\n", + "plt.legend()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": true, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": true + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/test/fixtures/whole_applications/xarray/doc/examples/blank_template.ipynb b/test/fixtures/whole_applications/xarray/doc/examples/blank_template.ipynb new file mode 100644 index 0000000..bcb15c1 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/examples/blank_template.ipynb @@ -0,0 +1,58 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "d8f54f6a", + "metadata": {}, + "source": [ + "# Blank template\n", + "\n", + "Use this notebook from Binder to test an issue or reproduce a bug report" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "41b90ede", + "metadata": {}, + "outputs": [], + "source": [ + "import xarray as xr\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "ds = xr.tutorial.load_dataset(\"air_temperature\")\n", + "da = ds[\"air\"]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "effd9aeb", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/test/fixtures/whole_applications/xarray/doc/examples/monthly-means.ipynb b/test/fixtures/whole_applications/xarray/doc/examples/monthly-means.ipynb new file mode 100644 index 0000000..fd31e21 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/examples/monthly-means.ipynb @@ -0,0 +1,258 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Calculating Seasonal Averages from Time Series of Monthly Means \n", + "=====\n", + "\n", + "Author: [Joe Hamman](https://github.com/jhamman/)\n", + "\n", + "The data used for this example can be found in the [xarray-data](https://github.com/pydata/xarray-data) repository. You may need to change the path to `rasm.nc` below.\n", + "\n", + "Suppose we have a netCDF or `xarray.Dataset` of monthly mean data and we want to calculate the seasonal average. To do this properly, we need to calculate the weighted average considering that each month has a different number of days." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-28T20:51:35.958210Z", + "start_time": "2018-11-28T20:51:35.936966Z" + } + }, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "import numpy as np\n", + "import pandas as pd\n", + "import xarray as xr\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Open the `Dataset`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-28T20:51:36.072316Z", + "start_time": "2018-11-28T20:51:36.016594Z" + } + }, + "outputs": [], + "source": [ + "ds = xr.tutorial.open_dataset(\"rasm\").load()\n", + "ds" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Now for the heavy lifting:\n", + "We first have to come up with the weights,\n", + "- calculate the month length for each monthly data record\n", + "- calculate weights using `groupby('time.season')`\n", + "\n", + "Finally, we just need to multiply our weights by the `Dataset` and sum along the time dimension. Creating a `DataArray` for the month length is as easy as using the `days_in_month` accessor on the time coordinate. The calendar type, in this case `'noleap'`, is automatically considered in this operation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "month_length = ds.time.dt.days_in_month\n", + "month_length" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-28T20:51:36.132413Z", + "start_time": "2018-11-28T20:51:36.073708Z" + } + }, + "outputs": [], + "source": [ + "# Calculate the weights by grouping by 'time.season'.\n", + "weights = (\n", + " month_length.groupby(\"time.season\") / month_length.groupby(\"time.season\").sum()\n", + ")\n", + "\n", + "# Test that the sum of the weights for each season is 1.0\n", + "np.testing.assert_allclose(weights.groupby(\"time.season\").sum().values, np.ones(4))\n", + "\n", + "# Calculate the weighted average\n", + "ds_weighted = (ds * weights).groupby(\"time.season\").sum(dim=\"time\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-28T20:51:36.152913Z", + "start_time": "2018-11-28T20:51:36.133997Z" + } + }, + "outputs": [], + "source": [ + "ds_weighted" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-28T20:51:36.190765Z", + "start_time": "2018-11-28T20:51:36.154416Z" + } + }, + "outputs": [], + "source": [ + "# only used for comparisons\n", + "ds_unweighted = ds.groupby(\"time.season\").mean(\"time\")\n", + "ds_diff = ds_weighted - ds_unweighted" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-28T20:51:40.264871Z", + "start_time": "2018-11-28T20:51:36.192467Z" + } + }, + "outputs": [], + "source": [ + "# Quick plot to show the results\n", + "notnull = pd.notnull(ds_unweighted[\"Tair\"][0])\n", + "\n", + "fig, axes = plt.subplots(nrows=4, ncols=3, figsize=(14, 12))\n", + "for i, season in enumerate((\"DJF\", \"MAM\", \"JJA\", \"SON\")):\n", + " ds_weighted[\"Tair\"].sel(season=season).where(notnull).plot.pcolormesh(\n", + " ax=axes[i, 0],\n", + " vmin=-30,\n", + " vmax=30,\n", + " cmap=\"Spectral_r\",\n", + " add_colorbar=True,\n", + " extend=\"both\",\n", + " )\n", + "\n", + " ds_unweighted[\"Tair\"].sel(season=season).where(notnull).plot.pcolormesh(\n", + " ax=axes[i, 1],\n", + " vmin=-30,\n", + " vmax=30,\n", + " cmap=\"Spectral_r\",\n", + " add_colorbar=True,\n", + " extend=\"both\",\n", + " )\n", + "\n", + " ds_diff[\"Tair\"].sel(season=season).where(notnull).plot.pcolormesh(\n", + " ax=axes[i, 2],\n", + " vmin=-0.1,\n", + " vmax=0.1,\n", + " cmap=\"RdBu_r\",\n", + " add_colorbar=True,\n", + " extend=\"both\",\n", + " )\n", + "\n", + " axes[i, 0].set_ylabel(season)\n", + " axes[i, 1].set_ylabel(\"\")\n", + " axes[i, 2].set_ylabel(\"\")\n", + "\n", + "for ax in axes.flat:\n", + " ax.axes.get_xaxis().set_ticklabels([])\n", + " ax.axes.get_yaxis().set_ticklabels([])\n", + " ax.axes.axis(\"tight\")\n", + " ax.set_xlabel(\"\")\n", + "\n", + "axes[0, 0].set_title(\"Weighted by DPM\")\n", + "axes[0, 1].set_title(\"Equal Weighting\")\n", + "axes[0, 2].set_title(\"Difference\")\n", + "\n", + "plt.tight_layout()\n", + "\n", + "fig.suptitle(\"Seasonal Surface Air Temperature\", fontsize=16, y=1.02)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-28T20:51:40.284898Z", + "start_time": "2018-11-28T20:51:40.266406Z" + } + }, + "outputs": [], + "source": [ + "# Wrap it into a simple function\n", + "def season_mean(ds, calendar=\"standard\"):\n", + " # Make a DataArray with the number of days in each month, size = len(time)\n", + " month_length = ds.time.dt.days_in_month\n", + "\n", + " # Calculate the weights by grouping by 'time.season'\n", + " weights = (\n", + " month_length.groupby(\"time.season\") / month_length.groupby(\"time.season\").sum()\n", + " )\n", + "\n", + " # Test that the sum of the weights for each season is 1.0\n", + " np.testing.assert_allclose(weights.groupby(\"time.season\").sum().values, np.ones(4))\n", + "\n", + " # Calculate the weighted average\n", + " return (ds * weights).groupby(\"time.season\").sum(dim=\"time\")" + ] + } + ], + "metadata": { + "anaconda-cloud": {}, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.3" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": true, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": true + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/test/fixtures/whole_applications/xarray/doc/examples/monthly_means_output.png b/test/fixtures/whole_applications/xarray/doc/examples/monthly_means_output.png new file mode 100644 index 0000000..0f391a5 Binary files /dev/null and b/test/fixtures/whole_applications/xarray/doc/examples/monthly_means_output.png differ diff --git a/test/fixtures/whole_applications/xarray/doc/examples/multidimensional-coords.ipynb b/test/fixtures/whole_applications/xarray/doc/examples/multidimensional-coords.ipynb new file mode 100644 index 0000000..a138dff --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/examples/multidimensional-coords.ipynb @@ -0,0 +1,230 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Working with Multidimensional Coordinates\n", + "\n", + "Author: [Ryan Abernathey](https://github.com/rabernat)\n", + "\n", + "Many datasets have _physical coordinates_ which differ from their _logical coordinates_. Xarray provides several ways to plot and analyze such datasets." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-28T20:49:56.068395Z", + "start_time": "2018-11-28T20:49:56.035349Z" + } + }, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "import numpy as np\n", + "import pandas as pd\n", + "import xarray as xr\n", + "import cartopy.crs as ccrs\n", + "from matplotlib import pyplot as plt" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As an example, consider this dataset from the [xarray-data](https://github.com/pydata/xarray-data) repository." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-28T20:50:13.629720Z", + "start_time": "2018-11-28T20:50:13.484542Z" + } + }, + "outputs": [], + "source": [ + "ds = xr.tutorial.open_dataset(\"rasm\").load()\n", + "ds" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this example, the _logical coordinates_ are `x` and `y`, while the _physical coordinates_ are `xc` and `yc`, which represent the longitudes and latitudes of the data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-28T20:50:15.836061Z", + "start_time": "2018-11-28T20:50:15.768376Z" + } + }, + "outputs": [], + "source": [ + "print(ds.xc.attrs)\n", + "print(ds.yc.attrs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Plotting ##\n", + "\n", + "Let's examine these coordinate variables by plotting them." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-28T20:50:17.928556Z", + "start_time": "2018-11-28T20:50:17.031211Z" + } + }, + "outputs": [], + "source": [ + "fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(14, 4))\n", + "ds.xc.plot(ax=ax1)\n", + "ds.yc.plot(ax=ax2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that the variables `xc` (longitude) and `yc` (latitude) are two-dimensional scalar fields.\n", + "\n", + "If we try to plot the data variable `Tair`, by default we get the logical coordinates." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-28T20:50:20.567749Z", + "start_time": "2018-11-28T20:50:19.999393Z" + } + }, + "outputs": [], + "source": [ + "ds.Tair[0].plot()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In order to visualize the data on a conventional latitude-longitude grid, we can take advantage of xarray's ability to apply [cartopy](http://scitools.org.uk/cartopy/index.html) map projections." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-28T20:50:31.131708Z", + "start_time": "2018-11-28T20:50:30.444697Z" + } + }, + "outputs": [], + "source": [ + "plt.figure(figsize=(14, 6))\n", + "ax = plt.axes(projection=ccrs.PlateCarree())\n", + "ax.set_global()\n", + "ds.Tair[0].plot.pcolormesh(\n", + " ax=ax, transform=ccrs.PlateCarree(), x=\"xc\", y=\"yc\", add_colorbar=False\n", + ")\n", + "ax.coastlines()\n", + "ax.set_ylim([0, 90]);" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Multidimensional Groupby ##\n", + "\n", + "The above example allowed us to visualize the data on a regular latitude-longitude grid. But what if we want to do a calculation that involves grouping over one of these physical coordinates (rather than the logical coordinates), for example, calculating the mean temperature at each latitude. This can be achieved using xarray's `groupby` function, which accepts multidimensional variables. By default, `groupby` will use every unique value in the variable, which is probably not what we want. Instead, we can use the `groupby_bins` function to specify the output coordinates of the group. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-28T20:50:43.670463Z", + "start_time": "2018-11-28T20:50:43.245501Z" + } + }, + "outputs": [], + "source": [ + "# define two-degree wide latitude bins\n", + "lat_bins = np.arange(0, 91, 2)\n", + "# define a label for each bin corresponding to the central latitude\n", + "lat_center = np.arange(1, 90, 2)\n", + "# group according to those bins and take the mean\n", + "Tair_lat_mean = ds.Tair.groupby_bins(\"yc\", lat_bins, labels=lat_center).mean(\n", + " dim=xr.ALL_DIMS\n", + ")\n", + "# plot the result\n", + "Tair_lat_mean.plot()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The resulting coordinate for the `groupby_bins` operation got the `_bins` suffix appended: `yc_bins`. This help us distinguish it from the original multidimensional variable `yc`.\n", + "\n", + "**Note**: This group-by-latitude approach does not take into account the finite-size geometry of grid cells. It simply bins each value according to the coordinates at the cell center. Xarray has no understanding of grid cells and their geometry. More precise geographic regridding for xarray data is available via the [xesmf](https://xesmf.readthedocs.io) package." + ] + } + ], + "metadata": { + "anaconda-cloud": {}, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": true, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": true + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/test/fixtures/whole_applications/xarray/doc/examples/visualization_gallery.ipynb b/test/fixtures/whole_applications/xarray/doc/examples/visualization_gallery.ipynb new file mode 100644 index 0000000..e7e9196 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/examples/visualization_gallery.ipynb @@ -0,0 +1,226 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Visualization Gallery\n", + "\n", + "This notebook shows common visualization issues encountered in xarray." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import cartopy.crs as ccrs\n", + "import matplotlib.pyplot as plt\n", + "import xarray as xr\n", + "\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load example dataset:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds = xr.tutorial.load_dataset(\"air_temperature\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Multiple plots and map projections\n", + "\n", + "Control the map projection parameters on multiple axes\n", + "\n", + "This example illustrates how to plot multiple maps and control their extent\n", + "and aspect ratio.\n", + "\n", + "For more details see [this discussion](https://github.com/pydata/xarray/issues/1397#issuecomment-299190567) on github." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "air = ds.air.isel(time=[0, 724]) - 273.15\n", + "\n", + "# This is the map projection we want to plot *onto*\n", + "map_proj = ccrs.LambertConformal(central_longitude=-95, central_latitude=45)\n", + "\n", + "p = air.plot(\n", + " transform=ccrs.PlateCarree(), # the data's projection\n", + " col=\"time\",\n", + " col_wrap=1, # multiplot settings\n", + " aspect=ds.dims[\"lon\"] / ds.dims[\"lat\"], # for a sensible figsize\n", + " subplot_kws={\"projection\": map_proj},\n", + ") # the plot's projection\n", + "\n", + "# We have to set the map's options on all axes\n", + "for ax in p.axes.flat:\n", + " ax.coastlines()\n", + " ax.set_extent([-160, -30, 5, 75])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Centered colormaps\n", + "\n", + "Xarray's automatic colormaps choice" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "air = ds.air.isel(time=0)\n", + "\n", + "f, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(8, 6))\n", + "\n", + "# The first plot (in kelvins) chooses \"viridis\" and uses the data's min/max\n", + "air.plot(ax=ax1, cbar_kwargs={\"label\": \"K\"})\n", + "ax1.set_title(\"Kelvins: default\")\n", + "ax2.set_xlabel(\"\")\n", + "\n", + "# The second plot (in celsius) now chooses \"BuRd\" and centers min/max around 0\n", + "airc = air - 273.15\n", + "airc.plot(ax=ax2, cbar_kwargs={\"label\": \"°C\"})\n", + "ax2.set_title(\"Celsius: default\")\n", + "ax2.set_xlabel(\"\")\n", + "ax2.set_ylabel(\"\")\n", + "\n", + "# The center doesn't have to be 0\n", + "air.plot(ax=ax3, center=273.15, cbar_kwargs={\"label\": \"K\"})\n", + "ax3.set_title(\"Kelvins: center=273.15\")\n", + "\n", + "# Or it can be ignored\n", + "airc.plot(ax=ax4, center=False, cbar_kwargs={\"label\": \"°C\"})\n", + "ax4.set_title(\"Celsius: center=False\")\n", + "ax4.set_ylabel(\"\")\n", + "\n", + "# Make it nice\n", + "plt.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Control the plot's colorbar\n", + "\n", + "Use ``cbar_kwargs`` keyword to specify the number of ticks.\n", + "The ``spacing`` kwarg can be used to draw proportional ticks." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "air2d = ds.air.isel(time=500)\n", + "\n", + "# Prepare the figure\n", + "f, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(14, 4))\n", + "\n", + "# Irregular levels to illustrate the use of a proportional colorbar\n", + "levels = [245, 250, 255, 260, 265, 270, 275, 280, 285, 290, 310, 340]\n", + "\n", + "# Plot data\n", + "air2d.plot(ax=ax1, levels=levels)\n", + "air2d.plot(ax=ax2, levels=levels, cbar_kwargs={\"ticks\": levels})\n", + "air2d.plot(\n", + " ax=ax3, levels=levels, cbar_kwargs={\"ticks\": levels, \"spacing\": \"proportional\"}\n", + ")\n", + "\n", + "# Show plots\n", + "plt.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Multiple lines from a 2d DataArray\n", + "\n", + "Use ``xarray.plot.line`` on a 2d DataArray to plot selections as\n", + "multiple lines.\n", + "\n", + "See ``plotting.multiplelines`` for more details." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "air = ds.air - 273.15 # to celsius\n", + "\n", + "# Prepare the figure\n", + "f, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4), sharey=True)\n", + "\n", + "# Selected latitude indices\n", + "isel_lats = [10, 15, 20]\n", + "\n", + "# Temperature vs longitude plot - illustrates the \"hue\" kwarg\n", + "air.isel(time=0, lat=isel_lats).plot.line(ax=ax1, hue=\"lat\")\n", + "ax1.set_ylabel(\"°C\")\n", + "\n", + "# Temperature vs time plot - illustrates the \"x\" and \"add_legend\" kwargs\n", + "air.isel(lon=30, lat=isel_lats).plot.line(ax=ax2, x=\"time\", add_legend=False)\n", + "ax2.set_ylabel(\"\")\n", + "\n", + "# Show\n", + "plt.tight_layout()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "state": {}, + "version_major": 2, + "version_minor": 0 + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/test/fixtures/whole_applications/xarray/doc/examples/weather-data.ipynb b/test/fixtures/whole_applications/xarray/doc/examples/weather-data.ipynb new file mode 100644 index 0000000..f582453 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/examples/weather-data.ipynb @@ -0,0 +1,374 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Toy weather data\n", + "\n", + "Here is an example of how to easily manipulate a toy weather dataset using\n", + "xarray and other recommended Python libraries:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-27T15:43:36.127628Z", + "start_time": "2020-01-27T15:43:36.081733Z" + } + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import seaborn as sns\n", + "\n", + "import xarray as xr\n", + "\n", + "np.random.seed(123)\n", + "\n", + "xr.set_options(display_style=\"html\")\n", + "\n", + "times = pd.date_range(\"2000-01-01\", \"2001-12-31\", name=\"time\")\n", + "annual_cycle = np.sin(2 * np.pi * (times.dayofyear.values / 365.25 - 0.28))\n", + "\n", + "base = 10 + 15 * annual_cycle.reshape(-1, 1)\n", + "tmin_values = base + 3 * np.random.randn(annual_cycle.size, 3)\n", + "tmax_values = base + 10 + 3 * np.random.randn(annual_cycle.size, 3)\n", + "\n", + "ds = xr.Dataset(\n", + " {\n", + " \"tmin\": ((\"time\", \"location\"), tmin_values),\n", + " \"tmax\": ((\"time\", \"location\"), tmax_values),\n", + " },\n", + " {\"time\": times, \"location\": [\"IA\", \"IN\", \"IL\"]},\n", + ")\n", + "\n", + "ds" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Examine a dataset with pandas and seaborn" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Convert to a pandas DataFrame" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-27T15:47:14.160297Z", + "start_time": "2020-01-27T15:47:14.126738Z" + } + }, + "outputs": [], + "source": [ + "df = ds.to_dataframe()\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-27T15:47:32.682065Z", + "start_time": "2020-01-27T15:47:32.652629Z" + } + }, + "outputs": [], + "source": [ + "df.describe()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Visualize using pandas" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-27T15:47:34.617042Z", + "start_time": "2020-01-27T15:47:34.282605Z" + } + }, + "outputs": [], + "source": [ + "ds.mean(dim=\"location\").to_dataframe().plot()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Visualize using seaborn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-27T15:47:37.643175Z", + "start_time": "2020-01-27T15:47:37.202479Z" + } + }, + "outputs": [], + "source": [ + "sns.pairplot(df.reset_index(), vars=ds.data_vars)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Probability of freeze by calendar month" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-27T15:48:11.241224Z", + "start_time": "2020-01-27T15:48:11.211156Z" + } + }, + "outputs": [], + "source": [ + "freeze = (ds[\"tmin\"] <= 0).groupby(\"time.month\").mean(\"time\")\n", + "freeze" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-27T15:48:13.131247Z", + "start_time": "2020-01-27T15:48:12.924985Z" + } + }, + "outputs": [], + "source": [ + "freeze.to_pandas().plot()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Monthly averaging" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-27T15:48:08.498259Z", + "start_time": "2020-01-27T15:48:08.210890Z" + } + }, + "outputs": [], + "source": [ + "monthly_avg = ds.resample(time=\"1MS\").mean()\n", + "monthly_avg.sel(location=\"IA\").to_dataframe().plot(style=\"s-\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that ``MS`` here refers to Month-Start; ``M`` labels Month-End (the last day of the month)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Calculate monthly anomalies" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In climatology, \"anomalies\" refer to the difference between observations and\n", + "typical weather for a particular season. Unlike observations, anomalies should\n", + "not show any seasonal cycle." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-27T15:49:34.855086Z", + "start_time": "2020-01-27T15:49:34.406439Z" + } + }, + "outputs": [], + "source": [ + "climatology = ds.groupby(\"time.month\").mean(\"time\")\n", + "anomalies = ds.groupby(\"time.month\") - climatology\n", + "anomalies.mean(\"location\").to_dataframe()[[\"tmin\", \"tmax\"]].plot()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Calculate standardized monthly anomalies" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can create standardized anomalies where the difference between the\n", + "observations and the climatological monthly mean is\n", + "divided by the climatological standard deviation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-27T15:50:09.144586Z", + "start_time": "2020-01-27T15:50:08.734682Z" + } + }, + "outputs": [], + "source": [ + "climatology_mean = ds.groupby(\"time.month\").mean(\"time\")\n", + "climatology_std = ds.groupby(\"time.month\").std(\"time\")\n", + "stand_anomalies = xr.apply_ufunc(\n", + " lambda x, m, s: (x - m) / s,\n", + " ds.groupby(\"time.month\"),\n", + " climatology_mean,\n", + " climatology_std,\n", + ")\n", + "\n", + "stand_anomalies.mean(\"location\").to_dataframe()[[\"tmin\", \"tmax\"]].plot()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Fill missing values with climatology" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-27T15:50:46.192491Z", + "start_time": "2020-01-27T15:50:46.174554Z" + } + }, + "source": [ + "The ``fillna`` method on grouped objects lets you easily fill missing values by group:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-27T15:51:40.279299Z", + "start_time": "2020-01-27T15:51:40.220342Z" + } + }, + "outputs": [], + "source": [ + "# throw away the first half of every month\n", + "some_missing = ds.tmin.sel(time=ds[\"time.day\"] > 15).reindex_like(ds)\n", + "filled = some_missing.groupby(\"time.month\").fillna(climatology.tmin)\n", + "both = xr.Dataset({\"some_missing\": some_missing, \"filled\": filled})\n", + "both" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-27T15:52:11.815769Z", + "start_time": "2020-01-27T15:52:11.770825Z" + } + }, + "outputs": [], + "source": [ + "df = both.sel(time=\"2000\").mean(\"location\").reset_coords(drop=True).to_dataframe()\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-27T15:52:14.867866Z", + "start_time": "2020-01-27T15:52:14.449684Z" + } + }, + "outputs": [], + "source": [ + "df[[\"filled\", \"some_missing\"]].plot()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.3" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": true, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/test/fixtures/whole_applications/xarray/doc/gallery.rst b/test/fixtures/whole_applications/xarray/doc/gallery.rst new file mode 100644 index 0000000..61ec45c --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/gallery.rst @@ -0,0 +1,35 @@ +Gallery +======= + +Here's a list of examples on how to use xarray. We will be adding more examples soon. +Contributions are highly welcomed and appreciated. So, if you are interested in contributing, please consult the +:doc:`contributing` guide. + + + +Notebook Examples +----------------- + +.. include:: notebooks-examples-gallery.txt + + +.. toctree:: + :maxdepth: 1 + :hidden: + + examples/weather-data + examples/monthly-means + examples/area_weighted_temperature + examples/multidimensional-coords + examples/visualization_gallery + examples/ROMS_ocean_model + examples/ERA5-GRIB-example + examples/apply_ufunc_vectorize_1d + examples/blank_template + + +External Examples +----------------- + + +.. include:: external-examples-gallery.txt diff --git a/test/fixtures/whole_applications/xarray/doc/gallery.yml b/test/fixtures/whole_applications/xarray/doc/gallery.yml new file mode 100644 index 0000000..f831601 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/gallery.yml @@ -0,0 +1,45 @@ +notebooks-examples: + - title: Toy weather data + path: examples/weather-data.html + thumbnail: _static/thumbnails/toy-weather-data.png + + - title: Calculating Seasonal Averages from Timeseries of Monthly Means + path: examples/monthly-means.html + thumbnail: _static/thumbnails/monthly-means.png + + - title: Compare weighted and unweighted mean temperature + path: examples/area_weighted_temperature.html + thumbnail: _static/thumbnails/area_weighted_temperature.png + + - title: Working with Multidimensional Coordinates + path: examples/multidimensional-coords.html + thumbnail: _static/thumbnails/multidimensional-coords.png + + - title: Visualization Gallery + path: examples/visualization_gallery.html + thumbnail: _static/thumbnails/visualization_gallery.png + + - title: GRIB Data Example + path: examples/ERA5-GRIB-example.html + thumbnail: _static/thumbnails/ERA5-GRIB-example.png + + - title: Applying unvectorized functions with apply_ufunc + path: examples/apply_ufunc_vectorize_1d.html + thumbnail: _static/logos/Xarray_Logo_RGB_Final.svg + +external-examples: + - title: Managing raster data with rioxarray + path: https://corteva.github.io/rioxarray/stable/examples/examples.html + thumbnail: _static/logos/Xarray_Logo_RGB_Final.svg + + - title: Xarray and dask on the cloud with Pangeo + path: https://gallery.pangeo.io/ + thumbnail: https://avatars.githubusercontent.com/u/60833341?s=200&v=4 + + - title: Xarray with Dask Arrays + path: https://examples.dask.org/xarray.html_ + thumbnail: _static/logos/Xarray_Logo_RGB_Final.svg + + - title: Project Pythia Foundations Book + path: https://foundations.projectpythia.org/core/xarray.html + thumbnail: https://raw.githubusercontent.com/ProjectPythia/projectpythia.github.io/main/portal/_static/images/logos/pythia_logo-blue-btext-twocolor.svg diff --git a/test/fixtures/whole_applications/xarray/doc/gallery/README.txt b/test/fixtures/whole_applications/xarray/doc/gallery/README.txt new file mode 100644 index 0000000..63f7d47 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/gallery/README.txt @@ -0,0 +1,4 @@ +.. _recipes: + +Gallery +======= diff --git a/test/fixtures/whole_applications/xarray/doc/gallery/plot_cartopy_facetgrid.py b/test/fixtures/whole_applications/xarray/doc/gallery/plot_cartopy_facetgrid.py new file mode 100644 index 0000000..faa1489 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/gallery/plot_cartopy_facetgrid.py @@ -0,0 +1,44 @@ +""" +================================== +Multiple plots and map projections +================================== + +Control the map projection parameters on multiple axes + +This example illustrates how to plot multiple maps and control their extent +and aspect ratio. + +For more details see `this discussion`_ on github. + +.. _this discussion: https://github.com/pydata/xarray/issues/1397#issuecomment-299190567 +""" + +import cartopy.crs as ccrs +import matplotlib.pyplot as plt + +import xarray as xr + +# Load the data +ds = xr.tutorial.load_dataset("air_temperature") +air = ds.air.isel(time=[0, 724]) - 273.15 + +# This is the map projection we want to plot *onto* +map_proj = ccrs.LambertConformal(central_longitude=-95, central_latitude=45) + +p = air.plot( + transform=ccrs.PlateCarree(), # the data's projection + col="time", + col_wrap=1, # multiplot settings + aspect=ds.sizes["lon"] / ds.sizes["lat"], # for a sensible figsize + subplot_kws={"projection": map_proj}, # the plot's projection +) + +# We have to set the map's options on all four axes +for ax in p.axes.flat: + ax.coastlines() + ax.set_extent([-160, -30, 5, 75]) + # Without this aspect attributes the maps will look chaotic and the + # "extent" attribute above will be ignored + ax.set_aspect("equal") + +plt.show() diff --git a/test/fixtures/whole_applications/xarray/doc/gallery/plot_colorbar_center.py b/test/fixtures/whole_applications/xarray/doc/gallery/plot_colorbar_center.py new file mode 100644 index 0000000..da3447a --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/gallery/plot_colorbar_center.py @@ -0,0 +1,43 @@ +""" +================== +Centered colormaps +================== + +xarray's automatic colormaps choice + +""" + +import matplotlib.pyplot as plt + +import xarray as xr + +# Load the data +ds = xr.tutorial.load_dataset("air_temperature") +air = ds.air.isel(time=0) + +f, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(8, 6)) + +# The first plot (in kelvins) chooses "viridis" and uses the data's min/max +air.plot(ax=ax1, cbar_kwargs={"label": "K"}) +ax1.set_title("Kelvins: default") +ax2.set_xlabel("") + +# The second plot (in celsius) now chooses "BuRd" and centers min/max around 0 +airc = air - 273.15 +airc.plot(ax=ax2, cbar_kwargs={"label": "°C"}) +ax2.set_title("Celsius: default") +ax2.set_xlabel("") +ax2.set_ylabel("") + +# The center doesn't have to be 0 +air.plot(ax=ax3, center=273.15, cbar_kwargs={"label": "K"}) +ax3.set_title("Kelvins: center=273.15") + +# Or it can be ignored +airc.plot(ax=ax4, center=False, cbar_kwargs={"label": "°C"}) +ax4.set_title("Celsius: center=False") +ax4.set_ylabel("") + +# Make it nice +plt.tight_layout() +plt.show() diff --git a/test/fixtures/whole_applications/xarray/doc/gallery/plot_control_colorbar.py b/test/fixtures/whole_applications/xarray/doc/gallery/plot_control_colorbar.py new file mode 100644 index 0000000..280e753 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/gallery/plot_control_colorbar.py @@ -0,0 +1,33 @@ +""" +=========================== +Control the plot's colorbar +=========================== + +Use ``cbar_kwargs`` keyword to specify the number of ticks. +The ``spacing`` kwarg can be used to draw proportional ticks. +""" + +import matplotlib.pyplot as plt + +import xarray as xr + +# Load the data +air_temp = xr.tutorial.load_dataset("air_temperature") +air2d = air_temp.air.isel(time=500) + +# Prepare the figure +f, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(14, 4)) + +# Irregular levels to illustrate the use of a proportional colorbar +levels = [245, 250, 255, 260, 265, 270, 275, 280, 285, 290, 310, 340] + +# Plot data +air2d.plot(ax=ax1, levels=levels) +air2d.plot(ax=ax2, levels=levels, cbar_kwargs={"ticks": levels}) +air2d.plot( + ax=ax3, levels=levels, cbar_kwargs={"ticks": levels, "spacing": "proportional"} +) + +# Show plots +plt.tight_layout() +plt.show() diff --git a/test/fixtures/whole_applications/xarray/doc/getting-started-guide/faq.rst b/test/fixtures/whole_applications/xarray/doc/getting-started-guide/faq.rst new file mode 100644 index 0000000..7f99fa7 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/getting-started-guide/faq.rst @@ -0,0 +1,433 @@ +.. _faq: + +Frequently Asked Questions +========================== + +.. ipython:: python + :suppress: + + import numpy as np + import pandas as pd + import xarray as xr + + np.random.seed(123456) + + +Your documentation keeps mentioning pandas. What is pandas? +----------------------------------------------------------- + +pandas_ is a very popular data analysis package in Python +with wide usage in many fields. Our API is heavily inspired by pandas — +this is why there are so many references to pandas. + +.. _pandas: https://pandas.pydata.org + + +Do I need to know pandas to use xarray? +--------------------------------------- + +No! Our API is heavily inspired by pandas so while knowing pandas will let you +become productive more quickly, knowledge of pandas is not necessary to use xarray. + + +Should I use xarray instead of pandas? +-------------------------------------- + +It's not an either/or choice! xarray provides robust support for converting +back and forth between the tabular data-structures of pandas and its own +multi-dimensional data-structures. + +That said, you should only bother with xarray if some aspect of data is +fundamentally multi-dimensional. If your data is unstructured or +one-dimensional, pandas is usually the right choice: it has better performance +for common operations such as ``groupby`` and you'll find far more usage +examples online. + + +Why is pandas not enough? +------------------------- + +pandas is a fantastic library for analysis of low-dimensional labelled data - +if it can be sensibly described as "rows and columns", pandas is probably the +right choice. However, sometimes we want to use higher dimensional arrays +(`ndim > 2`), or arrays for which the order of dimensions (e.g., columns vs +rows) shouldn't really matter. For example, the images of a movie can be +natively represented as an array with four dimensions: time, row, column and +color. + +pandas has historically supported N-dimensional panels, but deprecated them in +version 0.20 in favor of xarray data structures. There are now built-in methods +on both sides to convert between pandas and xarray, allowing for more focused +development effort. Xarray objects have a much richer model of dimensionality - +if you were using Panels: + +- You need to create a new factory type for each dimensionality. +- You can't do math between NDPanels with different dimensionality. +- Each dimension in a NDPanel has a name (e.g., 'labels', 'items', + 'major_axis', etc.) but the dimension names refer to order, not their + meaning. You can't specify an operation as to be applied along the "time" + axis. +- You often have to manually convert collections of pandas arrays + (Series, DataFrames, etc) to have the same number of dimensions. + In contrast, this sort of data structure fits very naturally in an + xarray ``Dataset``. + +You can :ref:`read about switching from Panels to xarray here `. +pandas gets a lot of things right, but many science, engineering and complex +analytics use cases need fully multi-dimensional data structures. + +How do xarray data structures differ from those found in pandas? +---------------------------------------------------------------- + +The main distinguishing feature of xarray's ``DataArray`` over labeled arrays in +pandas is that dimensions can have names (e.g., "time", "latitude", +"longitude"). Names are much easier to keep track of than axis numbers, and +xarray uses dimension names for indexing, aggregation and broadcasting. Not only +can you write ``x.sel(time='2000-01-01')`` and ``x.mean(dim='time')``, but +operations like ``x - x.mean(dim='time')`` always work, no matter the order +of the "time" dimension. You never need to reshape arrays (e.g., with +``np.newaxis``) to align them for arithmetic operations in xarray. + + +Why don't aggregations return Python scalars? +--------------------------------------------- + +Xarray tries hard to be self-consistent: operations on a ``DataArray`` (resp. +``Dataset``) return another ``DataArray`` (resp. ``Dataset``) object. In +particular, operations returning scalar values (e.g. indexing or aggregations +like ``mean`` or ``sum`` applied to all axes) will also return xarray objects. + +Unfortunately, this means we sometimes have to explicitly cast our results from +xarray when using them in other libraries. As an illustration, the following +code fragment + +.. ipython:: python + + arr = xr.DataArray([1, 2, 3]) + pd.Series({"x": arr[0], "mean": arr.mean(), "std": arr.std()}) + +does not yield the pandas DataFrame we expected. We need to specify the type +conversion ourselves: + +.. ipython:: python + + pd.Series({"x": arr[0], "mean": arr.mean(), "std": arr.std()}, dtype=float) + +Alternatively, we could use the ``item`` method or the ``float`` constructor to +convert values one at a time + +.. ipython:: python + + pd.Series({"x": arr[0].item(), "mean": float(arr.mean())}) + + +.. _approach to metadata: + +What is your approach to metadata? +---------------------------------- + +We are firm believers in the power of labeled data! In addition to dimensions +and coordinates, xarray supports arbitrary metadata in the form of global +(Dataset) and variable specific (DataArray) attributes (``attrs``). + +Automatic interpretation of labels is powerful but also reduces flexibility. +With xarray, we draw a firm line between labels that the library understands +(``dims`` and ``coords``) and labels for users and user code (``attrs``). For +example, we do not automatically interpret and enforce units or `CF +conventions`_. (An exception is serialization to and from netCDF files.) + +.. _CF conventions: https://cfconventions.org/latest.html + +An implication of this choice is that we do not propagate ``attrs`` through +most operations unless explicitly flagged (some methods have a ``keep_attrs`` +option, and there is a global flag, accessible with :py:func:`xarray.set_options`, +for setting this to be always True or False). Similarly, xarray does not check +for conflicts between ``attrs`` when combining arrays and datasets, unless +explicitly requested with the option ``compat='identical'``. The guiding +principle is that metadata should not be allowed to get in the way. + +What other netCDF related Python libraries should I know about? +--------------------------------------------------------------- + +`netCDF4-python`__ provides a lower level interface for working with +netCDF and OpenDAP datasets in Python. We use netCDF4-python internally in +xarray, and have contributed a number of improvements and fixes upstream. Xarray +does not yet support all of netCDF4-python's features, such as modifying files +on-disk. + +__ https://unidata.github.io/netcdf4-python/ + +Iris_ (supported by the UK Met office) provides similar tools for in- +memory manipulation of labeled arrays, aimed specifically at weather and +climate data needs. Indeed, the Iris :py:class:`~iris.cube.Cube` was direct +inspiration for xarray's :py:class:`~xarray.DataArray`. Xarray and Iris take very +different approaches to handling metadata: Iris strictly interprets +`CF conventions`_. Iris particularly shines at mapping, thanks to its +integration with Cartopy_. + +.. _Iris: https://scitools-iris.readthedocs.io/en/stable/ +.. _Cartopy: https://scitools.org.uk/cartopy/docs/latest/ + +We think the design decisions we have made for xarray (namely, basing it on +pandas) make it a faster and more flexible data analysis tool. That said, Iris +has some great domain specific functionality, and xarray includes +methods for converting back and forth between xarray and Iris. See +:py:meth:`~xarray.DataArray.to_iris` for more details. + +What other projects leverage xarray? +------------------------------------ + +See section :ref:`ecosystem`. + +How do I open format X file as an xarray dataset? +------------------------------------------------- + +To open format X file in xarray, you need to know the `format of the data `_ you want to read. If the format is supported, you can use the appropriate function provided by xarray. The following table provides functions used for different file formats in xarray, as well as links to other packages that can be used: + +.. csv-table:: + :header: "File Format", "Open via", " Related Packages" + :widths: 15, 45, 15 + + "NetCDF (.nc, .nc4, .cdf)","``open_dataset()`` OR ``open_mfdataset()``", "`netCDF4 `_, `netcdf `_ , `cdms2 `_" + "HDF5 (.h5, .hdf5)","``open_dataset()`` OR ``open_mfdataset()``", "`h5py `_, `pytables `_ " + "GRIB (.grb, .grib)", "``open_dataset()``", "`cfgrib `_, `pygrib `_" + "CSV (.csv)","``open_dataset()``", "`pandas`_ , `dask `_" + "Zarr (.zarr)","``open_dataset()`` OR ``open_mfdataset()``", "`zarr `_ , `dask `_ " + +.. _pandas: https://pandas.pydata.org + +If you are unable to open a file in xarray: + +- You should check that you are having all necessary dependencies installed, including any optional dependencies (like scipy, h5netcdf, cfgrib etc as mentioned below) that may be required for the specific use case. + +- If all necessary dependencies are installed but the file still cannot be opened, you must check if there are any specialized backends available for the specific file format you are working with. You can consult the xarray documentation or the documentation for the file format to determine if a specialized backend is required, and if so, how to install and use it with xarray. + +- If the file format is not supported by xarray or any of its available backends, the user may need to use a different library or tool to work with the file. You can consult the documentation for the file format to determine which tools are recommended for working with it. + +Xarray provides a default engine to read files, which is usually determined by the file extension or type. If you don't specify the engine, xarray will try to guess it based on the file extension or type, and may fall back to a different engine if it cannot determine the correct one. + +Therefore, it's good practice to always specify the engine explicitly, to ensure that the correct backend is used and especially when working with complex data formats or non-standard file extensions. + +:py:func:`xarray.backends.list_engines` is a function in xarray that returns a dictionary of available engines and their BackendEntrypoint objects. + +You can use the `engine` argument to specify the backend when calling ``open_dataset()`` or other reading functions in xarray, as shown below: + +NetCDF +~~~~~~ +If you are reading a netCDF file with a ".nc" extension, the default engine is `netcdf4`. However if you have files with non-standard extensions or if the file format is ambiguous. Specify the engine explicitly, to ensure that the correct backend is used. + +Use :py:func:`~xarray.open_dataset` to open a NetCDF file and return an xarray Dataset object. + +.. code:: python + + import xarray as xr + + # use xarray to open the file and return an xarray.Dataset object using netcdf4 engine + + ds = xr.open_dataset("/path/to/my/file.nc", engine="netcdf4") + + # Print Dataset object + + print(ds) + + # use xarray to open the file and return an xarray.Dataset object using scipy engine + + ds = xr.open_dataset("/path/to/my/file.nc", engine="scipy") + +We recommend installing `scipy` via conda using the below given code: + +:: + + conda install scipy + +HDF5 +~~~~ +Use :py:func:`~xarray.open_dataset` to open an HDF5 file and return an xarray Dataset object. + +You should specify the `engine` keyword argument when reading HDF5 files with xarray, as there are multiple backends that can be used to read HDF5 files, and xarray may not always be able to automatically detect the correct one based on the file extension or file format. + +To read HDF5 files with xarray, you can use the :py:func:`~xarray.open_dataset` function from the `h5netcdf` backend, as follows: + +.. code:: python + + import xarray as xr + + # Open HDF5 file as an xarray Dataset + + ds = xr.open_dataset("path/to/hdf5/file.hdf5", engine="h5netcdf") + + # Print Dataset object + + print(ds) + +We recommend you to install `h5netcdf` library using the below given code: + +:: + + conda install -c conda-forge h5netcdf + +If you want to use the `netCDF4` backend to read a file with a ".h5" extension (which is typically associated with HDF5 file format), you can specify the engine argument as follows: + +.. code:: python + + ds = xr.open_dataset("path/to/file.h5", engine="netcdf4") + +GRIB +~~~~ +You should specify the `engine` keyword argument when reading GRIB files with xarray, as there are multiple backends that can be used to read GRIB files, and xarray may not always be able to automatically detect the correct one based on the file extension or file format. + +Use the :py:func:`~xarray.open_dataset` function from the `cfgrib` package to open a GRIB file as an xarray Dataset. + +.. code:: python + + import xarray as xr + + # define the path to your GRIB file and the engine you want to use to open the file + # use ``open_dataset()`` to open the file with the specified engine and return an xarray.Dataset object + + ds = xr.open_dataset("path/to/your/file.grib", engine="cfgrib") + + # Print Dataset object + + print(ds) + +We recommend installing `cfgrib` via conda using the below given code: + +:: + + conda install -c conda-forge cfgrib + +CSV +~~~ +By default, xarray uses the built-in `pandas` library to read CSV files. In general, you don't need to specify the engine keyword argument when reading CSV files with xarray, as the default `pandas` engine is usually sufficient for most use cases. If you are working with very large CSV files or if you need to perform certain types of data processing that are not supported by the default `pandas` engine, you may want to use a different backend. +In such cases, you can specify the engine argument when reading the CSV file with xarray. + +To read CSV files with xarray, use the :py:func:`~xarray.open_dataset` function and specify the path to the CSV file as follows: + +.. code:: python + + import xarray as xr + import pandas as pd + + # Load CSV file into pandas DataFrame using the "c" engine + + df = pd.read_csv("your_file.csv", engine="c") + + # Convert `:py:func:pandas` DataFrame to xarray.Dataset + + ds = xr.Dataset.from_dataframe(df) + + # Prints the resulting xarray dataset + + print(ds) + +Zarr +~~~~ +When opening a Zarr dataset with xarray, the `engine` is automatically detected based on the file extension or the type of input provided. If the dataset is stored in a directory with a ".zarr" extension, xarray will automatically use the "zarr" engine. + +To read zarr files with xarray, use the :py:func:`~xarray.open_dataset` function and specify the path to the zarr file as follows: + +.. code:: python + + import xarray as xr + + # use xarray to open the file and return an xarray.Dataset object using zarr engine + + ds = xr.open_dataset("path/to/your/file.zarr", engine="zarr") + + # Print Dataset object + + print(ds) + +We recommend installing `zarr` via conda using the below given code: + +:: + + conda install -c conda-forge zarr + +There may be situations where you need to specify the engine manually using the `engine` keyword argument. For example, if you have a Zarr dataset stored in a file with a different extension (e.g., ".npy"), you will need to specify the engine as "zarr" explicitly when opening the dataset. + +Some packages may have additional functionality beyond what is shown here. You can refer to the documentation for each package for more information. + +How does xarray handle missing values? +-------------------------------------- + +**xarray can handle missing values using ``np.NaN``** + +- ``np.NaN`` is used to represent missing values in labeled arrays and datasets. It is a commonly used standard for representing missing or undefined numerical data in scientific computing. ``np.NaN`` is a constant value in NumPy that represents "Not a Number" or missing values. + +- Most of xarray's computation methods are designed to automatically handle missing values appropriately. + + For example, when performing operations like addition or multiplication on arrays that contain missing values, xarray will automatically ignore the missing values and only perform the operation on the valid data. This makes it easy to work with data that may contain missing or undefined values without having to worry about handling them explicitly. + +- Many of xarray's `aggregation methods `_, such as ``sum()``, ``mean()``, ``min()``, ``max()``, and others, have a skipna argument that controls whether missing values (represented by NaN) should be skipped (True) or treated as NaN (False) when performing the calculation. + + By default, ``skipna`` is set to `True`, so missing values are ignored when computing the result. However, you can set ``skipna`` to `False` if you want missing values to be treated as NaN and included in the calculation. + +- On `plotting `_ an xarray dataset or array that contains missing values, xarray will simply leave the missing values as blank spaces in the plot. + +- We have a set of `methods `_ for manipulating missing and filling values. + +How should I cite xarray? +------------------------- + +If you are using xarray and would like to cite it in academic publication, we +would certainly appreciate it. We recommend two citations. + + 1. At a minimum, we recommend citing the xarray overview journal article, + published in the Journal of Open Research Software. + + - Hoyer, S. & Hamman, J., (2017). xarray: N-D labeled Arrays and + Datasets in Python. Journal of Open Research Software. 5(1), p.10. + DOI: https://doi.org/10.5334/jors.148 + + Here’s an example of a BibTeX entry:: + + @article{hoyer2017xarray, + title = {xarray: {N-D} labeled arrays and datasets in {Python}}, + author = {Hoyer, S. and J. Hamman}, + journal = {Journal of Open Research Software}, + volume = {5}, + number = {1}, + year = {2017}, + publisher = {Ubiquity Press}, + doi = {10.5334/jors.148}, + url = {https://doi.org/10.5334/jors.148} + } + + 2. You may also want to cite a specific version of the xarray package. We + provide a `Zenodo citation and DOI `_ + for this purpose: + + .. image:: https://zenodo.org/badge/doi/10.5281/zenodo.598201.svg + :target: https://doi.org/10.5281/zenodo.598201 + + An example BibTeX entry:: + + @misc{xarray_v0_8_0, + author = {Stephan Hoyer and Clark Fitzgerald and Joe Hamman and others}, + title = {xarray: v0.8.0}, + month = aug, + year = 2016, + doi = {10.5281/zenodo.59499}, + url = {https://doi.org/10.5281/zenodo.59499} + } + +.. _public api: + +What parts of xarray are considered public API? +----------------------------------------------- + +As a rule, only functions/methods documented in our :ref:`api` are considered +part of xarray's public API. Everything else (in particular, everything in +``xarray.core`` that is not also exposed in the top level ``xarray`` namespace) +is considered a private implementation detail that may change at any time. + +Objects that exist to facilitate xarray's fluent interface on ``DataArray`` and +``Dataset`` objects are a special case. For convenience, we document them in +the API docs, but only their methods and the ``DataArray``/``Dataset`` +methods/properties to construct them (e.g., ``.plot()``, ``.groupby()``, +``.str``) are considered public API. Constructors and other details of the +internal classes used to implemented them (i.e., +``xarray.plot.plotting._PlotMethods``, ``xarray.core.groupby.DataArrayGroupBy``, +``xarray.core.accessor_str.StringAccessor``) are not. diff --git a/test/fixtures/whole_applications/xarray/doc/getting-started-guide/index.rst b/test/fixtures/whole_applications/xarray/doc/getting-started-guide/index.rst new file mode 100644 index 0000000..20fd49f --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/getting-started-guide/index.rst @@ -0,0 +1,15 @@ +################ +Getting Started +################ + +The getting started guide aims to get you using xarray productively as quickly as possible. +It is designed as an entry point for new users, and it provided an introduction to xarray's main concepts. + +.. toctree:: + :maxdepth: 2 + :hidden: + + why-xarray + installing + quick-overview + faq diff --git a/test/fixtures/whole_applications/xarray/doc/getting-started-guide/installing.rst b/test/fixtures/whole_applications/xarray/doc/getting-started-guide/installing.rst new file mode 100644 index 0000000..ca12ae6 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/getting-started-guide/installing.rst @@ -0,0 +1,174 @@ +.. _installing: + +Installation +============ + +Required dependencies +--------------------- + +- Python (3.9 or later) +- `numpy `__ (1.23 or later) +- `packaging `__ (22 or later) +- `pandas `__ (1.5 or later) + +.. _optional-dependencies: + +Optional dependencies +--------------------- + +.. note:: + + If you are using pip to install xarray, optional dependencies can be installed by + specifying *extras*. :ref:`installation-instructions` for both pip and conda + are given below. + +For netCDF and IO +~~~~~~~~~~~~~~~~~ + +- `netCDF4 `__: recommended if you + want to use xarray for reading or writing netCDF files +- `scipy `__: used as a fallback for reading/writing netCDF3 +- `pydap `__: used as a fallback for accessing OPeNDAP +- `h5netcdf `__: an alternative library for + reading and writing netCDF4 files that does not use the netCDF-C libraries +- `zarr `__: for chunked, compressed, N-dimensional arrays. +- `cftime `__: recommended if you + want to encode/decode datetimes for non-standard calendars or dates before + year 1678 or after year 2262. +- `iris `__: for conversion to and from iris' + Cube objects + +For accelerating xarray +~~~~~~~~~~~~~~~~~~~~~~~ + +- `scipy `__: necessary to enable the interpolation features for + xarray objects +- `bottleneck `__: speeds up + NaN-skipping and rolling window aggregations by a large factor +- `numbagg `_: for exponential rolling + window operations + +For parallel computing +~~~~~~~~~~~~~~~~~~~~~~ + +- `dask.array `__: required for :ref:`dask`. + +For plotting +~~~~~~~~~~~~ + +- `matplotlib `__: required for :ref:`plotting` +- `cartopy `__: recommended for :ref:`plot-maps` +- `seaborn `__: for better + color palettes +- `nc-time-axis `__: for plotting + cftime.datetime objects + +Alternative data containers +~~~~~~~~~~~~~~~~~~~~~~~~~~~ +- `sparse `_: for sparse arrays +- `pint `_: for units of measure +- Any numpy-like objects that support + `NEP-18 `_. + Note that while such libraries theoretically should work, they are untested. + Integration tests are in the process of being written for individual libraries. + + +.. _mindeps_policy: + +Minimum dependency versions +--------------------------- +Xarray adopts a rolling policy regarding the minimum supported version of its +dependencies: + +- **Python:** 30 months + (`NEP-29 `_) +- **numpy:** 18 months + (`NEP-29 `_) +- **all other libraries:** 12 months + +This means the latest minor (X.Y) version from N months prior. Patch versions (x.y.Z) +are not pinned, and only the latest available at the moment of publishing the xarray +release is guaranteed to work. + +You can see the actual minimum tested versions: + +``_ + +.. _installation-instructions: + +Instructions +------------ + +Xarray itself is a pure Python package, but its dependencies are not. The +easiest way to get everything installed is to use conda_. To install xarray +with its recommended dependencies using the conda command line tool:: + + $ conda install -c conda-forge xarray dask netCDF4 bottleneck + +.. _conda: https://docs.conda.io + +If you require other :ref:`optional-dependencies` add them to the line above. + +We recommend using the community maintained `conda-forge `__ channel, +as some of the dependencies are difficult to build. New releases may also appear in conda-forge before +being updated in the default channel. + +If you don't use conda, be sure you have the required dependencies (numpy and +pandas) installed first. Then, install xarray with pip:: + + $ python -m pip install xarray + +We also maintain other dependency sets for different subsets of functionality:: + + $ python -m pip install "xarray[io]" # Install optional dependencies for handling I/O + $ python -m pip install "xarray[accel]" # Install optional dependencies for accelerating xarray + $ python -m pip install "xarray[parallel]" # Install optional dependencies for dask arrays + $ python -m pip install "xarray[viz]" # Install optional dependencies for visualization + $ python -m pip install "xarray[complete]" # Install all the above + +The above commands should install most of the `optional dependencies`_. However, +some packages which are either not listed on PyPI or require extra +installation steps are excluded. To know which dependencies would be +installed, take a look at the ``[project.optional-dependencies]`` section in +``pyproject.toml``: + +.. literalinclude:: ../../pyproject.toml + :language: toml + :start-at: [project.optional-dependencies] + :end-before: [build-system] + +Development versions +-------------------- +To install the most recent development version, install from github:: + + $ python -m pip install git+https://github.com/pydata/xarray.git + +or from TestPyPI:: + + $ python -m pip install --index-url https://test.pypi.org/simple --extra-index-url https://pypi.org/simple --pre xarray + +Testing +------- + +To run the test suite after installing xarray, install (via pypi or conda) `py.test +`__ and run ``pytest`` in the root directory of the xarray +repository. + + +Performance Monitoring +~~~~~~~~~~~~~~~~~~~~~~ + +.. + TODO: uncomment once we have a working setup + see https://github.com/pydata/xarray/pull/5066 + + A fixed-point performance monitoring of (a part of) our code can be seen on + `this page `__. + +To run these benchmark tests in a local machine, first install + +- `airspeed-velocity `__: a tool for benchmarking + Python packages over their lifetime. + +and run +``asv run # this will install some conda environments in ./.asv/envs`` diff --git a/test/fixtures/whole_applications/xarray/doc/getting-started-guide/quick-overview.rst b/test/fixtures/whole_applications/xarray/doc/getting-started-guide/quick-overview.rst new file mode 100644 index 0000000..ee13fea --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/getting-started-guide/quick-overview.rst @@ -0,0 +1,230 @@ +############## +Quick overview +############## + +Here are some quick examples of what you can do with :py:class:`xarray.DataArray` +objects. Everything is explained in much more detail in the rest of the +documentation. + +To begin, import numpy, pandas and xarray using their customary abbreviations: + +.. ipython:: python + + import numpy as np + import pandas as pd + import xarray as xr + +Create a DataArray +------------------ + +You can make a DataArray from scratch by supplying data in the form of a numpy +array or list, with optional *dimensions* and *coordinates*: + +.. ipython:: python + + data = xr.DataArray(np.random.randn(2, 3), dims=("x", "y"), coords={"x": [10, 20]}) + data + +In this case, we have generated a 2D array, assigned the names *x* and *y* to the two dimensions respectively and associated two *coordinate labels* '10' and '20' with the two locations along the x dimension. If you supply a pandas :py:class:`~pandas.Series` or :py:class:`~pandas.DataFrame`, metadata is copied directly: + +.. ipython:: python + + xr.DataArray(pd.Series(range(3), index=list("abc"), name="foo")) + +Here are the key properties for a ``DataArray``: + +.. ipython:: python + + # like in pandas, values is a numpy array that you can modify in-place + data.values + data.dims + data.coords + # you can use this dictionary to store arbitrary metadata + data.attrs + + +Indexing +-------- + +Xarray supports four kinds of indexing. Since we have assigned coordinate labels to the x dimension we can use label-based indexing along that dimension just like pandas. The four examples below all yield the same result (the value at `x=10`) but at varying levels of convenience and intuitiveness. + +.. ipython:: python + + # positional and by integer label, like numpy + data[0, :] + + # loc or "location": positional and coordinate label, like pandas + data.loc[10] + + # isel or "integer select": by dimension name and integer label + data.isel(x=0) + + # sel or "select": by dimension name and coordinate label + data.sel(x=10) + + +Unlike positional indexing, label-based indexing frees us from having to know how our array is organized. All we need to know are the dimension name and the label we wish to index i.e. ``data.sel(x=10)`` works regardless of whether ``x`` is the first or second dimension of the array and regardless of whether ``10`` is the first or second element of ``x``. We have already told xarray that x is the first dimension when we created ``data``: xarray keeps track of this so we don't have to. For more, see :ref:`indexing`. + + +Attributes +---------- + +While you're setting up your DataArray, it's often a good idea to set metadata attributes. A useful choice is to set ``data.attrs['long_name']`` and ``data.attrs['units']`` since xarray will use these, if present, to automatically label your plots. These special names were chosen following the `NetCDF Climate and Forecast (CF) Metadata Conventions `_. ``attrs`` is just a Python dictionary, so you can assign anything you wish. + +.. ipython:: python + + data.attrs["long_name"] = "random velocity" + data.attrs["units"] = "metres/sec" + data.attrs["description"] = "A random variable created as an example." + data.attrs["random_attribute"] = 123 + data.attrs + # you can add metadata to coordinates too + data.x.attrs["units"] = "x units" + + +Computation +----------- + +Data arrays work very similarly to numpy ndarrays: + +.. ipython:: python + + data + 10 + np.sin(data) + # transpose + data.T + data.sum() + +However, aggregation operations can use dimension names instead of axis +numbers: + +.. ipython:: python + + data.mean(dim="x") + +Arithmetic operations broadcast based on dimension name. This means you don't +need to insert dummy dimensions for alignment: + +.. ipython:: python + + a = xr.DataArray(np.random.randn(3), [data.coords["y"]]) + b = xr.DataArray(np.random.randn(4), dims="z") + + a + b + + a + b + +It also means that in most cases you do not need to worry about the order of +dimensions: + +.. ipython:: python + + data - data.T + +Operations also align based on index labels: + +.. ipython:: python + + data[:-1] - data[:1] + +For more, see :ref:`comput`. + +GroupBy +------- + +Xarray supports grouped operations using a very similar API to pandas (see :ref:`groupby`): + +.. ipython:: python + + labels = xr.DataArray(["E", "F", "E"], [data.coords["y"]], name="labels") + labels + data.groupby(labels).mean("y") + data.groupby(labels).map(lambda x: x - x.min()) + +Plotting +-------- + +Visualizing your datasets is quick and convenient: + +.. ipython:: python + + @savefig plotting_quick_overview.png + data.plot() + +Note the automatic labeling with names and units. Our effort in adding metadata attributes has paid off! Many aspects of these figures are customizable: see :ref:`plotting`. + +pandas +------ + +Xarray objects can be easily converted to and from pandas objects using the :py:meth:`~xarray.DataArray.to_series`, :py:meth:`~xarray.DataArray.to_dataframe` and :py:meth:`~pandas.DataFrame.to_xarray` methods: + +.. ipython:: python + + series = data.to_series() + series + + # convert back + series.to_xarray() + +Datasets +-------- + +:py:class:`xarray.Dataset` is a dict-like container of aligned ``DataArray`` +objects. You can think of it as a multi-dimensional generalization of the +:py:class:`pandas.DataFrame`: + +.. ipython:: python + + ds = xr.Dataset(dict(foo=data, bar=("x", [1, 2]), baz=np.pi)) + ds + + +This creates a dataset with three DataArrays named ``foo``, ``bar`` and ``baz``. Use dictionary or dot indexing to pull out ``Dataset`` variables as ``DataArray`` objects but note that assignment only works with dictionary indexing: + +.. ipython:: python + + ds["foo"] + ds.foo + + +When creating ``ds``, we specified that ``foo`` is identical to ``data`` created earlier, ``bar`` is one-dimensional with single dimension ``x`` and associated values '1' and '2', and ``baz`` is a scalar not associated with any dimension in ``ds``. Variables in datasets can have different ``dtype`` and even different dimensions, but all dimensions are assumed to refer to points in the same shared coordinate system i.e. if two variables have dimension ``x``, that dimension must be identical in both variables. + +For example, when creating ``ds`` xarray automatically *aligns* ``bar`` with ``DataArray`` ``foo``, i.e., they share the same coordinate system so that ``ds.bar['x'] == ds.foo['x'] == ds['x']``. Consequently, the following works without explicitly specifying the coordinate ``x`` when creating ``ds['bar']``: + +.. ipython:: python + + ds.bar.sel(x=10) + + + +You can do almost everything you can do with ``DataArray`` objects with +``Dataset`` objects (including indexing and arithmetic) if you prefer to work +with multiple variables at once. + +Read & write netCDF files +------------------------- + +NetCDF is the recommended file format for xarray objects. Users +from the geosciences will recognize that the :py:class:`~xarray.Dataset` data +model looks very similar to a netCDF file (which, in fact, inspired it). + +You can directly read and write xarray objects to disk using :py:meth:`~xarray.Dataset.to_netcdf`, :py:func:`~xarray.open_dataset` and +:py:func:`~xarray.open_dataarray`: + +.. ipython:: python + + ds.to_netcdf("example.nc") + reopened = xr.open_dataset("example.nc") + reopened + +.. ipython:: python + :suppress: + + import os + + reopened.close() + os.remove("example.nc") + + +It is common for datasets to be distributed across multiple files (commonly one file per timestep). Xarray supports this use-case by providing the :py:meth:`~xarray.open_mfdataset` and the :py:meth:`~xarray.save_mfdataset` methods. For more, see :ref:`io`. diff --git a/test/fixtures/whole_applications/xarray/doc/getting-started-guide/why-xarray.rst b/test/fixtures/whole_applications/xarray/doc/getting-started-guide/why-xarray.rst new file mode 100644 index 0000000..d795681 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/getting-started-guide/why-xarray.rst @@ -0,0 +1,115 @@ +Overview: Why xarray? +===================== + +Xarray introduces labels in the form of dimensions, coordinates and attributes on top of +raw NumPy-like multidimensional arrays, which allows for a more intuitive, more concise, +and less error-prone developer experience. + +What labels enable +------------------ + +Multi-dimensional (a.k.a. N-dimensional, ND) arrays (sometimes called +"tensors") are an essential part of computational science. +They are encountered in a wide range of fields, including physics, astronomy, +geoscience, bioinformatics, engineering, finance, and deep learning. +In Python, NumPy_ provides the fundamental data structure and API for +working with raw ND arrays. +However, real-world datasets are usually more than just raw numbers; +they have labels which encode information about how the array values map +to locations in space, time, etc. + +Xarray doesn't just keep track of labels on arrays -- it uses them to provide a +powerful and concise interface. For example: + +- Apply operations over dimensions by name: ``x.sum('time')``. +- Select values by label (or logical location) instead of integer location: + ``x.loc['2014-01-01']`` or ``x.sel(time='2014-01-01')``. +- Mathematical operations (e.g., ``x - y``) vectorize across multiple + dimensions (array broadcasting) based on dimension names, not shape. +- Easily use the `split-apply-combine `_ + paradigm with ``groupby``: + ``x.groupby('time.dayofyear').mean()``. +- Database-like alignment based on coordinate labels that smoothly + handles missing values: ``x, y = xr.align(x, y, join='outer')``. +- Keep track of arbitrary metadata in the form of a Python dictionary: + ``x.attrs``. + +The N-dimensional nature of xarray's data structures makes it suitable for dealing +with multi-dimensional scientific data, and its use of dimension names +instead of axis labels (``dim='time'`` instead of ``axis=0``) makes such +arrays much more manageable than the raw numpy ndarray: with xarray, you don't +need to keep track of the order of an array's dimensions or insert dummy dimensions of +size 1 to align arrays (e.g., using ``np.newaxis``). + +The immediate payoff of using xarray is that you'll write less code. The +long-term payoff is that you'll understand what you were thinking when you come +back to look at it weeks or months later. + +Core data structures +-------------------- + +Xarray has two core data structures, which build upon and extend the core +strengths of NumPy_ and pandas_. Both data structures are fundamentally N-dimensional: + +- :py:class:`~xarray.DataArray` is our implementation of a labeled, N-dimensional + array. It is an N-D generalization of a :py:class:`pandas.Series`. The name + ``DataArray`` itself is borrowed from Fernando Perez's datarray_ project, + which prototyped a similar data structure. +- :py:class:`~xarray.Dataset` is a multi-dimensional, in-memory array database. + It is a dict-like container of ``DataArray`` objects aligned along any number of + shared dimensions, and serves a similar purpose in xarray to the + :py:class:`pandas.DataFrame`. + +The value of attaching labels to numpy's :py:class:`numpy.ndarray` may be +fairly obvious, but the dataset may need more motivation. + +The power of the dataset over a plain dictionary is that, in addition to +pulling out arrays by name, it is possible to select or combine data along a +dimension across all arrays simultaneously. Like a +:py:class:`~pandas.DataFrame`, datasets facilitate array operations with +heterogeneous data -- the difference is that the arrays in a dataset can have +not only different data types, but also different numbers of dimensions. + +This data model is borrowed from the netCDF_ file format, which also provides +xarray with a natural and portable serialization format. NetCDF is very popular +in the geosciences, and there are existing libraries for reading and writing +netCDF in many programming languages, including Python. + +Xarray distinguishes itself from many tools for working with netCDF data +in-so-far as it provides data structures for in-memory analytics that both +utilize and preserve labels. You only need to do the tedious work of adding +metadata once, not every time you save a file. + +Goals and aspirations +--------------------- + +Xarray contributes domain-agnostic data-structures and tools for labeled +multi-dimensional arrays to Python's SciPy_ ecosystem for numerical computing. +In particular, xarray builds upon and integrates with NumPy_ and pandas_: + +- Our user-facing interfaces aim to be more explicit versions of those found in + NumPy/pandas. +- Compatibility with the broader ecosystem is a major goal: it should be easy + to get your data in and out. +- We try to keep a tight focus on functionality and interfaces related to + labeled data, and leverage other Python libraries for everything else, e.g., + NumPy/pandas for fast arrays/indexing (xarray itself contains no compiled + code), Dask_ for parallel computing, matplotlib_ for plotting, etc. + +Xarray is a collaborative and community driven project, run entirely on +volunteer effort (see :ref:`contributing`). +Our target audience is anyone who needs N-dimensional labeled arrays in Python. +Originally, development was driven by the data analysis needs of physical +scientists (especially geoscientists who already know and love +netCDF_), but it has become a much more broadly useful tool, and is still +under active development. +See our technical :ref:`roadmap` for more details, and feel free to reach out +with questions about whether xarray is the right tool for your needs. + +.. _datarray: https://github.com/fperez/datarray +.. _Dask: http://dask.org +.. _matplotlib: http://matplotlib.org +.. _netCDF: http://www.unidata.ucar.edu/software/netcdf +.. _NumPy: http://www.numpy.org +.. _pandas: http://pandas.pydata.org +.. _SciPy: http://www.scipy.org diff --git a/test/fixtures/whole_applications/xarray/doc/howdoi.rst b/test/fixtures/whole_applications/xarray/doc/howdoi.rst new file mode 100644 index 0000000..97b0872 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/howdoi.rst @@ -0,0 +1,77 @@ +.. currentmodule:: xarray + +.. _howdoi: + +How do I ... +============ + +.. list-table:: + :header-rows: 1 + :widths: 40 60 + + * - How do I... + - Solution + * - add a DataArray to my dataset as a new variable + - ``my_dataset[varname] = my_dataArray`` or :py:meth:`Dataset.assign` (see also :ref:`dictionary_like_methods`) + * - add variables from other datasets to my dataset + - :py:meth:`Dataset.merge` + * - add a new dimension and/or coordinate + - :py:meth:`DataArray.expand_dims`, :py:meth:`Dataset.expand_dims` + * - add a new coordinate variable + - :py:meth:`DataArray.assign_coords` + * - change a data variable to a coordinate variable + - :py:meth:`Dataset.set_coords` + * - change the order of dimensions + - :py:meth:`DataArray.transpose`, :py:meth:`Dataset.transpose` + * - reshape dimensions + - :py:meth:`DataArray.stack`, :py:meth:`Dataset.stack`, :py:meth:`Dataset.coarsen.construct`, :py:meth:`DataArray.coarsen.construct` + * - remove a variable from my object + - :py:meth:`Dataset.drop_vars`, :py:meth:`DataArray.drop_vars` + * - remove dimensions of length 1 or 0 + - :py:meth:`DataArray.squeeze`, :py:meth:`Dataset.squeeze` + * - remove all variables with a particular dimension + - :py:meth:`Dataset.drop_dims` + * - convert non-dimension coordinates to data variables or remove them + - :py:meth:`DataArray.reset_coords`, :py:meth:`Dataset.reset_coords` + * - rename a variable, dimension or coordinate + - :py:meth:`Dataset.rename`, :py:meth:`DataArray.rename`, :py:meth:`Dataset.rename_vars`, :py:meth:`Dataset.rename_dims`, + * - convert a DataArray to Dataset or vice versa + - :py:meth:`DataArray.to_dataset`, :py:meth:`Dataset.to_dataarray`, :py:meth:`Dataset.to_stacked_array`, :py:meth:`DataArray.to_unstacked_dataset` + * - extract variables that have certain attributes + - :py:meth:`Dataset.filter_by_attrs` + * - extract the underlying array (e.g. NumPy or Dask arrays) + - :py:attr:`DataArray.data` + * - convert to and extract the underlying NumPy array + - :py:attr:`DataArray.to_numpy` + * - convert to a pandas DataFrame + - :py:attr:`Dataset.to_dataframe` + * - sort values + - :py:attr:`Dataset.sortby` + * - find out if my xarray object is wrapping a Dask Array + - :py:func:`dask.is_dask_collection` + * - know how much memory my object requires + - :py:attr:`DataArray.nbytes`, :py:attr:`Dataset.nbytes` + * - Get axis number for a dimension + - :py:meth:`DataArray.get_axis_num` + * - convert a possibly irregularly sampled timeseries to a regularly sampled timeseries + - :py:meth:`DataArray.resample`, :py:meth:`Dataset.resample` (see :ref:`resampling` for more) + * - apply a function on all data variables in a Dataset + - :py:meth:`Dataset.map` + * - write xarray objects with complex values to a netCDF file + - :py:func:`Dataset.to_netcdf`, :py:func:`DataArray.to_netcdf` specifying ``engine="h5netcdf", invalid_netcdf=True`` + * - make xarray objects look like other xarray objects + - :py:func:`~xarray.ones_like`, :py:func:`~xarray.zeros_like`, :py:func:`~xarray.full_like`, :py:meth:`Dataset.reindex_like`, :py:meth:`Dataset.interp_like`, :py:meth:`Dataset.broadcast_like`, :py:meth:`DataArray.reindex_like`, :py:meth:`DataArray.interp_like`, :py:meth:`DataArray.broadcast_like` + * - Make sure my datasets have values at the same coordinate locations + - ``xr.align(dataset_1, dataset_2, join="exact")`` + * - replace NaNs with other values + - :py:meth:`Dataset.fillna`, :py:meth:`Dataset.ffill`, :py:meth:`Dataset.bfill`, :py:meth:`Dataset.interpolate_na`, :py:meth:`DataArray.fillna`, :py:meth:`DataArray.ffill`, :py:meth:`DataArray.bfill`, :py:meth:`DataArray.interpolate_na` + * - extract the year, month, day or similar from a DataArray of time values + - ``obj.dt.month`` for example where ``obj`` is a :py:class:`~xarray.DataArray` containing ``datetime64`` or ``cftime`` values. See :ref:`dt_accessor` for more. + * - round off time values to a specified frequency + - ``obj.dt.ceil``, ``obj.dt.floor``, ``obj.dt.round``. See :ref:`dt_accessor` for more. + * - make a mask that is ``True`` where an object contains any of the values in a array + - :py:meth:`Dataset.isin`, :py:meth:`DataArray.isin` + * - Index using a boolean mask + - :py:meth:`Dataset.query`, :py:meth:`DataArray.query`, :py:meth:`Dataset.where`, :py:meth:`DataArray.where` + * - preserve ``attrs`` during (most) xarray operations + - ``xr.set_options(keep_attrs=True)`` diff --git a/test/fixtures/whole_applications/xarray/doc/index.rst b/test/fixtures/whole_applications/xarray/doc/index.rst new file mode 100644 index 0000000..138e9d9 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/index.rst @@ -0,0 +1,89 @@ +.. module:: xarray + +Xarray documentation +==================== + +Xarray makes working with labelled multi-dimensional arrays in Python simple, +efficient, and fun! + +**Useful links**: +`Home `__ | +`Code Repository `__ | +`Issues `__ | +`Discussions `__ | +`Releases `__ | +`Stack Overflow `__ | +`Mailing List `__ | +`Blog `__ + + +.. grid:: 1 1 2 2 + :gutter: 2 + + .. grid-item-card:: Getting started + :img-top: _static/index_getting_started.svg + :link: getting-started-guide/index + :link-type: doc + + New to *xarray*? Check out the getting started guides. They contain an + introduction to *Xarray's* main concepts and links to additional tutorials. + + .. grid-item-card:: User guide + :img-top: _static/index_user_guide.svg + :link: user-guide/index + :link-type: doc + + The user guide provides in-depth information on the + key concepts of Xarray with useful background information and explanation. + + .. grid-item-card:: API reference + :img-top: _static/index_api.svg + :link: api + :link-type: doc + + The reference guide contains a detailed description of the Xarray API. + The reference describes how the methods work and which parameters can + be used. It assumes that you have an understanding of the key concepts. + + .. grid-item-card:: Developer guide + :img-top: _static/index_contribute.svg + :link: contributing + :link-type: doc + + Saw a typo in the documentation? Want to improve existing functionalities? + The contributing guidelines will guide you through the process of improving + Xarray. + +.. toctree:: + :maxdepth: 2 + :hidden: + :caption: For users + + Getting Started + User Guide + Gallery + Tutorials & Videos + API Reference + How do I ... + Ecosystem + +.. toctree:: + :maxdepth: 2 + :hidden: + :caption: For developers/contributors + + Contributing Guide + Xarray Internals + Development Roadmap + Team + Developers Meeting + What’s New + GitHub repository + +.. toctree:: + :maxdepth: 1 + :hidden: + :caption: Community + + GitHub discussions + StackOverflow diff --git a/test/fixtures/whole_applications/xarray/doc/internals/chunked-arrays.rst b/test/fixtures/whole_applications/xarray/doc/internals/chunked-arrays.rst new file mode 100644 index 0000000..ba7ce72 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/internals/chunked-arrays.rst @@ -0,0 +1,102 @@ +.. currentmodule:: xarray + +.. _internals.chunkedarrays: + +Alternative chunked array types +=============================== + +.. warning:: + + This is a *highly* experimental feature. Please report any bugs or other difficulties on `xarray's issue tracker `_. + In particular see discussion on `xarray issue #6807 `_ + +Xarray can wrap chunked dask arrays (see :ref:`dask`), but can also wrap any other chunked array type that exposes the correct interface. +This allows us to support using other frameworks for distributed and out-of-core processing, with user code still written as xarray commands. +In particular xarray also supports wrapping :py:class:`cubed.Array` objects +(see `Cubed's documentation `_ and the `cubed-xarray package `_). + +The basic idea is that by wrapping an array that has an explicit notion of ``.chunks``, xarray can expose control over +the choice of chunking scheme to users via methods like :py:meth:`DataArray.chunk` whilst the wrapped array actually +implements the handling of processing all of the chunks. + +Chunked array methods and "core operations" +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +A chunked array needs to meet all the :ref:`requirements for normal duck arrays `, but must also +implement additional features. + +Chunked arrays have additional attributes and methods, such as ``.chunks`` and ``.rechunk``. +Furthermore, Xarray dispatches chunk-aware computations across one or more chunked arrays using special functions known +as "core operations". Examples include ``map_blocks``, ``blockwise``, and ``apply_gufunc``. + +The core operations are generalizations of functions first implemented in :py:mod:`dask.array`. +The implementation of these functions is specific to the type of arrays passed to them. For example, when applying the +``map_blocks`` core operation, :py:class:`dask.array.Array` objects must be processed by :py:func:`dask.array.map_blocks`, +whereas :py:class:`cubed.Array` objects must be processed by :py:func:`cubed.map_blocks`. + +In order to use the correct implementation of a core operation for the array type encountered, xarray dispatches to the +corresponding subclass of :py:class:`~xarray.namedarray.parallelcompat.ChunkManagerEntrypoint`, +also known as a "Chunk Manager". Therefore **a full list of the operations that need to be defined is set by the +API of the** :py:class:`~xarray.namedarray.parallelcompat.ChunkManagerEntrypoint` **abstract base class**. Note that chunked array +methods are also currently dispatched using this class. + +Chunked array creation is also handled by this class. As chunked array objects have a one-to-one correspondence with +in-memory numpy arrays, it should be possible to create a chunked array from a numpy array by passing the desired +chunking pattern to an implementation of :py:class:`~xarray.namedarray.parallelcompat.ChunkManagerEntrypoint.from_array``. + +.. note:: + + The :py:class:`~xarray.namedarray.parallelcompat.ChunkManagerEntrypoint` abstract base class is mostly just acting as a + namespace for containing the chunked-aware function primitives. Ideally in the future we would have an API standard + for chunked array types which codified this structure, making the entrypoint system unnecessary. + +.. currentmodule:: xarray.namedarray.parallelcompat + +.. autoclass:: xarray.namedarray.parallelcompat.ChunkManagerEntrypoint + :members: + +Registering a new ChunkManagerEntrypoint subclass +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Rather than hard-coding various chunk managers to deal with specific chunked array implementations, xarray uses an +entrypoint system to allow developers of new chunked array implementations to register their corresponding subclass of +:py:class:`~xarray.namedarray.parallelcompat.ChunkManagerEntrypoint`. + + +To register a new entrypoint you need to add an entry to the ``setup.cfg`` like this:: + + [options.entry_points] + xarray.chunkmanagers = + dask = xarray.namedarray.daskmanager:DaskManager + +See also `cubed-xarray `_ for another example. + +To check that the entrypoint has worked correctly, you may find it useful to display the available chunkmanagers using +the internal function :py:func:`~xarray.namedarray.parallelcompat.list_chunkmanagers`. + +.. autofunction:: list_chunkmanagers + + +User interface +~~~~~~~~~~~~~~ + +Once the chunkmanager subclass has been registered, xarray objects wrapping the desired array type can be created in 3 ways: + +#. By manually passing the array type to the :py:class:`~xarray.DataArray` constructor, see the examples for :ref:`numpy-like arrays `, + +#. Calling :py:meth:`~xarray.DataArray.chunk`, passing the keyword arguments ``chunked_array_type`` and ``from_array_kwargs``, + +#. Calling :py:func:`~xarray.open_dataset`, passing the keyword arguments ``chunked_array_type`` and ``from_array_kwargs``. + +The latter two methods ultimately call the chunkmanager's implementation of ``.from_array``, to which they pass the ``from_array_kwargs`` dict. +The ``chunked_array_type`` kwarg selects which registered chunkmanager subclass to dispatch to. It defaults to ``'dask'`` +if Dask is installed, otherwise it defaults to whichever chunkmanager is registered if only one is registered. +If multiple chunkmanagers are registered it will raise an error by default. + +Parallel processing without chunks +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To use a parallel array type that does not expose a concept of chunks explicitly, none of the information on this page +is theoretically required. Such an array type (e.g. `Ramba `_ or +`Arkouda `_) could be wrapped using xarray's existing support for +:ref:`numpy-like "duck" arrays `. diff --git a/test/fixtures/whole_applications/xarray/doc/internals/duck-arrays-integration.rst b/test/fixtures/whole_applications/xarray/doc/internals/duck-arrays-integration.rst new file mode 100644 index 0000000..43b17be --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/internals/duck-arrays-integration.rst @@ -0,0 +1,87 @@ + +.. _internals.duckarrays: + +Integrating with duck arrays +============================= + +.. warning:: + + This is an experimental feature. Please report any bugs or other difficulties on `xarray's issue tracker `_. + +Xarray can wrap custom numpy-like arrays (":term:`duck array`\s") - see the :ref:`user guide documentation `. +This page is intended for developers who are interested in wrapping a new custom array type with xarray. + +.. _internals.duckarrays.requirements: + +Duck array requirements +~~~~~~~~~~~~~~~~~~~~~~~ + +Xarray does not explicitly check that required methods are defined by the underlying duck array object before +attempting to wrap the given array. However, a wrapped array type should at a minimum define these attributes: + +* ``shape`` property, +* ``dtype`` property, +* ``ndim`` property, +* ``__array__`` method, +* ``__array_ufunc__`` method, +* ``__array_function__`` method. + +These need to be defined consistently with :py:class:`numpy.ndarray`, for example the array ``shape`` +property needs to obey `numpy's broadcasting rules `_ +(see also the `Python Array API standard's explanation `_ +of these same rules). + +.. _internals.duckarrays.array_api_standard: + +Python Array API standard support +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +As an integration library xarray benefits greatly from the standardization of duck-array libraries' APIs, and so is a +big supporter of the `Python Array API Standard `_. + +We aim to support any array libraries that follow the Array API standard out-of-the-box. However, xarray does occasionally +call some numpy functions which are not (yet) part of the standard (e.g. :py:meth:`xarray.DataArray.pad` calls :py:func:`numpy.pad`). +See `xarray issue #7848 `_ for a list of such functions. We can still support dispatching on these functions through +the array protocols above, it just means that if you exclusively implement the methods in the Python Array API standard +then some features in xarray will not work. + +Custom inline reprs +~~~~~~~~~~~~~~~~~~~ + +In certain situations (e.g. when printing the collapsed preview of +variables of a ``Dataset``), xarray will display the repr of a :term:`duck array` +in a single line, truncating it to a certain number of characters. If that +would drop too much information, the :term:`duck array` may define a +``_repr_inline_`` method that takes ``max_width`` (number of characters) as an +argument + +.. code:: python + + class MyDuckArray: + ... + + def _repr_inline_(self, max_width): + """format to a single line with at most max_width characters""" + ... + + ... + +To avoid duplicated information, this method must omit information about the shape and +:term:`dtype`. For example, the string representation of a ``dask`` array or a +``sparse`` matrix would be: + +.. ipython:: python + + import dask.array as da + import xarray as xr + import sparse + + a = da.linspace(0, 1, 20, chunks=2) + a + + b = np.eye(10) + b[[5, 7, 3, 0], [6, 8, 2, 9]] = 2 + b = sparse.COO.from_numpy(b) + b + + xr.Dataset(dict(a=("x", a), b=(("y", "z"), b))) diff --git a/test/fixtures/whole_applications/xarray/doc/internals/extending-xarray.rst b/test/fixtures/whole_applications/xarray/doc/internals/extending-xarray.rst new file mode 100644 index 0000000..0537ae8 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/internals/extending-xarray.rst @@ -0,0 +1,116 @@ + +.. _internals.accessors: + +Extending xarray using accessors +================================ + +.. ipython:: python + :suppress: + + import xarray as xr + + +Xarray is designed as a general purpose library and hence tries to avoid +including overly domain specific functionality. But inevitably, the need for more +domain specific logic arises. + +.. _internals.accessors.composition: + +Composition over Inheritance +---------------------------- + +One potential solution to this problem is to subclass Dataset and/or DataArray to +add domain specific functionality. However, inheritance is not very robust. It's +easy to inadvertently use internal APIs when subclassing, which means that your +code may break when xarray upgrades. Furthermore, many builtin methods will +only return native xarray objects. + +The standard advice is to use :issue:`composition over inheritance <706>`, but +reimplementing an API as large as xarray's on your own objects can be an onerous +task, even if most methods are only forwarding to xarray implementations. +(For an example of a project which took this approach of subclassing see `UXarray `_). + +If you simply want the ability to call a function with the syntax of a +method call, then the builtin :py:meth:`~xarray.DataArray.pipe` method (copied +from pandas) may suffice. + +.. _internals.accessors.writing accessors: + +Writing Custom Accessors +------------------------ + +To resolve this issue for more complex cases, xarray has the +:py:func:`~xarray.register_dataset_accessor` and +:py:func:`~xarray.register_dataarray_accessor` decorators for adding custom +"accessors" on xarray objects, thereby "extending" the functionality of your xarray object. + +Here's how you might use these decorators to +write a custom "geo" accessor implementing a geography specific extension to +xarray: + +.. literalinclude:: ../examples/_code/accessor_example.py + +In general, the only restriction on the accessor class is that the ``__init__`` method +must have a single parameter: the ``Dataset`` or ``DataArray`` object it is supposed +to work on. + +This achieves the same result as if the ``Dataset`` class had a cached property +defined that returns an instance of your class: + +.. code-block:: python + + class Dataset: + ... + + @property + def geo(self): + return GeoAccessor(self) + +However, using the register accessor decorators is preferable to simply adding +your own ad-hoc property (i.e., ``Dataset.geo = property(...)``), for several +reasons: + +1. It ensures that the name of your property does not accidentally conflict with + any other attributes or methods (including other accessors). +2. Instances of accessor object will be cached on the xarray object that creates + them. This means you can save state on them (e.g., to cache computed + properties). +3. Using an accessor provides an implicit namespace for your custom + functionality that clearly identifies it as separate from built-in xarray + methods. + +.. note:: + + Accessors are created once per DataArray and Dataset instance. New + instances, like those created from arithmetic operations or when accessing + a DataArray from a Dataset (ex. ``ds[var_name]``), will have new + accessors created. + +Back in an interactive IPython session, we can use these properties: + +.. ipython:: python + :suppress: + + exec(open("examples/_code/accessor_example.py").read()) + +.. ipython:: python + + ds = xr.Dataset({"longitude": np.linspace(0, 10), "latitude": np.linspace(0, 20)}) + ds.geo.center + ds.geo.plot() + +The intent here is that libraries that extend xarray could add such an accessor +to implement subclass specific functionality rather than using actual subclasses +or patching in a large number of domain specific methods. For further reading +on ways to write new accessors and the philosophy behind the approach, see +https://github.com/pydata/xarray/issues/1080. + +To help users keep things straight, please `let us know +`_ if you plan to write a new accessor +for an open source library. Existing open source accessors and the libraries +that implement them are available in the list on the :ref:`ecosystem` page. + +To make documenting accessors with ``sphinx`` and ``sphinx.ext.autosummary`` +easier, you can use `sphinx-autosummary-accessors`_. + +.. _sphinx-autosummary-accessors: https://sphinx-autosummary-accessors.readthedocs.io/ diff --git a/test/fixtures/whole_applications/xarray/doc/internals/how-to-add-new-backend.rst b/test/fixtures/whole_applications/xarray/doc/internals/how-to-add-new-backend.rst new file mode 100644 index 0000000..4352dd3 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/internals/how-to-add-new-backend.rst @@ -0,0 +1,499 @@ +.. _add_a_backend: + +How to add a new backend +------------------------ + +Adding a new backend for read support to Xarray does not require +to integrate any code in Xarray; all you need to do is: + +- Create a class that inherits from Xarray :py:class:`~xarray.backends.BackendEntrypoint` + and implements the method ``open_dataset`` see :ref:`RST backend_entrypoint` + +- Declare this class as an external plugin in your project configuration, see :ref:`RST + backend_registration` + +If you also want to support lazy loading and dask see :ref:`RST lazy_loading`. + +Note that the new interface for backends is available from Xarray +version >= 0.18 onwards. + +You can see what backends are currently available in your working environment +with :py:class:`~xarray.backends.list_engines()`. + +.. _RST backend_entrypoint: + +BackendEntrypoint subclassing ++++++++++++++++++++++++++++++ + +Your ``BackendEntrypoint`` sub-class is the primary interface with Xarray, and +it should implement the following attributes and methods: + +- the ``open_dataset`` method (mandatory) +- the ``open_dataset_parameters`` attribute (optional) +- the ``guess_can_open`` method (optional) +- the ``description`` attribute (optional) +- the ``url`` attribute (optional). + +This is what a ``BackendEntrypoint`` subclass should look like: + +.. code-block:: python + + from xarray.backends import BackendEntrypoint + + + class MyBackendEntrypoint(BackendEntrypoint): + def open_dataset( + self, + filename_or_obj, + *, + drop_variables=None, + # other backend specific keyword arguments + # `chunks` and `cache` DO NOT go here, they are handled by xarray + ): + return my_open_dataset(filename_or_obj, drop_variables=drop_variables) + + open_dataset_parameters = ["filename_or_obj", "drop_variables"] + + def guess_can_open(self, filename_or_obj): + try: + _, ext = os.path.splitext(filename_or_obj) + except TypeError: + return False + return ext in {".my_format", ".my_fmt"} + + description = "Use .my_format files in Xarray" + + url = "https://link_to/your_backend/documentation" + +``BackendEntrypoint`` subclass methods and attributes are detailed in the following. + +.. _RST open_dataset: + +open_dataset +^^^^^^^^^^^^ + +The backend ``open_dataset`` shall implement reading from file, the variables +decoding and it shall instantiate the output Xarray class :py:class:`~xarray.Dataset`. + +The following is an example of the high level processing steps: + +.. code-block:: python + + def open_dataset( + self, + filename_or_obj, + *, + drop_variables=None, + decode_times=True, + decode_timedelta=True, + decode_coords=True, + my_backend_option=None, + ): + vars, attrs, coords = my_reader( + filename_or_obj, + drop_variables=drop_variables, + my_backend_option=my_backend_option, + ) + vars, attrs, coords = my_decode_variables( + vars, attrs, decode_times, decode_timedelta, decode_coords + ) # see also conventions.decode_cf_variables + + ds = xr.Dataset(vars, attrs=attrs, coords=coords) + ds.set_close(my_close_method) + + return ds + + +The output :py:class:`~xarray.Dataset` shall implement the additional custom method +``close``, used by Xarray to ensure the related files are eventually closed. This +method shall be set by using :py:meth:`~xarray.Dataset.set_close`. + + +The input of ``open_dataset`` method are one argument +(``filename_or_obj``) and one keyword argument (``drop_variables``): + +- ``filename_or_obj``: can be any object but usually it is a string containing a path or an instance of + :py:class:`pathlib.Path`. +- ``drop_variables``: can be `None` or an iterable containing the variable + names to be dropped when reading the data. + +If it makes sense for your backend, your ``open_dataset`` method +should implement in its interface the following boolean keyword arguments, called +**decoders**, which default to ``None``: + +- ``mask_and_scale`` +- ``decode_times`` +- ``decode_timedelta`` +- ``use_cftime`` +- ``concat_characters`` +- ``decode_coords`` + +Note: all the supported decoders shall be declared explicitly +in backend ``open_dataset`` signature and adding a ``**kwargs`` is not allowed. + +These keyword arguments are explicitly defined in Xarray +:py:func:`~xarray.open_dataset` signature. Xarray will pass them to the +backend only if the User explicitly sets a value different from ``None``. +For more details on decoders see :ref:`RST decoders`. + +Your backend can also take as input a set of backend-specific keyword +arguments. All these keyword arguments can be passed to +:py:func:`~xarray.open_dataset` grouped either via the ``backend_kwargs`` +parameter or explicitly using the syntax ``**kwargs``. + + +If you don't want to support the lazy loading, then the +:py:class:`~xarray.Dataset` shall contain values as a :py:class:`numpy.ndarray` +and your work is almost done. + +.. _RST open_dataset_parameters: + +open_dataset_parameters +^^^^^^^^^^^^^^^^^^^^^^^ + +``open_dataset_parameters`` is the list of backend ``open_dataset`` parameters. +It is not a mandatory parameter, and if the backend does not provide it +explicitly, Xarray creates a list of them automatically by inspecting the +backend signature. + +If ``open_dataset_parameters`` is not defined, but ``**kwargs`` and ``*args`` +are in the backend ``open_dataset`` signature, Xarray raises an error. +On the other hand, if the backend provides the ``open_dataset_parameters``, +then ``**kwargs`` and ``*args`` can be used in the signature. +However, this practice is discouraged unless there is a good reasons for using +``**kwargs`` or ``*args``. + +.. _RST guess_can_open: + +guess_can_open +^^^^^^^^^^^^^^ + +``guess_can_open`` is used to identify the proper engine to open your data +file automatically in case the engine is not specified explicitly. If you are +not interested in supporting this feature, you can skip this step since +:py:class:`~xarray.backends.BackendEntrypoint` already provides a +default :py:meth:`~xarray.backends.BackendEntrypoint.guess_can_open` +that always returns ``False``. + +Backend ``guess_can_open`` takes as input the ``filename_or_obj`` parameter of +Xarray :py:meth:`~xarray.open_dataset`, and returns a boolean. + +.. _RST properties: + +description and url +^^^^^^^^^^^^^^^^^^^^ + +``description`` is used to provide a short text description of the backend. +``url`` is used to include a link to the backend's documentation or code. + +These attributes are surfaced when a user prints :py:class:`~xarray.backends.BackendEntrypoint`. +If ``description`` or ``url`` are not defined, an empty string is returned. + +.. _RST decoders: + +Decoders +^^^^^^^^ + +The decoders implement specific operations to transform data from on-disk +representation to Xarray representation. + +A classic example is the “time” variable decoding operation. In NetCDF, the +elements of the “time” variable are stored as integers, and the unit contains +an origin (for example: "seconds since 1970-1-1"). In this case, Xarray +transforms the pair integer-unit in a :py:class:`numpy.datetime64`. + +The standard coders implemented in Xarray are: + +- :py:class:`xarray.coding.strings.CharacterArrayCoder()` +- :py:class:`xarray.coding.strings.EncodedStringCoder()` +- :py:class:`xarray.coding.variables.UnsignedIntegerCoder()` +- :py:class:`xarray.coding.variables.CFMaskCoder()` +- :py:class:`xarray.coding.variables.CFScaleOffsetCoder()` +- :py:class:`xarray.coding.times.CFTimedeltaCoder()` +- :py:class:`xarray.coding.times.CFDatetimeCoder()` + +Xarray coders all have the same interface. They have two methods: ``decode`` +and ``encode``. The method ``decode`` takes a ``Variable`` in on-disk +format and returns a ``Variable`` in Xarray format. Variable +attributes no more applicable after the decoding, are dropped and stored in the +``Variable.encoding`` to make them available to the ``encode`` method, which +performs the inverse transformation. + +In the following an example on how to use the coders ``decode`` method: + +.. ipython:: python + :suppress: + + import xarray as xr + +.. ipython:: python + + var = xr.Variable( + dims=("x",), data=np.arange(10.0), attrs={"scale_factor": 10, "add_offset": 2} + ) + var + + coder = xr.coding.variables.CFScaleOffsetCoder() + decoded_var = coder.decode(var) + decoded_var + decoded_var.encoding + +Some of the transformations can be common to more backends, so before +implementing a new decoder, be sure Xarray does not already implement that one. + +The backends can reuse Xarray’s decoders, either instantiating the coders +and using the method ``decode`` directly or using the higher-level function +:py:func:`~xarray.conventions.decode_cf_variables` that groups Xarray decoders. + +In some cases, the transformation to apply strongly depends on the on-disk +data format. Therefore, you may need to implement your own decoder. + +An example of such a case is when you have to deal with the time format of a +grib file. grib format is very different from the NetCDF one: in grib, the +time is stored in two attributes dataDate and dataTime as strings. Therefore, +it is not possible to reuse the Xarray time decoder, and implementing a new +one is mandatory. + +Decoders can be activated or deactivated using the boolean keywords of +Xarray :py:meth:`~xarray.open_dataset` signature: ``mask_and_scale``, +``decode_times``, ``decode_timedelta``, ``use_cftime``, +``concat_characters``, ``decode_coords``. +Such keywords are passed to the backend only if the User sets a value +different from ``None``. Note that the backend does not necessarily have to +implement all the decoders, but it shall declare in its ``open_dataset`` +interface only the boolean keywords related to the supported decoders. + +.. _RST backend_registration: + +How to register a backend ++++++++++++++++++++++++++ + +Define a new entrypoint in your ``pyproject.toml`` (or ``setup.cfg/setup.py`` for older +configurations), with: + +- group: ``xarray.backends`` +- name: the name to be passed to :py:meth:`~xarray.open_dataset` as ``engine`` +- object reference: the reference of the class that you have implemented. + +You can declare the entrypoint in your project configuration like so: + +.. tab:: pyproject.toml + + .. code:: toml + + [project.entry-points."xarray.backends"] + my_engine = "my_package.my_module:MyBackendEntrypoint" + +.. tab:: pyproject.toml [Poetry] + + .. code-block:: toml + + [tool.poetry.plugins."xarray.backends"] + my_engine = "my_package.my_module:MyBackendEntrypoint" + +.. tab:: setup.cfg + + .. code-block:: cfg + + [options.entry_points] + xarray.backends = + my_engine = my_package.my_module:MyBackendEntrypoint + +.. tab:: setup.py + + .. code-block:: + + setuptools.setup( + entry_points={ + "xarray.backends": [ + "my_engine=my_package.my_module:MyBackendEntrypoint" + ], + }, + ) + + +See the `Python Packaging User Guide +`_ for more +information on entrypoints and details of the syntax. + +If you're using Poetry, note that table name in ``pyproject.toml`` is slightly different. +See `the Poetry docs `_ for more +information on plugins. + +.. _RST lazy_loading: + +How to support lazy loading ++++++++++++++++++++++++++++ + +If you want to make your backend effective with big datasets, then you should +support lazy loading. +Basically, you shall replace the :py:class:`numpy.ndarray` inside the +variables with a custom class that supports lazy loading indexing. +See the example below: + +.. code-block:: python + + backend_array = MyBackendArray() + data = indexing.LazilyIndexedArray(backend_array) + var = xr.Variable(dims, data, attrs=attrs, encoding=encoding) + +Where: + +- :py:class:`~xarray.core.indexing.LazilyIndexedArray` is a class + provided by Xarray that manages the lazy loading. +- ``MyBackendArray`` shall be implemented by the backend and shall inherit + from :py:class:`~xarray.backends.BackendArray`. + +BackendArray subclassing +^^^^^^^^^^^^^^^^^^^^^^^^ + +The BackendArray subclass shall implement the following method and attributes: + +- the ``__getitem__`` method that takes in input an index and returns a + `NumPy `__ array +- the ``shape`` attribute +- the ``dtype`` attribute. + +Xarray supports different type of :doc:`/user-guide/indexing`, that can be +grouped in three types of indexes +:py:class:`~xarray.core.indexing.BasicIndexer`, +:py:class:`~xarray.core.indexing.OuterIndexer` and +:py:class:`~xarray.core.indexing.VectorizedIndexer`. +This implies that the implementation of the method ``__getitem__`` can be tricky. +In order to simplify this task, Xarray provides a helper function, +:py:func:`~xarray.core.indexing.explicit_indexing_adapter`, that transforms +all the input ``indexer`` types (`basic`, `outer`, `vectorized`) in a tuple +which is interpreted correctly by your backend. + +This is an example ``BackendArray`` subclass implementation: + +.. code-block:: python + + from xarray.backends import BackendArray + + + class MyBackendArray(BackendArray): + def __init__( + self, + shape, + dtype, + lock, + # other backend specific keyword arguments + ): + self.shape = shape + self.dtype = dtype + self.lock = lock + + def __getitem__( + self, key: xarray.core.indexing.ExplicitIndexer + ) -> np.typing.ArrayLike: + return indexing.explicit_indexing_adapter( + key, + self.shape, + indexing.IndexingSupport.BASIC, + self._raw_indexing_method, + ) + + def _raw_indexing_method(self, key: tuple) -> np.typing.ArrayLike: + # thread safe method that access to data on disk + with self.lock: + ... + return item + +Note that ``BackendArray.__getitem__`` must be thread safe to support +multi-thread processing. + +The :py:func:`~xarray.core.indexing.explicit_indexing_adapter` method takes in +input the ``key``, the array ``shape`` and the following parameters: + +- ``indexing_support``: the type of index supported by ``raw_indexing_method`` +- ``raw_indexing_method``: a method that shall take in input a key in the form + of a tuple and return an indexed :py:class:`numpy.ndarray`. + +For more details see +:py:class:`~xarray.core.indexing.IndexingSupport` and :ref:`RST indexing`. + +In order to support `Dask Distributed `__ and +:py:mod:`multiprocessing`, ``BackendArray`` subclass should be serializable +either with :ref:`io.pickle` or +`cloudpickle `__. +That implies that all the reference to open files should be dropped. For +opening files, we therefore suggest to use the helper class provided by Xarray +:py:class:`~xarray.backends.CachingFileManager`. + +.. _RST indexing: + +Indexing examples +^^^^^^^^^^^^^^^^^ + +**BASIC** + +In the ``BASIC`` indexing support, numbers and slices are supported. + +Example: + +.. ipython:: + :verbatim: + + In [1]: # () shall return the full array + ...: backend_array._raw_indexing_method(()) + Out[1]: array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]) + + In [2]: # shall support integers + ...: backend_array._raw_indexing_method(1, 1) + Out[2]: 5 + + In [3]: # shall support slices + ...: backend_array._raw_indexing_method(slice(0, 3), slice(2, 4)) + Out[3]: array([[2, 3], [6, 7], [10, 11]]) + +**OUTER** + +The ``OUTER`` indexing shall support number, slices and in addition it shall +support also lists of integers. The the outer indexing is equivalent to +combining multiple input list with ``itertools.product()``: + +.. ipython:: + :verbatim: + + In [1]: backend_array._raw_indexing_method([0, 1], [0, 1, 2]) + Out[1]: array([[0, 1, 2], [4, 5, 6]]) + + # shall support integers + In [2]: backend_array._raw_indexing_method(1, 1) + Out[2]: 5 + + +**OUTER_1VECTOR** + +The ``OUTER_1VECTOR`` indexing shall supports number, slices and at most one +list. The behaviour with the list shall be the same of ``OUTER`` indexing. + +If you support more complex indexing as `explicit indexing` or +`numpy indexing`, you can have a look to the implementation of Zarr backend and Scipy backend, +currently available in :py:mod:`~xarray.backends` module. + +.. _RST preferred_chunks: + +Preferred chunk sizes +^^^^^^^^^^^^^^^^^^^^^ + +To potentially improve performance with lazy loading, the backend may define for each +variable the chunk sizes that it prefers---that is, sizes that align with how the +variable is stored. (Note that the backend is not directly involved in `Dask +`__ chunking, because Xarray internally manages chunking.) To define +the preferred chunk sizes, store a mapping within the variable's encoding under the key +``"preferred_chunks"`` (that is, ``var.encoding["preferred_chunks"]``). The mapping's +keys shall be the names of dimensions with preferred chunk sizes, and each value shall +be the corresponding dimension's preferred chunk sizes expressed as either an integer +(such as ``{"dim1": 1000, "dim2": 2000}``) or a tuple of integers (such as ``{"dim1": +(1000, 100), "dim2": (2000, 2000, 2000)}``). + +Xarray uses the preferred chunk sizes in some special cases of the ``chunks`` argument +of the :py:func:`~xarray.open_dataset` and :py:func:`~xarray.open_mfdataset` functions. +If ``chunks`` is a ``dict``, then for any dimensions missing from the keys or whose +value is ``None``, Xarray sets the chunk sizes to the preferred sizes. If ``chunks`` +equals ``"auto"``, then Xarray seeks ideal chunk sizes informed by the preferred chunk +sizes. Specifically, it determines the chunk sizes using +:py:func:`dask.array.core.normalize_chunks` with the ``previous_chunks`` argument set +according to the preferred chunk sizes. diff --git a/test/fixtures/whole_applications/xarray/doc/internals/how-to-create-custom-index.rst b/test/fixtures/whole_applications/xarray/doc/internals/how-to-create-custom-index.rst new file mode 100644 index 0000000..90b3412 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/internals/how-to-create-custom-index.rst @@ -0,0 +1,235 @@ +.. currentmodule:: xarray + +.. _internals.custom indexes: + +How to create a custom index +============================ + +.. warning:: + + This feature is highly experimental. Support for custom indexes has been + introduced in v2022.06.0 and is still incomplete. API is subject to change + without deprecation notice. However we encourage you to experiment and report issues that arise. + +Xarray's built-in support for label-based indexing (e.g. `ds.sel(latitude=40, method="nearest")`) and alignment operations +relies on :py:class:`pandas.Index` objects. Pandas Indexes are powerful and suitable for many +applications but also have some limitations: + +- it only works with 1-dimensional coordinates where explicit labels + are fully loaded in memory +- it is hard to reuse it with irregular data for which there exist more + efficient, tree-based structures to perform data selection +- it doesn't support extra metadata that may be required for indexing and + alignment (e.g., a coordinate reference system) + +Fortunately, Xarray now allows extending this functionality with custom indexes, +which can be implemented in 3rd-party libraries. + +The Index base class +-------------------- + +Every Xarray index must inherit from the :py:class:`Index` base class. It is for +example the case of Xarray built-in ``PandasIndex`` and ``PandasMultiIndex`` +subclasses, which wrap :py:class:`pandas.Index` and +:py:class:`pandas.MultiIndex` respectively. + +The ``Index`` API closely follows the :py:class:`Dataset` and +:py:class:`DataArray` API, e.g., for an index to support :py:meth:`DataArray.sel` it needs to +implement :py:meth:`Index.sel`, to support :py:meth:`DataArray.stack` and :py:meth:`DataArray.unstack` it +needs to implement :py:meth:`Index.stack` and :py:meth:`Index.unstack`, etc. + +Some guidelines and examples are given below. More details can be found in the +documented :py:class:`Index` API. + +Minimal requirements +-------------------- + +Every index must at least implement the :py:meth:`Index.from_variables` class +method, which is used by Xarray to build a new index instance from one or more +existing coordinates in a Dataset or DataArray. + +Since any collection of coordinates can be passed to that method (i.e., the +number, order and dimensions of the coordinates are all arbitrary), it is the +responsibility of the index to check the consistency and validity of those input +coordinates. + +For example, :py:class:`~xarray.core.indexes.PandasIndex` accepts only one coordinate and +:py:class:`~xarray.core.indexes.PandasMultiIndex` accepts one or more 1-dimensional coordinates that must all +share the same dimension. Other, custom indexes need not have the same +constraints, e.g., + +- a georeferenced raster index which only accepts two 1-d coordinates with + distinct dimensions +- a staggered grid index which takes coordinates with different dimension name + suffixes (e.g., "_c" and "_l" for center and left) + +Optional requirements +--------------------- + +Pretty much everything else is optional. Depending on the method, in the absence +of a (re)implementation, an index will either raise a `NotImplementedError` +or won't do anything specific (just drop, pass or copy itself +from/to the resulting Dataset or DataArray). + +For example, you can just skip re-implementing :py:meth:`Index.rename` if there +is no internal attribute or object to rename according to the new desired +coordinate or dimension names. In the case of ``PandasIndex``, we rename the +underlying ``pandas.Index`` object and/or update the ``PandasIndex.dim`` +attribute since the associated dimension name has been changed. + +Wrap index data as coordinate data +---------------------------------- + +In some cases it is possible to reuse the index's underlying object or structure +as coordinate data and hence avoid data duplication. + +For ``PandasIndex`` and ``PandasMultiIndex``, we +leverage the fact that ``pandas.Index`` objects expose some array-like API. In +Xarray we use some wrappers around those underlying objects as a thin +compatibility layer to preserve dtypes, handle explicit and n-dimensional +indexing, etc. + +Other structures like tree-based indexes (e.g., kd-tree) may differ too much +from arrays to reuse it as coordinate data. + +If the index data can be reused as coordinate data, the ``Index`` subclass +should implement :py:meth:`Index.create_variables`. This method accepts a +dictionary of variable names as keys and :py:class:`Variable` objects as values (used for propagating +variable metadata) and should return a dictionary of new :py:class:`Variable` or +:py:class:`IndexVariable` objects. + +Data selection +-------------- + +For an index to support label-based selection, it needs to at least implement +:py:meth:`Index.sel`. This method accepts a dictionary of labels where the keys +are coordinate names (already filtered for the current index) and the values can +be pretty much anything (e.g., a slice, a tuple, a list, a numpy array, a +:py:class:`Variable` or a :py:class:`DataArray`). It is the responsibility of +the index to properly handle those input labels. + +:py:meth:`Index.sel` must return an instance of :py:class:`IndexSelResult`. The +latter is a small data class that holds positional indexers (indices) and that +may also hold new variables, new indexes, names of variables or indexes to drop, +names of dimensions to rename, etc. For example, this is useful in the case of +``PandasMultiIndex`` as it allows Xarray to convert it into a single ``PandasIndex`` +when only one level remains after the selection. + +The :py:class:`IndexSelResult` class is also used to merge results from label-based +selection performed by different indexes. Note that it is now possible to have +two distinct indexes for two 1-d coordinates sharing the same dimension, but it +is not currently possible to use those two indexes in the same call to +:py:meth:`Dataset.sel`. + +Optionally, the index may also implement :py:meth:`Index.isel`. In the case of +``PandasIndex`` we use it to create a new index object by just indexing the +underlying ``pandas.Index`` object. In other cases this may not be possible, +e.g., a kd-tree object may not be easily indexed. If ``Index.isel()`` is not +implemented, the index in just dropped in the DataArray or Dataset resulting +from the selection. + +Alignment +--------- + +For an index to support alignment, it needs to implement: + +- :py:meth:`Index.equals`, which compares the index with another index and + returns either ``True`` or ``False`` +- :py:meth:`Index.join`, which combines the index with another index and returns + a new Index object +- :py:meth:`Index.reindex_like`, which queries the index with another index and + returns positional indexers that are used to re-index Dataset or DataArray + variables along one or more dimensions + +Xarray ensures that those three methods are called with an index of the same +type as argument. + +Meta-indexes +------------ + +Nothing prevents writing a custom Xarray index that itself encapsulates other +Xarray index(es). We call such index a "meta-index". + +Here is a small example of a meta-index for geospatial, raster datasets (i.e., +regularly spaced 2-dimensional data) that internally relies on two +``PandasIndex`` instances for the x and y dimensions respectively: + +.. code-block:: python + + from xarray import Index + from xarray.core.indexes import PandasIndex + from xarray.core.indexing import merge_sel_results + + + class RasterIndex(Index): + def __init__(self, xy_indexes): + assert len(xy_indexes) == 2 + + # must have two distinct dimensions + dim = [idx.dim for idx in xy_indexes.values()] + assert dim[0] != dim[1] + + self._xy_indexes = xy_indexes + + @classmethod + def from_variables(cls, variables): + assert len(variables) == 2 + + xy_indexes = { + k: PandasIndex.from_variables({k: v}) for k, v in variables.items() + } + + return cls(xy_indexes) + + def create_variables(self, variables): + idx_variables = {} + + for index in self._xy_indexes.values(): + idx_variables.update(index.create_variables(variables)) + + return idx_variables + + def sel(self, labels): + results = [] + + for k, index in self._xy_indexes.items(): + if k in labels: + results.append(index.sel({k: labels[k]})) + + return merge_sel_results(results) + + +This basic index only supports label-based selection. Providing a full-featured +index by implementing the other ``Index`` methods should be pretty +straightforward for this example, though. + +This example is also not very useful unless we add some extra functionality on +top of the two encapsulated ``PandasIndex`` objects, such as a coordinate +reference system. + +How to use a custom index +------------------------- + +You can use :py:meth:`Dataset.set_xindex` or :py:meth:`DataArray.set_xindex` to assign a +custom index to a Dataset or DataArray, e.g., using the ``RasterIndex`` above: + +.. code-block:: python + + import numpy as np + import xarray as xr + + da = xr.DataArray( + np.random.uniform(size=(100, 50)), + coords={"x": ("x", np.arange(50)), "y": ("y", np.arange(100))}, + dims=("y", "x"), + ) + + # Xarray create default indexes for the 'x' and 'y' coordinates + # we first need to explicitly drop it + da = da.drop_indexes(["x", "y"]) + + # Build a RasterIndex from the 'x' and 'y' coordinates + da_raster = da.set_xindex(["x", "y"], RasterIndex) + + # RasterIndex now takes care of label-based selection + selected = da_raster.sel(x=10, y=slice(20, 50)) diff --git a/test/fixtures/whole_applications/xarray/doc/internals/index.rst b/test/fixtures/whole_applications/xarray/doc/internals/index.rst new file mode 100644 index 0000000..b2a3790 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/internals/index.rst @@ -0,0 +1,28 @@ +.. _internals: + +Xarray Internals +================ + +Xarray builds upon two of the foundational libraries of the scientific Python +stack, NumPy and pandas. It is written in pure Python (no C or Cython +extensions), which makes it easy to develop and extend. Instead, we push +compiled code to :ref:`optional dependencies`. + +The pages in this section are intended for: + +* Contributors to xarray who wish to better understand some of the internals, +* Developers from other fields who wish to extend xarray with domain-specific logic, perhaps to support a new scientific community of users, +* Developers of other packages who wish to interface xarray with their existing tools, e.g. by creating a backend for reading a new file format, or wrapping a custom array type. + +.. toctree:: + :maxdepth: 2 + :hidden: + + internal-design + interoperability + duck-arrays-integration + chunked-arrays + extending-xarray + how-to-add-new-backend + how-to-create-custom-index + zarr-encoding-spec diff --git a/test/fixtures/whole_applications/xarray/doc/internals/internal-design.rst b/test/fixtures/whole_applications/xarray/doc/internals/internal-design.rst new file mode 100644 index 0000000..55ab2d7 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/internals/internal-design.rst @@ -0,0 +1,224 @@ +.. ipython:: python + :suppress: + + import numpy as np + import pandas as pd + import xarray as xr + + np.random.seed(123456) + np.set_printoptions(threshold=20) + +.. _internal design: + +Internal Design +=============== + +This page gives an overview of the internal design of xarray. + +In totality, the Xarray project defines 4 key data structures. +In order of increasing complexity, they are: + +- :py:class:`xarray.Variable`, +- :py:class:`xarray.DataArray`, +- :py:class:`xarray.Dataset`, +- :py:class:`datatree.DataTree`. + +The user guide lists only :py:class:`xarray.DataArray` and :py:class:`xarray.Dataset`, +but :py:class:`~xarray.Variable` is the fundamental object internally, +and :py:class:`~datatree.DataTree` is a natural generalisation of :py:class:`xarray.Dataset`. + +.. note:: + + Our :ref:`roadmap` includes plans both to document :py:class:`~xarray.Variable` as fully public API, + and to merge the `xarray-datatree `_ package into xarray's main repository. + +Internally private :ref:`lazy indexing classes ` are used to avoid loading more data than necessary, +and flexible indexes classes (derived from :py:class:`~xarray.indexes.Index`) provide performant label-based lookups. + + +.. _internal design.data structures: + +Data Structures +--------------- + +The :ref:`data structures` page in the user guide explains the basics and concentrates on user-facing behavior, +whereas this section explains how xarray's data structure classes actually work internally. + + +.. _internal design.data structures.variable: + +Variable Objects +~~~~~~~~~~~~~~~~ + +The core internal data structure in xarray is the :py:class:`~xarray.Variable`, +which is used as the basic building block behind xarray's +:py:class:`~xarray.Dataset`, :py:class:`~xarray.DataArray` types. A +:py:class:`~xarray.Variable` consists of: + +- ``dims``: A tuple of dimension names. +- ``data``: The N-dimensional array (typically a NumPy or Dask array) storing + the Variable's data. It must have the same number of dimensions as the length + of ``dims``. +- ``attrs``: A dictionary of metadata associated with this array. By + convention, xarray's built-in operations never use this metadata. +- ``encoding``: Another dictionary used to store information about how + these variable's data is represented on disk. See :ref:`io.encoding` for more + details. + +:py:class:`~xarray.Variable` has an interface similar to NumPy arrays, but extended to make use +of named dimensions. For example, it uses ``dim`` in preference to an ``axis`` +argument for methods like ``mean``, and supports :ref:`compute.broadcasting`. + +However, unlike ``Dataset`` and ``DataArray``, the basic ``Variable`` does not +include coordinate labels along each axis. + +:py:class:`~xarray.Variable` is public API, but because of its incomplete support for labeled +data, it is mostly intended for advanced uses, such as in xarray itself, for +writing new backends, or when creating custom indexes. +You can access the variable objects that correspond to xarray objects via the (readonly) +:py:attr:`Dataset.variables ` and +:py:attr:`DataArray.variable ` attributes. + + +.. _internal design.dataarray: + +DataArray Objects +~~~~~~~~~~~~~~~~~ + +The simplest data structure used by most users is :py:class:`~xarray.DataArray`. +A :py:class:`~xarray.DataArray` is a composite object consisting of multiple +:py:class:`~xarray.core.variable.Variable` objects which store related data. + +A single :py:class:`~xarray.core.Variable` is referred to as the "data variable", and stored under the :py:attr:`~xarray.DataArray.variable`` attribute. +A :py:class:`~xarray.DataArray` inherits all of the properties of this data variable, i.e. ``dims``, ``data``, ``attrs`` and ``encoding``, +all of which are implemented by forwarding on to the underlying ``Variable`` object. + +In addition, a :py:class:`~xarray.DataArray` stores additional ``Variable`` objects stored in a dict under the private ``_coords`` attribute, +each of which is referred to as a "Coordinate Variable". These coordinate variable objects are only allowed to have ``dims`` that are a subset of the data variable's ``dims``, +and each dim has a specific length. This means that the full :py:attr:`~xarray.DataArray.size` of the dataarray can be represented by a dictionary mapping dimension names to integer sizes. +The underlying data variable has this exact same size, and the attached coordinate variables have sizes which are some subset of the size of the data variable. +Another way of saying this is that all coordinate variables must be "alignable" with the data variable. + +When a coordinate is accessed by the user (e.g. via the dict-like :py:class:`~xarray.DataArray.__getitem__` syntax), +then a new ``DataArray`` is constructed by finding all coordinate variables that have compatible dimensions and re-attaching them before the result is returned. +This is why most users never see the ``Variable`` class underlying each coordinate variable - it is always promoted to a ``DataArray`` before returning. + +Lookups are performed by special :py:class:`~xarray.indexes.Index` objects, which are stored in a dict under the private ``_indexes`` attribute. +Indexes must be associated with one or more coordinates, and essentially act by translating a query given in physical coordinate space +(typically via the :py:meth:`~xarray.DataArray.sel` method) into a set of integer indices in array index space that can be used to index the underlying n-dimensional array-like ``data``. +Indexing in array index space (typically performed via the :py:meth:`~xarray.DataArray.isel` method) does not require consulting an ``Index`` object. + +Finally a :py:class:`~xarray.DataArray` defines a :py:attr:`~xarray.DataArray.name` attribute, which refers to its data +variable but is stored on the wrapping ``DataArray`` class. +The ``name`` attribute is primarily used when one or more :py:class:`~xarray.DataArray` objects are promoted into a :py:class:`~xarray.Dataset` +(e.g. via :py:meth:`~xarray.DataArray.to_dataset`). +Note that the underlying :py:class:`~xarray.core.Variable` objects are all unnamed, so they can always be referred to uniquely via a +dict-like mapping. + +.. _internal design.dataset: + +Dataset Objects +~~~~~~~~~~~~~~~ + +The :py:class:`~xarray.Dataset` class is a generalization of the :py:class:`~xarray.DataArray` class that can hold multiple data variables. +Internally all data variables and coordinate variables are stored under a single ``variables`` dict, and coordinates are +specified by storing their names in a private ``_coord_names`` dict. + +The dataset's ``dims`` are the set of all dims present across any variable, but (similar to in dataarrays) coordinate +variables cannot have a dimension that is not present on any data variable. + +When a data variable or coordinate variable is accessed, a new ``DataArray`` is again constructed from all compatible +coordinates before returning. + +.. _internal design.subclassing: + +.. note:: + + The way that selecting a variable from a ``DataArray`` or ``Dataset`` actually involves internally wrapping the + ``Variable`` object back up into a ``DataArray``/``Dataset`` is the primary reason :ref:`we recommend against subclassing ` + Xarray objects. The main problem it creates is that we currently cannot easily guarantee that for example selecting + a coordinate variable from your ``SubclassedDataArray`` would return an instance of ``SubclassedDataArray`` instead + of just an :py:class:`xarray.DataArray`. See `GH issue `_ for more details. + +.. _internal design.lazy indexing: + +Lazy Indexing Classes +--------------------- + +Lazy Loading +~~~~~~~~~~~~ + +If we open a ``Variable`` object from disk using :py:func:`~xarray.open_dataset` we can see that the actual values of +the array wrapped by the data variable are not displayed. + +.. ipython:: python + + da = xr.tutorial.open_dataset("air_temperature")["air"] + var = da.variable + var + +We can see the size, and the dtype of the underlying array, but not the actual values. +This is because the values have not yet been loaded. + +If we look at the private attribute :py:meth:`~xarray.Variable._data` containing the underlying array object, we see +something interesting: + +.. ipython:: python + + var._data + +You're looking at one of xarray's internal `Lazy Indexing Classes`. These powerful classes are hidden from the user, +but provide important functionality. + +Calling the public :py:attr:`~xarray.Variable.data` property loads the underlying array into memory. + +.. ipython:: python + + var.data + +This array is now cached, which we can see by accessing the private attribute again: + +.. ipython:: python + + var._data + +Lazy Indexing +~~~~~~~~~~~~~ + +The purpose of these lazy indexing classes is to prevent more data being loaded into memory than is necessary for the +subsequent analysis, by deferring loading data until after indexing is performed. + +Let's open the data from disk again. + +.. ipython:: python + + da = xr.tutorial.open_dataset("air_temperature")["air"] + var = da.variable + +Now, notice how even after subsetting the data has does not get loaded: + +.. ipython:: python + + var.isel(time=0) + +The shape has changed, but the values are still not shown. + +Looking at the private attribute again shows how this indexing information was propagated via the hidden lazy indexing classes: + +.. ipython:: python + + var.isel(time=0)._data + +.. note:: + + Currently only certain indexing operations are lazy, not all array operations. For discussion of making all array + operations lazy see `GH issue #5081 `_. + + +Lazy Dask Arrays +~~~~~~~~~~~~~~~~ + +Note that xarray's implementation of Lazy Indexing classes is completely separate from how :py:class:`dask.array.Array` +objects evaluate lazily. Dask-backed xarray objects delay almost all operations until :py:meth:`~xarray.DataArray.compute` +is called (either explicitly or implicitly via :py:meth:`~xarray.DataArray.plot` for example). The exceptions to this +laziness are operations whose output shape is data-dependent, such as when calling :py:meth:`~xarray.DataArray.where`. diff --git a/test/fixtures/whole_applications/xarray/doc/internals/interoperability.rst b/test/fixtures/whole_applications/xarray/doc/internals/interoperability.rst new file mode 100644 index 0000000..a45363b --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/internals/interoperability.rst @@ -0,0 +1,45 @@ +.. _interoperability: + +Interoperability of Xarray +========================== + +Xarray is designed to be extremely interoperable, in many orthogonal ways. +Making xarray as flexible as possible is the common theme of most of the goals on our :ref:`roadmap`. + +This interoperability comes via a set of flexible abstractions into which the user can plug in. The current full list is: + +- :ref:`Custom file backends ` via the :py:class:`~xarray.backends.BackendEntrypoint` system, +- Numpy-like :ref:`"duck" array wrapping `, which supports the `Python Array API Standard `_, +- :ref:`Chunked distributed array computation ` via the :py:class:`~xarray.core.parallelcompat.ChunkManagerEntrypoint` system, +- Custom :py:class:`~xarray.Index` objects for :ref:`flexible label-based lookups `, +- Extending xarray objects with domain-specific methods via :ref:`custom accessors `. + +.. warning:: + + One obvious way in which xarray could be more flexible is that whilst subclassing xarray objects is possible, we + currently don't support it in most transformations, instead recommending composition over inheritance. See the + :ref:`internal design page ` for the rationale and look at the corresponding `GH issue `_ + if you're interested in improving support for subclassing! + +.. note:: + + If you think there is another way in which xarray could become more generically flexible then please + tell us your ideas by `raising an issue to request the feature `_! + + +Whilst xarray was originally designed specifically to open ``netCDF4`` files as :py:class:`numpy.ndarray` objects labelled by :py:class:`pandas.Index` objects, +it is entirely possible today to: + +- lazily open an xarray object directly from a custom binary file format (e.g. using ``xarray.open_dataset(path, engine='my_custom_format')``, +- handle the data as any API-compliant numpy-like array type (e.g. sparse or GPU-backed), +- distribute out-of-core computation across that array type in parallel (e.g. via :ref:`dask`), +- track the physical units of the data through computations (e.g via `pint-xarray `_), +- query the data via custom index logic optimized for specific applications (e.g. an :py:class:`~xarray.Index` object backed by a KDTree structure), +- attach domain-specific logic via accessor methods (e.g. to understand geographic Coordinate Reference System metadata), +- organize hierarchical groups of xarray data in a :py:class:`~datatree.DataTree` (e.g. to treat heterogeneous simulation and observational data together during analysis). + +All of these features can be provided simultaneously, using libraries compatible with the rest of the scientific python ecosystem. +In this situation xarray would be essentially a thin wrapper acting as pure-python framework, providing a common interface and +separation of concerns via various domain-agnostic abstractions. + +Most of the remaining pages in the documentation of xarray's internals describe these various types of interoperability in more detail. diff --git a/test/fixtures/whole_applications/xarray/doc/internals/zarr-encoding-spec.rst b/test/fixtures/whole_applications/xarray/doc/internals/zarr-encoding-spec.rst new file mode 100644 index 0000000..7f468b8 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/internals/zarr-encoding-spec.rst @@ -0,0 +1,74 @@ +.. currentmodule:: xarray + +.. _zarr_encoding: + +Zarr Encoding Specification +============================ + +In implementing support for the `Zarr `_ storage +format, Xarray developers made some *ad hoc* choices about how to store +NetCDF data in Zarr. +Future versions of the Zarr spec will likely include a more formal convention +for the storage of the NetCDF data model in Zarr; see +`Zarr spec repo `_ for ongoing +discussion. + +First, Xarray can only read and write Zarr groups. There is currently no support +for reading / writing individual Zarr arrays. Zarr groups are mapped to +Xarray ``Dataset`` objects. + +Second, from Xarray's point of view, the key difference between +NetCDF and Zarr is that all NetCDF arrays have *dimension names* while Zarr +arrays do not. Therefore, in order to store NetCDF data in Zarr, Xarray must +somehow encode and decode the name of each array's dimensions. + +To accomplish this, Xarray developers decided to define a special Zarr array +attribute: ``_ARRAY_DIMENSIONS``. The value of this attribute is a list of +dimension names (strings), for example ``["time", "lon", "lat"]``. When writing +data to Zarr, Xarray sets this attribute on all variables based on the variable +dimensions. When reading a Zarr group, Xarray looks for this attribute on all +arrays, raising an error if it can't be found. The attribute is used to define +the variable dimension names and then removed from the attributes dictionary +returned to the user. + +Because of these choices, Xarray cannot read arbitrary array data, but only +Zarr data with valid ``_ARRAY_DIMENSIONS`` or +`NCZarr `_ attributes +on each array (NCZarr dimension names are defined in the ``.zarray`` file). + +After decoding the ``_ARRAY_DIMENSIONS`` or NCZarr attribute and assigning the variable +dimensions, Xarray proceeds to [optionally] decode each variable using its +standard CF decoding machinery used for NetCDF data (see :py:func:`decode_cf`). + +Finally, it's worth noting that Xarray writes (and attempts to read) +"consolidated metadata" by default (the ``.zmetadata`` file), which is another +non-standard Zarr extension, albeit one implemented upstream in Zarr-Python. +You do not need to write consolidated metadata to make Zarr stores readable in +Xarray, but because Xarray can open these stores much faster, users will see a +warning about poor performance when reading non-consolidated stores unless they +explicitly set ``consolidated=False``. See :ref:`io.zarr.consolidated_metadata` +for more details. + +As a concrete example, here we write a tutorial dataset to Zarr and then +re-open it directly with Zarr: + +.. ipython:: python + + import os + import xarray as xr + import zarr + + ds = xr.tutorial.load_dataset("rasm") + ds.to_zarr("rasm.zarr", mode="w") + + zgroup = zarr.open("rasm.zarr") + print(os.listdir("rasm.zarr")) + print(zgroup.tree()) + dict(zgroup["Tair"].attrs) + +.. ipython:: python + :suppress: + + import shutil + + shutil.rmtree("rasm.zarr") diff --git a/test/fixtures/whole_applications/xarray/doc/roadmap.rst b/test/fixtures/whole_applications/xarray/doc/roadmap.rst new file mode 100644 index 0000000..820ff82 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/roadmap.rst @@ -0,0 +1,284 @@ +.. _roadmap: + +Development roadmap +=================== + +Authors: Xarray developers + +Date: September 7, 2021 + +Xarray is an open source Python library for labeled multidimensional +arrays and datasets. + +Our philosophy +-------------- + +Why has xarray been successful? In our opinion: + +- Xarray does a great job of solving **specific use-cases** for + multidimensional data analysis: + + - The dominant use-case for xarray is for analysis of gridded + dataset in the geosciences, e.g., as part of the + `Pangeo `__ project. + - Xarray is also used more broadly in the physical sciences, where + we've found the needs for analyzing multidimensional datasets are + remarkably consistent (e.g., see + `SunPy `__ and + `PlasmaPy `__). + - Finally, xarray is used in a variety of other domains, including + finance, `probabilistic + programming `__ and + genomics. + +- Xarray is also a **domain agnostic** solution: + + - We focus on providing a flexible set of functionality related + labeled multidimensional arrays, rather than solving particular + problems. + - This facilitates collaboration between users with different needs, + and helps us attract a broad community of contributors. + - Importantly, this retains flexibility, for use cases that don't + fit particularly well into existing frameworks. + +- Xarray **integrates well** with other libraries in the scientific + Python stack. + + - We leverage first-class external libraries for core features of + xarray (e.g., NumPy for ndarrays, pandas for indexing, dask for + parallel computing) + - We expose our internal abstractions to users (e.g., + ``apply_ufunc()``), which facilitates extending xarray in various + ways. + +Together, these features have made xarray a first-class choice for +labeled multidimensional arrays in Python. + +We want to double-down on xarray's strengths by making it an even more +flexible and powerful tool for multidimensional data analysis. We want +to continue to engage xarray's core geoscience users, and to also reach +out to new domains to learn from other successful data models like those +of `yt `__ or the `OLAP +cube `__. + +Specific needs +-------------- + +The user community has voiced a number specific needs related to how +xarray interfaces with domain specific problems. Xarray may not solve +all of these issues directly, but these areas provide opportunities for +xarray to provide better, more extensible, interfaces. Some examples of +these common needs are: + +- Non-regular grids (e.g., staggered and unstructured meshes). +- Physical units. +- Lazily computed arrays (e.g., for coordinate systems). +- New file-formats. + +Technical vision +---------------- + +We think the right approach to extending xarray's user community and the +usefulness of the project is to focus on improving key interfaces that +can be used externally to meet domain-specific needs. + +We can generalize the community's needs into three main categories: + +- More flexible grids/indexing. +- More flexible arrays/computing. +- More flexible storage backends. +- More flexible data structures. + +Each of these are detailed further in the subsections below. + +Flexible indexes +~~~~~~~~~~~~~~~~ + +.. note:: + Work on flexible grids and indexes is currently underway. See + `GH Project #1 `__ for more detail. + +Xarray currently keeps track of indexes associated with coordinates by +storing them in the form of a ``pandas.Index`` in special +``xarray.IndexVariable`` objects. + +The limitations of this model became clear with the addition of +``pandas.MultiIndex`` support in xarray 0.9, where a single index +corresponds to multiple xarray variables. MultiIndex support is highly +useful, but xarray now has numerous special cases to check for +MultiIndex levels. + +A cleaner model would be to elevate ``indexes`` to an explicit part of +xarray's data model, e.g., as attributes on the ``Dataset`` and +``DataArray`` classes. Indexes would need to be propagated along with +coordinates in xarray operations, but will no longer would need to have +a one-to-one correspondence with coordinate variables. Instead, an index +should be able to refer to multiple (possibly multidimensional) +coordinates that define it. See :issue:`1603` for full details. + +Specific tasks: + +- Add an ``indexes`` attribute to ``xarray.Dataset`` and + ``xarray.Dataset``, as dictionaries that map from coordinate names to + xarray index objects. +- Use the new index interface to write wrappers for ``pandas.Index``, + ``pandas.MultiIndex`` and ``scipy.spatial.KDTree``. +- Expose the interface externally to allow third-party libraries to + implement custom indexing routines, e.g., for geospatial look-ups on + the surface of the Earth. + +In addition to the new features it directly enables, this clean up will +allow xarray to more easily implement some long-awaited features that +build upon indexing, such as groupby operations with multiple variables. + +Flexible arrays +~~~~~~~~~~~~~~~ + +.. note:: + Work on flexible arrays is currently underway. See + `GH Project #2 `__ for more detail. + +Xarray currently supports wrapping multidimensional arrays defined by +NumPy, dask and to a limited-extent pandas. It would be nice to have +interfaces that allow xarray to wrap alternative N-D array +implementations, e.g.: + +- Arrays holding physical units. +- Lazily computed arrays. +- Other ndarray objects, e.g., sparse, xnd, xtensor. + +Our strategy has been to pursue upstream improvements in NumPy (see +`NEP-22 `__) +for supporting a complete duck-typing interface using with NumPy's +higher level array API. Improvements in NumPy's support for custom data +types would also be highly useful for xarray users. + +By pursuing these improvements in NumPy we hope to extend the benefits +to the full scientific Python community, and avoid tight coupling +between xarray and specific third-party libraries (e.g., for +implementing units). This will allow xarray to maintain its domain +agnostic strengths. + +We expect that we may eventually add some minimal interfaces in xarray +for features that we delegate to external array libraries (e.g., for +getting units and changing units). If we do add these features, we +expect them to be thin wrappers, with core functionality implemented by +third-party libraries. + +Flexible storage +~~~~~~~~~~~~~~~~ + +.. note:: + Work on flexible storage backends is currently underway. See + `GH Project #3 `__ for more detail. + +The xarray backends module has grown in size and complexity. Much of +this growth has been "organic" and mostly to support incremental +additions to the supported backends. This has left us with a fragile +internal API that is difficult for even experienced xarray developers to +use. Moreover, the lack of a public facing API for building xarray +backends means that users can not easily build backend interface for +xarray in third-party libraries. + +The idea of refactoring the backends API and exposing it to users was +originally proposed in :issue:`1970`. The idea would be to develop a +well tested and generic backend base class and associated utilities +for external use. Specific tasks for this development would include: + +- Exposing an abstract backend for writing new storage systems. +- Exposing utilities for features like automatic closing of files, + LRU-caching and explicit/lazy indexing. +- Possibly moving some infrequently used backends to third-party + packages. + +Flexible data structures +~~~~~~~~~~~~~~~~~~~~~~~~ + +Xarray provides two primary data structures, the ``xarray.DataArray`` and +the ``xarray.Dataset``. This section describes two possible data model +extensions. + +Tree-like data structure +++++++++++++++++++++++++ + +.. note:: + Work on developing a hierarchical data structure in xarray is just + beginning. See `Datatree `__ + for an early prototype. + +Xarray’s highest-level object is currently an ``xarray.Dataset``, whose data +model echoes that of a single netCDF group. However real-world datasets are +often better represented by a collection of related Datasets. Particular common +examples include: + +- Multi-resolution datasets, +- Collections of time series datasets with differing lengths, +- Heterogeneous datasets comprising multiple different types of related + observational or simulation data, +- Bayesian workflows involving various statistical distributions over multiple + variables, +- Whole netCDF files containing multiple groups. +- Comparison of output from many similar models (such as in the IPCC's Coupled Model Intercomparison Projects) + +A new tree-like data structure which is essentially a structured hierarchical +collection of Datasets could represent these cases, and would instead map to +multiple netCDF groups (see :issue:`4118`). + +Currently there are several libraries which have wrapped xarray in order to build +domain-specific data structures (e.g. `xarray-multiscale `__.), +but a general ``xarray.DataTree`` object would obviate the need for these and] +consolidate effort in a single domain-agnostic tool, much as xarray has already achieved. + +Labeled array without coordinates ++++++++++++++++++++++++++++++++++ + +There is a need for a lightweight array structure with named dimensions for +convenient indexing and broadcasting. Xarray includes such a structure internally +(``xarray.Variable``). We want to factor out xarray's “Variable” object into a +standalone package with minimal dependencies for integration with libraries that +don't want to inherit xarray's dependency on pandas (e.g. scikit-learn). +The new “Variable” class will follow established array protocols and the new +data-apis standard. It will be capable of wrapping multiple array-like objects +(e.g. NumPy, Dask, Sparse, Pint, CuPy, Pytorch). While “DataArray” fits some of +these requirements, it offers a more complex data model than is desired for +many applications and depends on pandas. + +Engaging more users +------------------- + +.. note:: + Work on improving xarray’s documentation and user engagement is + currently underway. See `GH Project #4 `__ + for more detail. + +Like many open-source projects, the documentation of xarray has grown +together with the library's features. While we think that the xarray +documentation is comprehensive already, we acknowledge that the adoption +of xarray might be slowed down because of the substantial time +investment required to learn its working principles. In particular, +non-computer scientists or users less familiar with the pydata ecosystem +might find it difficult to learn xarray and realize how xarray can help +them in their daily work. + +In order to lower this adoption barrier, we propose to: + +- Develop entry-level tutorials for users with different backgrounds. For + example, we would like to develop tutorials for users with or without + previous knowledge of pandas, NumPy, netCDF, etc. These tutorials may be + built as part of xarray's documentation or included in a separate repository + to enable interactive use (e.g. mybinder.org). +- Document typical user workflows in a dedicated website, following the example + of `dask-stories + `__. +- Write a basic glossary that defines terms that might not be familiar to all + (e.g. "lazy", "labeled", "serialization", "indexing", "backend"). + + +Administrative +-------------- + +NumFOCUS +~~~~~~~~ + +On July 16, 2018, Joe and Stephan submitted xarray's fiscal sponsorship +application to NumFOCUS. diff --git a/test/fixtures/whole_applications/xarray/doc/tutorials-and-videos.rst b/test/fixtures/whole_applications/xarray/doc/tutorials-and-videos.rst new file mode 100644 index 0000000..7a9e524 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/tutorials-and-videos.rst @@ -0,0 +1,32 @@ + +Tutorials and Videos +==================== + + +Tutorials +---------- + +- `Xarray's Tutorials`_ repository +- The `UW eScience Institute's Geohackweek`_ tutorial on xarray for geospatial data scientists. +- `Nicolas Fauchereau's 2015 tutorial`_ on xarray for netCDF users. + + + +Videos +------- + +.. include:: videos-gallery.txt + + +Books, Chapters and Articles +----------------------------- + +- Stephan Hoyer and Joe Hamman's `Journal of Open Research Software paper`_ describing the xarray project. + + +.. _Xarray's Tutorials: https://xarray-contrib.github.io/xarray-tutorial/ +.. _Journal of Open Research Software paper: https://doi.org/10.5334/jors.148 +.. _UW eScience Institute's Geohackweek : https://geohackweek.github.io/nDarrays/ +.. _tutorial: https://github.com/Unidata/unidata-users-workshop/blob/master/notebooks/xray-tutorial.ipynb +.. _with answers: https://github.com/Unidata/unidata-users-workshop/blob/master/notebooks/xray-tutorial-with-answers.ipynb +.. _Nicolas Fauchereau's 2015 tutorial: https://nbviewer.iPython.org/github/nicolasfauchereau/metocean/blob/master/notebooks/xray.ipynb diff --git a/test/fixtures/whole_applications/xarray/doc/user-guide/combining.rst b/test/fixtures/whole_applications/xarray/doc/user-guide/combining.rst new file mode 100644 index 0000000..1dad200 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/user-guide/combining.rst @@ -0,0 +1,308 @@ +.. _combining data: + +Combining data +-------------- + +.. ipython:: python + :suppress: + + import numpy as np + import pandas as pd + import xarray as xr + + np.random.seed(123456) + +* For combining datasets or data arrays along a single dimension, see concatenate_. +* For combining datasets with different variables, see merge_. +* For combining datasets or data arrays with different indexes or missing values, see combine_. +* For combining datasets or data arrays along multiple dimensions see combining.multi_. + +.. _concatenate: + +Concatenate +~~~~~~~~~~~ + +To combine :py:class:`~xarray.Dataset`s / :py:class:`~xarray.DataArray`s along an existing or new dimension +into a larger object, you can use :py:func:`~xarray.concat`. ``concat`` +takes an iterable of ``DataArray`` or ``Dataset`` objects, as well as a +dimension name, and concatenates along that dimension: + +.. ipython:: python + + da = xr.DataArray( + np.arange(6).reshape(2, 3), [("x", ["a", "b"]), ("y", [10, 20, 30])] + ) + da.isel(y=slice(0, 1)) # same as da[:, :1] + # This resembles how you would use np.concatenate: + xr.concat([da[:, :1], da[:, 1:]], dim="y") + # For more friendly pandas-like indexing you can use: + xr.concat([da.isel(y=slice(0, 1)), da.isel(y=slice(1, None))], dim="y") + +In addition to combining along an existing dimension, ``concat`` can create a +new dimension by stacking lower dimensional arrays together: + +.. ipython:: python + + da.sel(x="a") + xr.concat([da.isel(x=0), da.isel(x=1)], "x") + +If the second argument to ``concat`` is a new dimension name, the arrays will +be concatenated along that new dimension, which is always inserted as the first +dimension: + +.. ipython:: python + + xr.concat([da.isel(x=0), da.isel(x=1)], "new_dim") + +The second argument to ``concat`` can also be an :py:class:`~pandas.Index` or +:py:class:`~xarray.DataArray` object as well as a string, in which case it is +used to label the values along the new dimension: + +.. ipython:: python + + xr.concat([da.isel(x=0), da.isel(x=1)], pd.Index([-90, -100], name="new_dim")) + +Of course, ``concat`` also works on ``Dataset`` objects: + +.. ipython:: python + + ds = da.to_dataset(name="foo") + xr.concat([ds.sel(x="a"), ds.sel(x="b")], "x") + +:py:func:`~xarray.concat` has a number of options which provide deeper control +over which variables are concatenated and how it handles conflicting variables +between datasets. With the default parameters, xarray will load some coordinate +variables into memory to compare them between datasets. This may be prohibitively +expensive if you are manipulating your dataset lazily using :ref:`dask`. + +.. _merge: + +Merge +~~~~~ + +To combine variables and coordinates between multiple ``DataArray`` and/or +``Dataset`` objects, use :py:func:`~xarray.merge`. It can merge a list of +``Dataset``, ``DataArray`` or dictionaries of objects convertible to +``DataArray`` objects: + +.. ipython:: python + + xr.merge([ds, ds.rename({"foo": "bar"})]) + xr.merge([xr.DataArray(n, name="var%d" % n) for n in range(5)]) + +If you merge another dataset (or a dictionary including data array objects), by +default the resulting dataset will be aligned on the **union** of all index +coordinates: + +.. ipython:: python + + other = xr.Dataset({"bar": ("x", [1, 2, 3, 4]), "x": list("abcd")}) + xr.merge([ds, other]) + +This ensures that ``merge`` is non-destructive. ``xarray.MergeError`` is raised +if you attempt to merge two variables with the same name but different values: + +.. ipython:: + + @verbatim + In [1]: xr.merge([ds, ds + 1]) + MergeError: conflicting values for variable 'foo' on objects to be combined: + first value: + array([[ 0.4691123 , -0.28286334, -1.5090585 ], + [-1.13563237, 1.21211203, -0.17321465]]) + second value: + array([[ 1.4691123 , 0.71713666, -0.5090585 ], + [-0.13563237, 2.21211203, 0.82678535]]) + +The same non-destructive merging between ``DataArray`` index coordinates is +used in the :py:class:`~xarray.Dataset` constructor: + +.. ipython:: python + + xr.Dataset({"a": da.isel(x=slice(0, 1)), "b": da.isel(x=slice(1, 2))}) + +.. _combine: + +Combine +~~~~~~~ + +The instance method :py:meth:`~xarray.DataArray.combine_first` combines two +datasets/data arrays and defaults to non-null values in the calling object, +using values from the called object to fill holes. The resulting coordinates +are the union of coordinate labels. Vacant cells as a result of the outer-join +are filled with ``NaN``. For example: + +.. ipython:: python + + ar0 = xr.DataArray([[0, 0], [0, 0]], [("x", ["a", "b"]), ("y", [-1, 0])]) + ar1 = xr.DataArray([[1, 1], [1, 1]], [("x", ["b", "c"]), ("y", [0, 1])]) + ar0.combine_first(ar1) + ar1.combine_first(ar0) + +For datasets, ``ds0.combine_first(ds1)`` works similarly to +``xr.merge([ds0, ds1])``, except that ``xr.merge`` raises ``MergeError`` when +there are conflicting values in variables to be merged, whereas +``.combine_first`` defaults to the calling object's values. + +.. _update: + +Update +~~~~~~ + +In contrast to ``merge``, :py:meth:`~xarray.Dataset.update` modifies a dataset +in-place without checking for conflicts, and will overwrite any existing +variables with new values: + +.. ipython:: python + + ds.update({"space": ("space", [10.2, 9.4, 3.9])}) + +However, dimensions are still required to be consistent between different +Dataset variables, so you cannot change the size of a dimension unless you +replace all dataset variables that use it. + +``update`` also performs automatic alignment if necessary. Unlike ``merge``, it +maintains the alignment of the original array instead of merging indexes: + +.. ipython:: python + + ds.update(other) + +The exact same alignment logic when setting a variable with ``__setitem__`` +syntax: + +.. ipython:: python + + ds["baz"] = xr.DataArray([9, 9, 9, 9, 9], coords=[("x", list("abcde"))]) + ds.baz + +Equals and identical +~~~~~~~~~~~~~~~~~~~~ + +Xarray objects can be compared by using the :py:meth:`~xarray.Dataset.equals`, +:py:meth:`~xarray.Dataset.identical` and +:py:meth:`~xarray.Dataset.broadcast_equals` methods. These methods are used by +the optional ``compat`` argument on ``concat`` and ``merge``. + +:py:attr:`~xarray.Dataset.equals` checks dimension names, indexes and array +values: + +.. ipython:: python + + da.equals(da.copy()) + +:py:attr:`~xarray.Dataset.identical` also checks attributes, and the name of each +object: + +.. ipython:: python + + da.identical(da.rename("bar")) + +:py:attr:`~xarray.Dataset.broadcast_equals` does a more relaxed form of equality +check that allows variables to have different dimensions, as long as values +are constant along those new dimensions: + +.. ipython:: python + + left = xr.Dataset(coords={"x": 0}) + right = xr.Dataset({"x": [0, 0, 0]}) + left.broadcast_equals(right) + +Like pandas objects, two xarray objects are still equal or identical if they have +missing values marked by ``NaN`` in the same locations. + +In contrast, the ``==`` operation performs element-wise comparison (like +numpy): + +.. ipython:: python + + da == da.copy() + +Note that ``NaN`` does not compare equal to ``NaN`` in element-wise comparison; +you may need to deal with missing values explicitly. + +.. _combining.no_conflicts: + +Merging with 'no_conflicts' +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The ``compat`` argument ``'no_conflicts'`` is only available when +combining xarray objects with ``merge``. In addition to the above comparison +methods it allows the merging of xarray objects with locations where *either* +have ``NaN`` values. This can be used to combine data with overlapping +coordinates as long as any non-missing values agree or are disjoint: + +.. ipython:: python + + ds1 = xr.Dataset({"a": ("x", [10, 20, 30, np.nan])}, {"x": [1, 2, 3, 4]}) + ds2 = xr.Dataset({"a": ("x", [np.nan, 30, 40, 50])}, {"x": [2, 3, 4, 5]}) + xr.merge([ds1, ds2], compat="no_conflicts") + +Note that due to the underlying representation of missing values as floating +point numbers (``NaN``), variable data type is not always preserved when merging +in this manner. + +.. _combining.multi: + +Combining along multiple dimensions +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +For combining many objects along multiple dimensions xarray provides +:py:func:`~xarray.combine_nested` and :py:func:`~xarray.combine_by_coords`. These +functions use a combination of ``concat`` and ``merge`` across different +variables to combine many objects into one. + +:py:func:`~xarray.combine_nested` requires specifying the order in which the +objects should be combined, while :py:func:`~xarray.combine_by_coords` attempts to +infer this ordering automatically from the coordinates in the data. + +:py:func:`~xarray.combine_nested` is useful when you know the spatial +relationship between each object in advance. The datasets must be provided in +the form of a nested list, which specifies their relative position and +ordering. A common task is collecting data from a parallelized simulation where +each processor wrote out data to a separate file. A domain which was decomposed +into 4 parts, 2 each along both the x and y axes, requires organising the +datasets into a doubly-nested list, e.g: + +.. ipython:: python + + arr = xr.DataArray( + name="temperature", data=np.random.randint(5, size=(2, 2)), dims=["x", "y"] + ) + arr + ds_grid = [[arr, arr], [arr, arr]] + xr.combine_nested(ds_grid, concat_dim=["x", "y"]) + +:py:func:`~xarray.combine_nested` can also be used to explicitly merge datasets +with different variables. For example if we have 4 datasets, which are divided +along two times, and contain two different variables, we can pass ``None`` +to ``'concat_dim'`` to specify the dimension of the nested list over which +we wish to use ``merge`` instead of ``concat``: + +.. ipython:: python + + temp = xr.DataArray(name="temperature", data=np.random.randn(2), dims=["t"]) + precip = xr.DataArray(name="precipitation", data=np.random.randn(2), dims=["t"]) + ds_grid = [[temp, precip], [temp, precip]] + xr.combine_nested(ds_grid, concat_dim=["t", None]) + +:py:func:`~xarray.combine_by_coords` is for combining objects which have dimension +coordinates which specify their relationship to and order relative to one +another, for example a linearly-increasing 'time' dimension coordinate. + +Here we combine two datasets using their common dimension coordinates. Notice +they are concatenated in order based on the values in their dimension +coordinates, not on their position in the list passed to ``combine_by_coords``. + +.. ipython:: python + :okwarning: + + x1 = xr.DataArray(name="foo", data=np.random.randn(3), coords=[("x", [0, 1, 2])]) + x2 = xr.DataArray(name="foo", data=np.random.randn(3), coords=[("x", [3, 4, 5])]) + xr.combine_by_coords([x2, x1]) + +These functions can be used by :py:func:`~xarray.open_mfdataset` to open many +files as one dataset. The particular function used is specified by setting the +argument ``'combine'`` to ``'by_coords'`` or ``'nested'``. This is useful for +situations where your data is split across many files in multiple locations, +which have some known relationship between one another. diff --git a/test/fixtures/whole_applications/xarray/doc/user-guide/computation.rst b/test/fixtures/whole_applications/xarray/doc/user-guide/computation.rst new file mode 100644 index 0000000..f99d41b --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/user-guide/computation.rst @@ -0,0 +1,910 @@ +.. currentmodule:: xarray + +.. _comput: + +########### +Computation +########### + + +The labels associated with :py:class:`~xarray.DataArray` and +:py:class:`~xarray.Dataset` objects enables some powerful shortcuts for +computation, notably including aggregation and broadcasting by dimension +names. + +Basic array math +================ + +Arithmetic operations with a single DataArray automatically vectorize (like +numpy) over all array values: + +.. ipython:: python + :suppress: + + import numpy as np + import pandas as pd + import xarray as xr + + np.random.seed(123456) + +.. ipython:: python + + arr = xr.DataArray( + np.random.RandomState(0).randn(2, 3), [("x", ["a", "b"]), ("y", [10, 20, 30])] + ) + arr - 3 + abs(arr) + +You can also use any of numpy's or scipy's many `ufunc`__ functions directly on +a DataArray: + +__ https://numpy.org/doc/stable/reference/ufuncs.html + +.. ipython:: python + + np.sin(arr) + +Use :py:func:`~xarray.where` to conditionally switch between values: + +.. ipython:: python + + xr.where(arr > 0, "positive", "negative") + +Use `@` to compute the :py:func:`~xarray.dot` product: + +.. ipython:: python + + arr @ arr + +Data arrays also implement many :py:class:`numpy.ndarray` methods: + +.. ipython:: python + + arr.round(2) + arr.T + + intarr = xr.DataArray([0, 1, 2, 3, 4, 5]) + intarr << 2 # only supported for int types + intarr >> 1 + +.. _missing_values: + +Missing values +============== + +Xarray represents missing values using the "NaN" (Not a Number) value from NumPy, which is a +special floating-point value that indicates a value that is undefined or unrepresentable. +There are several methods for handling missing values in xarray: + +Xarray objects borrow the :py:meth:`~xarray.DataArray.isnull`, +:py:meth:`~xarray.DataArray.notnull`, :py:meth:`~xarray.DataArray.count`, +:py:meth:`~xarray.DataArray.dropna`, :py:meth:`~xarray.DataArray.fillna`, +:py:meth:`~xarray.DataArray.ffill`, and :py:meth:`~xarray.DataArray.bfill` +methods for working with missing data from pandas: + +:py:meth:`~xarray.DataArray.isnull` is a method in xarray that can be used to check for missing or null values in an xarray object. +It returns a new xarray object with the same dimensions as the original object, but with boolean values +indicating where **missing values** are present. + +.. ipython:: python + + x = xr.DataArray([0, 1, np.nan, np.nan, 2], dims=["x"]) + x.isnull() + +In this example, the third and fourth elements of 'x' are NaN, so the resulting :py:class:`~xarray.DataArray` +object has 'True' values in the third and fourth positions and 'False' values in the other positions. + +:py:meth:`~xarray.DataArray.notnull` is a method in xarray that can be used to check for non-missing or non-null values in an xarray +object. It returns a new xarray object with the same dimensions as the original object, but with boolean +values indicating where **non-missing values** are present. + +.. ipython:: python + + x = xr.DataArray([0, 1, np.nan, np.nan, 2], dims=["x"]) + x.notnull() + +In this example, the first two and the last elements of x are not NaN, so the resulting +:py:class:`~xarray.DataArray` object has 'True' values in these positions, and 'False' values in the +third and fourth positions where NaN is located. + +:py:meth:`~xarray.DataArray.count` is a method in xarray that can be used to count the number of +non-missing values along one or more dimensions of an xarray object. It returns a new xarray object with +the same dimensions as the original object, but with each element replaced by the count of non-missing +values along the specified dimensions. + +.. ipython:: python + + x = xr.DataArray([0, 1, np.nan, np.nan, 2], dims=["x"]) + x.count() + +In this example, 'x' has five elements, but two of them are NaN, so the resulting +:py:class:`~xarray.DataArray` object having a single element containing the value '3', which represents +the number of non-null elements in x. + +:py:meth:`~xarray.DataArray.dropna` is a method in xarray that can be used to remove missing or null values from an xarray object. +It returns a new xarray object with the same dimensions as the original object, but with missing values +removed. + +.. ipython:: python + + x = xr.DataArray([0, 1, np.nan, np.nan, 2], dims=["x"]) + x.dropna(dim="x") + +In this example, on calling x.dropna(dim="x") removes any missing values and returns a new +:py:class:`~xarray.DataArray` object with only the non-null elements [0, 1, 2] of 'x', in the +original order. + +:py:meth:`~xarray.DataArray.fillna` is a method in xarray that can be used to fill missing or null values in an xarray object with a +specified value or method. It returns a new xarray object with the same dimensions as the original object, but with missing values filled. + +.. ipython:: python + + x = xr.DataArray([0, 1, np.nan, np.nan, 2], dims=["x"]) + x.fillna(-1) + +In this example, there are two NaN values in 'x', so calling x.fillna(-1) replaces these values with -1 and +returns a new :py:class:`~xarray.DataArray` object with five elements, containing the values +[0, 1, -1, -1, 2] in the original order. + +:py:meth:`~xarray.DataArray.ffill` is a method in xarray that can be used to forward fill (or fill forward) missing values in an +xarray object along one or more dimensions. It returns a new xarray object with the same dimensions as the +original object, but with missing values replaced by the last non-missing value along the specified dimensions. + +.. ipython:: python + + x = xr.DataArray([0, 1, np.nan, np.nan, 2], dims=["x"]) + x.ffill("x") + +In this example, there are two NaN values in 'x', so calling x.ffill("x") fills these values with the last +non-null value in the same dimension, which are 0 and 1, respectively. The resulting :py:class:`~xarray.DataArray` object has +five elements, containing the values [0, 1, 1, 1, 2] in the original order. + +:py:meth:`~xarray.DataArray.bfill` is a method in xarray that can be used to backward fill (or fill backward) missing values in an +xarray object along one or more dimensions. It returns a new xarray object with the same dimensions as the original object, but +with missing values replaced by the next non-missing value along the specified dimensions. + +.. ipython:: python + + x = xr.DataArray([0, 1, np.nan, np.nan, 2], dims=["x"]) + x.bfill("x") + +In this example, there are two NaN values in 'x', so calling x.bfill("x") fills these values with the next +non-null value in the same dimension, which are 2 and 2, respectively. The resulting :py:class:`~xarray.DataArray` object has +five elements, containing the values [0, 1, 2, 2, 2] in the original order. + +Like pandas, xarray uses the float value ``np.nan`` (not-a-number) to represent +missing values. + +Xarray objects also have an :py:meth:`~xarray.DataArray.interpolate_na` method +for filling missing values via 1D interpolation. It returns a new xarray object with the same dimensions +as the original object, but with missing values interpolated. + +.. ipython:: python + + x = xr.DataArray( + [0, 1, np.nan, np.nan, 2], + dims=["x"], + coords={"xx": xr.Variable("x", [0, 1, 1.1, 1.9, 3])}, + ) + x.interpolate_na(dim="x", method="linear", use_coordinate="xx") + +In this example, there are two NaN values in 'x', so calling x.interpolate_na(dim="x", method="linear", +use_coordinate="xx") fills these values with interpolated values along the "x" dimension using linear +interpolation based on the values of the xx coordinate. The resulting :py:class:`~xarray.DataArray` object has five elements, +containing the values [0., 1., 1.05, 1.45, 2.] in the original order. Note that the interpolated values +are calculated based on the values of the 'xx' coordinate, which has non-integer values, resulting in +non-integer interpolated values. + +Note that xarray slightly diverges from the pandas ``interpolate`` syntax by +providing the ``use_coordinate`` keyword which facilitates a clear specification +of which values to use as the index in the interpolation. +Xarray also provides the ``max_gap`` keyword argument to limit the interpolation to +data gaps of length ``max_gap`` or smaller. See :py:meth:`~xarray.DataArray.interpolate_na` +for more. + +.. _agg: + +Aggregation +=========== + +Aggregation methods have been updated to take a `dim` argument instead of +`axis`. This allows for very intuitive syntax for aggregation methods that are +applied along particular dimension(s): + +.. ipython:: python + + arr.sum(dim="x") + arr.std(["x", "y"]) + arr.min() + + +If you need to figure out the axis number for a dimension yourself (say, +for wrapping code designed to work with numpy arrays), you can use the +:py:meth:`~xarray.DataArray.get_axis_num` method: + +.. ipython:: python + + arr.get_axis_num("y") + +These operations automatically skip missing values, like in pandas: + +.. ipython:: python + + xr.DataArray([1, 2, np.nan, 3]).mean() + +If desired, you can disable this behavior by invoking the aggregation method +with ``skipna=False``. + +.. _comput.rolling: + +Rolling window operations +========================= + +``DataArray`` objects include a :py:meth:`~xarray.DataArray.rolling` method. This +method supports rolling window aggregation: + +.. ipython:: python + + arr = xr.DataArray(np.arange(0, 7.5, 0.5).reshape(3, 5), dims=("x", "y")) + arr + +:py:meth:`~xarray.DataArray.rolling` is applied along one dimension using the +name of the dimension as a key (e.g. ``y``) and the window size as the value +(e.g. ``3``). We get back a ``Rolling`` object: + +.. ipython:: python + + arr.rolling(y=3) + +Aggregation and summary methods can be applied directly to the ``Rolling`` +object: + +.. ipython:: python + + r = arr.rolling(y=3) + r.reduce(np.std) + r.mean() + +Aggregation results are assigned the coordinate at the end of each window by +default, but can be centered by passing ``center=True`` when constructing the +``Rolling`` object: + +.. ipython:: python + + r = arr.rolling(y=3, center=True) + r.mean() + +As can be seen above, aggregations of windows which overlap the border of the +array produce ``nan``\s. Setting ``min_periods`` in the call to ``rolling`` +changes the minimum number of observations within the window required to have +a value when aggregating: + +.. ipython:: python + + r = arr.rolling(y=3, min_periods=2) + r.mean() + r = arr.rolling(y=3, center=True, min_periods=2) + r.mean() + +From version 0.17, xarray supports multidimensional rolling, + +.. ipython:: python + + r = arr.rolling(x=2, y=3, min_periods=2) + r.mean() + +.. tip:: + + Note that rolling window aggregations are faster and use less memory when bottleneck_ is installed. This only applies to numpy-backed xarray objects with 1d-rolling. + +.. _bottleneck: https://github.com/pydata/bottleneck + +We can also manually iterate through ``Rolling`` objects: + +.. code:: python + + for label, arr_window in r: + # arr_window is a view of x + ... + +.. _comput.rolling_exp: + +While ``rolling`` provides a simple moving average, ``DataArray`` also supports +an exponential moving average with :py:meth:`~xarray.DataArray.rolling_exp`. +This is similar to pandas' ``ewm`` method. numbagg_ is required. + +.. _numbagg: https://github.com/numbagg/numbagg + +.. code:: python + + arr.rolling_exp(y=3).mean() + +The ``rolling_exp`` method takes a ``window_type`` kwarg, which can be ``'alpha'``, +``'com'`` (for ``center-of-mass``), ``'span'``, and ``'halflife'``. The default is +``span``. + +Finally, the rolling object has a ``construct`` method which returns a +view of the original ``DataArray`` with the windowed dimension in +the last position. +You can use this for more advanced rolling operations such as strided rolling, +windowed rolling, convolution, short-time FFT etc. + +.. ipython:: python + + # rolling with 2-point stride + rolling_da = r.construct(x="x_win", y="y_win", stride=2) + rolling_da + rolling_da.mean(["x_win", "y_win"], skipna=False) + +Because the ``DataArray`` given by ``r.construct('window_dim')`` is a view +of the original array, it is memory efficient. +You can also use ``construct`` to compute a weighted rolling sum: + +.. ipython:: python + + weight = xr.DataArray([0.25, 0.5, 0.25], dims=["window"]) + arr.rolling(y=3).construct(y="window").dot(weight) + +.. note:: + numpy's Nan-aggregation functions such as ``nansum`` copy the original array. + In xarray, we internally use these functions in our aggregation methods + (such as ``.sum()``) if ``skipna`` argument is not specified or set to True. + This means ``rolling_da.mean('window_dim')`` is memory inefficient. + To avoid this, use ``skipna=False`` as the above example. + + +.. _comput.weighted: + +Weighted array reductions +========================= + +:py:class:`DataArray` and :py:class:`Dataset` objects include :py:meth:`DataArray.weighted` +and :py:meth:`Dataset.weighted` array reduction methods. They currently +support weighted ``sum``, ``mean``, ``std``, ``var`` and ``quantile``. + +.. ipython:: python + + coords = dict(month=("month", [1, 2, 3])) + + prec = xr.DataArray([1.1, 1.0, 0.9], dims=("month",), coords=coords) + weights = xr.DataArray([31, 28, 31], dims=("month",), coords=coords) + +Create a weighted object: + +.. ipython:: python + + weighted_prec = prec.weighted(weights) + weighted_prec + +Calculate the weighted sum: + +.. ipython:: python + + weighted_prec.sum() + +Calculate the weighted mean: + +.. ipython:: python + + weighted_prec.mean(dim="month") + +Calculate the weighted quantile: + +.. ipython:: python + + weighted_prec.quantile(q=0.5, dim="month") + +The weighted sum corresponds to: + +.. ipython:: python + + weighted_sum = (prec * weights).sum() + weighted_sum + +the weighted mean to: + +.. ipython:: python + + weighted_mean = weighted_sum / weights.sum() + weighted_mean + +the weighted variance to: + +.. ipython:: python + + weighted_var = weighted_prec.sum_of_squares() / weights.sum() + weighted_var + +and the weighted standard deviation to: + +.. ipython:: python + + weighted_std = np.sqrt(weighted_var) + weighted_std + +However, the functions also take missing values in the data into account: + +.. ipython:: python + + data = xr.DataArray([np.NaN, 2, 4]) + weights = xr.DataArray([8, 1, 1]) + + data.weighted(weights).mean() + +Using ``(data * weights).sum() / weights.sum()`` would (incorrectly) result +in 0.6. + + +If the weights add up to to 0, ``sum`` returns 0: + +.. ipython:: python + + data = xr.DataArray([1.0, 1.0]) + weights = xr.DataArray([-1.0, 1.0]) + + data.weighted(weights).sum() + +and ``mean``, ``std`` and ``var`` return ``NaN``: + +.. ipython:: python + + data.weighted(weights).mean() + + +.. note:: + ``weights`` must be a :py:class:`DataArray` and cannot contain missing values. + Missing values can be replaced manually by ``weights.fillna(0)``. + +.. _compute.coarsen: + +Coarsen large arrays +==================== + +:py:class:`DataArray` and :py:class:`Dataset` objects include a +:py:meth:`~xarray.DataArray.coarsen` and :py:meth:`~xarray.Dataset.coarsen` +methods. This supports block aggregation along multiple dimensions, + +.. ipython:: python + + x = np.linspace(0, 10, 300) + t = pd.date_range("1999-12-15", periods=364) + da = xr.DataArray( + np.sin(x) * np.cos(np.linspace(0, 1, 364)[:, np.newaxis]), + dims=["time", "x"], + coords={"time": t, "x": x}, + ) + da + +In order to take a block mean for every 7 days along ``time`` dimension and +every 2 points along ``x`` dimension, + +.. ipython:: python + + da.coarsen(time=7, x=2).mean() + +:py:meth:`~xarray.DataArray.coarsen` raises an ``ValueError`` if the data +length is not a multiple of the corresponding window size. +You can choose ``boundary='trim'`` or ``boundary='pad'`` options for trimming +the excess entries or padding ``nan`` to insufficient entries, + +.. ipython:: python + + da.coarsen(time=30, x=2, boundary="trim").mean() + +If you want to apply a specific function to coordinate, you can pass the +function or method name to ``coord_func`` option, + +.. ipython:: python + + da.coarsen(time=7, x=2, coord_func={"time": "min"}).mean() + +You can also :ref:`use coarsen to reshape` without applying a computation. + +.. _compute.using_coordinates: + +Computation using Coordinates +============================= + +Xarray objects have some handy methods for the computation with their +coordinates. :py:meth:`~xarray.DataArray.differentiate` computes derivatives by +central finite differences using their coordinates, + +.. ipython:: python + + a = xr.DataArray([0, 1, 2, 3], dims=["x"], coords=[[0.1, 0.11, 0.2, 0.3]]) + a + a.differentiate("x") + +This method can be used also for multidimensional arrays, + +.. ipython:: python + + a = xr.DataArray( + np.arange(8).reshape(4, 2), dims=["x", "y"], coords={"x": [0.1, 0.11, 0.2, 0.3]} + ) + a.differentiate("x") + +:py:meth:`~xarray.DataArray.integrate` computes integration based on +trapezoidal rule using their coordinates, + +.. ipython:: python + + a.integrate("x") + +.. note:: + These methods are limited to simple cartesian geometry. Differentiation + and integration along multidimensional coordinate are not supported. + + +.. _compute.polyfit: + +Fitting polynomials +=================== + +Xarray objects provide an interface for performing linear or polynomial regressions +using the least-squares method. :py:meth:`~xarray.DataArray.polyfit` computes the +best fitting coefficients along a given dimension and for a given order, + +.. ipython:: python + + x = xr.DataArray(np.arange(10), dims=["x"], name="x") + a = xr.DataArray(3 + 4 * x, dims=["x"], coords={"x": x}) + out = a.polyfit(dim="x", deg=1, full=True) + out + +The method outputs a dataset containing the coefficients (and more if `full=True`). +The inverse operation is done with :py:meth:`~xarray.polyval`, + +.. ipython:: python + + xr.polyval(coord=x, coeffs=out.polyfit_coefficients) + +.. note:: + These methods replicate the behaviour of :py:func:`numpy.polyfit` and :py:func:`numpy.polyval`. + + +.. _compute.curvefit: + +Fitting arbitrary functions +=========================== + +Xarray objects also provide an interface for fitting more complex functions using +:py:func:`scipy.optimize.curve_fit`. :py:meth:`~xarray.DataArray.curvefit` accepts +user-defined functions and can fit along multiple coordinates. + +For example, we can fit a relationship between two ``DataArray`` objects, maintaining +a unique fit at each spatial coordinate but aggregating over the time dimension: + +.. ipython:: python + + def exponential(x, a, xc): + return np.exp((x - xc) / a) + + + x = np.arange(-5, 5, 0.1) + t = np.arange(-5, 5, 0.1) + X, T = np.meshgrid(x, t) + Z1 = np.random.uniform(low=-5, high=5, size=X.shape) + Z2 = exponential(Z1, 3, X) + Z3 = exponential(Z1, 1, -X) + + ds = xr.Dataset( + data_vars=dict( + var1=(["t", "x"], Z1), var2=(["t", "x"], Z2), var3=(["t", "x"], Z3) + ), + coords={"t": t, "x": x}, + ) + ds[["var2", "var3"]].curvefit( + coords=ds.var1, + func=exponential, + reduce_dims="t", + bounds={"a": (0.5, 5), "xc": (-5, 5)}, + ) + +We can also fit multi-dimensional functions, and even use a wrapper function to +simultaneously fit a summation of several functions, such as this field containing +two gaussian peaks: + +.. ipython:: python + + def gaussian_2d(coords, a, xc, yc, xalpha, yalpha): + x, y = coords + z = a * np.exp( + -np.square(x - xc) / 2 / np.square(xalpha) + - np.square(y - yc) / 2 / np.square(yalpha) + ) + return z + + + def multi_peak(coords, *args): + z = np.zeros(coords[0].shape) + for i in range(len(args) // 5): + z += gaussian_2d(coords, *args[i * 5 : i * 5 + 5]) + return z + + + x = np.arange(-5, 5, 0.1) + y = np.arange(-5, 5, 0.1) + X, Y = np.meshgrid(x, y) + + n_peaks = 2 + names = ["a", "xc", "yc", "xalpha", "yalpha"] + names = [f"{name}{i}" for i in range(n_peaks) for name in names] + Z = gaussian_2d((X, Y), 3, 1, 1, 2, 1) + gaussian_2d((X, Y), 2, -1, -2, 1, 1) + Z += np.random.normal(scale=0.1, size=Z.shape) + + da = xr.DataArray(Z, dims=["y", "x"], coords={"y": y, "x": x}) + da.curvefit( + coords=["x", "y"], + func=multi_peak, + param_names=names, + kwargs={"maxfev": 10000}, + ) + +.. note:: + This method replicates the behavior of :py:func:`scipy.optimize.curve_fit`. + + +.. _compute.broadcasting: + +Broadcasting by dimension name +============================== + +``DataArray`` objects automatically align themselves ("broadcasting" in +the numpy parlance) by dimension name instead of axis order. With xarray, you +do not need to transpose arrays or insert dimensions of length 1 to get array +operations to work, as commonly done in numpy with :py:func:`numpy.reshape` or +:py:data:`numpy.newaxis`. + +This is best illustrated by a few examples. Consider two one-dimensional +arrays with different sizes aligned along different dimensions: + +.. ipython:: python + + a = xr.DataArray([1, 2], [("x", ["a", "b"])]) + a + b = xr.DataArray([-1, -2, -3], [("y", [10, 20, 30])]) + b + +With xarray, we can apply binary mathematical operations to these arrays, and +their dimensions are expanded automatically: + +.. ipython:: python + + a * b + +Moreover, dimensions are always reordered to the order in which they first +appeared: + +.. ipython:: python + + c = xr.DataArray(np.arange(6).reshape(3, 2), [b["y"], a["x"]]) + c + a + c + +This means, for example, that you always subtract an array from its transpose: + +.. ipython:: python + + c - c.T + +You can explicitly broadcast xarray data structures by using the +:py:func:`~xarray.broadcast` function: + +.. ipython:: python + + a2, b2 = xr.broadcast(a, b) + a2 + b2 + +.. _math automatic alignment: + +Automatic alignment +=================== + +Xarray enforces alignment between *index* :ref:`coordinates` (that is, +coordinates with the same name as a dimension, marked by ``*``) on objects used +in binary operations. + +Similarly to pandas, this alignment is automatic for arithmetic on binary +operations. The default result of a binary operation is by the *intersection* +(not the union) of coordinate labels: + +.. ipython:: python + + arr = xr.DataArray(np.arange(3), [("x", range(3))]) + arr + arr[:-1] + +If coordinate values for a dimension are missing on either argument, all +matching dimensions must have the same size: + +.. ipython:: + :verbatim: + + In [1]: arr + xr.DataArray([1, 2], dims="x") + ValueError: arguments without labels along dimension 'x' cannot be aligned because they have different dimension size(s) {2} than the size of the aligned dimension labels: 3 + + +However, one can explicitly change this default automatic alignment type ("inner") +via :py:func:`~xarray.set_options()` in context manager: + +.. ipython:: python + + with xr.set_options(arithmetic_join="outer"): + arr + arr[:1] + arr + arr[:1] + +Before loops or performance critical code, it's a good idea to align arrays +explicitly (e.g., by putting them in the same Dataset or using +:py:func:`~xarray.align`) to avoid the overhead of repeated alignment with each +operation. See :ref:`align and reindex` for more details. + +.. note:: + + There is no automatic alignment between arguments when performing in-place + arithmetic operations such as ``+=``. You will need to use + :ref:`manual alignment`. This ensures in-place + arithmetic never needs to modify data types. + +.. _coordinates math: + +Coordinates +=========== + +Although index coordinates are aligned, other coordinates are not, and if their +values conflict, they will be dropped. This is necessary, for example, because +indexing turns 1D coordinates into scalar coordinates: + +.. ipython:: python + + arr[0] + arr[1] + # notice that the scalar coordinate 'x' is silently dropped + arr[1] - arr[0] + +Still, xarray will persist other coordinates in arithmetic, as long as there +are no conflicting values: + +.. ipython:: python + + # only one argument has the 'x' coordinate + arr[0] + 1 + # both arguments have the same 'x' coordinate + arr[0] - arr[0] + +Math with datasets +================== + +Datasets support arithmetic operations by automatically looping over all data +variables: + +.. ipython:: python + + ds = xr.Dataset( + { + "x_and_y": (("x", "y"), np.random.randn(3, 5)), + "x_only": ("x", np.random.randn(3)), + }, + coords=arr.coords, + ) + ds > 0 + +Datasets support most of the same methods found on data arrays: + +.. ipython:: python + + ds.mean(dim="x") + abs(ds) + +Datasets also support NumPy ufuncs (requires NumPy v1.13 or newer), or +alternatively you can use :py:meth:`~xarray.Dataset.map` to map a function +to each variable in a dataset: + +.. ipython:: python + + np.sin(ds) + ds.map(np.sin) + +Datasets also use looping over variables for *broadcasting* in binary +arithmetic. You can do arithmetic between any ``DataArray`` and a dataset: + +.. ipython:: python + + ds + arr + +Arithmetic between two datasets matches data variables of the same name: + +.. ipython:: python + + ds2 = xr.Dataset({"x_and_y": 0, "x_only": 100}) + ds - ds2 + +Similarly to index based alignment, the result has the intersection of all +matching data variables. + +.. _comput.wrapping-custom: + +Wrapping custom computation +=========================== + +It doesn't always make sense to do computation directly with xarray objects: + + - In the inner loop of performance limited code, using xarray can add + considerable overhead compared to using NumPy or native Python types. + This is particularly true when working with scalars or small arrays (less + than ~1e6 elements). Keeping track of labels and ensuring their consistency + adds overhead, and xarray's core itself is not especially fast, because it's + written in Python rather than a compiled language like C. Also, xarray's + high level label-based APIs removes low-level control over how operations + are implemented. + - Even if speed doesn't matter, it can be important to wrap existing code, or + to support alternative interfaces that don't use xarray objects. + +For these reasons, it is often well-advised to write low-level routines that +work with NumPy arrays, and to wrap these routines to work with xarray objects. +However, adding support for labels on both :py:class:`~xarray.Dataset` and +:py:class:`~xarray.DataArray` can be a bit of a chore. + +To make this easier, xarray supplies the :py:func:`~xarray.apply_ufunc` helper +function, designed for wrapping functions that support broadcasting and +vectorization on unlabeled arrays in the style of a NumPy +`universal function `_ ("ufunc" for short). +``apply_ufunc`` takes care of everything needed for an idiomatic xarray wrapper, +including alignment, broadcasting, looping over ``Dataset`` variables (if +needed), and merging of coordinates. In fact, many internal xarray +functions/methods are written using ``apply_ufunc``. + +Simple functions that act independently on each value should work without +any additional arguments: + +.. ipython:: python + + squared_error = lambda x, y: (x - y) ** 2 + arr1 = xr.DataArray([0, 1, 2, 3], dims="x") + xr.apply_ufunc(squared_error, arr1, 1) + +For using more complex operations that consider some array values collectively, +it's important to understand the idea of "core dimensions" from NumPy's +`generalized ufuncs `_. Core dimensions are defined as dimensions +that should *not* be broadcast over. Usually, they correspond to the fundamental +dimensions over which an operation is defined, e.g., the summed axis in +``np.sum``. A good clue that core dimensions are needed is the presence of an +``axis`` argument on the corresponding NumPy function. + +With ``apply_ufunc``, core dimensions are recognized by name, and then moved to +the last dimension of any input arguments before applying the given function. +This means that for functions that accept an ``axis`` argument, you usually need +to set ``axis=-1``. As an example, here is how we would wrap +:py:func:`numpy.linalg.norm` to calculate the vector norm: + +.. code-block:: python + + def vector_norm(x, dim, ord=None): + return xr.apply_ufunc( + np.linalg.norm, x, input_core_dims=[[dim]], kwargs={"ord": ord, "axis": -1} + ) + +.. ipython:: python + :suppress: + + def vector_norm(x, dim, ord=None): + return xr.apply_ufunc( + np.linalg.norm, x, input_core_dims=[[dim]], kwargs={"ord": ord, "axis": -1} + ) + +.. ipython:: python + + vector_norm(arr1, dim="x") + +Because ``apply_ufunc`` follows a standard convention for ufuncs, it plays +nicely with tools for building vectorized functions, like +:py:func:`numpy.broadcast_arrays` and :py:class:`numpy.vectorize`. For high performance +needs, consider using :doc:`Numba's vectorize and guvectorize `. + +In addition to wrapping functions, ``apply_ufunc`` can automatically parallelize +many functions when using dask by setting ``dask='parallelized'``. See +:ref:`dask.automatic-parallelization` for details. + +:py:func:`~xarray.apply_ufunc` also supports some advanced options for +controlling alignment of variables and the form of the result. See the +docstring for full details and more examples. diff --git a/test/fixtures/whole_applications/xarray/doc/user-guide/dask.rst b/test/fixtures/whole_applications/xarray/doc/user-guide/dask.rst new file mode 100644 index 0000000..27e7449 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/user-guide/dask.rst @@ -0,0 +1,577 @@ +.. currentmodule:: xarray + +.. _dask: + +Parallel computing with Dask +============================ + +Xarray integrates with `Dask `__ to support parallel +computations and streaming computation on datasets that don't fit into memory. +Currently, Dask is an entirely optional feature for xarray. However, the +benefits of using Dask are sufficiently strong that Dask may become a required +dependency in a future version of xarray. + +For a full example of how to use xarray's Dask integration, read the +`blog post introducing xarray and Dask`_. More up-to-date examples +may be found at the `Pangeo project's gallery `_ +and at the `Dask examples website `_. + +.. _blog post introducing xarray and Dask: https://stephanhoyer.com/2015/06/11/xray-dask-out-of-core-labeled-arrays/ + +What is a Dask array? +--------------------- + +.. image:: ../_static/dask_array.png + :width: 40 % + :align: right + :alt: A Dask array + +Dask divides arrays into many small pieces, called *chunks*, each of which is +presumed to be small enough to fit into memory. + +Unlike NumPy, which has eager evaluation, operations on Dask arrays are lazy. +Operations queue up a series of tasks mapped over blocks, and no computation is +performed until you actually ask values to be computed (e.g., to print results +to your screen or write to disk). At that point, data is loaded into memory +and computation proceeds in a streaming fashion, block-by-block. + +The actual computation is controlled by a multi-processing or thread pool, +which allows Dask to take full advantage of multiple processors available on +most modern computers. + +For more details, read the `Dask documentation `__. +Note that xarray only makes use of ``dask.array`` and ``dask.delayed``. + +.. _dask.io: + +Reading and writing data +------------------------ + +The usual way to create a ``Dataset`` filled with Dask arrays is to load the +data from a netCDF file or files. You can do this by supplying a ``chunks`` +argument to :py:func:`~xarray.open_dataset` or using the +:py:func:`~xarray.open_mfdataset` function. + +.. ipython:: python + :suppress: + + import os + + import numpy as np + import pandas as pd + import xarray as xr + + np.random.seed(123456) + np.set_printoptions(precision=3, linewidth=100, threshold=100, edgeitems=3) + + ds = xr.Dataset( + { + "temperature": ( + ("time", "latitude", "longitude"), + np.random.randn(30, 180, 180), + ), + "time": pd.date_range("2015-01-01", periods=30), + "longitude": np.arange(180), + "latitude": np.arange(89.5, -90.5, -1), + } + ) + ds.to_netcdf("example-data.nc") + +.. ipython:: python + + ds = xr.open_dataset("example-data.nc", chunks={"time": 10}) + ds + +In this example ``latitude`` and ``longitude`` do not appear in the ``chunks`` +dict, so only one chunk will be used along those dimensions. It is also +entirely equivalent to opening a dataset using :py:func:`~xarray.open_dataset` +and then chunking the data using the ``chunk`` method, e.g., +``xr.open_dataset('example-data.nc').chunk({'time': 10})``. + +To open multiple files simultaneously in parallel using Dask delayed, +use :py:func:`~xarray.open_mfdataset`:: + + xr.open_mfdataset('my/files/*.nc', parallel=True) + +This function will automatically concatenate and merge datasets into one in +the simple cases that it understands (see :py:func:`~xarray.combine_by_coords` +for the full disclaimer). By default, :py:func:`~xarray.open_mfdataset` will chunk each +netCDF file into a single Dask array; again, supply the ``chunks`` argument to +control the size of the resulting Dask arrays. In more complex cases, you can +open each file individually using :py:func:`~xarray.open_dataset` and merge the result, as +described in :ref:`combining data`. Passing the keyword argument ``parallel=True`` to +:py:func:`~xarray.open_mfdataset` will speed up the reading of large multi-file datasets by +executing those read tasks in parallel using ``dask.delayed``. + +.. warning:: + + :py:func:`~xarray.open_mfdataset` called without ``chunks`` argument will return + dask arrays with chunk sizes equal to the individual files. Re-chunking + the dataset after creation with ``ds.chunk()`` will lead to an ineffective use of + memory and is not recommended. + +You'll notice that printing a dataset still shows a preview of array values, +even if they are actually Dask arrays. We can do this quickly with Dask because +we only need to compute the first few values (typically from the first block). +To reveal the true nature of an array, print a DataArray: + +.. ipython:: python + + ds.temperature + +Once you've manipulated a Dask array, you can still write a dataset too big to +fit into memory back to disk by using :py:meth:`~xarray.Dataset.to_netcdf` in the +usual way. + +.. ipython:: python + + ds.to_netcdf("manipulated-example-data.nc") + +By setting the ``compute`` argument to ``False``, :py:meth:`~xarray.Dataset.to_netcdf` +will return a ``dask.delayed`` object that can be computed later. + +.. ipython:: python + + from dask.diagnostics import ProgressBar + + # or distributed.progress when using the distributed scheduler + delayed_obj = ds.to_netcdf("manipulated-example-data.nc", compute=False) + with ProgressBar(): + results = delayed_obj.compute() + +.. ipython:: python + :suppress: + + os.remove("manipulated-example-data.nc") # Was not opened. + +.. note:: + + When using Dask's distributed scheduler to write NETCDF4 files, + it may be necessary to set the environment variable `HDF5_USE_FILE_LOCKING=FALSE` + to avoid competing locks within the HDF5 SWMR file locking scheme. Note that + writing netCDF files with Dask's distributed scheduler is only supported for + the `netcdf4` backend. + +A dataset can also be converted to a Dask DataFrame using :py:meth:`~xarray.Dataset.to_dask_dataframe`. + +.. ipython:: python + :okwarning: + + df = ds.to_dask_dataframe() + df + +Dask DataFrames do not support multi-indexes so the coordinate variables from the dataset are included as columns in the Dask DataFrame. + + +Using Dask with xarray +---------------------- + +Nearly all existing xarray methods (including those for indexing, computation, +concatenating and grouped operations) have been extended to work automatically +with Dask arrays. When you load data as a Dask array in an xarray data +structure, almost all xarray operations will keep it as a Dask array; when this +is not possible, they will raise an exception rather than unexpectedly loading +data into memory. Converting a Dask array into memory generally requires an +explicit conversion step. One notable exception is indexing operations: to +enable label based indexing, xarray will automatically load coordinate labels +into memory. + +.. tip:: + + By default, dask uses its multi-threaded scheduler, which distributes work across + multiple cores and allows for processing some datasets that do not fit into memory. + For running across a cluster, `setup the distributed scheduler `_. + +The easiest way to convert an xarray data structure from lazy Dask arrays into +*eager*, in-memory NumPy arrays is to use the :py:meth:`~xarray.Dataset.load` method: + +.. ipython:: python + + ds.load() + +You can also access :py:attr:`~xarray.DataArray.values`, which will always be a +NumPy array: + +.. ipython:: + :verbatim: + + In [5]: ds.temperature.values + Out[5]: + array([[[ 4.691e-01, -2.829e-01, ..., -5.577e-01, 3.814e-01], + [ 1.337e+00, -1.531e+00, ..., 8.726e-01, -1.538e+00], + ... + # truncated for brevity + +Explicit conversion by wrapping a DataArray with ``np.asarray`` also works: + +.. ipython:: + :verbatim: + + In [5]: np.asarray(ds.temperature) + Out[5]: + array([[[ 4.691e-01, -2.829e-01, ..., -5.577e-01, 3.814e-01], + [ 1.337e+00, -1.531e+00, ..., 8.726e-01, -1.538e+00], + ... + +Alternatively you can load the data into memory but keep the arrays as +Dask arrays using the :py:meth:`~xarray.Dataset.persist` method: + +.. ipython:: python + + persisted = ds.persist() + +:py:meth:`~xarray.Dataset.persist` is particularly useful when using a +distributed cluster because the data will be loaded into distributed memory +across your machines and be much faster to use than reading repeatedly from +disk. + +.. warning:: + + On a single machine :py:meth:`~xarray.Dataset.persist` will try to load all of + your data into memory. You should make sure that your dataset is not larger than + available memory. + +.. note:: + + For more on the differences between :py:meth:`~xarray.Dataset.persist` and + :py:meth:`~xarray.Dataset.compute` see this `Stack Overflow answer on the differences between client persist and client compute `_ and the `Dask documentation `_. + +For performance you may wish to consider chunk sizes. The correct choice of +chunk size depends both on your data and on the operations you want to perform. +With xarray, both converting data to a Dask arrays and converting the chunk +sizes of Dask arrays is done with the :py:meth:`~xarray.Dataset.chunk` method: + +.. ipython:: python + + rechunked = ds.chunk({"latitude": 100, "longitude": 100}) + +.. warning:: + + Rechunking an existing dask array created with :py:func:`~xarray.open_mfdataset` + is not recommended (see above). + +You can view the size of existing chunks on an array by viewing the +:py:attr:`~xarray.Dataset.chunks` attribute: + +.. ipython:: python + + rechunked.chunks + +If there are not consistent chunksizes between all the arrays in a dataset +along a particular dimension, an exception is raised when you try to access +``.chunks``. + +.. note:: + + In the future, we would like to enable automatic alignment of Dask + chunksizes (but not the other way around). We might also require that all + arrays in a dataset share the same chunking alignment. Neither of these + are currently done. + +NumPy ufuncs like ``np.sin`` transparently work on all xarray objects, including those +that store lazy Dask arrays: + +.. ipython:: python + + import numpy as np + + np.sin(rechunked) + +To access Dask arrays directly, use the +:py:attr:`DataArray.data ` attribute. This attribute exposes +array data either as a Dask array or as a NumPy array, depending on whether it has been +loaded into Dask or not: + +.. ipython:: python + + ds.temperature.data + +.. note:: + + ``.data`` is also used to expose other "computable" array backends beyond Dask and + NumPy (e.g. sparse and pint arrays). + +.. _dask.automatic-parallelization: + +Automatic parallelization with ``apply_ufunc`` and ``map_blocks`` +----------------------------------------------------------------- + +Almost all of xarray's built-in operations work on Dask arrays. If you want to +use a function that isn't wrapped by xarray, and have it applied in parallel on +each block of your xarray object, you have three options: + +1. Extract Dask arrays from xarray objects (``.data``) and use Dask directly. +2. Use :py:func:`~xarray.apply_ufunc` to apply functions that consume and return NumPy arrays. +3. Use :py:func:`~xarray.map_blocks`, :py:meth:`Dataset.map_blocks` or :py:meth:`DataArray.map_blocks` + to apply functions that consume and return xarray objects. + + +``apply_ufunc`` +~~~~~~~~~~~~~~~ + +:py:func:`~xarray.apply_ufunc` automates `embarrassingly parallel +`__ "map" type operations +where a function written for processing NumPy arrays should be repeatedly +applied to xarray objects containing Dask arrays. It works similarly to +:py:func:`dask.array.map_blocks` and :py:func:`dask.array.blockwise`, but without +requiring an intermediate layer of abstraction. + +For the best performance when using Dask's multi-threaded scheduler, wrap a +function that already releases the global interpreter lock, which fortunately +already includes most NumPy and Scipy functions. Here we show an example +using NumPy operations and a fast function from +`bottleneck `__, which +we use to calculate `Spearman's rank-correlation coefficient `__: + +.. code-block:: python + + import numpy as np + import xarray as xr + import bottleneck + + + def covariance_gufunc(x, y): + return ( + (x - x.mean(axis=-1, keepdims=True)) * (y - y.mean(axis=-1, keepdims=True)) + ).mean(axis=-1) + + + def pearson_correlation_gufunc(x, y): + return covariance_gufunc(x, y) / (x.std(axis=-1) * y.std(axis=-1)) + + + def spearman_correlation_gufunc(x, y): + x_ranks = bottleneck.rankdata(x, axis=-1) + y_ranks = bottleneck.rankdata(y, axis=-1) + return pearson_correlation_gufunc(x_ranks, y_ranks) + + + def spearman_correlation(x, y, dim): + return xr.apply_ufunc( + spearman_correlation_gufunc, + x, + y, + input_core_dims=[[dim], [dim]], + dask="parallelized", + output_dtypes=[float], + ) + +The only aspect of this example that is different from standard usage of +``apply_ufunc()`` is that we needed to supply the ``output_dtypes`` arguments. +(Read up on :ref:`comput.wrapping-custom` for an explanation of the +"core dimensions" listed in ``input_core_dims``.) + +Our new ``spearman_correlation()`` function achieves near linear speedup +when run on large arrays across the four cores on my laptop. It would also +work as a streaming operation, when run on arrays loaded from disk: + +.. ipython:: + :verbatim: + + In [56]: rs = np.random.RandomState(0) + + In [57]: array1 = xr.DataArray(rs.randn(1000, 100000), dims=["place", "time"]) # 800MB + + In [58]: array2 = array1 + 0.5 * rs.randn(1000, 100000) + + # using one core, on NumPy arrays + In [61]: %time _ = spearman_correlation(array1, array2, 'time') + CPU times: user 21.6 s, sys: 2.84 s, total: 24.5 s + Wall time: 24.9 s + + In [8]: chunked1 = array1.chunk({"place": 10}) + + In [9]: chunked2 = array2.chunk({"place": 10}) + + # using all my laptop's cores, with Dask + In [63]: r = spearman_correlation(chunked1, chunked2, "time").compute() + + In [64]: %time _ = r.compute() + CPU times: user 30.9 s, sys: 1.74 s, total: 32.6 s + Wall time: 4.59 s + +One limitation of ``apply_ufunc()`` is that it cannot be applied to arrays with +multiple chunks along a core dimension: + +.. ipython:: + :verbatim: + + In [63]: spearman_correlation(chunked1, chunked2, "place") + ValueError: dimension 'place' on 0th function argument to apply_ufunc with + dask='parallelized' consists of multiple chunks, but is also a core + dimension. To fix, rechunk into a single Dask array chunk along this + dimension, i.e., ``.rechunk({'place': -1})``, but beware that this may + significantly increase memory usage. + +This reflects the nature of core dimensions, in contrast to broadcast (non-core) +dimensions that allow operations to be split into arbitrary chunks for +application. + +.. tip:: + + For the majority of NumPy functions that are already wrapped by Dask, it's + usually a better idea to use the pre-existing ``dask.array`` function, by + using either a pre-existing xarray methods or + :py:func:`~xarray.apply_ufunc()` with ``dask='allowed'``. Dask can often + have a more efficient implementation that makes use of the specialized + structure of a problem, unlike the generic speedups offered by + ``dask='parallelized'``. + + +``map_blocks`` +~~~~~~~~~~~~~~ + +Functions that consume and return xarray objects can be easily applied in parallel using :py:func:`map_blocks`. +Your function will receive an xarray Dataset or DataArray subset to one chunk +along each chunked dimension. + +.. ipython:: python + + ds.temperature + +This DataArray has 3 chunks each with length 10 along the time dimension. +At compute time, a function applied with :py:func:`map_blocks` will receive a DataArray corresponding to a single block of shape 10x180x180 +(time x latitude x longitude) with values loaded. The following snippet illustrates how to check the shape of the object +received by the applied function. + +.. ipython:: python + + def func(da): + print(da.sizes) + return da.time + + + mapped = xr.map_blocks(func, ds.temperature) + mapped + +Notice that the :py:meth:`map_blocks` call printed +``Frozen({'time': 0, 'latitude': 0, 'longitude': 0})`` to screen. +``func`` is received 0-sized blocks! :py:meth:`map_blocks` needs to know what the final result +looks like in terms of dimensions, shapes etc. It does so by running the provided function on 0-shaped +inputs (*automated inference*). This works in many cases, but not all. If automatic inference does not +work for your function, provide the ``template`` kwarg (see below). + +In this case, automatic inference has worked so let's check that the result is as expected. + +.. ipython:: python + + mapped.load(scheduler="single-threaded") + mapped.identical(ds.time) + +Note that we use ``.load(scheduler="single-threaded")`` to execute the computation. +This executes the Dask graph in `serial` using a for loop, but allows for printing to screen and other +debugging techniques. We can easily see that our function is receiving blocks of shape 10x180x180 and +the returned result is identical to ``ds.time`` as expected. + + +Here is a common example where automated inference will not work. + +.. ipython:: python + :okexcept: + + def func(da): + print(da.sizes) + return da.isel(time=[1]) + + + mapped = xr.map_blocks(func, ds.temperature) + +``func`` cannot be run on 0-shaped inputs because it is not possible to extract element 1 along a +dimension of size 0. In this case we need to tell :py:func:`map_blocks` what the returned result looks +like using the ``template`` kwarg. ``template`` must be an xarray Dataset or DataArray (depending on +what the function returns) with dimensions, shapes, chunk sizes, attributes, coordinate variables *and* data +variables that look exactly like the expected result. The variables should be dask-backed and hence not +incur much memory cost. + +.. note:: + + Note that when ``template`` is provided, ``attrs`` from ``template`` are copied over to the result. Any + ``attrs`` set in ``func`` will be ignored. + + +.. ipython:: python + + template = ds.temperature.isel(time=[1, 11, 21]) + mapped = xr.map_blocks(func, ds.temperature, template=template) + + +Notice that the 0-shaped sizes were not printed to screen. Since ``template`` has been provided +:py:func:`map_blocks` does not need to infer it by running ``func`` on 0-shaped inputs. + +.. ipython:: python + + mapped.identical(template) + + +:py:func:`map_blocks` also allows passing ``args`` and ``kwargs`` down to the user function ``func``. +``func`` will be executed as ``func(block_xarray, *args, **kwargs)`` so ``args`` must be a list and ``kwargs`` must be a dictionary. + +.. ipython:: python + + def func(obj, a, b=0): + return obj + a + b + + + mapped = ds.map_blocks(func, args=[10], kwargs={"b": 10}) + expected = ds + 10 + 10 + mapped.identical(expected) + +.. ipython:: python + :suppress: + + ds.close() # Closes "example-data.nc". + os.remove("example-data.nc") + +.. tip:: + + As :py:func:`map_blocks` loads each block into memory, reduce as much as possible objects consumed by user functions. + For example, drop useless variables before calling ``func`` with :py:func:`map_blocks`. + + + +Chunking and performance +------------------------ + +The ``chunks`` parameter has critical performance implications when using Dask +arrays. If your chunks are too small, queueing up operations will be extremely +slow, because Dask will translate each operation into a huge number of +operations mapped across chunks. Computation on Dask arrays with small chunks +can also be slow, because each operation on a chunk has some fixed overhead from +the Python interpreter and the Dask task executor. + +Conversely, if your chunks are too big, some of your computation may be wasted, +because Dask only computes results one chunk at a time. + +A good rule of thumb is to create arrays with a minimum chunksize of at least +one million elements (e.g., a 1000x1000 matrix). With large arrays (10+ GB), the +cost of queueing up Dask operations can be noticeable, and you may need even +larger chunksizes. + +.. tip:: + + Check out the `dask documentation on chunks `_. + + +Optimization Tips +----------------- + +With analysis pipelines involving both spatial subsetting and temporal resampling, Dask performance +can become very slow or memory hungry in certain cases. Here are some optimization tips we have found +through experience: + +1. Do your spatial and temporal indexing (e.g. ``.sel()`` or ``.isel()``) early in the pipeline, especially before calling ``resample()`` or ``groupby()``. Grouping and resampling triggers some computation on all the blocks, which in theory should commute with indexing, but this optimization hasn't been implemented in Dask yet. (See `Dask issue #746 `_). + +2. More generally, ``groupby()`` is a costly operation and will perform a lot better if the ``flox`` package is installed. + See the `flox documentation `_ for more. By default Xarray will use ``flox`` if installed. + +3. Save intermediate results to disk as a netCDF files (using ``to_netcdf()``) and then load them again with ``open_dataset()`` for further computations. For example, if subtracting temporal mean from a dataset, save the temporal mean to disk before subtracting. Again, in theory, Dask should be able to do the computation in a streaming fashion, but in practice this is a fail case for the Dask scheduler, because it tries to keep every chunk of an array that it computes in memory. (See `Dask issue #874 `_) + +4. Specify smaller chunks across space when using :py:meth:`~xarray.open_mfdataset` (e.g., ``chunks={'latitude': 10, 'longitude': 10}``). This makes spatial subsetting easier, because there's no risk you will load subsets of data which span multiple chunks. On individual files, prefer to subset before chunking (suggestion 1). + +5. Chunk as early as possible, and avoid rechunking as much as possible. Always pass the ``chunks={}`` argument to :py:func:`~xarray.open_mfdataset` to avoid redundant file reads. + +6. Using the h5netcdf package by passing ``engine='h5netcdf'`` to :py:meth:`~xarray.open_mfdataset` can be quicker than the default ``engine='netcdf4'`` that uses the netCDF4 package. + +7. Find `best practices specific to Dask arrays in the documentation `_. + +8. The `dask diagnostics `_ can be useful in identifying performance bottlenecks. diff --git a/test/fixtures/whole_applications/xarray/doc/user-guide/data-structures.rst b/test/fixtures/whole_applications/xarray/doc/user-guide/data-structures.rst new file mode 100644 index 0000000..a1794f4 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/user-guide/data-structures.rst @@ -0,0 +1,646 @@ +.. _data structures: + +Data Structures +=============== + +.. ipython:: python + :suppress: + + import numpy as np + import pandas as pd + import xarray as xr + + np.random.seed(123456) + np.set_printoptions(threshold=10) + +DataArray +--------- + +:py:class:`xarray.DataArray` is xarray's implementation of a labeled, +multi-dimensional array. It has several key properties: + +- ``values``: a :py:class:`numpy.ndarray` or + :ref:`numpy-like array ` holding the array's values +- ``dims``: dimension names for each axis (e.g., ``('x', 'y', 'z')``) +- ``coords``: a dict-like container of arrays (*coordinates*) that label each + point (e.g., 1-dimensional arrays of numbers, datetime objects or + strings) +- ``attrs``: :py:class:`dict` to hold arbitrary metadata (*attributes*) + +Xarray uses ``dims`` and ``coords`` to enable its core metadata aware operations. +Dimensions provide names that xarray uses instead of the ``axis`` argument found +in many numpy functions. Coordinates enable fast label based indexing and +alignment, building on the functionality of the ``index`` found on a pandas +:py:class:`~pandas.DataFrame` or :py:class:`~pandas.Series`. + +DataArray objects also can have a ``name`` and can hold arbitrary metadata in +the form of their ``attrs`` property. Names and attributes are strictly for +users and user-written code: xarray makes no attempt to interpret them, and +propagates them only in unambiguous cases +(see FAQ, :ref:`approach to metadata`). + +.. _creating a dataarray: + +Creating a DataArray +~~~~~~~~~~~~~~~~~~~~ + +The :py:class:`~xarray.DataArray` constructor takes: + +- ``data``: a multi-dimensional array of values (e.g., a numpy ndarray, + a :ref:`numpy-like array `, :py:class:`~pandas.Series`, + :py:class:`~pandas.DataFrame` or ``pandas.Panel``) +- ``coords``: a list or dictionary of coordinates. If a list, it should be a + list of tuples where the first element is the dimension name and the second + element is the corresponding coordinate array_like object. +- ``dims``: a list of dimension names. If omitted and ``coords`` is a list of + tuples, dimension names are taken from ``coords``. +- ``attrs``: a dictionary of attributes to add to the instance +- ``name``: a string that names the instance + +.. ipython:: python + + data = np.random.rand(4, 3) + locs = ["IA", "IL", "IN"] + times = pd.date_range("2000-01-01", periods=4) + foo = xr.DataArray(data, coords=[times, locs], dims=["time", "space"]) + foo + +Only ``data`` is required; all of other arguments will be filled +in with default values: + +.. ipython:: python + + xr.DataArray(data) + +As you can see, dimension names are always present in the xarray data model: if +you do not provide them, defaults of the form ``dim_N`` will be created. +However, coordinates are always optional, and dimensions do not have automatic +coordinate labels. + +.. note:: + + This is different from pandas, where axes always have tick labels, which + default to the integers ``[0, ..., n-1]``. + + Prior to xarray v0.9, xarray copied this behavior: default coordinates for + each dimension would be created if coordinates were not supplied explicitly. + This is no longer the case. + +Coordinates can be specified in the following ways: + +- A list of values with length equal to the number of dimensions, providing + coordinate labels for each dimension. Each value must be of one of the + following forms: + + * A :py:class:`~xarray.DataArray` or :py:class:`~xarray.Variable` + * A tuple of the form ``(dims, data[, attrs])``, which is converted into + arguments for :py:class:`~xarray.Variable` + * A pandas object or scalar value, which is converted into a ``DataArray`` + * A 1D array or list, which is interpreted as values for a one dimensional + coordinate variable along the same dimension as it's name + +- A dictionary of ``{coord_name: coord}`` where values are of the same form + as the list. Supplying coordinates as a dictionary allows other coordinates + than those corresponding to dimensions (more on these later). If you supply + ``coords`` as a dictionary, you must explicitly provide ``dims``. + +As a list of tuples: + +.. ipython:: python + + xr.DataArray(data, coords=[("time", times), ("space", locs)]) + +As a dictionary: + +.. ipython:: python + + xr.DataArray( + data, + coords={ + "time": times, + "space": locs, + "const": 42, + "ranking": ("space", [1, 2, 3]), + }, + dims=["time", "space"], + ) + +As a dictionary with coords across multiple dimensions: + +.. ipython:: python + + xr.DataArray( + data, + coords={ + "time": times, + "space": locs, + "const": 42, + "ranking": (("time", "space"), np.arange(12).reshape(4, 3)), + }, + dims=["time", "space"], + ) + +If you create a ``DataArray`` by supplying a pandas +:py:class:`~pandas.Series`, :py:class:`~pandas.DataFrame` or +``pandas.Panel``, any non-specified arguments in the +``DataArray`` constructor will be filled in from the pandas object: + +.. ipython:: python + + df = pd.DataFrame({"x": [0, 1], "y": [2, 3]}, index=["a", "b"]) + df.index.name = "abc" + df.columns.name = "xyz" + df + xr.DataArray(df) + +DataArray properties +~~~~~~~~~~~~~~~~~~~~ + +Let's take a look at the important properties on our array: + +.. ipython:: python + + foo.values + foo.dims + foo.coords + foo.attrs + print(foo.name) + +You can modify ``values`` inplace: + +.. ipython:: python + + foo.values = 1.0 * foo.values + +.. note:: + + The array values in a :py:class:`~xarray.DataArray` have a single + (homogeneous) data type. To work with heterogeneous or structured data + types in xarray, use coordinates, or put separate ``DataArray`` objects + in a single :py:class:`~xarray.Dataset` (see below). + +Now fill in some of that missing metadata: + +.. ipython:: python + + foo.name = "foo" + foo.attrs["units"] = "meters" + foo + +The :py:meth:`~xarray.DataArray.rename` method is another option, returning a +new data array: + +.. ipython:: python + + foo.rename("bar") + +DataArray Coordinates +~~~~~~~~~~~~~~~~~~~~~ + +The ``coords`` property is ``dict`` like. Individual coordinates can be +accessed from the coordinates by name, or even by indexing the data array +itself: + +.. ipython:: python + + foo.coords["time"] + foo["time"] + +These are also :py:class:`~xarray.DataArray` objects, which contain tick-labels +for each dimension. + +Coordinates can also be set or removed by using the dictionary like syntax: + +.. ipython:: python + + foo["ranking"] = ("space", [1, 2, 3]) + foo.coords + del foo["ranking"] + foo.coords + +For more details, see :ref:`coordinates` below. + +Dataset +------- + +:py:class:`xarray.Dataset` is xarray's multi-dimensional equivalent of a +:py:class:`~pandas.DataFrame`. It is a dict-like +container of labeled arrays (:py:class:`~xarray.DataArray` objects) with aligned +dimensions. It is designed as an in-memory representation of the data model +from the `netCDF`__ file format. + +__ https://www.unidata.ucar.edu/software/netcdf/ + +In addition to the dict-like interface of the dataset itself, which can be used +to access any variable in a dataset, datasets have four key properties: + +- ``dims``: a dictionary mapping from dimension names to the fixed length of + each dimension (e.g., ``{'x': 6, 'y': 6, 'time': 8}``) +- ``data_vars``: a dict-like container of DataArrays corresponding to variables +- ``coords``: another dict-like container of DataArrays intended to label points + used in ``data_vars`` (e.g., arrays of numbers, datetime objects or strings) +- ``attrs``: :py:class:`dict` to hold arbitrary metadata + +The distinction between whether a variable falls in data or coordinates +(borrowed from `CF conventions`_) is mostly semantic, and you can probably get +away with ignoring it if you like: dictionary like access on a dataset will +supply variables found in either category. However, xarray does make use of the +distinction for indexing and computations. Coordinates indicate +constant/fixed/independent quantities, unlike the varying/measured/dependent +quantities that belong in data. + +.. _CF conventions: https://cfconventions.org/ + +Here is an example of how we might structure a dataset for a weather forecast: + +.. image:: ../_static/dataset-diagram.png + +In this example, it would be natural to call ``temperature`` and +``precipitation`` "data variables" and all the other arrays "coordinate +variables" because they label the points along the dimensions. (see [1]_ for +more background on this example). + +.. _dataarray constructor: + +Creating a Dataset +~~~~~~~~~~~~~~~~~~ + +To make an :py:class:`~xarray.Dataset` from scratch, supply dictionaries for any +variables (``data_vars``), coordinates (``coords``) and attributes (``attrs``). + +- ``data_vars`` should be a dictionary with each key as the name of the variable + and each value as one of: + + * A :py:class:`~xarray.DataArray` or :py:class:`~xarray.Variable` + * A tuple of the form ``(dims, data[, attrs])``, which is converted into + arguments for :py:class:`~xarray.Variable` + * A pandas object, which is converted into a ``DataArray`` + * A 1D array or list, which is interpreted as values for a one dimensional + coordinate variable along the same dimension as it's name + +- ``coords`` should be a dictionary of the same form as ``data_vars``. + +- ``attrs`` should be a dictionary. + +Let's create some fake data for the example we show above. In this +example dataset, we will represent measurements of the temperature and +pressure that were made under various conditions: + +* the measurements were made on four different days; +* they were made at two separate locations, which we will represent using + their latitude and longitude; and +* they were made using instruments by three different manufacutrers, which we + will refer to as `'manufac1'`, `'manufac2'`, and `'manufac3'`. + +.. ipython:: python + + np.random.seed(0) + temperature = 15 + 8 * np.random.randn(2, 3, 4) + precipitation = 10 * np.random.rand(2, 3, 4) + lon = [-99.83, -99.32] + lat = [42.25, 42.21] + instruments = ["manufac1", "manufac2", "manufac3"] + time = pd.date_range("2014-09-06", periods=4) + reference_time = pd.Timestamp("2014-09-05") + + # for real use cases, its good practice to supply array attributes such as + # units, but we won't bother here for the sake of brevity + ds = xr.Dataset( + { + "temperature": (["loc", "instrument", "time"], temperature), + "precipitation": (["loc", "instrument", "time"], precipitation), + }, + coords={ + "lon": (["loc"], lon), + "lat": (["loc"], lat), + "instrument": instruments, + "time": time, + "reference_time": reference_time, + }, + ) + ds + +Here we pass :py:class:`xarray.DataArray` objects or a pandas object as values +in the dictionary: + +.. ipython:: python + + xr.Dataset(dict(bar=foo)) + + +.. ipython:: python + + xr.Dataset(dict(bar=foo.to_pandas())) + +Where a pandas object is supplied as a value, the names of its indexes are used as dimension +names, and its data is aligned to any existing dimensions. + +You can also create an dataset from: + +- A :py:class:`pandas.DataFrame` or ``pandas.Panel`` along its columns and items + respectively, by passing it into the :py:class:`~xarray.Dataset` directly +- A :py:class:`pandas.DataFrame` with :py:meth:`Dataset.from_dataframe `, + which will additionally handle MultiIndexes See :ref:`pandas` +- A netCDF file on disk with :py:func:`~xarray.open_dataset`. See :ref:`io`. + +Dataset contents +~~~~~~~~~~~~~~~~ + +:py:class:`~xarray.Dataset` implements the Python mapping interface, with +values given by :py:class:`xarray.DataArray` objects: + +.. ipython:: python + + "temperature" in ds + ds["temperature"] + +Valid keys include each listed coordinate and data variable. + +Data and coordinate variables are also contained separately in the +:py:attr:`~xarray.Dataset.data_vars` and :py:attr:`~xarray.Dataset.coords` +dictionary-like attributes: + +.. ipython:: python + + ds.data_vars + ds.coords + +Finally, like data arrays, datasets also store arbitrary metadata in the form +of `attributes`: + +.. ipython:: python + + ds.attrs + + ds.attrs["title"] = "example attribute" + ds + +Xarray does not enforce any restrictions on attributes, but serialization to +some file formats may fail if you use objects that are not strings, numbers +or :py:class:`numpy.ndarray` objects. + +As a useful shortcut, you can use attribute style access for reading (but not +setting) variables and attributes: + +.. ipython:: python + + ds.temperature + +This is particularly useful in an exploratory context, because you can +tab-complete these variable names with tools like IPython. + +.. _dictionary_like_methods: + +Dictionary like methods +~~~~~~~~~~~~~~~~~~~~~~~ + +We can update a dataset in-place using Python's standard dictionary syntax. For +example, to create this example dataset from scratch, we could have written: + +.. ipython:: python + + ds = xr.Dataset() + ds["temperature"] = (("loc", "instrument", "time"), temperature) + ds["temperature_double"] = (("loc", "instrument", "time"), temperature * 2) + ds["precipitation"] = (("loc", "instrument", "time"), precipitation) + ds.coords["lat"] = (("loc",), lat) + ds.coords["lon"] = (("loc",), lon) + ds.coords["time"] = pd.date_range("2014-09-06", periods=4) + ds.coords["reference_time"] = pd.Timestamp("2014-09-05") + +To change the variables in a ``Dataset``, you can use all the standard dictionary +methods, including ``values``, ``items``, ``__delitem__``, ``get`` and +:py:meth:`~xarray.Dataset.update`. Note that assigning a ``DataArray`` or pandas +object to a ``Dataset`` variable using ``__setitem__`` or ``update`` will +:ref:`automatically align` the array(s) to the original +dataset's indexes. + +You can copy a ``Dataset`` by calling the :py:meth:`~xarray.Dataset.copy` +method. By default, the copy is shallow, so only the container will be copied: +the arrays in the ``Dataset`` will still be stored in the same underlying +:py:class:`numpy.ndarray` objects. You can copy all data by calling +``ds.copy(deep=True)``. + +.. _transforming datasets: + +Transforming datasets +~~~~~~~~~~~~~~~~~~~~~ + +In addition to dictionary-like methods (described above), xarray has additional +methods (like pandas) for transforming datasets into new objects. + +For removing variables, you can select and drop an explicit list of +variables by indexing with a list of names or using the +:py:meth:`~xarray.Dataset.drop_vars` methods to return a new ``Dataset``. These +operations keep around coordinates: + +.. ipython:: python + + ds[["temperature"]] + ds[["temperature", "temperature_double"]] + ds.drop_vars("temperature") + +To remove a dimension, you can use :py:meth:`~xarray.Dataset.drop_dims` method. +Any variables using that dimension are dropped: + +.. ipython:: python + + ds.drop_dims("time") + +As an alternate to dictionary-like modifications, you can use +:py:meth:`~xarray.Dataset.assign` and :py:meth:`~xarray.Dataset.assign_coords`. +These methods return a new dataset with additional (or replaced) values: + +.. ipython:: python + + ds.assign(temperature2=2 * ds.temperature) + +There is also the :py:meth:`~xarray.Dataset.pipe` method that allows you to use +a method call with an external function (e.g., ``ds.pipe(func)``) instead of +simply calling it (e.g., ``func(ds)``). This allows you to write pipelines for +transforming your data (using "method chaining") instead of writing hard to +follow nested function calls: + +.. ipython:: python + + # these lines are equivalent, but with pipe we can make the logic flow + # entirely from left to right + plt.plot((2 * ds.temperature.sel(loc=0)).mean("instrument")) + (ds.temperature.sel(loc=0).pipe(lambda x: 2 * x).mean("instrument").pipe(plt.plot)) + +Both ``pipe`` and ``assign`` replicate the pandas methods of the same names +(:py:meth:`DataFrame.pipe ` and +:py:meth:`DataFrame.assign `). + +With xarray, there is no performance penalty for creating new datasets, even if +variables are lazily loaded from a file on disk. Creating new objects instead +of mutating existing objects often results in easier to understand code, so we +encourage using this approach. + +Renaming variables +~~~~~~~~~~~~~~~~~~ + +Another useful option is the :py:meth:`~xarray.Dataset.rename` method to rename +dataset variables: + +.. ipython:: python + + ds.rename({"temperature": "temp", "precipitation": "precip"}) + +The related :py:meth:`~xarray.Dataset.swap_dims` method allows you do to swap +dimension and non-dimension variables: + +.. ipython:: python + + ds.coords["day"] = ("time", [6, 7, 8, 9]) + ds.swap_dims({"time": "day"}) + +.. _coordinates: + +Coordinates +----------- + +Coordinates are ancillary variables stored for ``DataArray`` and ``Dataset`` +objects in the ``coords`` attribute: + +.. ipython:: python + + ds.coords + +Unlike attributes, xarray *does* interpret and persist coordinates in +operations that transform xarray objects. There are two types of coordinates +in xarray: + +- **dimension coordinates** are one dimensional coordinates with a name equal + to their sole dimension (marked by ``*`` when printing a dataset or data + array). They are used for label based indexing and alignment, + like the ``index`` found on a pandas :py:class:`~pandas.DataFrame` or + :py:class:`~pandas.Series`. Indeed, these "dimension" coordinates use a + :py:class:`pandas.Index` internally to store their values. + +- **non-dimension coordinates** are variables that contain coordinate + data, but are not a dimension coordinate. They can be multidimensional (see + :ref:`/examples/multidimensional-coords.ipynb`), and there is no + relationship between the name of a non-dimension coordinate and the + name(s) of its dimension(s). Non-dimension coordinates can be + useful for indexing or plotting; otherwise, xarray does not make any + direct use of the values associated with them. They are not used + for alignment or automatic indexing, nor are they required to match + when doing arithmetic (see :ref:`coordinates math`). + +.. note:: + + Xarray's terminology differs from the `CF terminology`_, where the + "dimension coordinates" are called "coordinate variables", and the + "non-dimension coordinates" are called "auxiliary coordinate variables" + (see :issue:`1295` for more details). + +.. _CF terminology: https://cfconventions.org/cf-conventions/v1.6.0/cf-conventions.html#terminology + + +Modifying coordinates +~~~~~~~~~~~~~~~~~~~~~ + +To entirely add or remove coordinate arrays, you can use dictionary like +syntax, as shown above. + +To convert back and forth between data and coordinates, you can use the +:py:meth:`~xarray.Dataset.set_coords` and +:py:meth:`~xarray.Dataset.reset_coords` methods: + +.. ipython:: python + + ds.reset_coords() + ds.set_coords(["temperature", "precipitation"]) + ds["temperature"].reset_coords(drop=True) + +Notice that these operations skip coordinates with names given by dimensions, +as used for indexing. This mostly because we are not entirely sure how to +design the interface around the fact that xarray cannot store a coordinate and +variable with the name but different values in the same dictionary. But we do +recognize that supporting something like this would be useful. + +Coordinates methods +~~~~~~~~~~~~~~~~~~~ + +``Coordinates`` objects also have a few useful methods, mostly for converting +them into dataset objects: + +.. ipython:: python + + ds.coords.to_dataset() + +The merge method is particularly interesting, because it implements the same +logic used for merging coordinates in arithmetic operations +(see :ref:`comput`): + +.. ipython:: python + + alt = xr.Dataset(coords={"z": [10], "lat": 0, "lon": 0}) + ds.coords.merge(alt.coords) + +The ``coords.merge`` method may be useful if you want to implement your own +binary operations that act on xarray objects. In the future, we hope to write +more helper functions so that you can easily make your functions act like +xarray's built-in arithmetic. + +Indexes +~~~~~~~ + +To convert a coordinate (or any ``DataArray``) into an actual +:py:class:`pandas.Index`, use the :py:meth:`~xarray.DataArray.to_index` method: + +.. ipython:: python + + ds["time"].to_index() + +A useful shortcut is the ``indexes`` property (on both ``DataArray`` and +``Dataset``), which lazily constructs a dictionary whose keys are given by each +dimension and whose the values are ``Index`` objects: + +.. ipython:: python + + ds.indexes + +MultiIndex coordinates +~~~~~~~~~~~~~~~~~~~~~~ + +Xarray supports labeling coordinate values with a :py:class:`pandas.MultiIndex`: + +.. ipython:: python + + midx = pd.MultiIndex.from_arrays( + [["R", "R", "V", "V"], [0.1, 0.2, 0.7, 0.9]], names=("band", "wn") + ) + mda = xr.DataArray(np.random.rand(4), coords={"spec": midx}, dims="spec") + mda + +For convenience multi-index levels are directly accessible as "virtual" or +"derived" coordinates (marked by ``-`` when printing a dataset or data array): + +.. ipython:: python + + mda["band"] + mda.wn + +Indexing with multi-index levels is also possible using the ``sel`` method +(see :ref:`multi-level indexing`). + +Unlike other coordinates, "virtual" level coordinates are not stored in +the ``coords`` attribute of ``DataArray`` and ``Dataset`` objects +(although they are shown when printing the ``coords`` attribute). +Consequently, most of the coordinates related methods don't apply for them. +It also can't be used to replace one particular level. + +Because in a ``DataArray`` or ``Dataset`` object each multi-index level is +accessible as a "virtual" coordinate, its name must not conflict with the names +of the other levels, coordinates and data variables of the same object. +Even though xarray sets default names for multi-indexes with unnamed levels, +it is recommended that you explicitly set the names of the levels. + +.. [1] Latitude and longitude are 2D arrays because the dataset uses + `projected coordinates`__. ``reference_time`` refers to the reference time + at which the forecast was made, rather than ``time`` which is the valid time + for which the forecast applies. + +__ https://en.wikipedia.org/wiki/Map_projection diff --git a/test/fixtures/whole_applications/xarray/doc/user-guide/duckarrays.rst b/test/fixtures/whole_applications/xarray/doc/user-guide/duckarrays.rst new file mode 100644 index 0000000..f0650ac --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/user-guide/duckarrays.rst @@ -0,0 +1,226 @@ +.. currentmodule:: xarray + +.. _userguide.duckarrays: + +Working with numpy-like arrays +============================== + +NumPy-like arrays (often known as :term:`duck array`\s) are drop-in replacements for the :py:class:`numpy.ndarray` +class but with different features, such as propagating physical units or a different layout in memory. +Xarray can often wrap these array types, allowing you to use labelled dimensions and indexes whilst benefiting from the +additional features of these array libraries. + +Some numpy-like array types that xarray already has some support for: + +* `Cupy `_ - GPU support (see `cupy-xarray `_), +* `Sparse `_ - for performant arrays with many zero elements, +* `Pint `_ - for tracking the physical units of your data (see `pint-xarray `_), +* `Dask `_ - parallel computing on larger-than-memory arrays (see :ref:`using dask with xarray `), +* `Cubed `_ - another parallel computing framework that emphasises reliability (see `cubed-xarray `_). + +.. warning:: + + This feature should be considered somewhat experimental. Please report any bugs you find on + `xarray’s issue tracker `_. + +.. note:: + + For information on wrapping dask arrays see :ref:`dask`. Whilst xarray wraps dask arrays in a similar way to that + described on this page, chunked array types like :py:class:`dask.array.Array` implement additional methods that require + slightly different user code (e.g. calling ``.chunk`` or ``.compute``). See the docs on :ref:`wrapping chunked arrays `. + +Why "duck"? +----------- + +Why is it also called a "duck" array? This comes from a common statement of object-oriented programming - +"If it walks like a duck, and quacks like a duck, treat it like a duck". In other words, a library like xarray that +is capable of using multiple different types of arrays does not have to explicitly check that each one it encounters is +permitted (e.g. ``if dask``, ``if numpy``, ``if sparse`` etc.). Instead xarray can take the more permissive approach of simply +treating the wrapped array as valid, attempting to call the relevant methods (e.g. ``.mean()``) and only raising an +error if a problem occurs (e.g. the method is not found on the wrapped class). This is much more flexible, and allows +objects and classes from different libraries to work together more easily. + +What is a numpy-like array? +--------------------------- + +A "numpy-like array" (also known as a "duck array") is a class that contains array-like data, and implements key +numpy-like functionality such as indexing, broadcasting, and computation methods. + +For example, the `sparse `_ library provides a sparse array type which is useful for representing nD array objects like sparse matrices +in a memory-efficient manner. We can create a sparse array object (of the :py:class:`sparse.COO` type) from a numpy array like this: + +.. ipython:: python + + from sparse import COO + + x = np.eye(4, dtype=np.uint8) # create diagonal identity matrix + s = COO.from_numpy(x) + s + +This sparse object does not attempt to explicitly store every element in the array, only the non-zero elements. +This approach is much more efficient for large arrays with only a few non-zero elements (such as tri-diagonal matrices). +Sparse array objects can be converted back to a "dense" numpy array by calling :py:meth:`sparse.COO.todense`. + +Just like :py:class:`numpy.ndarray` objects, :py:class:`sparse.COO` arrays support indexing + +.. ipython:: python + + s[1, 1] # diagonal elements should be ones + s[2, 3] # off-diagonal elements should be zero + +broadcasting, + +.. ipython:: python + + x2 = np.zeros( + (4, 1), dtype=np.uint8 + ) # create second sparse array of different shape + s2 = COO.from_numpy(x2) + (s * s2) # multiplication requires broadcasting + +and various computation methods + +.. ipython:: python + + s.sum(axis=1) + +This numpy-like array also supports calling so-called `numpy ufuncs `_ +("universal functions") on it directly: + +.. ipython:: python + + np.sum(s, axis=1) + + +Notice that in each case the API for calling the operation on the sparse array is identical to that of calling it on the +equivalent numpy array - this is the sense in which the sparse array is "numpy-like". + +.. note:: + + For discussion on exactly which methods a class needs to implement to be considered "numpy-like", see :ref:`internals.duckarrays`. + +Wrapping numpy-like arrays in xarray +------------------------------------ + +:py:class:`DataArray`, :py:class:`Dataset`, and :py:class:`Variable` objects can wrap these numpy-like arrays. + +Constructing xarray objects which wrap numpy-like arrays +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The primary way to create an xarray object which wraps a numpy-like array is to pass that numpy-like array instance directly +to the constructor of the xarray class. The :ref:`page on xarray data structures ` shows how :py:class:`DataArray` and :py:class:`Dataset` +both accept data in various forms through their ``data`` argument, but in fact this data can also be any wrappable numpy-like array. + +For example, we can wrap the sparse array we created earlier inside a new DataArray object: + +.. ipython:: python + + s_da = xr.DataArray(s, dims=["i", "j"]) + s_da + +We can see what's inside - the printable representation of our xarray object (the repr) automatically uses the printable +representation of the underlying wrapped array. + +Of course our sparse array object is still there underneath - it's stored under the ``.data`` attribute of the dataarray: + +.. ipython:: python + + s_da.data + +Array methods +~~~~~~~~~~~~~ + +We saw above that numpy-like arrays provide numpy methods. Xarray automatically uses these when you call the corresponding xarray method: + +.. ipython:: python + + s_da.sum(dim="j") + +Converting wrapped types +~~~~~~~~~~~~~~~~~~~~~~~~ + +If you want to change the type inside your xarray object you can use :py:meth:`DataArray.as_numpy`: + +.. ipython:: python + + s_da.as_numpy() + +This returns a new :py:class:`DataArray` object, but now wrapping a normal numpy array. + +If instead you want to convert to numpy and return that numpy array you can use either :py:meth:`DataArray.to_numpy` or +:py:meth:`DataArray.values`, where the former is strongly preferred. The difference is in the way they coerce to numpy - :py:meth:`~DataArray.values` +always uses :py:func:`numpy.asarray` which will fail for some array types (e.g. ``cupy``), whereas :py:meth:`~DataArray.to_numpy` +uses the correct method depending on the array type. + +.. ipython:: python + + s_da.to_numpy() + +.. ipython:: python + :okexcept: + + s_da.values + +This illustrates the difference between :py:meth:`~DataArray.data` and :py:meth:`~DataArray.values`, +which is sometimes a point of confusion for new xarray users. +Explicitly: :py:meth:`DataArray.data` returns the underlying numpy-like array, regardless of type, whereas +:py:meth:`DataArray.values` converts the underlying array to a numpy array before returning it. +(This is another reason to use :py:meth:`~DataArray.to_numpy` over :py:meth:`~DataArray.values` - the intention is clearer.) + +Conversion to numpy as a fallback +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If a wrapped array does not implement the corresponding array method then xarray will often attempt to convert the +underlying array to a numpy array so that the operation can be performed. You may want to watch out for this behavior, +and report any instances in which it causes problems. + +Most of xarray's API does support using :term:`duck array` objects, but there are a few areas where +the code will still convert to ``numpy`` arrays: + +- Dimension coordinates, and thus all indexing operations: + + * :py:meth:`Dataset.sel` and :py:meth:`DataArray.sel` + * :py:meth:`Dataset.loc` and :py:meth:`DataArray.loc` + * :py:meth:`Dataset.drop_sel` and :py:meth:`DataArray.drop_sel` + * :py:meth:`Dataset.reindex`, :py:meth:`Dataset.reindex_like`, + :py:meth:`DataArray.reindex` and :py:meth:`DataArray.reindex_like`: duck arrays in + data variables and non-dimension coordinates won't be casted + +- Functions and methods that depend on external libraries or features of ``numpy`` not + covered by ``__array_function__`` / ``__array_ufunc__``: + + * :py:meth:`Dataset.ffill` and :py:meth:`DataArray.ffill` (uses ``bottleneck``) + * :py:meth:`Dataset.bfill` and :py:meth:`DataArray.bfill` (uses ``bottleneck``) + * :py:meth:`Dataset.interp`, :py:meth:`Dataset.interp_like`, + :py:meth:`DataArray.interp` and :py:meth:`DataArray.interp_like` (uses ``scipy``): + duck arrays in data variables and non-dimension coordinates will be casted in + addition to not supporting duck arrays in dimension coordinates + * :py:meth:`Dataset.rolling` and :py:meth:`DataArray.rolling` (requires ``numpy>=1.20``) + * :py:meth:`Dataset.rolling_exp` and :py:meth:`DataArray.rolling_exp` (uses + ``numbagg``) + * :py:meth:`Dataset.interpolate_na` and :py:meth:`DataArray.interpolate_na` (uses + :py:class:`numpy.vectorize`) + * :py:func:`apply_ufunc` with ``vectorize=True`` (uses :py:class:`numpy.vectorize`) + +- Incompatibilities between different :term:`duck array` libraries: + + * :py:meth:`Dataset.chunk` and :py:meth:`DataArray.chunk`: this fails if the data was + not already chunked and the :term:`duck array` (e.g. a ``pint`` quantity) should + wrap the new ``dask`` array; changing the chunk sizes works however. + +Extensions using duck arrays +---------------------------- + +Whilst the features above allow many numpy-like array libraries to be used pretty seamlessly with xarray, it often also +makes sense to use an interfacing package to make certain tasks easier. + +For example the `pint-xarray package `_ offers a custom ``.pint`` accessor (see :ref:`internals.accessors`) which provides +convenient access to information stored within the wrapped array (e.g. ``.units`` and ``.magnitude``), and makes makes +creating wrapped pint arrays (and especially xarray-wrapping-pint-wrapping-dask arrays) simpler for the user. + +We maintain a list of libraries extending ``xarray`` to make working with particular wrapped duck arrays +easier. If you know of more that aren't on this list please raise an issue to add them! + +- `pint-xarray `_ +- `cupy-xarray `_ +- `cubed-xarray `_ diff --git a/test/fixtures/whole_applications/xarray/doc/user-guide/groupby.rst b/test/fixtures/whole_applications/xarray/doc/user-guide/groupby.rst new file mode 100644 index 0000000..1ad2d52 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/user-guide/groupby.rst @@ -0,0 +1,234 @@ +.. _groupby: + +GroupBy: Group and Bin Data +--------------------------- + +Often we want to bin or group data, produce statistics (mean, variance) on +the groups, and then return a reduced data set. To do this, Xarray supports +`"group by"`__ operations with the same API as pandas to implement the +`split-apply-combine`__ strategy: + +__ https://pandas.pydata.org/pandas-docs/stable/groupby.html +__ https://www.jstatsoft.org/v40/i01/paper + +- Split your data into multiple independent groups. +- Apply some function to each group. +- Combine your groups back into a single data object. + +Group by operations work on both :py:class:`~xarray.Dataset` and +:py:class:`~xarray.DataArray` objects. Most of the examples focus on grouping by +a single one-dimensional variable, although support for grouping +over a multi-dimensional variable has recently been implemented. Note that for +one-dimensional data, it is usually faster to rely on pandas' implementation of +the same pipeline. + +.. tip:: + + To substantially improve the performance of GroupBy operations, particularly + with dask `install the flox package `_. flox + `extends Xarray's in-built GroupBy capabilities `_ + by allowing grouping by multiple variables, and lazy grouping by dask arrays. If installed, Xarray will automatically use flox by default. + +Split +~~~~~ + +Let's create a simple example dataset: + +.. ipython:: python + :suppress: + + import numpy as np + import pandas as pd + import xarray as xr + + np.random.seed(123456) + +.. ipython:: python + + ds = xr.Dataset( + {"foo": (("x", "y"), np.random.rand(4, 3))}, + coords={"x": [10, 20, 30, 40], "letters": ("x", list("abba"))}, + ) + arr = ds["foo"] + ds + +If we groupby the name of a variable or coordinate in a dataset (we can also +use a DataArray directly), we get back a ``GroupBy`` object: + +.. ipython:: python + + ds.groupby("letters") + +This object works very similarly to a pandas GroupBy object. You can view +the group indices with the ``groups`` attribute: + +.. ipython:: python + + ds.groupby("letters").groups + +You can also iterate over groups in ``(label, group)`` pairs: + +.. ipython:: python + + list(ds.groupby("letters")) + +You can index out a particular group: + +.. ipython:: python + + ds.groupby("letters")["b"] + +Just like in pandas, creating a GroupBy object is cheap: it does not actually +split the data until you access particular values. + +Binning +~~~~~~~ + +Sometimes you don't want to use all the unique values to determine the groups +but instead want to "bin" the data into coarser groups. You could always create +a customized coordinate, but xarray facilitates this via the +:py:meth:`~xarray.Dataset.groupby_bins` method. + +.. ipython:: python + + x_bins = [0, 25, 50] + ds.groupby_bins("x", x_bins).groups + +The binning is implemented via :func:`pandas.cut`, whose documentation details how +the bins are assigned. As seen in the example above, by default, the bins are +labeled with strings using set notation to precisely identify the bin limits. To +override this behavior, you can specify the bin labels explicitly. Here we +choose `float` labels which identify the bin centers: + +.. ipython:: python + + x_bin_labels = [12.5, 37.5] + ds.groupby_bins("x", x_bins, labels=x_bin_labels).groups + + +Apply +~~~~~ + +To apply a function to each group, you can use the flexible +:py:meth:`~xarray.core.groupby.DatasetGroupBy.map` method. The resulting objects are automatically +concatenated back together along the group axis: + +.. ipython:: python + + def standardize(x): + return (x - x.mean()) / x.std() + + + arr.groupby("letters").map(standardize) + +GroupBy objects also have a :py:meth:`~xarray.core.groupby.DatasetGroupBy.reduce` method and +methods like :py:meth:`~xarray.core.groupby.DatasetGroupBy.mean` as shortcuts for applying an +aggregation function: + +.. ipython:: python + + arr.groupby("letters").mean(dim="x") + +Using a groupby is thus also a convenient shortcut for aggregating over all +dimensions *other than* the provided one: + +.. ipython:: python + + ds.groupby("x").std(...) + +.. note:: + + We use an ellipsis (`...`) here to indicate we want to reduce over all + other dimensions + + +First and last +~~~~~~~~~~~~~~ + +There are two special aggregation operations that are currently only found on +groupby objects: first and last. These provide the first or last example of +values for group along the grouped dimension: + +.. ipython:: python + + ds.groupby("letters").first(...) + +By default, they skip missing values (control this with ``skipna``). + +Grouped arithmetic +~~~~~~~~~~~~~~~~~~ + +GroupBy objects also support a limited set of binary arithmetic operations, as +a shortcut for mapping over all unique labels. Binary arithmetic is supported +for ``(GroupBy, Dataset)`` and ``(GroupBy, DataArray)`` pairs, as long as the +dataset or data array uses the unique grouped values as one of its index +coordinates. For example: + +.. ipython:: python + + alt = arr.groupby("letters").mean(...) + alt + ds.groupby("letters") - alt + +This last line is roughly equivalent to the following:: + + results = [] + for label, group in ds.groupby('letters'): + results.append(group - alt.sel(letters=label)) + xr.concat(results, dim='x') + +Iterating and Squeezing +~~~~~~~~~~~~~~~~~~~~~~~ + +Previously, Xarray defaulted to squeezing out dimensions of size one when iterating over +a GroupBy object. This behaviour is being removed. +You can always squeeze explicitly later with the Dataset or DataArray +:py:meth:`~xarray.DataArray.squeeze` methods. + +.. ipython:: python + + next(iter(arr.groupby("x", squeeze=False))) + + +.. _groupby.multidim: + +Multidimensional Grouping +~~~~~~~~~~~~~~~~~~~~~~~~~ + +Many datasets have a multidimensional coordinate variable (e.g. longitude) +which is different from the logical grid dimensions (e.g. nx, ny). Such +variables are valid under the `CF conventions`__. Xarray supports groupby +operations over multidimensional coordinate variables: + +__ https://cfconventions.org/cf-conventions/v1.6.0/cf-conventions.html#_two_dimensional_latitude_longitude_coordinate_variables + +.. ipython:: python + + da = xr.DataArray( + [[0, 1], [2, 3]], + coords={ + "lon": (["ny", "nx"], [[30, 40], [40, 50]]), + "lat": (["ny", "nx"], [[10, 10], [20, 20]]), + }, + dims=["ny", "nx"], + ) + da + da.groupby("lon").sum(...) + da.groupby("lon").map(lambda x: x - x.mean(), shortcut=False) + +Because multidimensional groups have the ability to generate a very large +number of bins, coarse-binning via :py:meth:`~xarray.Dataset.groupby_bins` +may be desirable: + +.. ipython:: python + + da.groupby_bins("lon", [0, 45, 50]).sum() + +These methods group by `lon` values. It is also possible to groupby each +cell in a grid, regardless of value, by stacking multiple dimensions, +applying your function, and then unstacking the result: + +.. ipython:: python + + stacked = da.stack(gridcell=["ny", "nx"]) + stacked.groupby("gridcell").sum(...).unstack("gridcell") diff --git a/test/fixtures/whole_applications/xarray/doc/user-guide/index.rst b/test/fixtures/whole_applications/xarray/doc/user-guide/index.rst new file mode 100644 index 0000000..45f0ce3 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/user-guide/index.rst @@ -0,0 +1,29 @@ +########### +User Guide +########### + +In this user guide, you will find detailed descriptions and +examples that describe many common tasks that you can accomplish with xarray. + + +.. toctree:: + :maxdepth: 2 + :hidden: + + terminology + data-structures + indexing + interpolation + computation + groupby + reshaping + combining + time-series + weather-climate + pandas + io + dask + plotting + options + testing + duckarrays diff --git a/test/fixtures/whole_applications/xarray/doc/user-guide/indexing.rst b/test/fixtures/whole_applications/xarray/doc/user-guide/indexing.rst new file mode 100644 index 0000000..0f57516 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/user-guide/indexing.rst @@ -0,0 +1,867 @@ +.. _indexing: + +Indexing and selecting data +=========================== + +.. ipython:: python + :suppress: + + import numpy as np + import pandas as pd + import xarray as xr + + np.random.seed(123456) + +Xarray offers extremely flexible indexing routines that combine the best +features of NumPy and pandas for data selection. + +The most basic way to access elements of a :py:class:`~xarray.DataArray` +object is to use Python's ``[]`` syntax, such as ``array[i, j]``, where +``i`` and ``j`` are both integers. +As xarray objects can store coordinates corresponding to each dimension of an +array, label-based indexing similar to ``pandas.DataFrame.loc`` is also possible. +In label-based indexing, the element position ``i`` is automatically +looked-up from the coordinate values. + +Dimensions of xarray objects have names, so you can also lookup the dimensions +by name, instead of remembering their positional order. + +Quick overview +-------------- + +In total, xarray supports four different kinds of indexing, as described +below and summarized in this table: + +.. |br| raw:: html + +
+ ++------------------+--------------+---------------------------------+--------------------------------+ +| Dimension lookup | Index lookup | ``DataArray`` syntax | ``Dataset`` syntax | ++==================+==============+=================================+================================+ +| Positional | By integer | ``da[:, 0]`` | *not available* | ++------------------+--------------+---------------------------------+--------------------------------+ +| Positional | By label | ``da.loc[:, 'IA']`` | *not available* | ++------------------+--------------+---------------------------------+--------------------------------+ +| By name | By integer | ``da.isel(space=0)`` or |br| | ``ds.isel(space=0)`` or |br| | +| | | ``da[dict(space=0)]`` | ``ds[dict(space=0)]`` | ++------------------+--------------+---------------------------------+--------------------------------+ +| By name | By label | ``da.sel(space='IA')`` or |br| | ``ds.sel(space='IA')`` or |br| | +| | | ``da.loc[dict(space='IA')]`` | ``ds.loc[dict(space='IA')]`` | ++------------------+--------------+---------------------------------+--------------------------------+ + +More advanced indexing is also possible for all the methods by +supplying :py:class:`~xarray.DataArray` objects as indexer. +See :ref:`vectorized_indexing` for the details. + + +Positional indexing +------------------- + +Indexing a :py:class:`~xarray.DataArray` directly works (mostly) just like it +does for numpy arrays, except that the returned object is always another +DataArray: + +.. ipython:: python + + da = xr.DataArray( + np.random.rand(4, 3), + [ + ("time", pd.date_range("2000-01-01", periods=4)), + ("space", ["IA", "IL", "IN"]), + ], + ) + da[:2] + da[0, 0] + da[:, [2, 1]] + +Attributes are persisted in all indexing operations. + +.. warning:: + + Positional indexing deviates from the NumPy when indexing with multiple + arrays like ``da[[0, 1], [0, 1]]``, as described in + :ref:`vectorized_indexing`. + +Xarray also supports label-based indexing, just like pandas. Because +we use a :py:class:`pandas.Index` under the hood, label based indexing is very +fast. To do label based indexing, use the :py:attr:`~xarray.DataArray.loc` attribute: + +.. ipython:: python + + da.loc["2000-01-01":"2000-01-02", "IA"] + +In this example, the selected is a subpart of the array +in the range '2000-01-01':'2000-01-02' along the first coordinate `time` +and with 'IA' value from the second coordinate `space`. + +You can perform any of the `label indexing operations supported by pandas`__, +including indexing with individual, slices and lists/arrays of labels, as well as +indexing with boolean arrays. Like pandas, label based indexing in xarray is +*inclusive* of both the start and stop bounds. + +__ https://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-label + +Setting values with label based indexing is also supported: + +.. ipython:: python + + da.loc["2000-01-01", ["IL", "IN"]] = -10 + da + + +Indexing with dimension names +----------------------------- + +With the dimension names, we do not have to rely on dimension order and can +use them explicitly to slice data. There are two ways to do this: + +1. Use the :py:meth:`~xarray.DataArray.sel` and :py:meth:`~xarray.DataArray.isel` + convenience methods: + + .. ipython:: python + + # index by integer array indices + da.isel(space=0, time=slice(None, 2)) + + # index by dimension coordinate labels + da.sel(time=slice("2000-01-01", "2000-01-02")) + +2. Use a dictionary as the argument for array positional or label based array + indexing: + + .. ipython:: python + + # index by integer array indices + da[dict(space=0, time=slice(None, 2))] + + # index by dimension coordinate labels + da.loc[dict(time=slice("2000-01-01", "2000-01-02"))] + +The arguments to these methods can be any objects that could index the array +along the dimension given by the keyword, e.g., labels for an individual value, +:py:class:`Python slice` objects or 1-dimensional arrays. + + +.. note:: + + We would love to be able to do indexing with labeled dimension names inside + brackets, but unfortunately, `Python does not yet support indexing with + keyword arguments`__ like ``da[space=0]`` + +__ https://legacy.python.org/dev/peps/pep-0472/ + + +.. _nearest neighbor lookups: + +Nearest neighbor lookups +------------------------ + +The label based selection methods :py:meth:`~xarray.Dataset.sel`, +:py:meth:`~xarray.Dataset.reindex` and :py:meth:`~xarray.Dataset.reindex_like` all +support ``method`` and ``tolerance`` keyword argument. The method parameter allows for +enabling nearest neighbor (inexact) lookups by use of the methods ``'pad'``, +``'backfill'`` or ``'nearest'``: + +.. ipython:: python + + da = xr.DataArray([1, 2, 3], [("x", [0, 1, 2])]) + da.sel(x=[1.1, 1.9], method="nearest") + da.sel(x=0.1, method="backfill") + da.reindex(x=[0.5, 1, 1.5, 2, 2.5], method="pad") + +Tolerance limits the maximum distance for valid matches with an inexact lookup: + +.. ipython:: python + + da.reindex(x=[1.1, 1.5], method="nearest", tolerance=0.2) + +The method parameter is not yet supported if any of the arguments +to ``.sel()`` is a ``slice`` object: + +.. ipython:: + :verbatim: + + In [1]: da.sel(x=slice(1, 3), method="nearest") + NotImplementedError + +However, you don't need to use ``method`` to do inexact slicing. Slicing +already returns all values inside the range (inclusive), as long as the index +labels are monotonic increasing: + +.. ipython:: python + + da.sel(x=slice(0.9, 3.1)) + +Indexing axes with monotonic decreasing labels also works, as long as the +``slice`` or ``.loc`` arguments are also decreasing: + +.. ipython:: python + + reversed_da = da[::-1] + reversed_da.loc[3.1:0.9] + + +.. note:: + + If you want to interpolate along coordinates rather than looking up the + nearest neighbors, use :py:meth:`~xarray.Dataset.interp` and + :py:meth:`~xarray.Dataset.interp_like`. + See :ref:`interpolation ` for the details. + + +Dataset indexing +---------------- + +We can also use these methods to index all variables in a dataset +simultaneously, returning a new dataset: + +.. ipython:: python + + da = xr.DataArray( + np.random.rand(4, 3), + [ + ("time", pd.date_range("2000-01-01", periods=4)), + ("space", ["IA", "IL", "IN"]), + ], + ) + ds = da.to_dataset(name="foo") + ds.isel(space=[0], time=[0]) + ds.sel(time="2000-01-01") + +Positional indexing on a dataset is not supported because the ordering of +dimensions in a dataset is somewhat ambiguous (it can vary between different +arrays). However, you can do normal indexing with dimension names: + +.. ipython:: python + + ds[dict(space=[0], time=[0])] + ds.loc[dict(time="2000-01-01")] + +Dropping labels and dimensions +------------------------------ + +The :py:meth:`~xarray.Dataset.drop_sel` method returns a new object with the listed +index labels along a dimension dropped: + +.. ipython:: python + + ds.drop_sel(space=["IN", "IL"]) + +``drop_sel`` is both a ``Dataset`` and ``DataArray`` method. + +Use :py:meth:`~xarray.Dataset.drop_dims` to drop a full dimension from a Dataset. +Any variables with these dimensions are also dropped: + +.. ipython:: python + + ds.drop_dims("time") + +.. _masking with where: + +Masking with ``where`` +---------------------- + +Indexing methods on xarray objects generally return a subset of the original data. +However, it is sometimes useful to select an object with the same shape as the +original data, but with some elements masked. To do this type of selection in +xarray, use :py:meth:`~xarray.DataArray.where`: + +.. ipython:: python + + da = xr.DataArray(np.arange(16).reshape(4, 4), dims=["x", "y"]) + da.where(da.x + da.y < 4) + +This is particularly useful for ragged indexing of multi-dimensional data, +e.g., to apply a 2D mask to an image. Note that ``where`` follows all the +usual xarray broadcasting and alignment rules for binary operations (e.g., +``+``) between the object being indexed and the condition, as described in +:ref:`comput`: + +.. ipython:: python + + da.where(da.y < 2) + +By default ``where`` maintains the original size of the data. For cases +where the selected data size is much smaller than the original data, +use of the option ``drop=True`` clips coordinate +elements that are fully masked: + +.. ipython:: python + + da.where(da.y < 2, drop=True) + +.. _selecting values with isin: + +Selecting values with ``isin`` +------------------------------ + +To check whether elements of an xarray object contain a single object, you can +compare with the equality operator ``==`` (e.g., ``arr == 3``). To check +multiple values, use :py:meth:`~xarray.DataArray.isin`: + +.. ipython:: python + + da = xr.DataArray([1, 2, 3, 4, 5], dims=["x"]) + da.isin([2, 4]) + +:py:meth:`~xarray.DataArray.isin` works particularly well with +:py:meth:`~xarray.DataArray.where` to support indexing by arrays that are not +already labels of an array: + +.. ipython:: python + + lookup = xr.DataArray([-1, -2, -3, -4, -5], dims=["x"]) + da.where(lookup.isin([-2, -4]), drop=True) + +However, some caution is in order: when done repeatedly, this type of indexing +is significantly slower than using :py:meth:`~xarray.DataArray.sel`. + +.. _vectorized_indexing: + +Vectorized Indexing +------------------- + +Like numpy and pandas, xarray supports indexing many array elements at once in a +`vectorized` manner. + +If you only provide integers, slices, or unlabeled arrays (array without +dimension names, such as ``np.ndarray``, ``list``, but not +:py:meth:`~xarray.DataArray` or :py:meth:`~xarray.Variable`) indexing can be +understood as orthogonally. Each indexer component selects independently along +the corresponding dimension, similar to how vector indexing works in Fortran or +MATLAB, or after using the :py:func:`numpy.ix_` helper: + +.. ipython:: python + + da = xr.DataArray( + np.arange(12).reshape((3, 4)), + dims=["x", "y"], + coords={"x": [0, 1, 2], "y": ["a", "b", "c", "d"]}, + ) + da + da[[0, 2, 2], [1, 3]] + +For more flexibility, you can supply :py:meth:`~xarray.DataArray` objects +as indexers. +Dimensions on resultant arrays are given by the ordered union of the indexers' +dimensions: + +.. ipython:: python + + ind_x = xr.DataArray([0, 1], dims=["x"]) + ind_y = xr.DataArray([0, 1], dims=["y"]) + da[ind_x, ind_y] # orthogonal indexing + +Slices or sequences/arrays without named-dimensions are treated as if they have +the same dimension which is indexed along: + +.. ipython:: python + + # Because [0, 1] is used to index along dimension 'x', + # it is assumed to have dimension 'x' + da[[0, 1], ind_x] + +Furthermore, you can use multi-dimensional :py:meth:`~xarray.DataArray` +as indexers, where the resultant array dimension is also determined by +indexers' dimension: + +.. ipython:: python + + ind = xr.DataArray([[0, 1], [0, 1]], dims=["a", "b"]) + da[ind] + +Similar to how `NumPy's advanced indexing`_ works, vectorized +indexing for xarray is based on our +:ref:`broadcasting rules `. +See :ref:`indexing.rules` for the complete specification. + +.. _NumPy's advanced indexing: https://numpy.org/doc/stable/reference/arrays.indexing.html + +Vectorized indexing also works with ``isel``, ``loc``, and ``sel``: + +.. ipython:: python + + ind = xr.DataArray([[0, 1], [0, 1]], dims=["a", "b"]) + da.isel(y=ind) # same as da[:, ind] + + ind = xr.DataArray([["a", "b"], ["b", "a"]], dims=["a", "b"]) + da.loc[:, ind] # same as da.sel(y=ind) + +These methods may also be applied to ``Dataset`` objects + +.. ipython:: python + + ds = da.to_dataset(name="bar") + ds.isel(x=xr.DataArray([0, 1, 2], dims=["points"])) + +Vectorized indexing may be used to extract information from the nearest +grid cells of interest, for example, the nearest climate model grid cells +to a collection specified weather station latitudes and longitudes. +To trigger vectorized indexing behavior +you will need to provide the selection dimensions with a new +shared output dimension name. In the example below, the selections +of the closest latitude and longitude are renamed to an output +dimension named "points": + + +.. ipython:: python + + ds = xr.tutorial.open_dataset("air_temperature") + + # Define target latitude and longitude (where weather stations might be) + target_lon = xr.DataArray([200, 201, 202, 205], dims="points") + target_lat = xr.DataArray([31, 41, 42, 42], dims="points") + + # Retrieve data at the grid cells nearest to the target latitudes and longitudes + da = ds["air"].sel(lon=target_lon, lat=target_lat, method="nearest") + da + +.. tip:: + + If you are lazily loading your data from disk, not every form of vectorized + indexing is supported (or if supported, may not be supported efficiently). + You may find increased performance by loading your data into memory first, + e.g., with :py:meth:`~xarray.Dataset.load`. + +.. note:: + + If an indexer is a :py:meth:`~xarray.DataArray`, its coordinates should not + conflict with the selected subpart of the target array (except for the + explicitly indexed dimensions with ``.loc``/``.sel``). + Otherwise, ``IndexError`` will be raised. + + +.. _assigning_values: + +Assigning values with indexing +------------------------------ + +To select and assign values to a portion of a :py:meth:`~xarray.DataArray` you +can use indexing with ``.loc`` : + +.. ipython:: python + + ds = xr.tutorial.open_dataset("air_temperature") + + # add an empty 2D dataarray + ds["empty"] = xr.full_like(ds.air.mean("time"), fill_value=0) + + # modify one grid point using loc() + ds["empty"].loc[dict(lon=260, lat=30)] = 100 + + # modify a 2D region using loc() + lc = ds.coords["lon"] + la = ds.coords["lat"] + ds["empty"].loc[ + dict(lon=lc[(lc > 220) & (lc < 260)], lat=la[(la > 20) & (la < 60)]) + ] = 100 + +or :py:meth:`~xarray.where`: + +.. ipython:: python + + # modify one grid point using xr.where() + ds["empty"] = xr.where( + (ds.coords["lat"] == 20) & (ds.coords["lon"] == 260), 100, ds["empty"] + ) + + # or modify a 2D region using xr.where() + mask = ( + (ds.coords["lat"] > 20) + & (ds.coords["lat"] < 60) + & (ds.coords["lon"] > 220) + & (ds.coords["lon"] < 260) + ) + ds["empty"] = xr.where(mask, 100, ds["empty"]) + + + +Vectorized indexing can also be used to assign values to xarray object. + +.. ipython:: python + + da = xr.DataArray( + np.arange(12).reshape((3, 4)), + dims=["x", "y"], + coords={"x": [0, 1, 2], "y": ["a", "b", "c", "d"]}, + ) + da + da[0] = -1 # assignment with broadcasting + da + + ind_x = xr.DataArray([0, 1], dims=["x"]) + ind_y = xr.DataArray([0, 1], dims=["y"]) + da[ind_x, ind_y] = -2 # assign -2 to (ix, iy) = (0, 0) and (1, 1) + da + + da[ind_x, ind_y] += 100 # increment is also possible + da + +Like ``numpy.ndarray``, value assignment sometimes works differently from what one may expect. + +.. ipython:: python + + da = xr.DataArray([0, 1, 2, 3], dims=["x"]) + ind = xr.DataArray([0, 0, 0], dims=["x"]) + da[ind] -= 1 + da + +Where the 0th element will be subtracted 1 only once. +This is because ``v[0] = v[0] - 1`` is called three times, rather than +``v[0] = v[0] - 1 - 1 - 1``. +See `Assigning values to indexed arrays`__ for the details. + +__ https://numpy.org/doc/stable/user/basics.indexing.html#assigning-values-to-indexed-arrays + + +.. note:: + Dask array does not support value assignment + (see :ref:`dask` for the details). + +.. note:: + + Coordinates in both the left- and right-hand-side arrays should not + conflict with each other. + Otherwise, ``IndexError`` will be raised. + +.. warning:: + + Do not try to assign values when using any of the indexing methods ``isel`` + or ``sel``:: + + # DO NOT do this + da.isel(space=0) = 0 + + Instead, values can be assigned using dictionary-based indexing:: + + da[dict(space=0)] = 0 + + Assigning values with the chained indexing using ``.sel`` or ``.isel`` fails silently. + + .. ipython:: python + + da = xr.DataArray([0, 1, 2, 3], dims=["x"]) + # DO NOT do this + da.isel(x=[0, 1, 2])[1] = -1 + da + +You can also assign values to all variables of a :py:class:`Dataset` at once: + +.. ipython:: python + :okwarning: + + ds_org = xr.tutorial.open_dataset("eraint_uvz").isel( + latitude=slice(56, 59), longitude=slice(255, 258), level=0 + ) + # set all values to 0 + ds = xr.zeros_like(ds_org) + ds + + # by integer + ds[dict(latitude=2, longitude=2)] = 1 + ds["u"] + ds["v"] + + # by label + ds.loc[dict(latitude=47.25, longitude=[11.25, 12])] = 100 + ds["u"] + + # dataset as new values + new_dat = ds_org.loc[dict(latitude=48, longitude=[11.25, 12])] + new_dat + ds.loc[dict(latitude=47.25, longitude=[11.25, 12])] = new_dat + ds["u"] + +The dimensions can differ between the variables in the dataset, but all variables need to have at least the dimensions specified in the indexer dictionary. +The new values must be either a scalar, a :py:class:`DataArray` or a :py:class:`Dataset` itself that contains all variables that also appear in the dataset to be modified. + +.. _more_advanced_indexing: + +More advanced indexing +----------------------- + +The use of :py:meth:`~xarray.DataArray` objects as indexers enables very +flexible indexing. The following is an example of the pointwise indexing: + +.. ipython:: python + + da = xr.DataArray(np.arange(56).reshape((7, 8)), dims=["x", "y"]) + da + da.isel(x=xr.DataArray([0, 1, 6], dims="z"), y=xr.DataArray([0, 1, 0], dims="z")) + + +where three elements at ``(ix, iy) = ((0, 0), (1, 1), (6, 0))`` are selected +and mapped along a new dimension ``z``. + +If you want to add a coordinate to the new dimension ``z``, +you can supply a :py:class:`~xarray.DataArray` with a coordinate, + +.. ipython:: python + + da.isel( + x=xr.DataArray([0, 1, 6], dims="z", coords={"z": ["a", "b", "c"]}), + y=xr.DataArray([0, 1, 0], dims="z"), + ) + +Analogously, label-based pointwise-indexing is also possible by the ``.sel`` +method: + +.. ipython:: python + + da = xr.DataArray( + np.random.rand(4, 3), + [ + ("time", pd.date_range("2000-01-01", periods=4)), + ("space", ["IA", "IL", "IN"]), + ], + ) + times = xr.DataArray( + pd.to_datetime(["2000-01-03", "2000-01-02", "2000-01-01"]), dims="new_time" + ) + da.sel(space=xr.DataArray(["IA", "IL", "IN"], dims=["new_time"]), time=times) + +.. _align and reindex: + +Align and reindex +----------------- + +Xarray's ``reindex``, ``reindex_like`` and ``align`` impose a ``DataArray`` or +``Dataset`` onto a new set of coordinates corresponding to dimensions. The +original values are subset to the index labels still found in the new labels, +and values corresponding to new labels not found in the original object are +in-filled with `NaN`. + +Xarray operations that combine multiple objects generally automatically align +their arguments to share the same indexes. However, manual alignment can be +useful for greater control and for increased performance. + +To reindex a particular dimension, use :py:meth:`~xarray.DataArray.reindex`: + +.. ipython:: python + + da.reindex(space=["IA", "CA"]) + +The :py:meth:`~xarray.DataArray.reindex_like` method is a useful shortcut. +To demonstrate, we will make a subset DataArray with new values: + +.. ipython:: python + + foo = da.rename("foo") + baz = (10 * da[:2, :2]).rename("baz") + baz + +Reindexing ``foo`` with ``baz`` selects out the first two values along each +dimension: + +.. ipython:: python + + foo.reindex_like(baz) + +The opposite operation asks us to reindex to a larger shape, so we fill in +the missing values with `NaN`: + +.. ipython:: python + + baz.reindex_like(foo) + +The :py:func:`~xarray.align` function lets us perform more flexible database-like +``'inner'``, ``'outer'``, ``'left'`` and ``'right'`` joins: + +.. ipython:: python + + xr.align(foo, baz, join="inner") + xr.align(foo, baz, join="outer") + +Both ``reindex_like`` and ``align`` work interchangeably between +:py:class:`~xarray.DataArray` and :py:class:`~xarray.Dataset` objects, and with any number of matching dimension names: + +.. ipython:: python + + ds + ds.reindex_like(baz) + other = xr.DataArray(["a", "b", "c"], dims="other") + # this is a no-op, because there are no shared dimension names + ds.reindex_like(other) + +.. _indexing.missing_coordinates: + +Missing coordinate labels +------------------------- + +Coordinate labels for each dimension are optional (as of xarray v0.9). Label +based indexing with ``.sel`` and ``.loc`` uses standard positional, +integer-based indexing as a fallback for dimensions without a coordinate label: + +.. ipython:: python + + da = xr.DataArray([1, 2, 3], dims="x") + da.sel(x=[0, -1]) + +Alignment between xarray objects where one or both do not have coordinate labels +succeeds only if all dimensions of the same name have the same length. +Otherwise, it raises an informative error: + +.. ipython:: + :verbatim: + + In [62]: xr.align(da, da[:2]) + ValueError: arguments without labels along dimension 'x' cannot be aligned because they have different dimension sizes: {2, 3} + +Underlying Indexes +------------------ + +Xarray uses the :py:class:`pandas.Index` internally to perform indexing +operations. If you need to access the underlying indexes, they are available +through the :py:attr:`~xarray.DataArray.indexes` attribute. + +.. ipython:: python + + da = xr.DataArray( + np.random.rand(4, 3), + [ + ("time", pd.date_range("2000-01-01", periods=4)), + ("space", ["IA", "IL", "IN"]), + ], + ) + da + da.indexes + da.indexes["time"] + +Use :py:meth:`~xarray.DataArray.get_index` to get an index for a dimension, +falling back to a default :py:class:`pandas.RangeIndex` if it has no coordinate +labels: + +.. ipython:: python + + da = xr.DataArray([1, 2, 3], dims="x") + da + da.get_index("x") + + +.. _copies_vs_views: + +Copies vs. Views +---------------- + +Whether array indexing returns a view or a copy of the underlying +data depends on the nature of the labels. + +For positional (integer) +indexing, xarray follows the same `rules`_ as NumPy: + +* Positional indexing with only integers and slices returns a view. +* Positional indexing with arrays or lists returns a copy. + +The rules for label based indexing are more complex: + +* Label-based indexing with only slices returns a view. +* Label-based indexing with arrays returns a copy. +* Label-based indexing with scalars returns a view or a copy, depending + upon if the corresponding positional indexer can be represented as an + integer or a slice object. The exact rules are determined by pandas. + +Whether data is a copy or a view is more predictable in xarray than in pandas, so +unlike pandas, xarray does not produce `SettingWithCopy warnings`_. However, you +should still avoid assignment with chained indexing. + +Note that other operations (such as :py:meth:`~xarray.DataArray.values`) may also return views rather than copies. + +.. _SettingWithCopy warnings: https://pandas.pydata.org/pandas-docs/stable/indexing.html#returning-a-view-versus-a-copy +.. _rules: https://numpy.org/doc/stable/user/basics.copies.html + +.. _multi-level indexing: + +Multi-level indexing +-------------------- + +Just like pandas, advanced indexing on multi-level indexes is possible with +``loc`` and ``sel``. You can slice a multi-index by providing multiple indexers, +i.e., a tuple of slices, labels, list of labels, or any selector allowed by +pandas: + +.. ipython:: python + + midx = pd.MultiIndex.from_product([list("abc"), [0, 1]], names=("one", "two")) + mda = xr.DataArray(np.random.rand(6, 3), [("x", midx), ("y", range(3))]) + mda + mda.sel(x=(list("ab"), [0])) + +You can also select multiple elements by providing a list of labels or tuples or +a slice of tuples: + +.. ipython:: python + + mda.sel(x=[("a", 0), ("b", 1)]) + +Additionally, xarray supports dictionaries: + +.. ipython:: python + + mda.sel(x={"one": "a", "two": 0}) + +For convenience, ``sel`` also accepts multi-index levels directly +as keyword arguments: + +.. ipython:: python + + mda.sel(one="a", two=0) + +Note that using ``sel`` it is not possible to mix a dimension +indexer with level indexers for that dimension +(e.g., ``mda.sel(x={'one': 'a'}, two=0)`` will raise a ``ValueError``). + +Like pandas, xarray handles partial selection on multi-index (level drop). +As shown below, it also renames the dimension / coordinate when the +multi-index is reduced to a single index. + +.. ipython:: python + + mda.loc[{"one": "a"}, ...] + +Unlike pandas, xarray does not guess whether you provide index levels or +dimensions when using ``loc`` in some ambiguous cases. For example, for +``mda.loc[{'one': 'a', 'two': 0}]`` and ``mda.loc['a', 0]`` xarray +always interprets ('one', 'two') and ('a', 0) as the names and +labels of the 1st and 2nd dimension, respectively. You must specify all +dimensions or use the ellipsis in the ``loc`` specifier, e.g. in the example +above, ``mda.loc[{'one': 'a', 'two': 0}, :]`` or ``mda.loc[('a', 0), ...]``. + + +.. _indexing.rules: + +Indexing rules +-------------- + +Here we describe the full rules xarray uses for vectorized indexing. Note that +this is for the purposes of explanation: for the sake of efficiency and to +support various backends, the actual implementation is different. + +0. (Only for label based indexing.) Look up positional indexes along each + dimension from the corresponding :py:class:`pandas.Index`. + +1. A full slice object ``:`` is inserted for each dimension without an indexer. + +2. ``slice`` objects are converted into arrays, given by + ``np.arange(*slice.indices(...))``. + +3. Assume dimension names for array indexers without dimensions, such as + ``np.ndarray`` and ``list``, from the dimensions to be indexed along. + For example, ``v.isel(x=[0, 1])`` is understood as + ``v.isel(x=xr.DataArray([0, 1], dims=['x']))``. + +4. For each variable in a ``Dataset`` or ``DataArray`` (the array and its + coordinates): + + a. Broadcast all relevant indexers based on their dimension names + (see :ref:`compute.broadcasting` for full details). + + b. Index the underling array by the broadcast indexers, using NumPy's + advanced indexing rules. + +5. If any indexer DataArray has coordinates and no coordinate with the + same name exists, attach them to the indexed object. + +.. note:: + + Only 1-dimensional boolean arrays can be used as indexers. diff --git a/test/fixtures/whole_applications/xarray/doc/user-guide/interpolation.rst b/test/fixtures/whole_applications/xarray/doc/user-guide/interpolation.rst new file mode 100644 index 0000000..311e1bf --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/user-guide/interpolation.rst @@ -0,0 +1,331 @@ +.. _interp: + +Interpolating data +================== + +.. ipython:: python + :suppress: + + import numpy as np + import pandas as pd + import xarray as xr + + np.random.seed(123456) + +Xarray offers flexible interpolation routines, which have a similar interface +to our :ref:`indexing `. + +.. note:: + + ``interp`` requires `scipy` installed. + + +Scalar and 1-dimensional interpolation +-------------------------------------- + +Interpolating a :py:class:`~xarray.DataArray` works mostly like labeled +indexing of a :py:class:`~xarray.DataArray`, + +.. ipython:: python + + da = xr.DataArray( + np.sin(0.3 * np.arange(12).reshape(4, 3)), + [("time", np.arange(4)), ("space", [0.1, 0.2, 0.3])], + ) + # label lookup + da.sel(time=3) + + # interpolation + da.interp(time=2.5) + + +Similar to the indexing, :py:meth:`~xarray.DataArray.interp` also accepts an +array-like, which gives the interpolated result as an array. + +.. ipython:: python + + # label lookup + da.sel(time=[2, 3]) + + # interpolation + da.interp(time=[2.5, 3.5]) + +To interpolate data with a :py:doc:`numpy.datetime64 ` coordinate you can pass a string. + +.. ipython:: python + + da_dt64 = xr.DataArray( + [1, 3], [("time", pd.date_range("1/1/2000", "1/3/2000", periods=2))] + ) + da_dt64.interp(time="2000-01-02") + +The interpolated data can be merged into the original :py:class:`~xarray.DataArray` +by specifying the time periods required. + +.. ipython:: python + + da_dt64.interp(time=pd.date_range("1/1/2000", "1/3/2000", periods=3)) + +Interpolation of data indexed by a :py:class:`~xarray.CFTimeIndex` is also +allowed. See :ref:`CFTimeIndex` for examples. + +.. note:: + + Currently, our interpolation only works for regular grids. + Therefore, similarly to :py:meth:`~xarray.DataArray.sel`, + only 1D coordinates along a dimension can be used as the + original coordinate to be interpolated. + + +Multi-dimensional Interpolation +------------------------------- + +Like :py:meth:`~xarray.DataArray.sel`, :py:meth:`~xarray.DataArray.interp` +accepts multiple coordinates. In this case, multidimensional interpolation +is carried out. + +.. ipython:: python + + # label lookup + da.sel(time=2, space=0.1) + + # interpolation + da.interp(time=2.5, space=0.15) + +Array-like coordinates are also accepted: + +.. ipython:: python + + # label lookup + da.sel(time=[2, 3], space=[0.1, 0.2]) + + # interpolation + da.interp(time=[1.5, 2.5], space=[0.15, 0.25]) + + +:py:meth:`~xarray.DataArray.interp_like` method is a useful shortcut. This +method interpolates an xarray object onto the coordinates of another xarray +object. For example, if we want to compute the difference between +two :py:class:`~xarray.DataArray` s (``da`` and ``other``) staying on slightly +different coordinates, + +.. ipython:: python + + other = xr.DataArray( + np.sin(0.4 * np.arange(9).reshape(3, 3)), + [("time", [0.9, 1.9, 2.9]), ("space", [0.15, 0.25, 0.35])], + ) + +it might be a good idea to first interpolate ``da`` so that it will stay on the +same coordinates of ``other``, and then subtract it. +:py:meth:`~xarray.DataArray.interp_like` can be used for such a case, + +.. ipython:: python + + # interpolate da along other's coordinates + interpolated = da.interp_like(other) + interpolated + +It is now possible to safely compute the difference ``other - interpolated``. + + +Interpolation methods +--------------------- + +We use :py:class:`scipy.interpolate.interp1d` for 1-dimensional interpolation. +For multi-dimensional interpolation, an attempt is first made to decompose the +interpolation in a series of 1-dimensional interpolations, in which case +:py:class:`scipy.interpolate.interp1d` is used. If a decomposition cannot be +made (e.g. with advanced interpolation), :py:func:`scipy.interpolate.interpn` is +used. + +The interpolation method can be specified by the optional ``method`` argument. + +.. ipython:: python + + da = xr.DataArray( + np.sin(np.linspace(0, 2 * np.pi, 10)), + dims="x", + coords={"x": np.linspace(0, 1, 10)}, + ) + + da.plot.line("o", label="original") + da.interp(x=np.linspace(0, 1, 100)).plot.line(label="linear (default)") + da.interp(x=np.linspace(0, 1, 100), method="cubic").plot.line(label="cubic") + @savefig interpolation_sample1.png width=4in + plt.legend() + +Additional keyword arguments can be passed to scipy's functions. + +.. ipython:: python + + # fill 0 for the outside of the original coordinates. + da.interp(x=np.linspace(-0.5, 1.5, 10), kwargs={"fill_value": 0.0}) + # 1-dimensional extrapolation + da.interp(x=np.linspace(-0.5, 1.5, 10), kwargs={"fill_value": "extrapolate"}) + # multi-dimensional extrapolation + da = xr.DataArray( + np.sin(0.3 * np.arange(12).reshape(4, 3)), + [("time", np.arange(4)), ("space", [0.1, 0.2, 0.3])], + ) + + da.interp( + time=4, space=np.linspace(-0.1, 0.5, 10), kwargs={"fill_value": "extrapolate"} + ) + + +Advanced Interpolation +---------------------- + +:py:meth:`~xarray.DataArray.interp` accepts :py:class:`~xarray.DataArray` +as similar to :py:meth:`~xarray.DataArray.sel`, which enables us more advanced interpolation. +Based on the dimension of the new coordinate passed to :py:meth:`~xarray.DataArray.interp`, the dimension of the result are determined. + +For example, if you want to interpolate a two dimensional array along a particular dimension, as illustrated below, +you can pass two 1-dimensional :py:class:`~xarray.DataArray` s with +a common dimension as new coordinate. + +.. image:: ../_static/advanced_selection_interpolation.svg + :height: 200px + :width: 400 px + :alt: advanced indexing and interpolation + :align: center + +For example: + +.. ipython:: python + + da = xr.DataArray( + np.sin(0.3 * np.arange(20).reshape(5, 4)), + [("x", np.arange(5)), ("y", [0.1, 0.2, 0.3, 0.4])], + ) + # advanced indexing + x = xr.DataArray([0, 2, 4], dims="z") + y = xr.DataArray([0.1, 0.2, 0.3], dims="z") + da.sel(x=x, y=y) + + # advanced interpolation, without extrapolation + x = xr.DataArray([0.5, 1.5, 2.5, 3.5], dims="z") + y = xr.DataArray([0.15, 0.25, 0.35, 0.45], dims="z") + da.interp(x=x, y=y) + +where values on the original coordinates +``(x, y) = ((0.5, 0.15), (1.5, 0.25), (2.5, 0.35), (3.5, 0.45))`` are obtained +by the 2-dimensional interpolation and mapped along a new dimension ``z``. Since +no keyword arguments are passed to the interpolation routine, no extrapolation +is performed resulting in a ``nan`` value. + +If you want to add a coordinate to the new dimension ``z``, you can supply +:py:class:`~xarray.DataArray` s with a coordinate. Extrapolation can be achieved +by passing additional arguments to SciPy's ``interpnd`` function, + +.. ipython:: python + + x = xr.DataArray([0.5, 1.5, 2.5, 3.5], dims="z", coords={"z": ["a", "b", "c", "d"]}) + y = xr.DataArray( + [0.15, 0.25, 0.35, 0.45], dims="z", coords={"z": ["a", "b", "c", "d"]} + ) + da.interp(x=x, y=y, kwargs={"fill_value": None}) + +For the details of the advanced indexing, +see :ref:`more advanced indexing `. + + +Interpolating arrays with NaN +----------------------------- + +Our :py:meth:`~xarray.DataArray.interp` works with arrays with NaN +the same way that +`scipy.interpolate.interp1d `_ and +`scipy.interpolate.interpn `_ do. +``linear`` and ``nearest`` methods return arrays including NaN, +while other methods such as ``cubic`` or ``quadratic`` return all NaN arrays. + +.. ipython:: python + + da = xr.DataArray([0, 2, np.nan, 3, 3.25], dims="x", coords={"x": range(5)}) + da.interp(x=[0.5, 1.5, 2.5]) + da.interp(x=[0.5, 1.5, 2.5], method="cubic") + +To avoid this, you can drop NaN by :py:meth:`~xarray.DataArray.dropna`, and +then make the interpolation + +.. ipython:: python + + dropped = da.dropna("x") + dropped + dropped.interp(x=[0.5, 1.5, 2.5], method="cubic") + +If NaNs are distributed randomly in your multidimensional array, +dropping all the columns containing more than one NaNs by +:py:meth:`~xarray.DataArray.dropna` may lose a significant amount of information. +In such a case, you can fill NaN by :py:meth:`~xarray.DataArray.interpolate_na`, +which is similar to :py:meth:`pandas.Series.interpolate`. + +.. ipython:: python + + filled = da.interpolate_na(dim="x") + filled + +This fills NaN by interpolating along the specified dimension. +After filling NaNs, you can interpolate: + +.. ipython:: python + + filled.interp(x=[0.5, 1.5, 2.5], method="cubic") + +For the details of :py:meth:`~xarray.DataArray.interpolate_na`, +see :ref:`Missing values `. + + +Example +------- + +Let's see how :py:meth:`~xarray.DataArray.interp` works on real data. + +.. ipython:: python + + # Raw data + ds = xr.tutorial.open_dataset("air_temperature").isel(time=0) + fig, axes = plt.subplots(ncols=2, figsize=(10, 4)) + ds.air.plot(ax=axes[0]) + axes[0].set_title("Raw data") + + # Interpolated data + new_lon = np.linspace(ds.lon[0], ds.lon[-1], ds.sizes["lon"] * 4) + new_lat = np.linspace(ds.lat[0], ds.lat[-1], ds.sizes["lat"] * 4) + dsi = ds.interp(lat=new_lat, lon=new_lon) + dsi.air.plot(ax=axes[1]) + @savefig interpolation_sample3.png width=8in + axes[1].set_title("Interpolated data") + +Our advanced interpolation can be used to remap the data to the new coordinate. +Consider the new coordinates x and z on the two dimensional plane. +The remapping can be done as follows + +.. ipython:: python + + # new coordinate + x = np.linspace(240, 300, 100) + z = np.linspace(20, 70, 100) + # relation between new and original coordinates + lat = xr.DataArray(z, dims=["z"], coords={"z": z}) + lon = xr.DataArray( + (x[:, np.newaxis] - 270) / np.cos(z * np.pi / 180) + 270, + dims=["x", "z"], + coords={"x": x, "z": z}, + ) + + fig, axes = plt.subplots(ncols=2, figsize=(10, 4)) + ds.air.plot(ax=axes[0]) + # draw the new coordinate on the original coordinates. + for idx in [0, 33, 66, 99]: + axes[0].plot(lon.isel(x=idx), lat, "--k") + for idx in [0, 33, 66, 99]: + axes[0].plot(*xr.broadcast(lon.isel(z=idx), lat.isel(z=idx)), "--k") + axes[0].set_title("Raw data") + + dsi = ds.interp(lon=lon, lat=lat) + dsi.air.plot(ax=axes[1]) + @savefig interpolation_sample4.png width=8in + axes[1].set_title("Remapped data") diff --git a/test/fixtures/whole_applications/xarray/doc/user-guide/io.rst b/test/fixtures/whole_applications/xarray/doc/user-guide/io.rst new file mode 100644 index 0000000..fd6b770 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/user-guide/io.rst @@ -0,0 +1,1319 @@ +.. currentmodule:: xarray +.. _io: + +Reading and writing files +========================= + +Xarray supports direct serialization and IO to several file formats, from +simple :ref:`io.pickle` files to the more flexible :ref:`io.netcdf` +format (recommended). + +.. ipython:: python + :suppress: + + import os + + import numpy as np + import pandas as pd + import xarray as xr + + np.random.seed(123456) + +.. _io.netcdf: + +netCDF +------ + +The recommended way to store xarray data structures is `netCDF`__, which +is a binary file format for self-described datasets that originated +in the geosciences. Xarray is based on the netCDF data model, so netCDF files +on disk directly correspond to :py:class:`Dataset` objects (more accurately, +a group in a netCDF file directly corresponds to a :py:class:`Dataset` object. +See :ref:`io.netcdf_groups` for more.) + +NetCDF is supported on almost all platforms, and parsers exist +for the vast majority of scientific programming languages. Recent versions of +netCDF are based on the even more widely used HDF5 file-format. + +__ https://www.unidata.ucar.edu/software/netcdf/ + +.. tip:: + + If you aren't familiar with this data format, the `netCDF FAQ`_ is a good + place to start. + +.. _netCDF FAQ: https://www.unidata.ucar.edu/software/netcdf/docs/faq.html#What-Is-netCDF + +Reading and writing netCDF files with xarray requires scipy, h5netcdf, or the +`netCDF4-Python`__ library to be installed. SciPy only supports reading and writing +of netCDF V3 files. + +__ https://github.com/Unidata/netcdf4-python + +We can save a Dataset to disk using the +:py:meth:`Dataset.to_netcdf` method: + +.. ipython:: python + + ds = xr.Dataset( + {"foo": (("x", "y"), np.random.rand(4, 5))}, + coords={ + "x": [10, 20, 30, 40], + "y": pd.date_range("2000-01-01", periods=5), + "z": ("x", list("abcd")), + }, + ) + + ds.to_netcdf("saved_on_disk.nc") + +By default, the file is saved as netCDF4 (assuming netCDF4-Python is +installed). You can control the format and engine used to write the file with +the ``format`` and ``engine`` arguments. + +.. tip:: + + Using the `h5netcdf `_ package + by passing ``engine='h5netcdf'`` to :py:meth:`open_dataset` can + sometimes be quicker than the default ``engine='netcdf4'`` that uses the + `netCDF4 `_ package. + + +We can load netCDF files to create a new Dataset using +:py:func:`open_dataset`: + +.. ipython:: python + + ds_disk = xr.open_dataset("saved_on_disk.nc") + ds_disk + +.. ipython:: python + :suppress: + + # Close "saved_on_disk.nc", but retain the file until after closing or deleting other + # datasets that will refer to it. + ds_disk.close() + +Similarly, a DataArray can be saved to disk using the +:py:meth:`DataArray.to_netcdf` method, and loaded +from disk using the :py:func:`open_dataarray` function. As netCDF files +correspond to :py:class:`Dataset` objects, these functions internally +convert the ``DataArray`` to a ``Dataset`` before saving, and then convert back +when loading, ensuring that the ``DataArray`` that is loaded is always exactly +the same as the one that was saved. + +A dataset can also be loaded or written to a specific group within a netCDF +file. To load from a group, pass a ``group`` keyword argument to the +``open_dataset`` function. The group can be specified as a path-like +string, e.g., to access subgroup 'bar' within group 'foo' pass +'/foo/bar' as the ``group`` argument. When writing multiple groups in one file, +pass ``mode='a'`` to ``to_netcdf`` to ensure that each call does not delete the +file. + +Data is *always* loaded lazily from netCDF files. You can manipulate, slice and subset +Dataset and DataArray objects, and no array values are loaded into memory until +you try to perform some sort of actual computation. For an example of how these +lazy arrays work, see the OPeNDAP section below. + +There may be minor differences in the :py:class:`Dataset` object returned +when reading a NetCDF file with different engines. + +It is important to note that when you modify values of a Dataset, even one +linked to files on disk, only the in-memory copy you are manipulating in xarray +is modified: the original file on disk is never touched. + +.. tip:: + + Xarray's lazy loading of remote or on-disk datasets is often but not always + desirable. Before performing computationally intense operations, it is + often a good idea to load a Dataset (or DataArray) entirely into memory by + invoking the :py:meth:`Dataset.load` method. + +Datasets have a :py:meth:`Dataset.close` method to close the associated +netCDF file. However, it's often cleaner to use a ``with`` statement: + +.. ipython:: python + + # this automatically closes the dataset after use + with xr.open_dataset("saved_on_disk.nc") as ds: + print(ds.keys()) + +Although xarray provides reasonable support for incremental reads of files on +disk, it does not support incremental writes, which can be a useful strategy +for dealing with datasets too big to fit into memory. Instead, xarray integrates +with dask.array (see :ref:`dask`), which provides a fully featured engine for +streaming computation. + +It is possible to append or overwrite netCDF variables using the ``mode='a'`` +argument. When using this option, all variables in the dataset will be written +to the original netCDF file, regardless if they exist in the original dataset. + + +.. _io.netcdf_groups: + +Groups +~~~~~~ + +NetCDF groups are not supported as part of the :py:class:`Dataset` data model. +Instead, groups can be loaded individually as Dataset objects. +To do so, pass a ``group`` keyword argument to the +:py:func:`open_dataset` function. The group can be specified as a path-like +string, e.g., to access subgroup ``'bar'`` within group ``'foo'`` pass +``'/foo/bar'`` as the ``group`` argument. + +In a similar way, the ``group`` keyword argument can be given to the +:py:meth:`Dataset.to_netcdf` method to write to a group +in a netCDF file. +When writing multiple groups in one file, pass ``mode='a'`` to +:py:meth:`Dataset.to_netcdf` to ensure that each call does not delete the file. +For example: + +.. ipython:: + :verbatim: + + In [1]: ds1 = xr.Dataset({"a": 0}) + + In [2]: ds2 = xr.Dataset({"b": 1}) + + In [3]: ds1.to_netcdf("file.nc", group="A") + + In [4]: ds2.to_netcdf("file.nc", group="B", mode="a") + +We can verify that two groups have been saved using the ncdump command-line utility. + +.. code:: bash + + $ ncdump file.nc + netcdf file { + + group: A { + variables: + int64 a ; + data: + + a = 0 ; + } // group A + + group: B { + variables: + int64 b ; + data: + + b = 1 ; + } // group B + } + +Either of these groups can be loaded from the file as an independent :py:class:`Dataset` object: + +.. ipython:: + :verbatim: + + In [1]: group1 = xr.open_dataset("file.nc", group="A") + + In [2]: group1 + Out[2]: + + Dimensions: () + Data variables: + a int64 ... + + In [3]: group2 = xr.open_dataset("file.nc", group="B") + + In [4]: group2 + Out[4]: + + Dimensions: () + Data variables: + b int64 ... + +.. note:: + + For native handling of multiple groups with xarray, including I/O, you might be interested in the experimental + `xarray-datatree `_ package. + + +.. _io.encoding: + +Reading encoded data +~~~~~~~~~~~~~~~~~~~~ + +NetCDF files follow some conventions for encoding datetime arrays (as numbers +with a "units" attribute) and for packing and unpacking data (as +described by the "scale_factor" and "add_offset" attributes). If the argument +``decode_cf=True`` (default) is given to :py:func:`open_dataset`, xarray will attempt +to automatically decode the values in the netCDF objects according to +`CF conventions`_. Sometimes this will fail, for example, if a variable +has an invalid "units" or "calendar" attribute. For these cases, you can +turn this decoding off manually. + +.. _CF conventions: http://cfconventions.org/ + +You can view this encoding information (among others) in the +:py:attr:`DataArray.encoding` and +:py:attr:`DataArray.encoding` attributes: + +.. ipython:: python + + ds_disk["y"].encoding + ds_disk.encoding + +Note that all operations that manipulate variables other than indexing +will remove encoding information. + +In some cases it is useful to intentionally reset a dataset's original encoding values. +This can be done with either the :py:meth:`Dataset.drop_encoding` or +:py:meth:`DataArray.drop_encoding` methods. + +.. ipython:: python + + ds_no_encoding = ds_disk.drop_encoding() + ds_no_encoding.encoding + +.. _combining multiple files: + +Reading multi-file datasets +........................... + +NetCDF files are often encountered in collections, e.g., with different files +corresponding to different model runs or one file per timestamp. +Xarray can straightforwardly combine such files into a single Dataset by making use of +:py:func:`concat`, :py:func:`merge`, :py:func:`combine_nested` and +:py:func:`combine_by_coords`. For details on the difference between these +functions see :ref:`combining data`. + +Xarray includes support for manipulating datasets that don't fit into memory +with dask_. If you have dask installed, you can open multiple files +simultaneously in parallel using :py:func:`open_mfdataset`:: + + xr.open_mfdataset('my/files/*.nc', parallel=True) + +This function automatically concatenates and merges multiple files into a +single xarray dataset. +It is the recommended way to open multiple files with xarray. +For more details on parallel reading, see :ref:`combining.multi`, :ref:`dask.io` and a +`blog post`_ by Stephan Hoyer. +:py:func:`open_mfdataset` takes many kwargs that allow you to +control its behaviour (for e.g. ``parallel``, ``combine``, ``compat``, ``join``, ``concat_dim``). +See its docstring for more details. + + +.. note:: + + A common use-case involves a dataset distributed across a large number of files with + each file containing a large number of variables. Commonly, a few of these variables + need to be concatenated along a dimension (say ``"time"``), while the rest are equal + across the datasets (ignoring floating point differences). The following command + with suitable modifications (such as ``parallel=True``) works well with such datasets:: + + xr.open_mfdataset('my/files/*.nc', concat_dim="time", combine="nested", + data_vars='minimal', coords='minimal', compat='override') + + This command concatenates variables along the ``"time"`` dimension, but only those that + already contain the ``"time"`` dimension (``data_vars='minimal', coords='minimal'``). + Variables that lack the ``"time"`` dimension are taken from the first dataset + (``compat='override'``). + + +.. _dask: http://dask.org +.. _blog post: http://stephanhoyer.com/2015/06/11/xray-dask-out-of-core-labeled-arrays/ + +Sometimes multi-file datasets are not conveniently organized for easy use of :py:func:`open_mfdataset`. +One can use the ``preprocess`` argument to provide a function that takes a dataset +and returns a modified Dataset. +:py:func:`open_mfdataset` will call ``preprocess`` on every dataset +(corresponding to each file) prior to combining them. + + +If :py:func:`open_mfdataset` does not meet your needs, other approaches are possible. +The general pattern for parallel reading of multiple files +using dask, modifying those datasets and then combining into a single ``Dataset`` is:: + + def modify(ds): + # modify ds here + return ds + + + # this is basically what open_mfdataset does + open_kwargs = dict(decode_cf=True, decode_times=False) + open_tasks = [dask.delayed(xr.open_dataset)(f, **open_kwargs) for f in file_names] + tasks = [dask.delayed(modify)(task) for task in open_tasks] + datasets = dask.compute(tasks) # get a list of xarray.Datasets + combined = xr.combine_nested(datasets) # or some combination of concat, merge + + +As an example, here's how we could approximate ``MFDataset`` from the netCDF4 +library:: + + from glob import glob + import xarray as xr + + def read_netcdfs(files, dim): + # glob expands paths with * to a list of files, like the unix shell + paths = sorted(glob(files)) + datasets = [xr.open_dataset(p) for p in paths] + combined = xr.concat(datasets, dim) + return combined + + combined = read_netcdfs('/all/my/files/*.nc', dim='time') + +This function will work in many cases, but it's not very robust. First, it +never closes files, which means it will fail if you need to load more than +a few thousand files. Second, it assumes that you want all the data from each +file and that it can all fit into memory. In many situations, you only need +a small subset or an aggregated summary of the data from each file. + +Here's a slightly more sophisticated example of how to remedy these +deficiencies:: + + def read_netcdfs(files, dim, transform_func=None): + def process_one_path(path): + # use a context manager, to ensure the file gets closed after use + with xr.open_dataset(path) as ds: + # transform_func should do some sort of selection or + # aggregation + if transform_func is not None: + ds = transform_func(ds) + # load all data from the transformed dataset, to ensure we can + # use it after closing each original file + ds.load() + return ds + + paths = sorted(glob(files)) + datasets = [process_one_path(p) for p in paths] + combined = xr.concat(datasets, dim) + return combined + + # here we suppose we only care about the combined mean of each file; + # you might also use indexing operations like .sel to subset datasets + combined = read_netcdfs('/all/my/files/*.nc', dim='time', + transform_func=lambda ds: ds.mean()) + +This pattern works well and is very robust. We've used similar code to process +tens of thousands of files constituting 100s of GB of data. + + +.. _io.netcdf.writing_encoded: + +Writing encoded data +~~~~~~~~~~~~~~~~~~~~ + +Conversely, you can customize how xarray writes netCDF files on disk by +providing explicit encodings for each dataset variable. The ``encoding`` +argument takes a dictionary with variable names as keys and variable specific +encodings as values. These encodings are saved as attributes on the netCDF +variables on disk, which allows xarray to faithfully read encoded data back into +memory. + +It is important to note that using encodings is entirely optional: if you do not +supply any of these encoding options, xarray will write data to disk using a +default encoding, or the options in the ``encoding`` attribute, if set. +This works perfectly fine in most cases, but encoding can be useful for +additional control, especially for enabling compression. + +In the file on disk, these encodings are saved as attributes on each variable, which +allow xarray and other CF-compliant tools for working with netCDF files to correctly +read the data. + +Scaling and type conversions +............................ + +These encoding options (based on `CF Conventions on packed data`_) work on any +version of the netCDF file format: + +- ``dtype``: Any valid NumPy dtype or string convertible to a dtype, e.g., ``'int16'`` + or ``'float32'``. This controls the type of the data written on disk. +- ``_FillValue``: Values of ``NaN`` in xarray variables are remapped to this value when + saved on disk. This is important when converting floating point with missing values + to integers on disk, because ``NaN`` is not a valid value for integer dtypes. By + default, variables with float types are attributed a ``_FillValue`` of ``NaN`` in the + output file, unless explicitly disabled with an encoding ``{'_FillValue': None}``. +- ``scale_factor`` and ``add_offset``: Used to convert from encoded data on disk to + to the decoded data in memory, according to the formula + ``decoded = scale_factor * encoded + add_offset``. Please note that ``scale_factor`` + and ``add_offset`` must be of same type and determine the type of the decoded data. + +These parameters can be fruitfully combined to compress discretized data on disk. For +example, to save the variable ``foo`` with a precision of 0.1 in 16-bit integers while +converting ``NaN`` to ``-9999``, we would use +``encoding={'foo': {'dtype': 'int16', 'scale_factor': 0.1, '_FillValue': -9999}}``. +Compression and decompression with such discretization is extremely fast. + +.. _CF Conventions on packed data: https://cfconventions.org/cf-conventions/cf-conventions.html#packed-data + +.. _io.string-encoding: + +String encoding +............... + +Xarray can write unicode strings to netCDF files in two ways: + +- As variable length strings. This is only supported on netCDF4 (HDF5) files. +- By encoding strings into bytes, and writing encoded bytes as a character + array. The default encoding is UTF-8. + +By default, we use variable length strings for compatible files and fall-back +to using encoded character arrays. Character arrays can be selected even for +netCDF4 files by setting the ``dtype`` field in ``encoding`` to ``S1`` +(corresponding to NumPy's single-character bytes dtype). + +If character arrays are used: + +- The string encoding that was used is stored on + disk in the ``_Encoding`` attribute, which matches an ad-hoc convention + `adopted by the netCDF4-Python library `_. + At the time of this writing (October 2017), a standard convention for indicating + string encoding for character arrays in netCDF files was + `still under discussion `_. + Technically, you can use + `any string encoding recognized by Python `_ if you feel the need to deviate from UTF-8, + by setting the ``_Encoding`` field in ``encoding``. But + `we don't recommend it `_. +- The character dimension name can be specified by the ``char_dim_name`` field of a variable's + ``encoding``. If the name of the character dimension is not specified, the default is + ``f'string{data.shape[-1]}'``. When decoding character arrays from existing files, the + ``char_dim_name`` is added to the variables ``encoding`` to preserve if encoding happens, but + the field can be edited by the user. + +.. warning:: + + Missing values in bytes or unicode string arrays (represented by ``NaN`` in + xarray) are currently written to disk as empty strings ``''``. This means + missing values will not be restored when data is loaded from disk. + This behavior is likely to change in the future (:issue:`1647`). + Unfortunately, explicitly setting a ``_FillValue`` for string arrays to handle + missing values doesn't work yet either, though we also hope to fix this in the + future. + +Chunk based compression +....................... + +``zlib``, ``complevel``, ``fletcher32``, ``contiguous`` and ``chunksizes`` +can be used for enabling netCDF4/HDF5's chunk based compression, as described +in the `documentation for createVariable`_ for netCDF4-Python. This only works +for netCDF4 files and thus requires using ``format='netCDF4'`` and either +``engine='netcdf4'`` or ``engine='h5netcdf'``. + +.. _documentation for createVariable: https://unidata.github.io/netcdf4-python/#netCDF4.Dataset.createVariable + +Chunk based gzip compression can yield impressive space savings, especially +for sparse data, but it comes with significant performance overhead. HDF5 +libraries can only read complete chunks back into memory, and maximum +decompression speed is in the range of 50-100 MB/s. Worse, HDF5's compression +and decompression currently cannot be parallelized with dask. For these reasons, we +recommend trying discretization based compression (described above) first. + +Time units +.......... + +The ``units`` and ``calendar`` attributes control how xarray serializes ``datetime64`` and +``timedelta64`` arrays to datasets on disk as numeric values. The ``units`` encoding +should be a string like ``'days since 1900-01-01'`` for ``datetime64`` data or a string +like ``'days'`` for ``timedelta64`` data. ``calendar`` should be one of the calendar types +supported by netCDF4-python: 'standard', 'gregorian', 'proleptic_gregorian' 'noleap', +'365_day', '360_day', 'julian', 'all_leap', '366_day'. + +By default, xarray uses the ``'proleptic_gregorian'`` calendar and units of the smallest time +difference between values, with a reference time of the first time value. + + +.. _io.coordinates: + +Coordinates +........... + +You can control the ``coordinates`` attribute written to disk by specifying ``DataArray.encoding["coordinates"]``. +If not specified, xarray automatically sets ``DataArray.encoding["coordinates"]`` to a space-delimited list +of names of coordinate variables that share dimensions with the ``DataArray`` being written. +This allows perfect roundtripping of xarray datasets but may not be desirable. +When an xarray ``Dataset`` contains non-dimensional coordinates that do not share dimensions with any of +the variables, these coordinate variable names are saved under a "global" ``"coordinates"`` attribute. +This is not CF-compliant but again facilitates roundtripping of xarray datasets. + +Invalid netCDF files +~~~~~~~~~~~~~~~~~~~~ + +The library ``h5netcdf`` allows writing some dtypes (booleans, complex, ...) that aren't +allowed in netCDF4 (see +`h5netcdf documentation `_). +This feature is available through :py:meth:`DataArray.to_netcdf` and +:py:meth:`Dataset.to_netcdf` when used with ``engine="h5netcdf"`` +and currently raises a warning unless ``invalid_netcdf=True`` is set: + +.. ipython:: python + :okwarning: + + # Writing complex valued data + da = xr.DataArray([1.0 + 1.0j, 2.0 + 2.0j, 3.0 + 3.0j]) + da.to_netcdf("complex.nc", engine="h5netcdf", invalid_netcdf=True) + + # Reading it back + reopened = xr.open_dataarray("complex.nc", engine="h5netcdf") + reopened + +.. ipython:: python + :suppress: + + reopened.close() + os.remove("complex.nc") + +.. warning:: + + Note that this produces a file that is likely to be not readable by other netCDF + libraries! + +.. _io.hdf5: + +HDF5 +---- +`HDF5`_ is both a file format and a data model for storing information. HDF5 stores +data hierarchically, using groups to create a nested structure. HDF5 is a more +general version of the netCDF4 data model, so the nested structure is one of many +similarities between the two data formats. + +Reading HDF5 files in xarray requires the ``h5netcdf`` engine, which can be installed +with ``conda install h5netcdf``. Once installed we can use xarray to open HDF5 files: + +.. code:: python + + xr.open_dataset("/path/to/my/file.h5") + +The similarities between HDF5 and netCDF4 mean that HDF5 data can be written with the +same :py:meth:`Dataset.to_netcdf` method as used for netCDF4 data: + +.. ipython:: python + + ds = xr.Dataset( + {"foo": (("x", "y"), np.random.rand(4, 5))}, + coords={ + "x": [10, 20, 30, 40], + "y": pd.date_range("2000-01-01", periods=5), + "z": ("x", list("abcd")), + }, + ) + + ds.to_netcdf("saved_on_disk.h5") + +Groups +~~~~~~ + +If you have multiple or highly nested groups, xarray by default may not read the group +that you want. A particular group of an HDF5 file can be specified using the ``group`` +argument: + +.. code:: python + + xr.open_dataset("/path/to/my/file.h5", group="/my/group") + +While xarray cannot interrogate an HDF5 file to determine which groups are available, +the HDF5 Python reader `h5py`_ can be used instead. + +Natively the xarray data structures can only handle one level of nesting, organized as +DataArrays inside of Datasets. If your HDF5 file has additional levels of hierarchy you +can only access one group and a time and will need to specify group names. + +.. note:: + + For native handling of multiple HDF5 groups with xarray, including I/O, you might be + interested in the experimental + `xarray-datatree `_ package. + + +.. _HDF5: https://hdfgroup.github.io/hdf5/index.html +.. _h5py: https://www.h5py.org/ + + +.. _io.zarr: + +Zarr +---- + +`Zarr`_ is a Python package that provides an implementation of chunked, compressed, +N-dimensional arrays. +Zarr has the ability to store arrays in a range of ways, including in memory, +in files, and in cloud-based object storage such as `Amazon S3`_ and +`Google Cloud Storage`_. +Xarray's Zarr backend allows xarray to leverage these capabilities, including +the ability to store and analyze datasets far too large fit onto disk +(particularly :ref:`in combination with dask `). + +Xarray can't open just any zarr dataset, because xarray requires special +metadata (attributes) describing the dataset dimensions and coordinates. +At this time, xarray can only open zarr datasets with these special attributes, +such as zarr datasets written by xarray, +`netCDF `_, +or `GDAL `_. +For implementation details, see :ref:`zarr_encoding`. + +To write a dataset with zarr, we use the :py:meth:`Dataset.to_zarr` method. + +To write to a local directory, we pass a path to a directory: + +.. ipython:: python + :suppress: + + ! rm -rf path/to/directory.zarr + +.. ipython:: python + + ds = xr.Dataset( + {"foo": (("x", "y"), np.random.rand(4, 5))}, + coords={ + "x": [10, 20, 30, 40], + "y": pd.date_range("2000-01-01", periods=5), + "z": ("x", list("abcd")), + }, + ) + ds.to_zarr("path/to/directory.zarr") + +(The suffix ``.zarr`` is optional--just a reminder that a zarr store lives +there.) If the directory does not exist, it will be created. If a zarr +store is already present at that path, an error will be raised, preventing it +from being overwritten. To override this behavior and overwrite an existing +store, add ``mode='w'`` when invoking :py:meth:`~Dataset.to_zarr`. + +DataArrays can also be saved to disk using the :py:meth:`DataArray.to_zarr` method, +and loaded from disk using the :py:func:`open_dataarray` function with `engine='zarr'`. +Similar to :py:meth:`DataArray.to_netcdf`, :py:meth:`DataArray.to_zarr` will +convert the ``DataArray`` to a ``Dataset`` before saving, and then convert back +when loading, ensuring that the ``DataArray`` that is loaded is always exactly +the same as the one that was saved. + +.. note:: + + xarray does not write `NCZarr `_ attributes. + Therefore, NCZarr data must be opened in read-only mode. + +To store variable length strings, convert them to object arrays first with +``dtype=object``. + +To read back a zarr dataset that has been created this way, we use the +:py:func:`open_zarr` method: + +.. ipython:: python + + ds_zarr = xr.open_zarr("path/to/directory.zarr") + ds_zarr + +Cloud Storage Buckets +~~~~~~~~~~~~~~~~~~~~~ + +It is possible to read and write xarray datasets directly from / to cloud +storage buckets using zarr. This example uses the `gcsfs`_ package to provide +an interface to `Google Cloud Storage`_. + +General `fsspec`_ URLs, those that begin with ``s3://`` or ``gcs://`` for example, +are parsed and the store set up for you automatically when reading. +You should include any arguments to the storage backend as the +key ```storage_options``, part of ``backend_kwargs``. + +.. code:: python + + ds_gcs = xr.open_dataset( + "gcs:///path.zarr", + backend_kwargs={ + "storage_options": {"project": "", "token": None} + }, + engine="zarr", + ) + + +This also works with ``open_mfdataset``, allowing you to pass a list of paths or +a URL to be interpreted as a glob string. + +For writing, you must explicitly set up a ``MutableMapping`` +instance and pass this, as follows: + +.. code:: python + + import gcsfs + + fs = gcsfs.GCSFileSystem(project="", token=None) + gcsmap = gcsfs.mapping.GCSMap("", gcs=fs, check=True, create=False) + # write to the bucket + ds.to_zarr(store=gcsmap) + # read it back + ds_gcs = xr.open_zarr(gcsmap) + +(or use the utility function ``fsspec.get_mapper()``). + +.. _fsspec: https://filesystem-spec.readthedocs.io/en/latest/ +.. _Zarr: https://zarr.readthedocs.io/ +.. _Amazon S3: https://aws.amazon.com/s3/ +.. _Google Cloud Storage: https://cloud.google.com/storage/ +.. _gcsfs: https://github.com/fsspec/gcsfs + +Zarr Compressors and Filters +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +There are many different `options for compression and filtering possible with +zarr `_. + +These options can be passed to the ``to_zarr`` method as variable encoding. +For example: + +.. ipython:: python + :suppress: + + ! rm -rf foo.zarr + +.. ipython:: python + + import zarr + + compressor = zarr.Blosc(cname="zstd", clevel=3, shuffle=2) + ds.to_zarr("foo.zarr", encoding={"foo": {"compressor": compressor}}) + +.. note:: + + Not all native zarr compression and filtering options have been tested with + xarray. + +.. _io.zarr.consolidated_metadata: + +Consolidated Metadata +~~~~~~~~~~~~~~~~~~~~~ + +Xarray needs to read all of the zarr metadata when it opens a dataset. +In some storage mediums, such as with cloud object storage (e.g. `Amazon S3`_), +this can introduce significant overhead, because two separate HTTP calls to the +object store must be made for each variable in the dataset. +By default Xarray uses a feature called +*consolidated metadata*, storing all metadata for the entire dataset with a +single key (by default called ``.zmetadata``). This typically drastically speeds +up opening the store. (For more information on this feature, consult the +`zarr docs on consolidating metadata `_.) + +By default, xarray writes consolidated metadata and attempts to read stores +with consolidated metadata, falling back to use non-consolidated metadata for +reads. Because this fall-back option is so much slower, xarray issues a +``RuntimeWarning`` with guidance when reading with consolidated metadata fails: + + Failed to open Zarr store with consolidated metadata, falling back to try + reading non-consolidated metadata. This is typically much slower for + opening a dataset. To silence this warning, consider: + + 1. Consolidating metadata in this existing store with + :py:func:`zarr.consolidate_metadata`. + 2. Explicitly setting ``consolidated=False``, to avoid trying to read + consolidate metadata. + 3. Explicitly setting ``consolidated=True``, to raise an error in this case + instead of falling back to try reading non-consolidated metadata. + +.. _io.zarr.appending: + +Modifying existing Zarr stores +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Xarray supports several ways of incrementally writing variables to a Zarr +store. These options are useful for scenarios when it is infeasible or +undesirable to write your entire dataset at once. + +1. Use ``mode='a'`` to add or overwrite entire variables, +2. Use ``append_dim`` to resize and append to existing variables, and +3. Use ``region`` to write to limited regions of existing arrays. + +.. tip:: + + For ``Dataset`` objects containing dask arrays, a + single call to ``to_zarr()`` will write all of your data in parallel. + +.. warning:: + + Alignment of coordinates is currently not checked when modifying an + existing Zarr store. It is up to the user to ensure that coordinates are + consistent. + +To add or overwrite entire variables, simply call :py:meth:`~Dataset.to_zarr` +with ``mode='a'`` on a Dataset containing the new variables, passing in an +existing Zarr store or path to a Zarr store. + +To resize and then append values along an existing dimension in a store, set +``append_dim``. This is a good option if data always arrives in a particular +order, e.g., for time-stepping a simulation: + +.. ipython:: python + :suppress: + + ! rm -rf path/to/directory.zarr + +.. ipython:: python + + ds1 = xr.Dataset( + {"foo": (("x", "y", "t"), np.random.rand(4, 5, 2))}, + coords={ + "x": [10, 20, 30, 40], + "y": [1, 2, 3, 4, 5], + "t": pd.date_range("2001-01-01", periods=2), + }, + ) + ds1.to_zarr("path/to/directory.zarr") + ds2 = xr.Dataset( + {"foo": (("x", "y", "t"), np.random.rand(4, 5, 2))}, + coords={ + "x": [10, 20, 30, 40], + "y": [1, 2, 3, 4, 5], + "t": pd.date_range("2001-01-03", periods=2), + }, + ) + ds2.to_zarr("path/to/directory.zarr", append_dim="t") + +Finally, you can use ``region`` to write to limited regions of existing arrays +in an existing Zarr store. This is a good option for writing data in parallel +from independent processes. + +To scale this up to writing large datasets, the first step is creating an +initial Zarr store without writing all of its array data. This can be done by +first creating a ``Dataset`` with dummy values stored in :ref:`dask `, +and then calling ``to_zarr`` with ``compute=False`` to write only metadata +(including ``attrs``) to Zarr: + +.. ipython:: python + :suppress: + + ! rm -rf path/to/directory.zarr + +.. ipython:: python + + import dask.array + + # The values of this dask array are entirely irrelevant; only the dtype, + # shape and chunks are used + dummies = dask.array.zeros(30, chunks=10) + ds = xr.Dataset({"foo": ("x", dummies)}, coords={"x": np.arange(30)}) + path = "path/to/directory.zarr" + # Now we write the metadata without computing any array values + ds.to_zarr(path, compute=False) + +Now, a Zarr store with the correct variable shapes and attributes exists that +can be filled out by subsequent calls to ``to_zarr``. +Setting ``region="auto"`` will open the existing store and determine the +correct alignment of the new data with the existing coordinates, or as an +explicit mapping from dimension names to Python ``slice`` objects indicating +where the data should be written (in index space, not label space), e.g., + +.. ipython:: python + + # For convenience, we'll slice a single dataset, but in the real use-case + # we would create them separately possibly even from separate processes. + ds = xr.Dataset({"foo": ("x", np.arange(30))}, coords={"x": np.arange(30)}) + # Any of the following region specifications are valid + ds.isel(x=slice(0, 10)).to_zarr(path, region="auto") + ds.isel(x=slice(10, 20)).to_zarr(path, region={"x": "auto"}) + ds.isel(x=slice(20, 30)).to_zarr(path, region={"x": slice(20, 30)}) + +Concurrent writes with ``region`` are safe as long as they modify distinct +chunks in the underlying Zarr arrays (or use an appropriate ``lock``). + +As a safety check to make it harder to inadvertently override existing values, +if you set ``region`` then *all* variables included in a Dataset must have +dimensions included in ``region``. Other variables (typically coordinates) +need to be explicitly dropped and/or written in a separate calls to ``to_zarr`` +with ``mode='a'``. + +.. _io.zarr.writing_chunks: + +Specifying chunks in a zarr store +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Chunk sizes may be specified in one of three ways when writing to a zarr store: + +1. Manual chunk sizing through the use of the ``encoding`` argument in :py:meth:`Dataset.to_zarr`: +2. Automatic chunking based on chunks in dask arrays +3. Default chunk behavior determined by the zarr library + +The resulting chunks will be determined based on the order of the above list; dask +chunks will be overridden by manually-specified chunks in the encoding argument, +and the presence of either dask chunks or chunks in the ``encoding`` attribute will +supersede the default chunking heuristics in zarr. + +Importantly, this logic applies to every array in the zarr store individually, +including coordinate arrays. Therefore, if a dataset contains one or more dask +arrays, it may still be desirable to specify a chunk size for the coordinate arrays +(for example, with a chunk size of `-1` to include the full coordinate). + +To specify chunks manually using the ``encoding`` argument, provide a nested +dictionary with the structure ``{'variable_or_coord_name': {'chunks': chunks_tuple}}``. + +.. note:: + + The positional ordering of the chunks in the encoding argument must match the + positional ordering of the dimensions in each array. Watch out for arrays with + differently-ordered dimensions within a single Dataset. + +For example, let's say we're working with a dataset with dimensions +``('time', 'x', 'y')``, a variable ``Tair`` which is chunked in ``x`` and ``y``, +and two multi-dimensional coordinates ``xc`` and ``yc``: + +.. ipython:: python + + ds = xr.tutorial.open_dataset("rasm") + + ds["Tair"] = ds["Tair"].chunk({"x": 100, "y": 100}) + + ds + +These multi-dimensional coordinates are only two-dimensional and take up very little +space on disk or in memory, yet when writing to disk the default zarr behavior is to +split them into chunks: + +.. ipython:: python + + ds.to_zarr("path/to/directory.zarr", mode="w") + ! ls -R path/to/directory.zarr + + +This may cause unwanted overhead on some systems, such as when reading from a cloud +storage provider. To disable this chunking, we can specify a chunk size equal to the +length of each dimension by using the shorthand chunk size ``-1``: + +.. ipython:: python + + ds.to_zarr( + "path/to/directory.zarr", + encoding={"xc": {"chunks": (-1, -1)}, "yc": {"chunks": (-1, -1)}}, + mode="w", + ) + ! ls -R path/to/directory.zarr + + +The number of chunks on Tair matches our dask chunks, while there is now only a single +chunk in the directory stores of each coordinate. + +.. _io.iris: + +Iris +---- + +The Iris_ tool allows easy reading of common meteorological and climate model formats +(including GRIB and UK MetOffice PP files) into ``Cube`` objects which are in many ways very +similar to ``DataArray`` objects, while enforcing a CF-compliant data model. If iris is +installed, xarray can convert a ``DataArray`` into a ``Cube`` using +:py:meth:`DataArray.to_iris`: + +.. ipython:: python + + da = xr.DataArray( + np.random.rand(4, 5), + dims=["x", "y"], + coords=dict(x=[10, 20, 30, 40], y=pd.date_range("2000-01-01", periods=5)), + ) + + cube = da.to_iris() + cube + +Conversely, we can create a new ``DataArray`` object from a ``Cube`` using +:py:meth:`DataArray.from_iris`: + +.. ipython:: python + + da_cube = xr.DataArray.from_iris(cube) + da_cube + + +.. _Iris: https://scitools.org.uk/iris + + +OPeNDAP +------- + +Xarray includes support for `OPeNDAP`__ (via the netCDF4 library or Pydap), which +lets us access large datasets over HTTP. + +__ https://www.opendap.org/ + +For example, we can open a connection to GBs of weather data produced by the +`PRISM`__ project, and hosted by `IRI`__ at Columbia: + +__ https://www.prism.oregonstate.edu/ +__ https://iri.columbia.edu/ + +.. ipython source code for this section + we don't use this to avoid hitting the DAP server on every doc build. + + remote_data = xr.open_dataset( + 'http://iridl.ldeo.columbia.edu/SOURCES/.OSU/.PRISM/.monthly/dods', + decode_times=False) + tmax = remote_data.tmax[:500, ::3, ::3] + tmax + + @savefig opendap-prism-tmax.png + tmax[0].plot() + +.. ipython:: + :verbatim: + + In [3]: remote_data = xr.open_dataset( + ...: "http://iridl.ldeo.columbia.edu/SOURCES/.OSU/.PRISM/.monthly/dods", + ...: decode_times=False, + ...: ) + + In [4]: remote_data + Out[4]: + + Dimensions: (T: 1422, X: 1405, Y: 621) + Coordinates: + * X (X) float32 -125.0 -124.958 -124.917 -124.875 -124.833 -124.792 -124.75 ... + * T (T) float32 -779.5 -778.5 -777.5 -776.5 -775.5 -774.5 -773.5 -772.5 -771.5 ... + * Y (Y) float32 49.9167 49.875 49.8333 49.7917 49.75 49.7083 49.6667 49.625 ... + Data variables: + ppt (T, Y, X) float64 ... + tdmean (T, Y, X) float64 ... + tmax (T, Y, X) float64 ... + tmin (T, Y, X) float64 ... + Attributes: + Conventions: IRIDL + expires: 1375315200 + +.. TODO: update this example to show off decode_cf? + +.. note:: + + Like many real-world datasets, this dataset does not entirely follow + `CF conventions`_. Unexpected formats will usually cause xarray's automatic + decoding to fail. The way to work around this is to either set + ``decode_cf=False`` in ``open_dataset`` to turn off all use of CF + conventions, or by only disabling the troublesome parser. + In this case, we set ``decode_times=False`` because the time axis here + provides the calendar attribute in a format that xarray does not expect + (the integer ``360`` instead of a string like ``'360_day'``). + +We can select and slice this data any number of times, and nothing is loaded +over the network until we look at particular values: + +.. ipython:: + :verbatim: + + In [4]: tmax = remote_data["tmax"][:500, ::3, ::3] + + In [5]: tmax + Out[5]: + + [48541500 values with dtype=float64] + Coordinates: + * Y (Y) float32 49.9167 49.7917 49.6667 49.5417 49.4167 49.2917 ... + * X (X) float32 -125.0 -124.875 -124.75 -124.625 -124.5 -124.375 ... + * T (T) float32 -779.5 -778.5 -777.5 -776.5 -775.5 -774.5 -773.5 ... + Attributes: + pointwidth: 120 + standard_name: air_temperature + units: Celsius_scale + expires: 1443657600 + + # the data is downloaded automatically when we make the plot + In [6]: tmax[0].plot() + +.. image:: ../_static/opendap-prism-tmax.png + +Some servers require authentication before we can access the data. For this +purpose we can explicitly create a :py:class:`backends.PydapDataStore` +and pass in a `Requests`__ session object. For example for +HTTP Basic authentication:: + + import xarray as xr + import requests + + session = requests.Session() + session.auth = ('username', 'password') + + store = xr.backends.PydapDataStore.open('http://example.com/data', + session=session) + ds = xr.open_dataset(store) + +`Pydap's cas module`__ has functions that generate custom sessions for +servers that use CAS single sign-on. For example, to connect to servers +that require NASA's URS authentication:: + + import xarray as xr + from pydata.cas.urs import setup_session + + ds_url = 'https://gpm1.gesdisc.eosdis.nasa.gov/opendap/hyrax/example.nc' + + session = setup_session('username', 'password', check_url=ds_url) + store = xr.backends.PydapDataStore.open(ds_url, session=session) + + ds = xr.open_dataset(store) + +__ https://docs.python-requests.org +__ https://www.pydap.org/en/latest/client.html#authentication + +.. _io.pickle: + +Pickle +------ + +The simplest way to serialize an xarray object is to use Python's built-in pickle +module: + +.. ipython:: python + + import pickle + + # use the highest protocol (-1) because it is way faster than the default + # text based pickle format + pkl = pickle.dumps(ds, protocol=-1) + + pickle.loads(pkl) + +Pickling is important because it doesn't require any external libraries +and lets you use xarray objects with Python modules like +:py:mod:`multiprocessing` or :ref:`Dask `. However, pickling is +**not recommended for long-term storage**. + +Restoring a pickle requires that the internal structure of the types for the +pickled data remain unchanged. Because the internal design of xarray is still +being refined, we make no guarantees (at this point) that objects pickled with +this version of xarray will work in future versions. + +.. note:: + + When pickling an object opened from a NetCDF file, the pickle file will + contain a reference to the file on disk. If you want to store the actual + array values, load it into memory first with :py:meth:`Dataset.load` + or :py:meth:`Dataset.compute`. + +.. _dictionary io: + +Dictionary +---------- + +We can convert a ``Dataset`` (or a ``DataArray``) to a dict using +:py:meth:`Dataset.to_dict`: + +.. ipython:: python + + ds = xr.Dataset({"foo": ("x", np.arange(30))}) + ds + + d = ds.to_dict() + d + +We can create a new xarray object from a dict using +:py:meth:`Dataset.from_dict`: + +.. ipython:: python + + ds_dict = xr.Dataset.from_dict(d) + ds_dict + +Dictionary support allows for flexible use of xarray objects. It doesn't +require external libraries and dicts can easily be pickled, or converted to +json, or geojson. All the values are converted to lists, so dicts might +be quite large. + +To export just the dataset schema without the data itself, use the +``data=False`` option: + +.. ipython:: python + + ds.to_dict(data=False) + +.. ipython:: python + :suppress: + + # We're now done with the dataset named `ds`. Although the `with` statement closed + # the dataset, displaying the unpickled pickle of `ds` re-opened "saved_on_disk.nc". + # However, `ds` (rather than the unpickled dataset) refers to the open file. Delete + # `ds` to close the file. + del ds + os.remove("saved_on_disk.nc") + +This can be useful for generating indices of dataset contents to expose to +search indices or other automated data discovery tools. + +.. _io.rasterio: + +Rasterio +-------- + +GDAL readable raster data using `rasterio`_ such as GeoTIFFs can be opened using the `rioxarray`_ extension. +`rioxarray`_ can also handle geospatial related tasks such as re-projecting and clipping. + +.. ipython:: + :verbatim: + + In [1]: import rioxarray + + In [2]: rds = rioxarray.open_rasterio("RGB.byte.tif") + + In [3]: rds + Out[3]: + + [1703814 values with dtype=uint8] + Coordinates: + * band (band) int64 1 2 3 + * y (y) float64 2.827e+06 2.826e+06 ... 2.612e+06 2.612e+06 + * x (x) float64 1.021e+05 1.024e+05 ... 3.389e+05 3.392e+05 + spatial_ref int64 0 + Attributes: + STATISTICS_MAXIMUM: 255 + STATISTICS_MEAN: 29.947726688477 + STATISTICS_MINIMUM: 0 + STATISTICS_STDDEV: 52.340921626611 + transform: (300.0379266750948, 0.0, 101985.0, 0.0, -300.0417827... + _FillValue: 0.0 + scale_factor: 1.0 + add_offset: 0.0 + grid_mapping: spatial_ref + + In [4]: rds.rio.crs + Out[4]: CRS.from_epsg(32618) + + In [5]: rds4326 = rds.rio.reproject("epsg:4326") + + In [6]: rds4326.rio.crs + Out[6]: CRS.from_epsg(4326) + + In [7]: rds4326.rio.to_raster("RGB.byte.4326.tif") + + +.. _rasterio: https://rasterio.readthedocs.io/en/latest/ +.. _rioxarray: https://corteva.github.io/rioxarray/stable/ +.. _test files: https://github.com/rasterio/rasterio/blob/master/tests/data/RGB.byte.tif +.. _pyproj: https://github.com/pyproj4/pyproj + +.. _io.cfgrib: + +.. ipython:: python + :suppress: + + import shutil + + shutil.rmtree("foo.zarr") + shutil.rmtree("path/to/directory.zarr") + +GRIB format via cfgrib +---------------------- + +Xarray supports reading GRIB files via ECMWF cfgrib_ python driver, +if it is installed. To open a GRIB file supply ``engine='cfgrib'`` +to :py:func:`open_dataset` after installing cfgrib_: + +.. ipython:: + :verbatim: + + In [1]: ds_grib = xr.open_dataset("example.grib", engine="cfgrib") + +We recommend installing cfgrib via conda:: + + conda install -c conda-forge cfgrib + +.. _cfgrib: https://github.com/ecmwf/cfgrib + + +CSV and other formats supported by pandas +----------------------------------------- + +For more options (tabular formats and CSV files in particular), consider +exporting your objects to pandas and using its broad range of `IO tools`_. +For CSV files, one might also consider `xarray_extras`_. + +.. _xarray_extras: https://xarray-extras.readthedocs.io/en/latest/api/csv.html + +.. _IO tools: http://pandas.pydata.org/pandas-docs/stable/io.html + + +Third party libraries +--------------------- + +More formats are supported by extension libraries: + +- `xarray-mongodb `_: Store xarray objects on MongoDB diff --git a/test/fixtures/whole_applications/xarray/doc/user-guide/options.rst b/test/fixtures/whole_applications/xarray/doc/user-guide/options.rst new file mode 100644 index 0000000..12844ec --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/user-guide/options.rst @@ -0,0 +1,36 @@ +.. currentmodule:: xarray + +.. _options: + +Configuration +============= + +Xarray offers a small number of configuration options through :py:func:`set_options`. With these, you can + +1. Control the ``repr``: + + - ``display_expand_attrs`` + - ``display_expand_coords`` + - ``display_expand_data`` + - ``display_expand_data_vars`` + - ``display_max_rows`` + - ``display_style`` + +2. Control behaviour during operations: ``arithmetic_join``, ``keep_attrs``, ``use_bottleneck``. +3. Control colormaps for plots:``cmap_divergent``, ``cmap_sequential``. +4. Aspects of file reading: ``file_cache_maxsize``, ``warn_on_unclosed_files``. + + +You can set these options either globally + +:: + + xr.set_options(arithmetic_join="exact") + +or locally as a context manager: + +:: + + with xr.set_options(arithmetic_join="exact"): + # do operation here + pass diff --git a/test/fixtures/whole_applications/xarray/doc/user-guide/pandas.rst b/test/fixtures/whole_applications/xarray/doc/user-guide/pandas.rst new file mode 100644 index 0000000..26fa7ea --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/user-guide/pandas.rst @@ -0,0 +1,265 @@ +.. currentmodule:: xarray +.. _pandas: + +=================== +Working with pandas +=================== + +One of the most important features of xarray is the ability to convert to and +from :py:mod:`pandas` objects to interact with the rest of the PyData +ecosystem. For example, for plotting labeled data, we highly recommend +using the `visualization built in to pandas itself`__ or provided by the pandas +aware libraries such as `Seaborn`__. + +__ https://pandas.pydata.org/pandas-docs/stable/visualization.html +__ https://seaborn.pydata.org/ + +.. ipython:: python + :suppress: + + import numpy as np + import pandas as pd + import xarray as xr + + np.random.seed(123456) + +Hierarchical and tidy data +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Tabular data is easiest to work with when it meets the criteria for +`tidy data`__: + +* Each column holds a different variable. +* Each rows holds a different observation. + +__ https://www.jstatsoft.org/v59/i10/ + +In this "tidy data" format, we can represent any :py:class:`Dataset` and +:py:class:`DataArray` in terms of :py:class:`~pandas.DataFrame` and +:py:class:`~pandas.Series`, respectively (and vice-versa). The representation +works by flattening non-coordinates to 1D, and turning the tensor product of +coordinate indexes into a :py:class:`pandas.MultiIndex`. + +Dataset and DataFrame +--------------------- + +To convert any dataset to a ``DataFrame`` in tidy form, use the +:py:meth:`Dataset.to_dataframe()` method: + +.. ipython:: python + + ds = xr.Dataset( + {"foo": (("x", "y"), np.random.randn(2, 3))}, + coords={ + "x": [10, 20], + "y": ["a", "b", "c"], + "along_x": ("x", np.random.randn(2)), + "scalar": 123, + }, + ) + ds + df = ds.to_dataframe() + df + +We see that each variable and coordinate in the Dataset is now a column in the +DataFrame, with the exception of indexes which are in the index. +To convert the ``DataFrame`` to any other convenient representation, +use ``DataFrame`` methods like :py:meth:`~pandas.DataFrame.reset_index`, +:py:meth:`~pandas.DataFrame.stack` and :py:meth:`~pandas.DataFrame.unstack`. + +For datasets containing dask arrays where the data should be lazily loaded, see the +:py:meth:`Dataset.to_dask_dataframe()` method. + +To create a ``Dataset`` from a ``DataFrame``, use the +:py:meth:`Dataset.from_dataframe` class method or the equivalent +:py:meth:`pandas.DataFrame.to_xarray` method: + +.. ipython:: python + + xr.Dataset.from_dataframe(df) + +Notice that that dimensions of variables in the ``Dataset`` have now +expanded after the round-trip conversion to a ``DataFrame``. This is because +every object in a ``DataFrame`` must have the same indices, so we need to +broadcast the data of each array to the full size of the new ``MultiIndex``. + +Likewise, all the coordinates (other than indexes) ended up as variables, +because pandas does not distinguish non-index coordinates. + +DataArray and Series +-------------------- + +``DataArray`` objects have a complementary representation in terms of a +:py:class:`~pandas.Series`. Using a Series preserves the ``Dataset`` to +``DataArray`` relationship, because ``DataFrames`` are dict-like containers +of ``Series``. The methods are very similar to those for working with +DataFrames: + +.. ipython:: python + + s = ds["foo"].to_series() + s + # or equivalently, with Series.to_xarray() + xr.DataArray.from_series(s) + +Both the ``from_series`` and ``from_dataframe`` methods use reindexing, so they +work even if not the hierarchical index is not a full tensor product: + +.. ipython:: python + + s[::2] + s[::2].to_xarray() + +Lossless and reversible conversion +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The previous ``Dataset`` example shows that the conversion is not reversible (lossy roundtrip) and +that the size of the ``Dataset`` increases. + +Particularly after a roundtrip, the following deviations are noted: + +- a non-dimension Dataset ``coordinate`` is converted into ``variable`` +- a non-dimension DataArray ``coordinate`` is not converted +- ``dtype`` is not allways the same (e.g. "str" is converted to "object") +- ``attrs`` metadata is not conserved + +To avoid these problems, the third-party `ntv-pandas `__ library offers lossless and reversible conversions between +``Dataset``/ ``DataArray`` and pandas ``DataFrame`` objects. + +This solution is particularly interesting for converting any ``DataFrame`` into a ``Dataset`` (the converter find the multidimensional structure hidden by the tabular structure). + +The `ntv-pandas examples `__ show how to improve the conversion for the previous ``Dataset`` example and for more complex examples. + +Multi-dimensional data +~~~~~~~~~~~~~~~~~~~~~~ + +Tidy data is great, but it sometimes you want to preserve dimensions instead of +automatically stacking them into a ``MultiIndex``. + +:py:meth:`DataArray.to_pandas()` is a shortcut that lets you convert a +DataArray directly into a pandas object with the same dimensionality, if +available in pandas (i.e., a 1D array is converted to a +:py:class:`~pandas.Series` and 2D to :py:class:`~pandas.DataFrame`): + +.. ipython:: python + + arr = xr.DataArray( + np.random.randn(2, 3), coords=[("x", [10, 20]), ("y", ["a", "b", "c"])] + ) + df = arr.to_pandas() + df + +To perform the inverse operation of converting any pandas objects into a data +array with the same shape, simply use the :py:class:`DataArray` +constructor: + +.. ipython:: python + + xr.DataArray(df) + +Both the ``DataArray`` and ``Dataset`` constructors directly convert pandas +objects into xarray objects with the same shape. This means that they +preserve all use of multi-indexes: + +.. ipython:: python + + index = pd.MultiIndex.from_arrays( + [["a", "a", "b"], [0, 1, 2]], names=["one", "two"] + ) + df = pd.DataFrame({"x": 1, "y": 2}, index=index) + ds = xr.Dataset(df) + ds + +However, you will need to set dimension names explicitly, either with the +``dims`` argument on in the ``DataArray`` constructor or by calling +:py:class:`~Dataset.rename` on the new object. + +.. _panel transition: + +Transitioning from pandas.Panel to xarray +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +``Panel``, pandas' data structure for 3D arrays, was always a second class +data structure compared to the Series and DataFrame. To allow pandas +developers to focus more on its core functionality built around the +DataFrame, pandas removed ``Panel`` in favor of directing users who use +multi-dimensional arrays to xarray. + +Xarray has most of ``Panel``'s features, a more explicit API (particularly around +indexing), and the ability to scale to >3 dimensions with the same interface. + +As discussed in the :ref:`data structures section of the docs `, there are two primary data structures in +xarray: ``DataArray`` and ``Dataset``. You can imagine a ``DataArray`` as a +n-dimensional pandas ``Series`` (i.e. a single typed array), and a ``Dataset`` +as the ``DataFrame`` equivalent (i.e. a dict of aligned ``DataArray`` objects). + +So you can represent a Panel, in two ways: + +- As a 3-dimensional ``DataArray``, +- Or as a ``Dataset`` containing a number of 2-dimensional DataArray objects. + +Let's take a look: + +.. ipython:: python + + data = np.random.RandomState(0).rand(2, 3, 4) + items = list("ab") + major_axis = list("mno") + minor_axis = pd.date_range(start="2000", periods=4, name="date") + +With old versions of pandas (prior to 0.25), this could stored in a ``Panel``: + +.. ipython:: + :verbatim: + + In [1]: pd.Panel(data, items, major_axis, minor_axis) + Out[1]: + + Dimensions: 2 (items) x 3 (major_axis) x 4 (minor_axis) + Items axis: a to b + Major_axis axis: m to o + Minor_axis axis: 2000-01-01 00:00:00 to 2000-01-04 00:00:00 + +To put this data in a ``DataArray``, write: + +.. ipython:: python + + array = xr.DataArray(data, [items, major_axis, minor_axis]) + array + +As you can see, there are three dimensions (each is also a coordinate). Two of +the axes of were unnamed, so have been assigned ``dim_0`` and ``dim_1`` +respectively, while the third retains its name ``date``. + +You can also easily convert this data into ``Dataset``: + +.. ipython:: python + + array.to_dataset(dim="dim_0") + +Here, there are two data variables, each representing a DataFrame on panel's +``items`` axis, and labeled as such. Each variable is a 2D array of the +respective values along the ``items`` dimension. + +While the xarray docs are relatively complete, a few items stand out for Panel users: + +- A DataArray's data is stored as a numpy array, and so can only contain a single + type. As a result, a Panel that contains :py:class:`~pandas.DataFrame` objects + with multiple types will be converted to ``dtype=object``. A ``Dataset`` of + multiple ``DataArray`` objects each with its own dtype will allow original + types to be preserved. +- :ref:`Indexing ` is similar to pandas, but more explicit and + leverages xarray's naming of dimensions. +- Because of those features, making much higher dimensional data is very + practical. +- Variables in ``Dataset`` objects can use a subset of its dimensions. For + example, you can have one dataset with Person x Score x Time, and another with + Person x Score. +- You can use coordinates are used for both dimensions and for variables which + _label_ the data variables, so you could have a coordinate Age, that labelled + the Person dimension of a Dataset of Person x Score x Time. + +While xarray may take some getting used to, it's worth it! If anything is unclear, +please `post an issue on GitHub `__ or +`StackOverflow `__, +and we'll endeavor to respond to the specific case or improve the general docs. diff --git a/test/fixtures/whole_applications/xarray/doc/user-guide/plotting.rst b/test/fixtures/whole_applications/xarray/doc/user-guide/plotting.rst new file mode 100644 index 0000000..2bc049f --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/user-guide/plotting.rst @@ -0,0 +1,1086 @@ +.. currentmodule:: xarray +.. _plotting: + +Plotting +======== + +Introduction +------------ + +Labeled data enables expressive computations. These same +labels can also be used to easily create informative plots. + +Xarray's plotting capabilities are centered around +:py:class:`DataArray` objects. +To plot :py:class:`Dataset` objects +simply access the relevant DataArrays, i.e. ``dset['var1']``. +Dataset specific plotting routines are also available (see :ref:`plot-dataset`). +Here we focus mostly on arrays 2d or larger. If your data fits +nicely into a pandas DataFrame then you're better off using one of the more +developed tools there. + +Xarray plotting functionality is a thin wrapper around the popular +`matplotlib `_ library. +Matplotlib syntax and function names were copied as much as possible, which +makes for an easy transition between the two. +Matplotlib must be installed before xarray can plot. + +To use xarray's plotting capabilities with time coordinates containing +``cftime.datetime`` objects +`nc-time-axis `_ v1.3.0 or later +needs to be installed. + +For more extensive plotting applications consider the following projects: + +- `Seaborn `_: "provides + a high-level interface for drawing attractive statistical graphics." + Integrates well with pandas. + +- `HoloViews `_ + and `GeoViews `_: "Composable, declarative + data structures for building even complex visualizations easily." Includes + native support for xarray objects. + +- `hvplot `_: ``hvplot`` makes it very easy to produce + dynamic plots (backed by ``Holoviews`` or ``Geoviews``) by adding a ``hvplot`` + accessor to DataArrays. + +- `Cartopy `_: Provides cartographic + tools. + +Imports +~~~~~~~ + +.. ipython:: python + :suppress: + + # Use defaults so we don't get gridlines in generated docs + import matplotlib as mpl + + mpl.rcdefaults() + +The following imports are necessary for all of the examples. + +.. ipython:: python + + import numpy as np + import pandas as pd + import matplotlib.pyplot as plt + import xarray as xr + +For these examples we'll use the North American air temperature dataset. + +.. ipython:: python + + airtemps = xr.tutorial.open_dataset("air_temperature") + airtemps + + # Convert to celsius + air = airtemps.air - 273.15 + + # copy attributes to get nice figure labels and change Kelvin to Celsius + air.attrs = airtemps.air.attrs + air.attrs["units"] = "deg C" + +.. note:: + Until :issue:`1614` is solved, you might need to copy over the metadata in ``attrs`` to get informative figure labels (as was done above). + + +DataArrays +---------- + +One Dimension +~~~~~~~~~~~~~ + +================ + Simple Example +================ + +The simplest way to make a plot is to call the :py:func:`DataArray.plot()` method. + +.. ipython:: python + :okwarning: + + air1d = air.isel(lat=10, lon=10) + + @savefig plotting_1d_simple.png width=4in + air1d.plot() + +Xarray uses the coordinate name along with metadata ``attrs.long_name``, +``attrs.standard_name``, ``DataArray.name`` and ``attrs.units`` (if available) +to label the axes. +The names ``long_name``, ``standard_name`` and ``units`` are copied from the +`CF-conventions spec `_. +When choosing names, the order of precedence is ``long_name``, ``standard_name`` and finally ``DataArray.name``. +The y-axis label in the above plot was constructed from the ``long_name`` and ``units`` attributes of ``air1d``. + +.. ipython:: python + + air1d.attrs + +====================== + Additional Arguments +====================== + +Additional arguments are passed directly to the matplotlib function which +does the work. +For example, :py:func:`xarray.plot.line` calls +matplotlib.pyplot.plot_ passing in the index and the array values as x and y, respectively. +So to make a line plot with blue triangles a matplotlib format string +can be used: + +.. _matplotlib.pyplot.plot: https://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.plot + +.. ipython:: python + :okwarning: + + @savefig plotting_1d_additional_args.png width=4in + air1d[:200].plot.line("b-^") + +.. note:: + Not all xarray plotting methods support passing positional arguments + to the wrapped matplotlib functions, but they do all + support keyword arguments. + +Keyword arguments work the same way, and are more explicit. + +.. ipython:: python + :okwarning: + + @savefig plotting_example_sin3.png width=4in + air1d[:200].plot.line(color="purple", marker="o") + +========================= + Adding to Existing Axis +========================= + +To add the plot to an existing axis pass in the axis as a keyword argument +``ax``. This works for all xarray plotting methods. +In this example ``axs`` is an array consisting of the left and right +axes created by ``plt.subplots``. + +.. ipython:: python + :okwarning: + + fig, axs = plt.subplots(ncols=2) + + axs + + air1d.plot(ax=axs[0]) + air1d.plot.hist(ax=axs[1]) + + plt.tight_layout() + + @savefig plotting_example_existing_axes.png width=6in + plt.draw() + +On the right is a histogram created by :py:func:`xarray.plot.hist`. + +.. _plotting.figsize: + +============================= + Controlling the figure size +============================= + +You can pass a ``figsize`` argument to all xarray's plotting methods to +control the figure size. For convenience, xarray's plotting methods also +support the ``aspect`` and ``size`` arguments which control the size of the +resulting image via the formula ``figsize = (aspect * size, size)``: + +.. ipython:: python + :okwarning: + + air1d.plot(aspect=2, size=3) + @savefig plotting_example_size_and_aspect.png + plt.tight_layout() + +.. ipython:: python + :suppress: + + # create a dummy figure so sphinx plots everything below normally + plt.figure() + +This feature also works with :ref:`plotting.faceting`. For facet plots, +``size`` and ``aspect`` refer to a single panel (so that ``aspect * size`` +gives the width of each facet in inches), while ``figsize`` refers to the +entire figure (as for matplotlib's ``figsize`` argument). + +.. note:: + + If ``figsize`` or ``size`` are used, a new figure is created, + so this is mutually exclusive with the ``ax`` argument. + +.. note:: + + The convention used by xarray (``figsize = (aspect * size, size)``) is + borrowed from seaborn: it is therefore `not equivalent to matplotlib's`_. + +.. _not equivalent to matplotlib's: https://github.com/mwaskom/seaborn/issues/746 + + +.. _plotting.multiplelines: + +========================= + Determine x-axis values +========================= + +Per default dimension coordinates are used for the x-axis (here the time coordinates). +However, you can also use non-dimension coordinates, MultiIndex levels, and dimensions +without coordinates along the x-axis. To illustrate this, let's calculate a 'decimal day' (epoch) +from the time and assign it as a non-dimension coordinate: + +.. ipython:: python + :okwarning: + + decimal_day = (air1d.time - air1d.time[0]) / pd.Timedelta("1d") + air1d_multi = air1d.assign_coords(decimal_day=("time", decimal_day.data)) + air1d_multi + +To use ``'decimal_day'`` as x coordinate it must be explicitly specified: + +.. ipython:: python + :okwarning: + + air1d_multi.plot(x="decimal_day") + +Creating a new MultiIndex named ``'date'`` from ``'time'`` and ``'decimal_day'``, +it is also possible to use a MultiIndex level as x-axis: + +.. ipython:: python + :okwarning: + + air1d_multi = air1d_multi.set_index(date=("time", "decimal_day")) + air1d_multi.plot(x="decimal_day") + +Finally, if a dataset does not have any coordinates it enumerates all data points: + +.. ipython:: python + :okwarning: + + air1d_multi = air1d_multi.drop_vars(["date", "time", "decimal_day"]) + air1d_multi.plot() + +The same applies to 2D plots below. + +==================================================== + Multiple lines showing variation along a dimension +==================================================== + +It is possible to make line plots of two-dimensional data by calling :py:func:`xarray.plot.line` +with appropriate arguments. Consider the 3D variable ``air`` defined above. We can use line +plots to check the variation of air temperature at three different latitudes along a longitude line: + +.. ipython:: python + :okwarning: + + @savefig plotting_example_multiple_lines_x_kwarg.png + air.isel(lon=10, lat=[19, 21, 22]).plot.line(x="time") + +It is required to explicitly specify either + +1. ``x``: the dimension to be used for the x-axis, or +2. ``hue``: the dimension you want to represent by multiple lines. + +Thus, we could have made the previous plot by specifying ``hue='lat'`` instead of ``x='time'``. +If required, the automatic legend can be turned off using ``add_legend=False``. Alternatively, +``hue`` can be passed directly to :py:func:`xarray.plot.line` as `air.isel(lon=10, lat=[19,21,22]).plot.line(hue='lat')`. + + +======================== + Dimension along y-axis +======================== + +It is also possible to make line plots such that the data are on the x-axis and a dimension is on the y-axis. This can be done by specifying the appropriate ``y`` keyword argument. + +.. ipython:: python + :okwarning: + + @savefig plotting_example_xy_kwarg.png + air.isel(time=10, lon=[10, 11]).plot(y="lat", hue="lon") + +============ + Step plots +============ + +As an alternative, also a step plot similar to matplotlib's ``plt.step`` can be +made using 1D data. + +.. ipython:: python + :okwarning: + + @savefig plotting_example_step.png width=4in + air1d[:20].plot.step(where="mid") + +The argument ``where`` defines where the steps should be placed, options are +``'pre'`` (default), ``'post'``, and ``'mid'``. This is particularly handy +when plotting data grouped with :py:meth:`Dataset.groupby_bins`. + +.. ipython:: python + :okwarning: + + air_grp = air.mean(["time", "lon"]).groupby_bins("lat", [0, 23.5, 66.5, 90]) + air_mean = air_grp.mean() + air_std = air_grp.std() + air_mean.plot.step() + (air_mean + air_std).plot.step(ls=":") + (air_mean - air_std).plot.step(ls=":") + plt.ylim(-20, 30) + @savefig plotting_example_step_groupby.png width=4in + plt.title("Zonal mean temperature") + +In this case, the actual boundaries of the bins are used and the ``where`` argument +is ignored. + + +Other axes kwargs +~~~~~~~~~~~~~~~~~ + + +The keyword arguments ``xincrease`` and ``yincrease`` let you control the axes direction. + +.. ipython:: python + :okwarning: + + @savefig plotting_example_xincrease_yincrease_kwarg.png + air.isel(time=10, lon=[10, 11]).plot.line( + y="lat", hue="lon", xincrease=False, yincrease=False + ) + +In addition, one can use ``xscale, yscale`` to set axes scaling; +``xticks, yticks`` to set axes ticks and ``xlim, ylim`` to set axes limits. +These accept the same values as the matplotlib methods ``ax.set_(x,y)scale()``, +``ax.set_(x,y)ticks()``, ``ax.set_(x,y)lim()``, respectively. + + +Two Dimensions +~~~~~~~~~~~~~~ + +================ + Simple Example +================ + +The default method :py:meth:`DataArray.plot` calls :py:func:`xarray.plot.pcolormesh` +by default when the data is two-dimensional. + +.. ipython:: python + :okwarning: + + air2d = air.isel(time=500) + + @savefig 2d_simple.png width=4in + air2d.plot() + +All 2d plots in xarray allow the use of the keyword arguments ``yincrease`` +and ``xincrease``. + +.. ipython:: python + :okwarning: + + @savefig 2d_simple_yincrease.png width=4in + air2d.plot(yincrease=False) + +.. note:: + + We use :py:func:`xarray.plot.pcolormesh` as the default two-dimensional plot + method because it is more flexible than :py:func:`xarray.plot.imshow`. + However, for large arrays, ``imshow`` can be much faster than ``pcolormesh``. + If speed is important to you and you are plotting a regular mesh, consider + using ``imshow``. + +================ + Missing Values +================ + +Xarray plots data with :ref:`missing_values`. + +.. ipython:: python + :okwarning: + + bad_air2d = air2d.copy() + + bad_air2d[dict(lat=slice(0, 10), lon=slice(0, 25))] = np.nan + + @savefig plotting_missing_values.png width=4in + bad_air2d.plot() + +======================== + Nonuniform Coordinates +======================== + +It's not necessary for the coordinates to be evenly spaced. Both +:py:func:`xarray.plot.pcolormesh` (default) and :py:func:`xarray.plot.contourf` can +produce plots with nonuniform coordinates. + +.. ipython:: python + :okwarning: + + b = air2d.copy() + # Apply a nonlinear transformation to one of the coords + b.coords["lat"] = np.log(b.coords["lat"]) + + @savefig plotting_nonuniform_coords.png width=4in + b.plot() + +==================== + Other types of plot +==================== + +There are several other options for plotting 2D data. + +Contour plot using :py:meth:`DataArray.plot.contour()` + +.. ipython:: python + :okwarning: + + @savefig plotting_contour.png width=4in + air2d.plot.contour() + +Filled contour plot using :py:meth:`DataArray.plot.contourf()` + +.. ipython:: python + :okwarning: + + @savefig plotting_contourf.png width=4in + air2d.plot.contourf() + +Surface plot using :py:meth:`DataArray.plot.surface()` + +.. ipython:: python + :okwarning: + + @savefig plotting_surface.png width=4in + # transpose just to make the example look a bit nicer + air2d.T.plot.surface() + +==================== + Calling Matplotlib +==================== + +Since this is a thin wrapper around matplotlib, all the functionality of +matplotlib is available. + +.. ipython:: python + :okwarning: + + air2d.plot(cmap=plt.cm.Blues) + plt.title("These colors prove North America\nhas fallen in the ocean") + plt.ylabel("latitude") + plt.xlabel("longitude") + plt.tight_layout() + + @savefig plotting_2d_call_matplotlib.png width=4in + plt.draw() + +.. note:: + + Xarray methods update label information and generally play around with the + axes. So any kind of updates to the plot + should be done *after* the call to the xarray's plot. + In the example below, ``plt.xlabel`` effectively does nothing, since + ``d_ylog.plot()`` updates the xlabel. + + .. ipython:: python + :okwarning: + + plt.xlabel("Never gonna see this.") + air2d.plot() + + @savefig plotting_2d_call_matplotlib2.png width=4in + plt.draw() + +=========== + Colormaps +=========== + +Xarray borrows logic from Seaborn to infer what kind of color map to use. For +example, consider the original data in Kelvins rather than Celsius: + +.. ipython:: python + :okwarning: + + @savefig plotting_kelvin.png width=4in + airtemps.air.isel(time=0).plot() + +The Celsius data contain 0, so a diverging color map was used. The +Kelvins do not have 0, so the default color map was used. + +.. _robust-plotting: + +======== + Robust +======== + +Outliers often have an extreme effect on the output of the plot. +Here we add two bad data points. This affects the color scale, +washing out the plot. + +.. ipython:: python + :okwarning: + + air_outliers = airtemps.air.isel(time=0).copy() + air_outliers[0, 0] = 100 + air_outliers[-1, -1] = 400 + + @savefig plotting_robust1.png width=4in + air_outliers.plot() + +This plot shows that we have outliers. The easy way to visualize +the data without the outliers is to pass the parameter +``robust=True``. +This will use the 2nd and 98th +percentiles of the data to compute the color limits. + +.. ipython:: python + :okwarning: + + @savefig plotting_robust2.png width=4in + air_outliers.plot(robust=True) + +Observe that the ranges of the color bar have changed. The arrows on the +color bar indicate +that the colors include data points outside the bounds. + +==================== + Discrete Colormaps +==================== + +It is often useful, when visualizing 2d data, to use a discrete colormap, +rather than the default continuous colormaps that matplotlib uses. The +``levels`` keyword argument can be used to generate plots with discrete +colormaps. For example, to make a plot with 8 discrete color intervals: + +.. ipython:: python + :okwarning: + + @savefig plotting_discrete_levels.png width=4in + air2d.plot(levels=8) + +It is also possible to use a list of levels to specify the boundaries of the +discrete colormap: + +.. ipython:: python + :okwarning: + + @savefig plotting_listed_levels.png width=4in + air2d.plot(levels=[0, 12, 18, 30]) + +You can also specify a list of discrete colors through the ``colors`` argument: + +.. ipython:: python + :okwarning: + + flatui = ["#9b59b6", "#3498db", "#95a5a6", "#e74c3c", "#34495e", "#2ecc71"] + @savefig plotting_custom_colors_levels.png width=4in + air2d.plot(levels=[0, 12, 18, 30], colors=flatui) + +Finally, if you have `Seaborn `_ +installed, you can also specify a seaborn color palette to the ``cmap`` +argument. Note that ``levels`` *must* be specified with seaborn color palettes +if using ``imshow`` or ``pcolormesh`` (but not with ``contour`` or ``contourf``, +since levels are chosen automatically). + +.. ipython:: python + :okwarning: + + @savefig plotting_seaborn_palette.png width=4in + air2d.plot(levels=10, cmap="husl") + plt.draw() + +.. _plotting.faceting: + +Faceting +~~~~~~~~ + +Faceting here refers to splitting an array along one or two dimensions and +plotting each group. +Xarray's basic plotting is useful for plotting two dimensional arrays. What +about three or four dimensional arrays? That's where facets become helpful. +The general approach to plotting here is called “small multiples”, where the +same kind of plot is repeated multiple times, and the specific use of small +multiples to display the same relationship conditioned on one or more other +variables is often called a “trellis plot”. + +Consider the temperature data set. There are 4 observations per day for two +years which makes for 2920 values along the time dimension. +One way to visualize this data is to make a +separate plot for each time period. + +The faceted dimension should not have too many values; +faceting on the time dimension will produce 2920 plots. That's +too much to be helpful. To handle this situation try performing +an operation that reduces the size of the data in some way. For example, we +could compute the average air temperature for each month and reduce the +size of this dimension from 2920 -> 12. A simpler way is +to just take a slice on that dimension. +So let's use a slice to pick 6 times throughout the first year. + +.. ipython:: python + + t = air.isel(time=slice(0, 365 * 4, 250)) + t.coords + +================ + Simple Example +================ + +The easiest way to create faceted plots is to pass in ``row`` or ``col`` +arguments to the xarray plotting methods/functions. This returns a +:py:class:`xarray.plot.FacetGrid` object. + +.. ipython:: python + :okwarning: + + @savefig plot_facet_dataarray.png + g_simple = t.plot(x="lon", y="lat", col="time", col_wrap=3) + +Faceting also works for line plots. + +.. ipython:: python + :okwarning: + + @savefig plot_facet_dataarray_line.png + g_simple_line = t.isel(lat=slice(0, None, 4)).plot( + x="lon", hue="lat", col="time", col_wrap=3 + ) + +=============== + 4 dimensional +=============== + +For 4 dimensional arrays we can use the rows and columns of the grids. +Here we create a 4 dimensional array by taking the original data and adding +a fixed amount. Now we can see how the temperature maps would compare if +one were much hotter. + +.. ipython:: python + :okwarning: + + t2 = t.isel(time=slice(0, 2)) + t4d = xr.concat([t2, t2 + 40], pd.Index(["normal", "hot"], name="fourth_dim")) + # This is a 4d array + t4d.coords + + @savefig plot_facet_4d.png + t4d.plot(x="lon", y="lat", col="time", row="fourth_dim") + +================ + Other features +================ + +Faceted plotting supports other arguments common to xarray 2d plots. + +.. ipython:: python + :suppress: + + plt.close("all") + +.. ipython:: python + :okwarning: + + hasoutliers = t.isel(time=slice(0, 5)).copy() + hasoutliers[0, 0, 0] = -100 + hasoutliers[-1, -1, -1] = 400 + + @savefig plot_facet_robust.png + g = hasoutliers.plot.pcolormesh( + x="lon", + y="lat", + col="time", + col_wrap=3, + robust=True, + cmap="viridis", + cbar_kwargs={"label": "this has outliers"}, + ) + +=================== + FacetGrid Objects +=================== + +The object returned, ``g`` in the above examples, is a :py:class:`~xarray.plot.FacetGrid` object +that links a :py:class:`DataArray` to a matplotlib figure with a particular structure. +This object can be used to control the behavior of the multiple plots. +It borrows an API and code from `Seaborn's FacetGrid +`_. +The structure is contained within the ``axs`` and ``name_dicts`` +attributes, both 2d NumPy object arrays. + +.. ipython:: python + + g.axs + + g.name_dicts + +It's possible to select the :py:class:`xarray.DataArray` or +:py:class:`xarray.Dataset` corresponding to the FacetGrid through the +``name_dicts``. + +.. ipython:: python + + g.data.loc[g.name_dicts[0, 0]] + +Here is an example of using the lower level API and then modifying the axes after +they have been plotted. + +.. ipython:: python + :okwarning: + + g = t.plot.imshow(x="lon", y="lat", col="time", col_wrap=3, robust=True) + + for i, ax in enumerate(g.axs.flat): + ax.set_title("Air Temperature %d" % i) + + bottomright = g.axs[-1, -1] + bottomright.annotate("bottom right", (240, 40)) + + @savefig plot_facet_iterator.png + plt.draw() + + +:py:class:`~xarray.plot.FacetGrid` objects have methods that let you customize the automatically generated +axis labels, axis ticks and plot titles. See :py:meth:`~xarray.plot.FacetGrid.set_titles`, +:py:meth:`~xarray.plot.FacetGrid.set_xlabels`, :py:meth:`~xarray.plot.FacetGrid.set_ylabels` and +:py:meth:`~xarray.plot.FacetGrid.set_ticks` for more information. +Plotting functions can be applied to each subset of the data by calling +:py:meth:`~xarray.plot.FacetGrid.map_dataarray` or to each subplot by calling :py:meth:`~xarray.plot.FacetGrid.map`. + +TODO: add an example of using the ``map`` method to plot dataset variables +(e.g., with ``plt.quiver``). + +.. _plot-dataset: + +Datasets +-------- + +Xarray has limited support for plotting Dataset variables against each other. +Consider this dataset + +.. ipython:: python + + ds = xr.tutorial.scatter_example_dataset(seed=42) + ds + + +Scatter +~~~~~~~ + +Let's plot the ``A`` DataArray as a function of the ``y`` coord + +.. ipython:: python + :okwarning: + + ds.A + + @savefig da_A_y.png + ds.A.plot.scatter(x="y") + +Same plot can be displayed using the dataset: + +.. ipython:: python + :okwarning: + + @savefig ds_A_y.png + ds.plot.scatter(x="y", y="A") + +Now suppose we want to scatter the ``A`` DataArray against the ``B`` DataArray + +.. ipython:: python + :okwarning: + + @savefig ds_simple_scatter.png + ds.plot.scatter(x="A", y="B") + +The ``hue`` kwarg lets you vary the color by variable value + +.. ipython:: python + :okwarning: + + @savefig ds_hue_scatter.png + ds.plot.scatter(x="A", y="B", hue="w") + +You can force a legend instead of a colorbar by setting ``add_legend=True, add_colorbar=False``. + +.. ipython:: python + :okwarning: + + @savefig ds_discrete_legend_hue_scatter.png + ds.plot.scatter(x="A", y="B", hue="w", add_legend=True, add_colorbar=False) + +.. ipython:: python + :okwarning: + + @savefig ds_discrete_colorbar_hue_scatter.png + ds.plot.scatter(x="A", y="B", hue="w", add_legend=False, add_colorbar=True) + +The ``markersize`` kwarg lets you vary the point's size by variable value. +You can additionally pass ``size_norm`` to control how the variable's values are mapped to point sizes. + +.. ipython:: python + :okwarning: + + @savefig ds_hue_size_scatter.png + ds.plot.scatter(x="A", y="B", hue="y", markersize="z") + +The ``z`` kwarg lets you plot the data along the z-axis as well. + +.. ipython:: python + :okwarning: + + @savefig ds_hue_size_scatter_z.png + ds.plot.scatter(x="A", y="B", z="z", hue="y", markersize="x") + +Faceting is also possible + +.. ipython:: python + :okwarning: + + @savefig ds_facet_scatter.png + ds.plot.scatter(x="A", y="B", hue="y", markersize="x", row="x", col="w") + +And adding the z-axis + +.. ipython:: python + :okwarning: + + @savefig ds_facet_scatter_z.png + ds.plot.scatter(x="A", y="B", z="z", hue="y", markersize="x", row="x", col="w") + +For more advanced scatter plots, we recommend converting the relevant data variables +to a pandas DataFrame and using the extensive plotting capabilities of ``seaborn``. + +Quiver +~~~~~~ + +Visualizing vector fields is supported with quiver plots: + +.. ipython:: python + :okwarning: + + @savefig ds_simple_quiver.png + ds.isel(w=1, z=1).plot.quiver(x="x", y="y", u="A", v="B") + + +where ``u`` and ``v`` denote the x and y direction components of the arrow vectors. Again, faceting is also possible: + +.. ipython:: python + :okwarning: + + @savefig ds_facet_quiver.png + ds.plot.quiver(x="x", y="y", u="A", v="B", col="w", row="z", scale=4) + +``scale`` is required for faceted quiver plots. +The scale determines the number of data units per arrow length unit, i.e. a smaller scale parameter makes the arrow longer. + +Streamplot +~~~~~~~~~~ + +Visualizing vector fields is also supported with streamline plots: + +.. ipython:: python + :okwarning: + + @savefig ds_simple_streamplot.png + ds.isel(w=1, z=1).plot.streamplot(x="x", y="y", u="A", v="B") + + +where ``u`` and ``v`` denote the x and y direction components of the vectors tangent to the streamlines. +Again, faceting is also possible: + +.. ipython:: python + :okwarning: + + @savefig ds_facet_streamplot.png + ds.plot.streamplot(x="x", y="y", u="A", v="B", col="w", row="z") + +.. _plot-maps: + +Maps +---- + +To follow this section you'll need to have Cartopy installed and working. + +This script will plot the air temperature on a map. + +.. ipython:: python + :okwarning: + + import cartopy.crs as ccrs + + air = xr.tutorial.open_dataset("air_temperature").air + + p = air.isel(time=0).plot( + subplot_kws=dict(projection=ccrs.Orthographic(-80, 35), facecolor="gray"), + transform=ccrs.PlateCarree(), + ) + p.axes.set_global() + + @savefig plotting_maps_cartopy.png width=100% + p.axes.coastlines() + +When faceting on maps, the projection can be transferred to the ``plot`` +function using the ``subplot_kws`` keyword. The axes for the subplots created +by faceting are accessible in the object returned by ``plot``: + +.. ipython:: python + :okwarning: + + p = air.isel(time=[0, 4]).plot( + transform=ccrs.PlateCarree(), + col="time", + subplot_kws={"projection": ccrs.Orthographic(-80, 35)}, + ) + for ax in p.axs.flat: + ax.coastlines() + ax.gridlines() + @savefig plotting_maps_cartopy_facetting.png width=100% + plt.draw() + + +Details +------- + +Ways to Use +~~~~~~~~~~~ + +There are three ways to use the xarray plotting functionality: + +1. Use ``plot`` as a convenience method for a DataArray. + +2. Access a specific plotting method from the ``plot`` attribute of a + DataArray. + +3. Directly from the xarray plot submodule. + +These are provided for user convenience; they all call the same code. + +.. ipython:: python + :okwarning: + + import xarray.plot as xplt + + da = xr.DataArray(range(5)) + fig, axs = plt.subplots(ncols=2, nrows=2) + da.plot(ax=axs[0, 0]) + da.plot.line(ax=axs[0, 1]) + xplt.plot(da, ax=axs[1, 0]) + xplt.line(da, ax=axs[1, 1]) + plt.tight_layout() + @savefig plotting_ways_to_use.png width=6in + plt.draw() + +Here the output is the same. Since the data is 1 dimensional the line plot +was used. + +The convenience method :py:meth:`xarray.DataArray.plot` dispatches to an appropriate +plotting function based on the dimensions of the ``DataArray`` and whether +the coordinates are sorted and uniformly spaced. This table +describes what gets plotted: + +=============== =========================== +Dimensions Plotting function +--------------- --------------------------- +1 :py:func:`xarray.plot.line` +2 :py:func:`xarray.plot.pcolormesh` +Anything else :py:func:`xarray.plot.hist` +=============== =========================== + +Coordinates +~~~~~~~~~~~ + +If you'd like to find out what's really going on in the coordinate system, +read on. + +.. ipython:: python + + a0 = xr.DataArray(np.zeros((4, 3, 2)), dims=("y", "x", "z"), name="temperature") + a0[0, 0, 0] = 1 + a = a0.isel(z=0) + a + +The plot will produce an image corresponding to the values of the array. +Hence the top left pixel will be a different color than the others. +Before reading on, you may want to look at the coordinates and +think carefully about what the limits, labels, and orientation for +each of the axes should be. + +.. ipython:: python + :okwarning: + + @savefig plotting_example_2d_simple.png width=4in + a.plot() + +It may seem strange that +the values on the y axis are decreasing with -0.5 on the top. This is because +the pixels are centered over their coordinates, and the +axis labels and ranges correspond to the values of the +coordinates. + +Multidimensional coordinates +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +See also: :ref:`/examples/multidimensional-coords.ipynb`. + +You can plot irregular grids defined by multidimensional coordinates with +xarray, but you'll have to tell the plot function to use these coordinates +instead of the default ones: + +.. ipython:: python + :okwarning: + + lon, lat = np.meshgrid(np.linspace(-20, 20, 5), np.linspace(0, 30, 4)) + lon += lat / 10 + lat += lon / 10 + da = xr.DataArray( + np.arange(20).reshape(4, 5), + dims=["y", "x"], + coords={"lat": (("y", "x"), lat), "lon": (("y", "x"), lon)}, + ) + + @savefig plotting_example_2d_irreg.png width=4in + da.plot.pcolormesh(x="lon", y="lat") + +Note that in this case, xarray still follows the pixel centered convention. +This might be undesirable in some cases, for example when your data is defined +on a polar projection (:issue:`781`). This is why the default is to not follow +this convention when plotting on a map: + +.. ipython:: python + :okwarning: + + import cartopy.crs as ccrs + + ax = plt.subplot(projection=ccrs.PlateCarree()) + da.plot.pcolormesh(x="lon", y="lat", ax=ax) + ax.scatter(lon, lat, transform=ccrs.PlateCarree()) + ax.coastlines() + @savefig plotting_example_2d_irreg_map.png width=4in + ax.gridlines(draw_labels=True) + +You can however decide to infer the cell boundaries and use the +``infer_intervals`` keyword: + +.. ipython:: python + :okwarning: + + ax = plt.subplot(projection=ccrs.PlateCarree()) + da.plot.pcolormesh(x="lon", y="lat", ax=ax, infer_intervals=True) + ax.scatter(lon, lat, transform=ccrs.PlateCarree()) + ax.coastlines() + @savefig plotting_example_2d_irreg_map_infer.png width=4in + ax.gridlines(draw_labels=True) + +.. note:: + The data model of xarray does not support datasets with `cell boundaries`_ + yet. If you want to use these coordinates, you'll have to make the plots + outside the xarray framework. + +.. _cell boundaries: https://cfconventions.org/cf-conventions/v1.6.0/cf-conventions.html#cell-boundaries + +One can also make line plots with multidimensional coordinates. In this case, ``hue`` must be a dimension name, not a coordinate name. + +.. ipython:: python + :okwarning: + + f, ax = plt.subplots(2, 1) + da.plot.line(x="lon", hue="y", ax=ax[0]) + @savefig plotting_example_2d_hue_xy.png + da.plot.line(x="lon", hue="x", ax=ax[1]) diff --git a/test/fixtures/whole_applications/xarray/doc/user-guide/reshaping.rst b/test/fixtures/whole_applications/xarray/doc/user-guide/reshaping.rst new file mode 100644 index 0000000..14b3435 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/user-guide/reshaping.rst @@ -0,0 +1,350 @@ +.. _reshape: + +############################### +Reshaping and reorganizing data +############################### + +Reshaping and reorganizing data refers to the process of changing the structure or organization of data by modifying dimensions, array shapes, order of values, or indexes. Xarray provides several methods to accomplish these tasks. + +These methods are particularly useful for reshaping xarray objects for use in machine learning packages, such as scikit-learn, that usually require two-dimensional numpy arrays as inputs. Reshaping can also be required before passing data to external visualization tools, for example geospatial data might expect input organized into a particular format corresponding to stacks of satellite images. + +Importing the library +--------------------- + +.. ipython:: python + :suppress: + + import numpy as np + import pandas as pd + import xarray as xr + + np.random.seed(123456) + +Reordering dimensions +--------------------- + +To reorder dimensions on a :py:class:`~xarray.DataArray` or across all variables +on a :py:class:`~xarray.Dataset`, use :py:meth:`~xarray.DataArray.transpose`. An +ellipsis (`...`) can be used to represent all other dimensions: + +.. ipython:: python + + ds = xr.Dataset({"foo": (("x", "y", "z"), [[[42]]]), "bar": (("y", "z"), [[24]])}) + ds.transpose("y", "z", "x") + ds.transpose(..., "x") # equivalent + ds.transpose() # reverses all dimensions + +Expand and squeeze dimensions +----------------------------- + +To expand a :py:class:`~xarray.DataArray` or all +variables on a :py:class:`~xarray.Dataset` along a new dimension, +use :py:meth:`~xarray.DataArray.expand_dims` + +.. ipython:: python + + expanded = ds.expand_dims("w") + expanded + +This method attaches a new dimension with size 1 to all data variables. + +To remove such a size-1 dimension from the :py:class:`~xarray.DataArray` +or :py:class:`~xarray.Dataset`, +use :py:meth:`~xarray.DataArray.squeeze` + +.. ipython:: python + + expanded.squeeze("w") + +Converting between datasets and arrays +-------------------------------------- + +To convert from a Dataset to a DataArray, use :py:meth:`~xarray.Dataset.to_dataarray`: + +.. ipython:: python + + arr = ds.to_dataarray() + arr + +This method broadcasts all data variables in the dataset against each other, +then concatenates them along a new dimension into a new array while preserving +coordinates. + +To convert back from a DataArray to a Dataset, use +:py:meth:`~xarray.DataArray.to_dataset`: + +.. ipython:: python + + arr.to_dataset(dim="variable") + +The broadcasting behavior of ``to_dataarray`` means that the resulting array +includes the union of data variable dimensions: + +.. ipython:: python + + ds2 = xr.Dataset({"a": 0, "b": ("x", [3, 4, 5])}) + + # the input dataset has 4 elements + ds2 + + # the resulting array has 6 elements + ds2.to_dataarray() + +Otherwise, the result could not be represented as an orthogonal array. + +If you use ``to_dataset`` without supplying the ``dim`` argument, the DataArray will be converted into a Dataset of one variable: + +.. ipython:: python + + arr.to_dataset(name="combined") + +.. _reshape.stack: + +Stack and unstack +----------------- + +As part of xarray's nascent support for :py:class:`pandas.MultiIndex`, we have +implemented :py:meth:`~xarray.DataArray.stack` and +:py:meth:`~xarray.DataArray.unstack` method, for combining or splitting dimensions: + +.. ipython:: python + + array = xr.DataArray( + np.random.randn(2, 3), coords=[("x", ["a", "b"]), ("y", [0, 1, 2])] + ) + stacked = array.stack(z=("x", "y")) + stacked + stacked.unstack("z") + +As elsewhere in xarray, an ellipsis (`...`) can be used to represent all unlisted dimensions: + +.. ipython:: python + + stacked = array.stack(z=[..., "x"]) + stacked + +These methods are modeled on the :py:class:`pandas.DataFrame` methods of the +same name, although in xarray they always create new dimensions rather than +adding to the existing index or columns. + +Like :py:meth:`DataFrame.unstack`, xarray's ``unstack`` +always succeeds, even if the multi-index being unstacked does not contain all +possible levels. Missing levels are filled in with ``NaN`` in the resulting object: + +.. ipython:: python + + stacked2 = stacked[::2] + stacked2 + stacked2.unstack("z") + +However, xarray's ``stack`` has an important difference from pandas: unlike +pandas, it does not automatically drop missing values. Compare: + +.. ipython:: python + + array = xr.DataArray([[np.nan, 1], [2, 3]], dims=["x", "y"]) + array.stack(z=("x", "y")) + array.to_pandas().stack() + +We departed from pandas's behavior here because predictable shapes for new +array dimensions is necessary for :ref:`dask`. + +.. _reshape.stacking_different: + +Stacking different variables together +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +These stacking and unstacking operations are particularly useful for reshaping +xarray objects for use in machine learning packages, such as `scikit-learn +`_, that usually require two-dimensional numpy +arrays as inputs. For datasets with only one variable, we only need ``stack`` +and ``unstack``, but combining multiple variables in a +:py:class:`xarray.Dataset` is more complicated. If the variables in the dataset +have matching numbers of dimensions, we can call +:py:meth:`~xarray.Dataset.to_dataarray` and then stack along the the new coordinate. +But :py:meth:`~xarray.Dataset.to_dataarray` will broadcast the dataarrays together, +which will effectively tile the lower dimensional variable along the missing +dimensions. The method :py:meth:`xarray.Dataset.to_stacked_array` allows +combining variables of differing dimensions without this wasteful copying while +:py:meth:`xarray.DataArray.to_unstacked_dataset` reverses this operation. +Just as with :py:meth:`xarray.Dataset.stack` the stacked coordinate is +represented by a :py:class:`pandas.MultiIndex` object. These methods are used +like this: + +.. ipython:: python + + data = xr.Dataset( + data_vars={"a": (("x", "y"), [[0, 1, 2], [3, 4, 5]]), "b": ("x", [6, 7])}, + coords={"y": ["u", "v", "w"]}, + ) + data + stacked = data.to_stacked_array("z", sample_dims=["x"]) + stacked + unstacked = stacked.to_unstacked_dataset("z") + unstacked + +In this example, ``stacked`` is a two dimensional array that we can easily pass to a scikit-learn or another generic +numerical method. + +.. note:: + + Unlike with ``stack``, in ``to_stacked_array``, the user specifies the dimensions they **do not** want stacked. + For a machine learning task, these unstacked dimensions can be interpreted as the dimensions over which samples are + drawn, whereas the stacked coordinates are the features. Naturally, all variables should possess these sampling + dimensions. + + +.. _reshape.set_index: + +Set and reset index +------------------- + +Complementary to stack / unstack, xarray's ``.set_index``, ``.reset_index`` and +``.reorder_levels`` allow easy manipulation of ``DataArray`` or ``Dataset`` +multi-indexes without modifying the data and its dimensions. + +You can create a multi-index from several 1-dimensional variables and/or +coordinates using :py:meth:`~xarray.DataArray.set_index`: + +.. ipython:: python + + da = xr.DataArray( + np.random.rand(4), + coords={ + "band": ("x", ["a", "a", "b", "b"]), + "wavenumber": ("x", np.linspace(200, 400, 4)), + }, + dims="x", + ) + da + mda = da.set_index(x=["band", "wavenumber"]) + mda + +These coordinates can now be used for indexing, e.g., + +.. ipython:: python + + mda.sel(band="a") + +Conversely, you can use :py:meth:`~xarray.DataArray.reset_index` +to extract multi-index levels as coordinates (this is mainly useful +for serialization): + +.. ipython:: python + + mda.reset_index("x") + +:py:meth:`~xarray.DataArray.reorder_levels` allows changing the order +of multi-index levels: + +.. ipython:: python + + mda.reorder_levels(x=["wavenumber", "band"]) + +As of xarray v0.9 coordinate labels for each dimension are optional. +You can also use ``.set_index`` / ``.reset_index`` to add / remove +labels for one or several dimensions: + +.. ipython:: python + + array = xr.DataArray([1, 2, 3], dims="x") + array + array["c"] = ("x", ["a", "b", "c"]) + array.set_index(x="c") + array = array.set_index(x="c") + array = array.reset_index("x", drop=True) + +.. _reshape.shift_and_roll: + +Shift and roll +-------------- + +To adjust coordinate labels, you can use the :py:meth:`~xarray.Dataset.shift` and +:py:meth:`~xarray.Dataset.roll` methods: + +.. ipython:: python + + array = xr.DataArray([1, 2, 3, 4], dims="x") + array.shift(x=2) + array.roll(x=2, roll_coords=True) + +.. _reshape.sort: + +Sort +---- + +One may sort a DataArray/Dataset via :py:meth:`~xarray.DataArray.sortby` and +:py:meth:`~xarray.Dataset.sortby`. The input can be an individual or list of +1D ``DataArray`` objects: + +.. ipython:: python + + ds = xr.Dataset( + { + "A": (("x", "y"), [[1, 2], [3, 4]]), + "B": (("x", "y"), [[5, 6], [7, 8]]), + }, + coords={"x": ["b", "a"], "y": [1, 0]}, + ) + dax = xr.DataArray([100, 99], [("x", [0, 1])]) + day = xr.DataArray([90, 80], [("y", [0, 1])]) + ds.sortby([day, dax]) + +As a shortcut, you can refer to existing coordinates by name: + +.. ipython:: python + + ds.sortby("x") + ds.sortby(["y", "x"]) + ds.sortby(["y", "x"], ascending=False) + +.. _reshape.coarsen: + +Reshaping via coarsen +--------------------- + +Whilst :py:class:`~xarray.DataArray.coarsen` is normally used for reducing your data's resolution by applying a reduction function +(see the :ref:`page on computation`), +it can also be used to reorganise your data without applying a computation via :py:meth:`~xarray.core.rolling.DataArrayCoarsen.construct`. + +Taking our example tutorial air temperature dataset over the Northern US + +.. ipython:: python + :suppress: + + # Use defaults so we don't get gridlines in generated docs + import matplotlib as mpl + + mpl.rcdefaults() + +.. ipython:: python + + air = xr.tutorial.open_dataset("air_temperature")["air"] + + @savefig pre_coarsening.png + air.isel(time=0).plot(x="lon", y="lat") + +we can split this up into sub-regions of size ``(9, 18)`` points using :py:meth:`~xarray.core.rolling.DataArrayCoarsen.construct`: + +.. ipython:: python + + regions = air.coarsen(lat=9, lon=18, boundary="pad").construct( + lon=("x_coarse", "x_fine"), lat=("y_coarse", "y_fine") + ) + regions + +9 new regions have been created, each of size 9 by 18 points. +The ``boundary="pad"`` kwarg ensured that all regions are the same size even though the data does not evenly divide into these sizes. + +By plotting these 9 regions together via :ref:`faceting` we can see how they relate to the original data. + +.. ipython:: python + + @savefig post_coarsening.png + regions.isel(time=0).plot( + x="x_fine", y="y_fine", col="x_coarse", row="y_coarse", yincrease=False + ) + +We are now free to easily apply any custom computation to each coarsened region of our new dataarray. +This would involve specifying that applied functions should act over the ``"x_fine"`` and ``"y_fine"`` dimensions, +but broadcast over the ``"x_coarse"`` and ``"y_coarse"`` dimensions. diff --git a/test/fixtures/whole_applications/xarray/doc/user-guide/terminology.rst b/test/fixtures/whole_applications/xarray/doc/user-guide/terminology.rst new file mode 100644 index 0000000..5593731 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/user-guide/terminology.rst @@ -0,0 +1,257 @@ +.. currentmodule:: xarray +.. _terminology: + +Terminology +=========== + +*Xarray terminology differs slightly from CF, mathematical conventions, and +pandas; so we've put together a glossary of its terms. Here,* ``arr`` +*refers to an xarray* :py:class:`DataArray` *in the examples. For more +complete examples, please consult the relevant documentation.* + +.. glossary:: + + DataArray + A multi-dimensional array with labeled or named + dimensions. ``DataArray`` objects add metadata such as dimension names, + coordinates, and attributes (defined below) to underlying "unlabeled" + data structures such as numpy and Dask arrays. If its optional ``name`` + property is set, it is a *named DataArray*. + + Dataset + A dict-like collection of ``DataArray`` objects with aligned + dimensions. Thus, most operations that can be performed on the + dimensions of a single ``DataArray`` can be performed on a + dataset. Datasets have data variables (see **Variable** below), + dimensions, coordinates, and attributes. + + Variable + A `NetCDF-like variable + `_ + consisting of dimensions, data, and attributes which describe a single + array. The main functional difference between variables and numpy arrays + is that numerical operations on variables implement array broadcasting + by dimension name. Each ``DataArray`` has an underlying variable that + can be accessed via ``arr.variable``. However, a variable is not fully + described outside of either a ``Dataset`` or a ``DataArray``. + + .. note:: + + The :py:class:`Variable` class is low-level interface and can + typically be ignored. However, the word "variable" appears often + enough in the code and documentation that is useful to understand. + + Dimension + In mathematics, the *dimension* of data is loosely the number of degrees + of freedom for it. A *dimension axis* is a set of all points in which + all but one of these degrees of freedom is fixed. We can think of each + dimension axis as having a name, for example the "x dimension". In + xarray, a ``DataArray`` object's *dimensions* are its named dimension + axes ``da.dims``, and the name of the ``i``-th dimension is ``da.dims[i]``. + If an array is created without specifying dimension names, the default dimension + names will be ``dim_0``, ``dim_1``, and so forth. + + Coordinate + An array that labels a dimension or set of dimensions of another + ``DataArray``. In the usual one-dimensional case, the coordinate array's + values can loosely be thought of as tick labels along a dimension. We + distinguish :term:`Dimension coordinate` vs. :term:`Non-dimension + coordinate` and :term:`Indexed coordinate` vs. :term:`Non-indexed + coordinate`. A coordinate named ``x`` can be retrieved from + ``arr.coords[x]``. A ``DataArray`` can have more coordinates than + dimensions because a single dimension can be labeled by multiple + coordinate arrays. However, only one coordinate array can be a assigned + as a particular dimension's dimension coordinate array. + + Dimension coordinate + A one-dimensional coordinate array assigned to ``arr`` with both a name + and dimension name in ``arr.dims``. Usually (but not always), a + dimension coordinate is also an :term:`Indexed coordinate` so that it can + be used for label-based indexing and alignment, like the index found on + a :py:class:`pandas.DataFrame` or :py:class:`pandas.Series`. + + Non-dimension coordinate + A coordinate array assigned to ``arr`` with a name in ``arr.coords`` but + *not* in ``arr.dims``. These coordinates arrays can be one-dimensional + or multidimensional, and they are useful for auxiliary labeling. As an + example, multidimensional coordinates are often used in geoscience + datasets when :doc:`the data's physical coordinates (such as latitude + and longitude) differ from their logical coordinates + <../examples/multidimensional-coords>`. Printing ``arr.coords`` will + print all of ``arr``'s coordinate names, with the corresponding + dimension(s) in parentheses. For example, ``coord_name (dim_name) 1 2 3 + ...``. + + Indexed coordinate + A coordinate which has an associated :term:`Index`. Generally this means + that the coordinate labels can be used for indexing (selection) and/or + alignment. An indexed coordinate may have one or more arbitrary + dimensions although in most cases it is also a :term:`Dimension + coordinate`. It may or may not be grouped with other indexed coordinates + depending on whether they share the same index. Indexed coordinates are + marked by an asterisk ``*`` when printing a ``DataArray`` or ``Dataset``. + + Non-indexed coordinate + A coordinate which has no associated :term:`Index`. It may still + represent fixed labels along one or more dimensions but it cannot be + used for label-based indexing and alignment. + + Index + An *index* is a data structure optimized for efficient data selection + and alignment within a discrete or continuous space that is defined by + coordinate labels (unless it is a functional index). By default, Xarray + creates a :py:class:`~xarray.indexes.PandasIndex` object (i.e., a + :py:class:`pandas.Index` wrapper) for each :term:`Dimension coordinate`. + For more advanced use cases (e.g., staggered or irregular grids, + geospatial indexes), Xarray also accepts any instance of a specialized + :py:class:`~xarray.indexes.Index` subclass that is associated to one or + more arbitrary coordinates. The index associated with the coordinate + ``x`` can be retrieved by ``arr.xindexes[x]`` (or ``arr.indexes["x"]`` + if the index is convertible to a :py:class:`pandas.Index` object). If + two coordinates ``x`` and ``y`` share the same index, + ``arr.xindexes[x]`` and ``arr.xindexes[y]`` both return the same + :py:class:`~xarray.indexes.Index` object. + + name + The names of dimensions, coordinates, DataArray objects and data + variables can be anything as long as they are :term:`hashable`. However, + it is preferred to use :py:class:`str` typed names. + + scalar + By definition, a scalar is not an :term:`array` and when converted to + one, it has 0 dimensions. That means that, e.g., :py:class:`int`, + :py:class:`float`, and :py:class:`str` objects are "scalar" while + :py:class:`list` or :py:class:`tuple` are not. + + duck array + `Duck arrays`__ are array implementations that behave + like numpy arrays. They have to define the ``shape``, ``dtype`` and + ``ndim`` properties. For integration with ``xarray``, the ``__array__``, + ``__array_ufunc__`` and ``__array_function__`` protocols are also required. + + __ https://numpy.org/neps/nep-0022-ndarray-duck-typing-overview.html + + .. ipython:: python + :suppress: + + import numpy as np + import xarray as xr + + Aligning + Aligning refers to the process of ensuring that two or more DataArrays or Datasets + have the same dimensions and coordinates, so that they can be combined or compared properly. + + .. ipython:: python + + x = xr.DataArray( + [[25, 35], [10, 24]], + dims=("lat", "lon"), + coords={"lat": [35.0, 40.0], "lon": [100.0, 120.0]}, + ) + y = xr.DataArray( + [[20, 5], [7, 13]], + dims=("lat", "lon"), + coords={"lat": [35.0, 42.0], "lon": [100.0, 120.0]}, + ) + x + y + + Broadcasting + A technique that allows operations to be performed on arrays with different shapes and dimensions. + When performing operations on arrays with different shapes and dimensions, xarray will automatically attempt to broadcast the + arrays to a common shape before the operation is applied. + + .. ipython:: python + + # 'a' has shape (3,) and 'b' has shape (4,) + a = xr.DataArray(np.array([1, 2, 3]), dims=["x"]) + b = xr.DataArray(np.array([4, 5, 6, 7]), dims=["y"]) + + # 2D array with shape (3, 4) + a + b + + Merging + Merging is used to combine two or more Datasets or DataArrays that have different variables or coordinates along + the same dimensions. When merging, xarray aligns the variables and coordinates of the different datasets along + the specified dimensions and creates a new ``Dataset`` containing all the variables and coordinates. + + .. ipython:: python + + # create two 1D arrays with names + arr1 = xr.DataArray( + [1, 2, 3], dims=["x"], coords={"x": [10, 20, 30]}, name="arr1" + ) + arr2 = xr.DataArray( + [4, 5, 6], dims=["x"], coords={"x": [20, 30, 40]}, name="arr2" + ) + + # merge the two arrays into a new dataset + merged_ds = xr.Dataset({"arr1": arr1, "arr2": arr2}) + merged_ds + + Concatenating + Concatenating is used to combine two or more Datasets or DataArrays along a dimension. When concatenating, + xarray arranges the datasets or dataarrays along a new dimension, and the resulting ``Dataset`` or ``Dataarray`` + will have the same variables and coordinates along the other dimensions. + + .. ipython:: python + + a = xr.DataArray([[1, 2], [3, 4]], dims=("x", "y")) + b = xr.DataArray([[5, 6], [7, 8]], dims=("x", "y")) + c = xr.concat([a, b], dim="c") + c + + Combining + Combining is the process of arranging two or more DataArrays or Datasets into a single ``DataArray`` or + ``Dataset`` using some combination of merging and concatenation operations. + + .. ipython:: python + + ds1 = xr.Dataset( + {"data": xr.DataArray([[1, 2], [3, 4]], dims=("x", "y"))}, + coords={"x": [1, 2], "y": [3, 4]}, + ) + ds2 = xr.Dataset( + {"data": xr.DataArray([[5, 6], [7, 8]], dims=("x", "y"))}, + coords={"x": [2, 3], "y": [4, 5]}, + ) + + # combine the datasets + combined_ds = xr.combine_by_coords([ds1, ds2]) + combined_ds + + lazy + Lazily-evaluated operations do not load data into memory until necessary.Instead of doing calculations + right away, xarray lets you plan what calculations you want to do, like finding the + average temperature in a dataset.This planning is called "lazy evaluation." Later, when + you're ready to see the final result, you tell xarray, "Okay, go ahead and do those calculations now!" + That's when xarray starts working through the steps you planned and gives you the answer you wanted.This + lazy approach helps save time and memory because xarray only does the work when you actually need the + results. + + labeled + Labeled data has metadata describing the context of the data, not just the raw data values. + This contextual information can be labels for array axes (i.e. dimension names) tick labels along axes (stored as Coordinate variables) or unique names for each array. These labels + provide context and meaning to the data, making it easier to understand and work with. If you have + temperature data for different cities over time. Using xarray, you can label the dimensions: one for + cities and another for time. + + serialization + Serialization is the process of converting your data into a format that makes it easy to save and share. + When you serialize data in xarray, you're taking all those temperature measurements, along with their + labels and other information, and turning them into a format that can be stored in a file or sent over + the internet. xarray objects can be serialized into formats which store the labels alongside the data. + Some supported serialization formats are files that can then be stored or transferred (e.g. netCDF), + whilst others are protocols that allow for data access over a network (e.g. Zarr). + + indexing + :ref:`Indexing` is how you select subsets of your data which you are interested in. + + - Label-based Indexing: Selecting data by passing a specific label and comparing it to the labels + stored in the associated coordinates. You can use labels to specify what you want like "Give me the + temperature for New York on July 15th." + + - Positional Indexing: You can use numbers to refer to positions in the data like "Give me the third temperature value" This is useful when you know the order of your data but don't need to remember the exact labels. + + - Slicing: You can take a "slice" of your data, like you might want all temperatures from July 1st + to July 10th. xarray supports slicing for both positional and label-based indexing. diff --git a/test/fixtures/whole_applications/xarray/doc/user-guide/testing.rst b/test/fixtures/whole_applications/xarray/doc/user-guide/testing.rst new file mode 100644 index 0000000..13279ec --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/user-guide/testing.rst @@ -0,0 +1,303 @@ +.. _testing: + +Testing your code +================= + +.. ipython:: python + :suppress: + + import numpy as np + import pandas as pd + import xarray as xr + + np.random.seed(123456) + +.. _testing.hypothesis: + +Hypothesis testing +------------------ + +.. note:: + + Testing with hypothesis is a fairly advanced topic. Before reading this section it is recommended that you take a look + at our guide to xarray's :ref:`data structures`, are familiar with conventional unit testing in + `pytest `_, and have seen the + `hypothesis library documentation `_. + +`The hypothesis library `_ is a powerful tool for property-based testing. +Instead of writing tests for one example at a time, it allows you to write tests parameterized by a source of many +dynamically generated examples. For example you might have written a test which you wish to be parameterized by the set +of all possible integers via :py:func:`hypothesis.strategies.integers()`. + +Property-based testing is extremely powerful, because (unlike more conventional example-based testing) it can find bugs +that you did not even think to look for! + +Strategies +~~~~~~~~~~ + +Each source of examples is called a "strategy", and xarray provides a range of custom strategies which produce xarray +data structures containing arbitrary data. You can use these to efficiently test downstream code, +quickly ensuring that your code can handle xarray objects of all possible structures and contents. + +These strategies are accessible in the :py:mod:`xarray.testing.strategies` module, which provides + +.. currentmodule:: xarray + +.. autosummary:: + + testing.strategies.supported_dtypes + testing.strategies.names + testing.strategies.dimension_names + testing.strategies.dimension_sizes + testing.strategies.attrs + testing.strategies.variables + testing.strategies.unique_subset_of + +These build upon the numpy and array API strategies offered in :py:mod:`hypothesis.extra.numpy` and :py:mod:`hypothesis.extra.array_api`: + +.. ipython:: python + + import hypothesis.extra.numpy as npst + +Generating Examples +~~~~~~~~~~~~~~~~~~~ + +To see an example of what each of these strategies might produce, you can call one followed by the ``.example()`` method, +which is a general hypothesis method valid for all strategies. + +.. ipython:: python + + import xarray.testing.strategies as xrst + + xrst.variables().example() + xrst.variables().example() + xrst.variables().example() + +You can see that calling ``.example()`` multiple times will generate different examples, giving you an idea of the wide +range of data that the xarray strategies can generate. + +In your tests however you should not use ``.example()`` - instead you should parameterize your tests with the +:py:func:`hypothesis.given` decorator: + +.. ipython:: python + + from hypothesis import given + +.. ipython:: python + + @given(xrst.variables()) + def test_function_that_acts_on_variables(var): + assert func(var) == ... + + +Chaining Strategies +~~~~~~~~~~~~~~~~~~~ + +Xarray's strategies can accept other strategies as arguments, allowing you to customise the contents of the generated +examples. + +.. ipython:: python + + # generate a Variable containing an array with a complex number dtype, but all other details still arbitrary + from hypothesis.extra.numpy import complex_number_dtypes + + xrst.variables(dtype=complex_number_dtypes()).example() + +This also works with custom strategies, or strategies defined in other packages. +For example you could imagine creating a ``chunks`` strategy to specify particular chunking patterns for a dask-backed array. + +Fixing Arguments +~~~~~~~~~~~~~~~~ + +If you want to fix one aspect of the data structure, whilst allowing variation in the generated examples +over all other aspects, then use :py:func:`hypothesis.strategies.just()`. + +.. ipython:: python + + import hypothesis.strategies as st + + # Generates only variable objects with dimensions ["x", "y"] + xrst.variables(dims=st.just(["x", "y"])).example() + +(This is technically another example of chaining strategies - :py:func:`hypothesis.strategies.just()` is simply a +special strategy that just contains a single example.) + +To fix the length of dimensions you can instead pass ``dims`` as a mapping of dimension names to lengths +(i.e. following xarray objects' ``.sizes()`` property), e.g. + +.. ipython:: python + + # Generates only variables with dimensions ["x", "y"], of lengths 2 & 3 respectively + xrst.variables(dims=st.just({"x": 2, "y": 3})).example() + +You can also use this to specify that you want examples which are missing some part of the data structure, for instance + +.. ipython:: python + + # Generates a Variable with no attributes + xrst.variables(attrs=st.just({})).example() + +Through a combination of chaining strategies and fixing arguments, you can specify quite complicated requirements on the +objects your chained strategy will generate. + +.. ipython:: python + + fixed_x_variable_y_maybe_z = st.fixed_dictionaries( + {"x": st.just(2), "y": st.integers(3, 4)}, optional={"z": st.just(2)} + ) + fixed_x_variable_y_maybe_z.example() + + special_variables = xrst.variables(dims=fixed_x_variable_y_maybe_z) + + special_variables.example() + special_variables.example() + +Here we have used one of hypothesis' built-in strategies :py:func:`hypothesis.strategies.fixed_dictionaries` to create a +strategy which generates mappings of dimension names to lengths (i.e. the ``size`` of the xarray object we want). +This particular strategy will always generate an ``x`` dimension of length 2, and a ``y`` dimension of +length either 3 or 4, and will sometimes also generate a ``z`` dimension of length 2. +By feeding this strategy for dictionaries into the ``dims`` argument of xarray's :py:func:`~st.variables` strategy, +we can generate arbitrary :py:class:`~xarray.Variable` objects whose dimensions will always match these specifications. + +Generating Duck-type Arrays +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Xarray objects don't have to wrap numpy arrays, in fact they can wrap any array type which presents the same API as a +numpy array (so-called "duck array wrapping", see :ref:`wrapping numpy-like arrays `). + +Imagine we want to write a strategy which generates arbitrary ``Variable`` objects, each of which wraps a +:py:class:`sparse.COO` array instead of a ``numpy.ndarray``. How could we do that? There are two ways: + +1. Create a xarray object with numpy data and use the hypothesis' ``.map()`` method to convert the underlying array to a +different type: + +.. ipython:: python + + import sparse + +.. ipython:: python + + def convert_to_sparse(var): + return var.copy(data=sparse.COO.from_numpy(var.to_numpy())) + +.. ipython:: python + + sparse_variables = xrst.variables(dims=xrst.dimension_names(min_dims=1)).map( + convert_to_sparse + ) + + sparse_variables.example() + sparse_variables.example() + +2. Pass a function which returns a strategy which generates the duck-typed arrays directly to the ``array_strategy_fn`` argument of the xarray strategies: + +.. ipython:: python + + def sparse_random_arrays(shape: tuple[int]) -> sparse._coo.core.COO: + """Strategy which generates random sparse.COO arrays""" + if shape is None: + shape = npst.array_shapes() + else: + shape = st.just(shape) + density = st.integers(min_value=0, max_value=1) + # note sparse.random does not accept a dtype kwarg + return st.builds(sparse.random, shape=shape, density=density) + + + def sparse_random_arrays_fn( + *, shape: tuple[int, ...], dtype: np.dtype + ) -> st.SearchStrategy[sparse._coo.core.COO]: + return sparse_random_arrays(shape=shape) + + +.. ipython:: python + + sparse_random_variables = xrst.variables( + array_strategy_fn=sparse_random_arrays_fn, dtype=st.just(np.dtype("float64")) + ) + sparse_random_variables.example() + +Either approach is fine, but one may be more convenient than the other depending on the type of the duck array which you +want to wrap. + +Compatibility with the Python Array API Standard +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Xarray aims to be compatible with any duck-array type that conforms to the `Python Array API Standard `_ +(see our :ref:`docs on Array API Standard support `). + +.. warning:: + + The strategies defined in :py:mod:`testing.strategies` are **not** guaranteed to use array API standard-compliant + dtypes by default. + For example arrays with the dtype ``np.dtype('float16')`` may be generated by :py:func:`testing.strategies.variables` + (assuming the ``dtype`` kwarg was not explicitly passed), despite ``np.dtype('float16')`` not being in the + array API standard. + +If the array type you want to generate has an array API-compliant top-level namespace +(e.g. that which is conventionally imported as ``xp`` or similar), +you can use this neat trick: + +.. ipython:: python + :okwarning: + + from numpy import array_api as xp # available in numpy 1.26.0 + + from hypothesis.extra.array_api import make_strategies_namespace + + xps = make_strategies_namespace(xp) + + xp_variables = xrst.variables( + array_strategy_fn=xps.arrays, + dtype=xps.scalar_dtypes(), + ) + xp_variables.example() + +Another array API-compliant duck array library would replace the import, e.g. ``import cupy as cp`` instead. + +Testing over Subsets of Dimensions +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +A common task when testing xarray user code is checking that your function works for all valid input dimensions. +We can chain strategies to achieve this, for which the helper strategy :py:func:`~testing.strategies.unique_subset_of` +is useful. + +It works for lists of dimension names + +.. ipython:: python + + dims = ["x", "y", "z"] + xrst.unique_subset_of(dims).example() + xrst.unique_subset_of(dims).example() + +as well as for mappings of dimension names to sizes + +.. ipython:: python + + dim_sizes = {"x": 2, "y": 3, "z": 4} + xrst.unique_subset_of(dim_sizes).example() + xrst.unique_subset_of(dim_sizes).example() + +This is useful because operations like reductions can be performed over any subset of the xarray object's dimensions. +For example we can write a pytest test that tests that a reduction gives the expected result when applying that reduction +along any possible valid subset of the Variable's dimensions. + +.. code-block:: python + + import numpy.testing as npt + + + @given(st.data(), xrst.variables(dims=xrst.dimension_names(min_dims=1))) + def test_mean(data, var): + """Test that the mean of an xarray Variable is always equal to the mean of the underlying array.""" + + # specify arbitrary reduction along at least one dimension + reduction_dims = data.draw(xrst.unique_subset_of(var.dims, min_size=1)) + + # create expected result (using nanmean because arrays with Nans will be generated) + reduction_axes = tuple(var.get_axis_num(dim) for dim in reduction_dims) + expected = np.nanmean(var.data, axis=reduction_axes) + + # assert property is always satisfied + result = var.mean(dim=reduction_dims).data + npt.assert_equal(expected, result) diff --git a/test/fixtures/whole_applications/xarray/doc/user-guide/time-series.rst b/test/fixtures/whole_applications/xarray/doc/user-guide/time-series.rst new file mode 100644 index 0000000..82172aa --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/user-guide/time-series.rst @@ -0,0 +1,262 @@ +.. _time-series: + +================ +Time series data +================ + +A major use case for xarray is multi-dimensional time-series data. +Accordingly, we've copied many of features that make working with time-series +data in pandas such a joy to xarray. In most cases, we rely on pandas for the +core functionality. + +.. ipython:: python + :suppress: + + import numpy as np + import pandas as pd + import xarray as xr + + np.random.seed(123456) + +Creating datetime64 data +------------------------ + +Xarray uses the numpy dtypes ``datetime64[ns]`` and ``timedelta64[ns]`` to +represent datetime data, which offer vectorized (if sometimes buggy) operations +with numpy and smooth integration with pandas. + +To convert to or create regular arrays of ``datetime64`` data, we recommend +using :py:func:`pandas.to_datetime` and :py:func:`pandas.date_range`: + +.. ipython:: python + + pd.to_datetime(["2000-01-01", "2000-02-02"]) + pd.date_range("2000-01-01", periods=365) + +Alternatively, you can supply arrays of Python ``datetime`` objects. These get +converted automatically when used as arguments in xarray objects: + +.. ipython:: python + + import datetime + + xr.Dataset({"time": datetime.datetime(2000, 1, 1)}) + +When reading or writing netCDF files, xarray automatically decodes datetime and +timedelta arrays using `CF conventions`_ (that is, by using a ``units`` +attribute like ``'days since 2000-01-01'``). + +.. _CF conventions: https://cfconventions.org + +.. note:: + + When decoding/encoding datetimes for non-standard calendars or for dates + before year 1678 or after year 2262, xarray uses the `cftime`_ library. + It was previously packaged with the ``netcdf4-python`` package under the + name ``netcdftime`` but is now distributed separately. ``cftime`` is an + :ref:`optional dependency` of xarray. + +.. _cftime: https://unidata.github.io/cftime + + +You can manual decode arrays in this form by passing a dataset to +:py:func:`~xarray.decode_cf`: + +.. ipython:: python + + attrs = {"units": "hours since 2000-01-01"} + ds = xr.Dataset({"time": ("time", [0, 1, 2, 3], attrs)}) + xr.decode_cf(ds) + +One unfortunate limitation of using ``datetime64[ns]`` is that it limits the +native representation of dates to those that fall between the years 1678 and +2262. When a netCDF file contains dates outside of these bounds, dates will be +returned as arrays of :py:class:`cftime.datetime` objects and a :py:class:`~xarray.CFTimeIndex` +will be used for indexing. :py:class:`~xarray.CFTimeIndex` enables a subset of +the indexing functionality of a :py:class:`pandas.DatetimeIndex` and is only +fully compatible with the standalone version of ``cftime`` (not the version +packaged with earlier versions ``netCDF4``). See :ref:`CFTimeIndex` for more +information. + +Datetime indexing +----------------- + +Xarray borrows powerful indexing machinery from pandas (see :ref:`indexing`). + +This allows for several useful and succinct forms of indexing, particularly for +`datetime64` data. For example, we support indexing with strings for single +items and with the `slice` object: + +.. ipython:: python + + time = pd.date_range("2000-01-01", freq="h", periods=365 * 24) + ds = xr.Dataset({"foo": ("time", np.arange(365 * 24)), "time": time}) + ds.sel(time="2000-01") + ds.sel(time=slice("2000-06-01", "2000-06-10")) + +You can also select a particular time by indexing with a +:py:class:`datetime.time` object: + +.. ipython:: python + + ds.sel(time=datetime.time(12)) + +For more details, read the pandas documentation and the section on :ref:`datetime_component_indexing` (i.e. using the ``.dt`` accessor). + +.. _dt_accessor: + +Datetime components +------------------- + +Similar to `pandas accessors`_, the components of datetime objects contained in a +given ``DataArray`` can be quickly computed using a special ``.dt`` accessor. + +.. _pandas accessors: https://pandas.pydata.org/pandas-docs/stable/basics.html#basics-dt-accessors + +.. ipython:: python + + time = pd.date_range("2000-01-01", freq="6h", periods=365 * 4) + ds = xr.Dataset({"foo": ("time", np.arange(365 * 4)), "time": time}) + ds.time.dt.hour + ds.time.dt.dayofweek + +The ``.dt`` accessor works on both coordinate dimensions as well as +multi-dimensional data. + +Xarray also supports a notion of "virtual" or "derived" coordinates for +`datetime components`__ implemented by pandas, including "year", "month", +"day", "hour", "minute", "second", "dayofyear", "week", "dayofweek", "weekday" +and "quarter": + +__ https://pandas.pydata.org/pandas-docs/stable/api.html#time-date-components + +.. ipython:: python + + ds["time.month"] + ds["time.dayofyear"] + +For use as a derived coordinate, xarray adds ``'season'`` to the list of +datetime components supported by pandas: + +.. ipython:: python + + ds["time.season"] + ds["time"].dt.season + +The set of valid seasons consists of 'DJF', 'MAM', 'JJA' and 'SON', labeled by +the first letters of the corresponding months. + +You can use these shortcuts with both Datasets and DataArray coordinates. + +In addition, xarray supports rounding operations ``floor``, ``ceil``, and ``round``. These operations require that you supply a `rounding frequency as a string argument.`__ + +__ https://pandas.pydata.org/pandas-docs/stable/timeseries.html#offset-aliases + +.. ipython:: python + + ds["time"].dt.floor("D") + +The ``.dt`` accessor can also be used to generate formatted datetime strings +for arrays utilising the same formatting as the standard `datetime.strftime`_. + +.. _datetime.strftime: https://docs.python.org/3/library/datetime.html#strftime-strptime-behavior + +.. ipython:: python + + ds["time"].dt.strftime("%a, %b %d %H:%M") + +.. _datetime_component_indexing: + +Indexing Using Datetime Components +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +You can use use the ``.dt`` accessor when subsetting your data as well. For example, we can subset for the month of January using the following: + +.. ipython:: python + + ds.isel(time=(ds.time.dt.month == 1)) + +You can also search for multiple months (in this case January through March), using ``isin``: + +.. ipython:: python + + ds.isel(time=ds.time.dt.month.isin([1, 2, 3])) + +.. _resampling: + +Resampling and grouped operations +--------------------------------- + +Datetime components couple particularly well with grouped operations (see +:ref:`groupby`) for analyzing features that repeat over time. Here's how to +calculate the mean by time of day: + +.. ipython:: python + :okwarning: + + ds.groupby("time.hour").mean() + +For upsampling or downsampling temporal resolutions, xarray offers a +:py:meth:`~xarray.Dataset.resample` method building on the core functionality +offered by the pandas method of the same name. Resample uses essentially the +same api as ``resample`` `in pandas`_. + +.. _in pandas: https://pandas.pydata.org/pandas-docs/stable/timeseries.html#up-and-downsampling + +For example, we can downsample our dataset from hourly to 6-hourly: + +.. ipython:: python + :okwarning: + + ds.resample(time="6h") + +This will create a specialized ``Resample`` object which saves information +necessary for resampling. All of the reduction methods which work with +``Resample`` objects can also be used for resampling: + +.. ipython:: python + :okwarning: + + ds.resample(time="6h").mean() + +You can also supply an arbitrary reduction function to aggregate over each +resampling group: + +.. ipython:: python + + ds.resample(time="6h").reduce(np.mean) + +You can also resample on the time dimension while applying reducing along other dimensions at the same time +by specifying the `dim` keyword argument + +.. code-block:: python + + ds.resample(time="6h").mean(dim=["time", "latitude", "longitude"]) + +For upsampling, xarray provides six methods: ``asfreq``, ``ffill``, ``bfill``, ``pad``, +``nearest`` and ``interpolate``. ``interpolate`` extends ``scipy.interpolate.interp1d`` +and supports all of its schemes. All of these resampling operations work on both +Dataset and DataArray objects with an arbitrary number of dimensions. + +In order to limit the scope of the methods ``ffill``, ``bfill``, ``pad`` and +``nearest`` the ``tolerance`` argument can be set in coordinate units. +Data that has indices outside of the given ``tolerance`` are set to ``NaN``. + +.. ipython:: python + + ds.resample(time="1h").nearest(tolerance="1h") + +It is often desirable to center the time values after a resampling operation. +That can be accomplished by updating the resampled dataset time coordinate values +using time offset arithmetic via the `pandas.tseries.frequencies.to_offset`_ function. + +.. _pandas.tseries.frequencies.to_offset: https://pandas.pydata.org/docs/reference/api/pandas.tseries.frequencies.to_offset.html + +.. ipython:: python + + resampled_ds = ds.resample(time="6h").mean() + offset = pd.tseries.frequencies.to_offset("6h") / 2 + resampled_ds["time"] = resampled_ds.get_index("time") + offset + resampled_ds + +For more examples of using grouped operations on a time dimension, see +:doc:`../examples/weather-data`. diff --git a/test/fixtures/whole_applications/xarray/doc/user-guide/weather-climate.rst b/test/fixtures/whole_applications/xarray/doc/user-guide/weather-climate.rst new file mode 100644 index 0000000..5014f5a --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/user-guide/weather-climate.rst @@ -0,0 +1,246 @@ +.. currentmodule:: xarray + +.. _weather-climate: + +Weather and climate data +======================== + +.. ipython:: python + :suppress: + + import xarray as xr + +Xarray can leverage metadata that follows the `Climate and Forecast (CF) conventions`_ if present. Examples include :ref:`automatic labelling of plots` with descriptive names and units if proper metadata is present and support for non-standard calendars used in climate science through the ``cftime`` module(Explained in the :ref:`CFTimeIndex` section). There are also a number of :ref:`geosciences-focused projects that build on xarray`. + +.. _Climate and Forecast (CF) conventions: https://cfconventions.org + +.. _cf_variables: + +Related Variables +----------------- + +Several CF variable attributes contain lists of other variables +associated with the variable with the attribute. A few of these are +now parsed by xarray, with the attribute value popped to encoding on +read and the variables in that value interpreted as non-dimension +coordinates: + +- ``coordinates`` +- ``bounds`` +- ``grid_mapping`` +- ``climatology`` +- ``geometry`` +- ``node_coordinates`` +- ``node_count`` +- ``part_node_count`` +- ``interior_ring`` +- ``cell_measures`` +- ``formula_terms`` + +This decoding is controlled by the ``decode_coords`` kwarg to +:py:func:`open_dataset` and :py:func:`open_mfdataset`. + +The CF attribute ``ancillary_variables`` was not included in the list +due to the variables listed there being associated primarily with the +variable with the attribute, rather than with the dimensions. + +.. _metpy_accessor: + +CF-compliant coordinate variables +--------------------------------- + +`MetPy`_ adds a ``metpy`` accessor that allows accessing coordinates with appropriate CF metadata using generic names ``x``, ``y``, ``vertical`` and ``time``. There is also a `cartopy_crs` attribute that provides projection information, parsed from the appropriate CF metadata, as a `Cartopy`_ projection object. See the `metpy documentation`_ for more information. + +.. _`MetPy`: https://unidata.github.io/MetPy/dev/index.html +.. _`metpy documentation`: https://unidata.github.io/MetPy/dev/tutorials/xarray_tutorial.html#coordinates +.. _`Cartopy`: https://scitools.org.uk/cartopy/docs/latest/crs/projections.html + +.. _CFTimeIndex: + +Non-standard calendars and dates outside the nanosecond-precision range +----------------------------------------------------------------------- + +Through the standalone ``cftime`` library and a custom subclass of +:py:class:`pandas.Index`, xarray supports a subset of the indexing +functionality enabled through the standard :py:class:`pandas.DatetimeIndex` for +dates from non-standard calendars commonly used in climate science or dates +using a standard calendar, but outside the `nanosecond-precision range`_ +(approximately between years 1678 and 2262). + +.. note:: + + As of xarray version 0.11, by default, :py:class:`cftime.datetime` objects + will be used to represent times (either in indexes, as a + :py:class:`~xarray.CFTimeIndex`, or in data arrays with dtype object) if + any of the following are true: + + - The dates are from a non-standard calendar + - Any dates are outside the nanosecond-precision range. + + Otherwise pandas-compatible dates from a standard calendar will be + represented with the ``np.datetime64[ns]`` data type, enabling the use of a + :py:class:`pandas.DatetimeIndex` or arrays with dtype ``np.datetime64[ns]`` + and their full set of associated features. + + As of pandas version 2.0.0, pandas supports non-nanosecond precision datetime + values. For the time being, xarray still automatically casts datetime values + to nanosecond-precision for backwards compatibility with older pandas + versions; however, this is something we would like to relax going forward. + See :issue:`7493` for more discussion. + +For example, you can create a DataArray indexed by a time +coordinate with dates from a no-leap calendar and a +:py:class:`~xarray.CFTimeIndex` will automatically be used: + +.. ipython:: python + + from itertools import product + from cftime import DatetimeNoLeap + + dates = [ + DatetimeNoLeap(year, month, 1) + for year, month in product(range(1, 3), range(1, 13)) + ] + da = xr.DataArray(np.arange(24), coords=[dates], dims=["time"], name="foo") + +Xarray also includes a :py:func:`~xarray.cftime_range` function, which enables +creating a :py:class:`~xarray.CFTimeIndex` with regularly-spaced dates. For +instance, we can create the same dates and DataArray we created above using: + +.. ipython:: python + + dates = xr.cftime_range(start="0001", periods=24, freq="MS", calendar="noleap") + da = xr.DataArray(np.arange(24), coords=[dates], dims=["time"], name="foo") + +Mirroring pandas' method with the same name, :py:meth:`~xarray.infer_freq` allows one to +infer the sampling frequency of a :py:class:`~xarray.CFTimeIndex` or a 1-D +:py:class:`~xarray.DataArray` containing cftime objects. It also works transparently with +``np.datetime64[ns]`` and ``np.timedelta64[ns]`` data. + +.. ipython:: python + + xr.infer_freq(dates) + +With :py:meth:`~xarray.CFTimeIndex.strftime` we can also easily generate formatted strings from +the datetime values of a :py:class:`~xarray.CFTimeIndex` directly or through the +``dt`` accessor for a :py:class:`~xarray.DataArray` +using the same formatting as the standard `datetime.strftime`_ convention . + +.. _datetime.strftime: https://docs.python.org/3/library/datetime.html#strftime-strptime-behavior + +.. ipython:: python + + dates.strftime("%c") + da["time"].dt.strftime("%Y%m%d") + +Conversion between non-standard calendar and to/from pandas DatetimeIndexes is +facilitated with the :py:meth:`xarray.Dataset.convert_calendar` method (also available as +:py:meth:`xarray.DataArray.convert_calendar`). Here, like elsewhere in xarray, the ``use_cftime`` +argument controls which datetime backend is used in the output. The default (``None``) is to +use `pandas` when possible, i.e. when the calendar is standard and dates are within 1678 and 2262. + +.. ipython:: python + + dates = xr.cftime_range(start="2001", periods=24, freq="MS", calendar="noleap") + da_nl = xr.DataArray(np.arange(24), coords=[dates], dims=["time"], name="foo") + da_std = da.convert_calendar("standard", use_cftime=True) + +The data is unchanged, only the timestamps are modified. Further options are implemented +for the special ``"360_day"`` calendar and for handling missing dates. There is also +:py:meth:`xarray.Dataset.interp_calendar` (and :py:meth:`xarray.DataArray.interp_calendar`) +for `interpolating` data between calendars. + +For data indexed by a :py:class:`~xarray.CFTimeIndex` xarray currently supports: + +- `Partial datetime string indexing`_: + +.. ipython:: python + + da.sel(time="0001") + da.sel(time=slice("0001-05", "0002-02")) + +.. note:: + + + For specifying full or partial datetime strings in cftime + indexing, xarray supports two versions of the `ISO 8601 standard`_, the + basic pattern (YYYYMMDDhhmmss) or the extended pattern + (YYYY-MM-DDThh:mm:ss), as well as the default cftime string format + (YYYY-MM-DD hh:mm:ss). This is somewhat more restrictive than pandas; + in other words, some datetime strings that would be valid for a + :py:class:`pandas.DatetimeIndex` are not valid for an + :py:class:`~xarray.CFTimeIndex`. + +- Access of basic datetime components via the ``dt`` accessor (in this case + just "year", "month", "day", "hour", "minute", "second", "microsecond", + "season", "dayofyear", "dayofweek", and "days_in_month") with the addition + of "calendar", absent from pandas: + +.. ipython:: python + + da.time.dt.year + da.time.dt.month + da.time.dt.season + da.time.dt.dayofyear + da.time.dt.dayofweek + da.time.dt.days_in_month + da.time.dt.calendar + +- Rounding of datetimes to fixed frequencies via the ``dt`` accessor: + +.. ipython:: python + + da.time.dt.ceil("3D") + da.time.dt.floor("5D") + da.time.dt.round("2D") + +- Group-by operations based on datetime accessor attributes (e.g. by month of + the year): + +.. ipython:: python + + da.groupby("time.month").sum() + +- Interpolation using :py:class:`cftime.datetime` objects: + +.. ipython:: python + + da.interp(time=[DatetimeNoLeap(1, 1, 15), DatetimeNoLeap(1, 2, 15)]) + +- Interpolation using datetime strings: + +.. ipython:: python + + da.interp(time=["0001-01-15", "0001-02-15"]) + +- Differentiation: + +.. ipython:: python + + da.differentiate("time") + +- Serialization: + +.. ipython:: python + + da.to_netcdf("example-no-leap.nc") + reopened = xr.open_dataset("example-no-leap.nc") + reopened + +.. ipython:: python + :suppress: + + import os + + reopened.close() + os.remove("example-no-leap.nc") + +- And resampling along the time dimension for data indexed by a :py:class:`~xarray.CFTimeIndex`: + +.. ipython:: python + + da.resample(time="81min", closed="right", label="right", offset="3min").mean() + +.. _nanosecond-precision range: https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#timestamp-limitations +.. _ISO 8601 standard: https://en.wikipedia.org/wiki/ISO_8601 +.. _partial datetime string indexing: https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#partial-string-indexing diff --git a/test/fixtures/whole_applications/xarray/doc/videos.yml b/test/fixtures/whole_applications/xarray/doc/videos.yml new file mode 100644 index 0000000..62c8956 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/videos.yml @@ -0,0 +1,38 @@ +- title: "Xdev Python Tutorial Seminar Series 2022 Thinking with Xarray : High-level computation patterns" + src: '' + authors: + - Deepak Cherian +- title: "Xdev Python Tutorial Seminar Series 2021 seminar introducing xarray (2 of 2)" + src: '' + authors: + - Anderson Banihirwe + +- title: "Xdev Python Tutorial Seminar Series 2021 seminar introducing xarray (1 of 2)" + src: '' + authors: + - Anderson Banihirwe + +- title: "Xarray's 2020 virtual tutorial" + src: '' + authors: + - Anderson Banihirwe + - Deepak Cherian + - Martin Durant + +- title: "Xarray's Tutorial presented at the 2020 SciPy Conference" + src: ' ' + authors: + - Joe Hamman + - Deepak Cherian + - Ryan Abernathey + - Stephan Hoyer + +- title: "Scipy 2015 talk introducing xarray to a general audience" + src: '' + authors: + - Stephan Hoyer + +- title: " 2015 Unidata Users Workshop talk and tutorial with (`with answers`_) introducing xarray to users familiar with netCDF" + src: '' + authors: + - Stephan Hoyer diff --git a/test/fixtures/whole_applications/xarray/doc/whats-new.rst b/test/fixtures/whole_applications/xarray/doc/whats-new.rst new file mode 100644 index 0000000..6fec10b --- /dev/null +++ b/test/fixtures/whole_applications/xarray/doc/whats-new.rst @@ -0,0 +1,7900 @@ +.. currentmodule:: xarray + +What's New +========== + +.. ipython:: python + :suppress: + + import numpy as np + import pandas as pd + import xarray as xray + import xarray + import xarray as xr + + np.random.seed(123456) + + +.. _whats-new.2024.06.0: + +v2024.06.0 (Jun 13, 2024) +------------------------- +This release brings various performance optimizations and compatibility with the upcoming numpy 2.0 release. + +Thanks to the 22 contributors to this release: +Alfonso Ladino, David Hoese, Deepak Cherian, Eni Awowale, Ilan Gold, Jessica Scheick, Joe Hamman, Justus Magin, Kai Mühlbauer, Mark Harfouche, Mathias Hauser, Matt Savoie, Maximilian Roos, Mike Thramann, Nicolas Karasiak, Owen Littlejohns, Paul Ockenfuß, Philippe THOMY, Scott Henderson, Spencer Clark, Stephan Hoyer and Tom Nicholas + +Performance +~~~~~~~~~~~ + +- Small optimization to the netCDF4 and h5netcdf backends (:issue:`9058`, :pull:`9067`). + By `Deepak Cherian `_. +- Small optimizations to help reduce indexing speed of datasets (:pull:`9002`). + By `Mark Harfouche `_. +- Performance improvement in `open_datatree` method for Zarr, netCDF4 and h5netcdf backends (:issue:`8994`, :pull:`9014`). + By `Alfonso Ladino `_. + + +Bug fixes +~~~~~~~~~ +- Preserve conversion of timezone-aware pandas Datetime arrays to numpy object arrays + (:issue:`9026`, :pull:`9042`). + By `Ilan Gold `_. +- :py:meth:`DataArrayResample.interpolate` and :py:meth:`DatasetResample.interpolate` method now + support arbitrary kwargs such as ``order`` for polynomial interpolation (:issue:`8762`). + By `Nicolas Karasiak `_. + + +Documentation +~~~~~~~~~~~~~ +- Add link to CF Conventions on packed data and sentence on type determination in the I/O user guide (:issue:`9041`, :pull:`9045`). + By `Kai Mühlbauer `_. + + +Internal Changes +~~~~~~~~~~~~~~~~ +- Migrates remainder of ``io.py`` to ``xarray/core/datatree_io.py`` and + ``TreeAttrAccessMixin`` into ``xarray/core/common.py`` (:pull:`9011`). + By `Owen Littlejohns `_ and + `Tom Nicholas `_. +- Compatibility with numpy 2 (:issue:`8844`, :pull:`8854`, :pull:`8946`). + By `Justus Magin `_ and `Stephan Hoyer `_. + + +.. _whats-new.2024.05.0: + +v2024.05.0 (May 12, 2024) +------------------------- + +This release brings support for pandas ExtensionArray objects, optimizations when reading Zarr, the ability to concatenate datasets without pandas indexes, +more compatibility fixes for the upcoming numpy 2.0, and the migration of most of the xarray-datatree project code into xarray ``main``! + +Thanks to the 18 contributors to this release: +Aimilios Tsouvelekakis, Andrey Akinshin, Deepak Cherian, Eni Awowale, Ilan Gold, Illviljan, Justus Magin, Mark Harfouche, Matt Savoie, Maximilian Roos, Noah C. Benson, Pascal Bourgault, Ray Bell, Spencer Clark, Tom Nicholas, ignamv, owenlittlejohns, and saschahofmann. + +New Features +~~~~~~~~~~~~ +- New "random" method for converting to and from 360_day calendars (:pull:`8603`). + By `Pascal Bourgault `_. +- Xarray now makes a best attempt not to coerce :py:class:`pandas.api.extensions.ExtensionArray` to a numpy array + by supporting 1D ``ExtensionArray`` objects internally where possible. Thus, :py:class:`Dataset` objects initialized with a ``pd.Categorical``, + for example, will retain the object. However, one cannot do operations that are not possible on the ``ExtensionArray`` + then, such as broadcasting. (:issue:`5287`, :issue:`8463`, :pull:`8723`) + By `Ilan Gold `_. +- :py:func:`testing.assert_allclose`/:py:func:`testing.assert_equal` now accept a new argument `check_dims="transpose"`, controlling whether a transposed array is considered equal. (:issue:`5733`, :pull:`8991`) + By `Ignacio Martinez Vazquez `_. +- Added the option to avoid automatically creating 1D pandas indexes in :py:meth:`Dataset.expand_dims()`, by passing the new kwarg + `create_index_for_new_dim=False`. (:pull:`8960`) + By `Tom Nicholas `_. +- Avoid automatically re-creating 1D pandas indexes in :py:func:`concat()`. Also added option to avoid creating 1D indexes for + new dimension coordinates by passing the new kwarg `create_index_for_new_dim=False`. (:issue:`8871`, :pull:`8872`) + By `Tom Nicholas `_. + +Breaking changes +~~~~~~~~~~~~~~~~ +- The PyNIO backend has been deleted (:issue:`4491`, :pull:`7301`). + By `Deepak Cherian `_. +- The minimum versions of some dependencies were changed, in particular our minimum supported pandas version is now Pandas 2. + + ===================== ========= ======= + Package Old New + ===================== ========= ======= + dask-core 2022.12 2023.4 + distributed 2022.12 2023.4 + h5py 3.7 3.8 + matplotlib-base 3.6 3.7 + packaging 22.0 23.1 + pandas 1.5 2.0 + pydap 3.3 3.4 + sparse 0.13 0.14 + typing_extensions 4.4 4.5 + zarr 2.13 2.14 + ===================== ========= ======= + +Bug fixes +~~~~~~~~~ +- Following `an upstream bug fix + `_ to + :py:func:`pandas.date_range`, date ranges produced by + :py:func:`xarray.cftime_range` with negative frequencies will now fall fully + within the bounds of the provided start and end dates (:pull:`8999`). + By `Spencer Clark `_. + +Internal Changes +~~~~~~~~~~~~~~~~ +- Enforces failures on CI when tests raise warnings from within xarray (:pull:`8974`) + By `Maximilian Roos `_ +- Migrates ``formatting_html`` functionality for ``DataTree`` into ``xarray/core`` (:pull: `8930`) + By `Eni Awowale `_, `Julia Signell `_ + and `Tom Nicholas `_. +- Migrates ``datatree_mapping`` functionality into ``xarray/core`` (:pull:`8948`) + By `Matt Savoie `_ `Owen Littlejohns + `_ and `Tom Nicholas `_. +- Migrates ``extensions``, ``formatting`` and ``datatree_render`` functionality for + ``DataTree`` into ``xarray/core``. Also migrates ``testing`` functionality into + ``xarray/testing/assertions`` for ``DataTree``. (:pull:`8967`) + By `Owen Littlejohns `_ and + `Tom Nicholas `_. +- Migrates ``ops.py`` functionality into ``xarray/core/datatree_ops.py`` (:pull:`8976`) + By `Matt Savoie `_ and `Tom Nicholas `_. +- Migrates ``iterator`` functionality into ``xarray/core`` (:pull: `8879`) + By `Owen Littlejohns `_, `Matt Savoie + `_ and `Tom Nicholas `_. +- ``transpose``, ``set_dims``, ``stack`` & ``unstack`` now use a ``dim`` kwarg + rather than ``dims`` or ``dimensions``. This is the final change to make xarray methods + consistent with their use of ``dim``. Using the existing kwarg will raise a + warning. + By `Maximilian Roos `_ + +.. _whats-new.2024.03.0: + +v2024.03.0 (Mar 29, 2024) +------------------------- + +This release brings performance improvements for grouped and resampled quantile calculations, CF decoding improvements, +minor optimizations to distributed Zarr writes, and compatibility fixes for Numpy 2.0 and Pandas 3.0. + +Thanks to the 18 contributors to this release: +Anderson Banihirwe, Christoph Hasse, Deepak Cherian, Etienne Schalk, Justus Magin, Kai Mühlbauer, Kevin Schwarzwald, Mark Harfouche, Martin, Matt Savoie, Maximilian Roos, Ray Bell, Roberto Chang, Spencer Clark, Tom Nicholas, crusaderky, owenlittlejohns, saschahofmann + +New Features +~~~~~~~~~~~~ +- Partial writes to existing chunks with ``region`` or ``append_dim`` will now raise an error + (unless ``safe_chunks=False``); previously an error would only be raised on + new variables. (:pull:`8459`, :issue:`8371`, :issue:`8882`) + By `Maximilian Roos `_. +- Grouped and resampling quantile calculations now use the vectorized algorithm in ``flox>=0.9.4`` if present. + By `Deepak Cherian `_. +- Do not broadcast in arithmetic operations when global option ``arithmetic_broadcast=False`` + (:issue:`6806`, :pull:`8784`). + By `Etienne Schalk `_ and `Deepak Cherian `_. +- Add the ``.oindex`` property to Explicitly Indexed Arrays for orthogonal indexing functionality. (:issue:`8238`, :pull:`8750`) + By `Anderson Banihirwe `_. +- Add the ``.vindex`` property to Explicitly Indexed Arrays for vectorized indexing functionality. (:issue:`8238`, :pull:`8780`) + By `Anderson Banihirwe `_. +- Expand use of ``.oindex`` and ``.vindex`` properties. (:pull: `8790`) + By `Anderson Banihirwe `_ and `Deepak Cherian `_. +- Allow creating :py:class:`xr.Coordinates` objects with no indexes (:pull:`8711`) + By `Benoit Bovy `_ and `Tom Nicholas + `_. +- Enable plotting of ``datetime.dates``. (:issue:`8866`, :pull:`8873`) + By `Sascha Hofmann `_. + +Breaking changes +~~~~~~~~~~~~~~~~ +- Don't allow overwriting index variables with ``to_zarr`` region writes. (:issue:`8589`, :pull:`8876`). + By `Deepak Cherian `_. + + +Bug fixes +~~~~~~~~~ +- The default ``freq`` parameter in :py:meth:`xr.date_range` and :py:meth:`xr.cftime_range` is + set to ``'D'`` only if ``periods``, ``start``, or ``end`` are ``None`` (:issue:`8770`, :pull:`8774`). + By `Roberto Chang `_. +- Ensure that non-nanosecond precision :py:class:`numpy.datetime64` and + :py:class:`numpy.timedelta64` values are cast to nanosecond precision values + when used in :py:meth:`DataArray.expand_dims` and + ::py:meth:`Dataset.expand_dims` (:pull:`8781`). By `Spencer + Clark `_. +- CF conform handling of `_FillValue`/`missing_value` and `dtype` in + `CFMaskCoder`/`CFScaleOffsetCoder` (:issue:`2304`, :issue:`5597`, + :issue:`7691`, :pull:`8713`, see also discussion in :pull:`7654`). + By `Kai Mühlbauer `_. +- Do not cast `_FillValue`/`missing_value` in `CFMaskCoder` if `_Unsigned` is provided + (:issue:`8844`, :pull:`8852`). +- Adapt handling of copy keyword argument for numpy >= 2.0dev + (:issue:`8844`, :pull:`8851`, :pull:`8865`). + By `Kai Mühlbauer `_. +- Import trapz/trapezoid depending on numpy version + (:issue:`8844`, :pull:`8865`). + By `Kai Mühlbauer `_. +- Warn and return bytes undecoded in case of UnicodeDecodeError in h5netcdf-backend + (:issue:`5563`, :pull:`8874`). + By `Kai Mühlbauer `_. +- Fix bug incorrectly disallowing creation of a dataset with a multidimensional coordinate variable with the same name as one of its dims. + (:issue:`8884`, :pull:`8886`) + By `Tom Nicholas `_. + +Internal Changes +~~~~~~~~~~~~~~~~ +- Migrates ``treenode`` functionality into ``xarray/core`` (:pull:`8757`) + By `Matt Savoie `_ and `Tom Nicholas + `_. +- Migrates ``datatree`` functionality into ``xarray/core``. (:pull: `8789`) + By `Owen Littlejohns `_, `Matt Savoie + `_ and `Tom Nicholas `_. + + +.. _whats-new.2024.02.0: + +v2024.02.0 (Feb 19, 2024) +------------------------- + +This release brings size information to the text ``repr``, changes to the accepted frequency +strings, and various bug fixes. + +Thanks to our 12 contributors: + +Anderson Banihirwe, Deepak Cherian, Eivind Jahren, Etienne Schalk, Justus Magin, Marco Wolsza, +Mathias Hauser, Matt Savoie, Maximilian Roos, Rambaud Pierrick, Tom Nicholas + +New Features +~~~~~~~~~~~~ + +- Added a simple ``nbytes`` representation in DataArrays and Dataset ``repr``. + (:issue:`8690`, :pull:`8702`). + By `Etienne Schalk `_. +- Allow negative frequency strings (e.g. ``"-1YE"``). These strings are for example used in + :py:func:`date_range`, and :py:func:`cftime_range` (:pull:`8651`). + By `Mathias Hauser `_. +- Add :py:meth:`NamedArray.expand_dims`, :py:meth:`NamedArray.permute_dims` and + :py:meth:`NamedArray.broadcast_to` (:pull:`8380`) + By `Anderson Banihirwe `_. +- Xarray now defers to `flox's heuristics `_ + to set the default `method` for groupby problems. This only applies to ``flox>=0.9``. + By `Deepak Cherian `_. +- All `quantile` methods (e.g. :py:meth:`DataArray.quantile`) now use `numbagg` + for the calculation of nanquantiles (i.e., `skipna=True`) if it is installed. + This is currently limited to the linear interpolation method (`method='linear'`). + (:issue:`7377`, :pull:`8684`) + By `Marco Wolsza `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + +- :py:func:`infer_freq` always returns the frequency strings as defined in pandas 2.2 + (:issue:`8612`, :pull:`8627`). + By `Mathias Hauser `_. + +Deprecations +~~~~~~~~~~~~ +- The `dt.weekday_name` parameter wasn't functional on modern pandas versions and has been + removed. (:issue:`8610`, :pull:`8664`) + By `Sam Coleman `_. + + +Bug fixes +~~~~~~~~~ + +- Fixed a regression that prevented multi-index level coordinates being serialized after resetting + or dropping the multi-index (:issue:`8628`, :pull:`8672`). + By `Benoit Bovy `_. +- Fix bug with broadcasting when wrapping array API-compliant classes. (:issue:`8665`, :pull:`8669`) + By `Tom Nicholas `_. +- Ensure :py:meth:`DataArray.unstack` works when wrapping array API-compliant + classes. (:issue:`8666`, :pull:`8668`) + By `Tom Nicholas `_. +- Fix negative slicing of Zarr arrays without dask installed. (:issue:`8252`) + By `Deepak Cherian `_. +- Preserve chunks when writing time-like variables to zarr by enabling lazy CF encoding of time-like + variables (:issue:`7132`, :issue:`8230`, :issue:`8432`, :pull:`8575`). + By `Spencer Clark `_ and `Mattia Almansi `_. +- Preserve chunks when writing time-like variables to zarr by enabling their lazy encoding + (:issue:`7132`, :issue:`8230`, :issue:`8432`, :pull:`8253`, :pull:`8575`; see also discussion in + :pull:`8253`). + By `Spencer Clark `_ and `Mattia Almansi `_. +- Raise an informative error if dtype encoding of time-like variables would lead to integer overflow + or unsafe conversion from floating point to integer values (:issue:`8542`, :pull:`8575`). + By `Spencer Clark `_. +- Raise an error when unstacking a MultiIndex that has duplicates as this would lead to silent data + loss (:issue:`7104`, :pull:`8737`). + By `Mathias Hauser `_. + +Documentation +~~~~~~~~~~~~~ +- Fix `variables` arg typo in `Dataset.sortby()` docstring (:issue:`8663`, :pull:`8670`) + By `Tom Vo `_. +- Fixed documentation where the use of the depreciated pandas frequency string prevented the + documentation from being built. (:pull:`8638`) + By `Sam Coleman `_. + +Internal Changes +~~~~~~~~~~~~~~~~ + +- ``DataArray.dt`` now raises an ``AttributeError`` rather than a ``TypeError`` when the data isn't + datetime-like. (:issue:`8718`, :pull:`8724`) + By `Maximilian Roos `_. +- Move ``parallelcompat`` and ``chunk managers`` modules from ``xarray/core`` to + ``xarray/namedarray``. (:pull:`8319`) + By `Tom Nicholas `_ and `Anderson Banihirwe `_. +- Imports ``datatree`` repository and history into internal location. (:pull:`8688`) + By `Matt Savoie `_, `Justus Magin `_ + and `Tom Nicholas `_. +- Adds :py:func:`open_datatree` into ``xarray/backends`` (:pull:`8697`) + By `Matt Savoie `_ and `Tom Nicholas + `_. +- Refactor :py:meth:`xarray.core.indexing.DaskIndexingAdapter.__getitem__` to remove an unnecessary + rewrite of the indexer key (:issue: `8377`, :pull:`8758`) + By `Anderson Banihirwe `_. + +.. _whats-new.2024.01.1: + +v2024.01.1 (23 Jan, 2024) +------------------------- + +This release is to fix a bug with the rendering of the documentation, but it also includes changes to the handling of pandas frequency strings. + +Breaking changes +~~~~~~~~~~~~~~~~ + +- Following pandas, :py:meth:`infer_freq` will return ``"YE"``, instead of ``"Y"`` (formerly ``"A"``). + This is to be consistent with the deprecation of the latter frequency string in pandas 2.2. + This is a follow up to :pull:`8415` (:issue:`8612`, :pull:`8642`). + By `Mathias Hauser `_. + +Deprecations +~~~~~~~~~~~~ + +- Following pandas, the frequency string ``"Y"`` (formerly ``"A"``) is deprecated in + favor of ``"YE"``. These strings are used, for example, in :py:func:`date_range`, + :py:func:`cftime_range`, :py:meth:`DataArray.resample`, and :py:meth:`Dataset.resample` + among others (:issue:`8612`, :pull:`8629`). + By `Mathias Hauser `_. + +Documentation +~~~~~~~~~~~~~ + +- Pin ``sphinx-book-theme`` to ``1.0.1`` to fix a rendering issue with the sidebar in the docs. (:issue:`8619`, :pull:`8632`) + By `Tom Nicholas `_. + +.. _whats-new.2024.01.0: + +v2024.01.0 (17 Jan, 2024) +------------------------- + +This release brings support for weights in correlation and covariance functions, +a new `DataArray.cumulative` aggregation, improvements to `xr.map_blocks`, +an update to our minimum dependencies, and various bugfixes. + +Thanks to our 17 contributors to this release: + +Abel Aoun, Deepak Cherian, Illviljan, Johan Mathe, Justus Magin, Kai Mühlbauer, +Llorenç Lledó, Mark Harfouche, Markel, Mathias Hauser, Maximilian Roos, Michael Niklas, +Niclas Rieger, Sébastien Celles, Tom Nicholas, Trinh Quoc Anh, and crusaderky. + +New Features +~~~~~~~~~~~~ + +- :py:meth:`xr.cov` and :py:meth:`xr.corr` now support using weights (:issue:`8527`, :pull:`7392`). + By `Llorenç Lledó `_. +- Accept the compression arguments new in netCDF 1.6.0 in the netCDF4 backend. + See `netCDF4 documentation `_ for details. + Note that some new compression filters needs plugins to be installed which may not be available in all netCDF distributions. + By `Markel García-Díez `_. (:issue:`6929`, :pull:`7551`) +- Add :py:meth:`DataArray.cumulative` & :py:meth:`Dataset.cumulative` to compute + cumulative aggregations, such as ``sum``, along a dimension — for example + ``da.cumulative('time').sum()``. This is similar to pandas' ``.expanding``, + and mostly equivalent to ``.cumsum`` methods, or to + :py:meth:`DataArray.rolling` with a window length equal to the dimension size. + By `Maximilian Roos `_. (:pull:`8512`) +- Decode/Encode netCDF4 enums and store the enum definition in dataarrays' dtype metadata. + If multiple variables share the same enum in netCDF4, each dataarray will have its own + enum definition in their respective dtype metadata. + By `Abel Aoun `_. (:issue:`8144`, :pull:`8147`) + +Breaking changes +~~~~~~~~~~~~~~~~ + +- The minimum versions of some dependencies were changed (:pull:`8586`): + + ===================== ========= ======== + Package Old New + ===================== ========= ======== + cartopy 0.20 0.21 + dask-core 2022.7 2022.12 + distributed 2022.7 2022.12 + flox 0.5 0.7 + iris 3.2 3.4 + matplotlib-base 3.5 3.6 + numpy 1.22 1.23 + numba 0.55 0.56 + packaging 21.3 22.0 + seaborn 0.11 0.12 + scipy 1.8 1.10 + typing_extensions 4.3 4.4 + zarr 2.12 2.13 + ===================== ========= ======== + +Deprecations +~~~~~~~~~~~~ + +- The `squeeze` kwarg to GroupBy is now deprecated. (:issue:`2157`, :pull:`8507`) + By `Deepak Cherian `_. + +Bug fixes +~~~~~~~~~ + +- Support non-string hashable dimensions in :py:class:`xarray.DataArray` (:issue:`8546`, :pull:`8559`). + By `Michael Niklas `_. +- Reverse index output of bottleneck's rolling move_argmax/move_argmin functions (:issue:`8541`, :pull:`8552`). + By `Kai Mühlbauer `_. +- Vendor `SerializableLock` from dask and use as default lock for netcdf4 backends (:issue:`8442`, :pull:`8571`). + By `Kai Mühlbauer `_. +- Add tests and fixes for empty :py:class:`CFTimeIndex`, including broken html repr (:issue:`7298`, :pull:`8600`). + By `Mathias Hauser `_. + +Internal Changes +~~~~~~~~~~~~~~~~ + +- The implementation of :py:func:`map_blocks` has changed to minimize graph size and duplication of data. + This should be a strict improvement even though the graphs are not always embarassingly parallel any more. + Please open an issue if you spot a regression. (:pull:`8412`, :issue:`8409`). + By `Deepak Cherian `_. +- Remove null values before plotting. (:pull:`8535`). + By `Jimmy Westling `_. +- Redirect cumulative reduction functions internally through the :py:class:`ChunkManagerEntryPoint`, + potentially allowing :py:meth:`~xarray.DataArray.ffill` and :py:meth:`~xarray.DataArray.bfill` to + use non-dask chunked array types. + (:pull:`8019`) By `Tom Nicholas `_. + +.. _whats-new.2023.12.0: + +v2023.12.0 (2023 Dec 08) +------------------------ + +This release brings new `hypothesis `_ strategies for testing, significantly faster rolling aggregations as well as +``ffill`` and ``bfill`` with ``numbagg``, a new :py:meth:`Dataset.eval` method, and improvements to +reading and writing Zarr arrays (including a new ``"a-"`` mode). + +Thanks to our 16 contributors: + +Anderson Banihirwe, Ben Mares, Carl Andersson, Deepak Cherian, Doug Latornell, Gregorio L. Trevisan, Illviljan, Jens Hedegaard Nielsen, Justus Magin, Mathias Hauser, Max Jones, Maximilian Roos, Michael Niklas, Patrick Hoefler, Ryan Abernathey, Tom Nicholas + +New Features +~~~~~~~~~~~~ + +- Added hypothesis strategies for generating :py:class:`xarray.Variable` objects containing arbitrary data, useful for parametrizing downstream tests. + Accessible under :py:mod:`testing.strategies`, and documented in a new page on testing in the User Guide. + (:issue:`6911`, :pull:`8404`) + By `Tom Nicholas `_. +- :py:meth:`rolling` uses `numbagg `_ for + most of its computations by default. Numbagg is up to 5x faster than bottleneck + where parallelization is possible. Where parallelization isn't possible — for + example a 1D array — it's about the same speed as bottleneck, and 2-5x faster + than pandas' default functions. (:pull:`8493`). numbagg is an optional + dependency, so requires installing separately. +- Use a concise format when plotting datetime arrays. (:pull:`8449`). + By `Jimmy Westling `_. +- Avoid overwriting unchanged existing coordinate variables when appending with :py:meth:`Dataset.to_zarr` by setting ``mode='a-'``. + By `Ryan Abernathey `_ and `Deepak Cherian `_. +- :py:meth:`~xarray.DataArray.rank` now operates on dask-backed arrays, assuming + the core dim has exactly one chunk. (:pull:`8475`). + By `Maximilian Roos `_. +- Add a :py:meth:`Dataset.eval` method, similar to the pandas' method of the + same name. (:pull:`7163`). This is currently marked as experimental and + doesn't yet support the ``numexpr`` engine. +- :py:meth:`Dataset.drop_vars` & :py:meth:`DataArray.drop_vars` allow passing a + callable, similar to :py:meth:`Dataset.where` & :py:meth:`Dataset.sortby` & others. + (:pull:`8511`). + By `Maximilian Roos `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + +- Explicitly warn when creating xarray objects with repeated dimension names. + Such objects will also now raise when :py:meth:`DataArray.get_axis_num` is called, + which means many functions will raise. + This latter change is technically a breaking change, but whilst allowed, + this behaviour was never actually supported! (:issue:`3731`, :pull:`8491`) + By `Tom Nicholas `_. + +Deprecations +~~~~~~~~~~~~ +- As part of an effort to standardize the API, we're renaming the ``dims`` + keyword arg to ``dim`` for the minority of functions which current use + ``dims``. This started with :py:func:`xarray.dot` & :py:meth:`DataArray.dot` + and we'll gradually roll this out across all functions. The warnings are + currently ``PendingDeprecationWarning``, which are silenced by default. We'll + convert these to ``DeprecationWarning`` in a future release. + By `Maximilian Roos `_. +- Raise a ``FutureWarning`` warning that the type of :py:meth:`Dataset.dims` will be changed + from a mapping of dimension names to lengths to a set of dimension names. + This is to increase consistency with :py:meth:`DataArray.dims`. + To access a mapping of dimension names to lengths please use :py:meth:`Dataset.sizes`. + The same change also applies to `DatasetGroupBy.dims`. + (:issue:`8496`, :pull:`8500`) + By `Tom Nicholas `_. +- :py:meth:`Dataset.drop` & :py:meth:`DataArray.drop` are now deprecated, since pending deprecation for + several years. :py:meth:`DataArray.drop_sel` & :py:meth:`DataArray.drop_var` + replace them for labels & variables respectively. (:pull:`8497`) + By `Maximilian Roos `_. + +Bug fixes +~~~~~~~~~ + +- Fix dtype inference for ``pd.CategoricalIndex`` when categories are backed by a ``pd.ExtensionDtype`` (:pull:`8481`) +- Fix writing a variable that requires transposing when not writing to a region (:pull:`8484`) + By `Maximilian Roos `_. +- Static typing of ``p0`` and ``bounds`` arguments of :py:func:`xarray.DataArray.curvefit` and :py:func:`xarray.Dataset.curvefit` + was changed to ``Mapping`` (:pull:`8502`). + By `Michael Niklas `_. +- Fix typing of :py:func:`xarray.DataArray.to_netcdf` and :py:func:`xarray.Dataset.to_netcdf` + when ``compute`` is evaluated to bool instead of a Literal (:pull:`8268`). + By `Jens Hedegaard Nielsen `_. + +Documentation +~~~~~~~~~~~~~ + +- Added illustration of updating the time coordinate values of a resampled dataset using + time offset arithmetic. + This is the recommended technique to replace the use of the deprecated ``loffset`` parameter + in ``resample`` (:pull:`8479`). + By `Doug Latornell `_. +- Improved error message when attempting to get a variable which doesn't exist from a Dataset. + (:pull:`8474`) + By `Maximilian Roos `_. +- Fix default value of ``combine_attrs`` in :py:func:`xarray.combine_by_coords` (:pull:`8471`) + By `Gregorio L. Trevisan `_. + +Internal Changes +~~~~~~~~~~~~~~~~ + +- :py:meth:`DataArray.bfill` & :py:meth:`DataArray.ffill` now use numbagg `_ by + default, which is up to 5x faster where parallelization is possible. (:pull:`8339`) + By `Maximilian Roos `_. +- Update mypy version to 1.7 (:issue:`8448`, :pull:`8501`). + By `Michael Niklas `_. + +.. _whats-new.2023.11.0: + +v2023.11.0 (Nov 16, 2023) +------------------------- + + +.. tip:: + + `This is our 10th year anniversary release! `_ Thank you for your love and support. + + +This release brings the ability to use ``opt_einsum`` for :py:func:`xarray.dot` by default, +support for auto-detecting ``region`` when writing partial datasets to Zarr, and the use of h5py +drivers with ``h5netcdf``. + +Thanks to the 19 contributors to this release: +Aman Bagrecha, Anderson Banihirwe, Ben Mares, Deepak Cherian, Dimitri Papadopoulos Orfanos, Ezequiel Cimadevilla Alvarez, +Illviljan, Justus Magin, Katelyn FitzGerald, Kai Muehlbauer, Martin Durant, Maximilian Roos, Metamess, Sam Levang, Spencer Clark, Tom Nicholas, mgunyho, templiert + +New Features +~~~~~~~~~~~~ + +- Use `opt_einsum `_ for :py:func:`xarray.dot` by default if installed. + By `Deepak Cherian `_. (:issue:`7764`, :pull:`8373`). +- Add ``DataArray.dt.total_seconds()`` method to match the Pandas API. (:pull:`8435`). + By `Ben Mares `_. +- Allow passing ``region="auto"`` in :py:meth:`Dataset.to_zarr` to automatically infer the + region to write in the original store. Also implement automatic transpose when dimension + order does not match the original store. (:issue:`7702`, :issue:`8421`, :pull:`8434`). + By `Sam Levang `_. +- Allow the usage of h5py drivers (eg: ros3) via h5netcdf (:pull:`8360`). + By `Ezequiel Cimadevilla `_. +- Enable VLEN string fill_values, preserve VLEN string dtypes (:issue:`1647`, :issue:`7652`, :issue:`7868`, :pull:`7869`). + By `Kai Mühlbauer `_. + +Breaking changes +~~~~~~~~~~~~~~~~ +- drop support for `cdms2 `_. Please use + `xcdat `_ instead (:pull:`8441`). + By `Justus Magin `_. +- Following pandas, :py:meth:`infer_freq` will return ``"Y"``, ``"YS"``, + ``"QE"``, ``"ME"``, ``"h"``, ``"min"``, ``"s"``, ``"ms"``, ``"us"``, or + ``"ns"`` instead of ``"A"``, ``"AS"``, ``"Q"``, ``"M"``, ``"H"``, ``"T"``, + ``"S"``, ``"L"``, ``"U"``, or ``"N"``. This is to be consistent with the + deprecation of the latter frequency strings (:issue:`8394`, :pull:`8415`). By + `Spencer Clark `_. +- Bump minimum tested pint version to ``>=0.22``. By `Deepak Cherian `_. +- Minimum supported versions for the following packages have changed: ``h5py >=3.7``, ``h5netcdf>=1.1``. + By `Kai Mühlbauer `_. + +Deprecations +~~~~~~~~~~~~ +- The PseudoNetCDF backend has been removed. By `Deepak Cherian `_. +- Supplying dimension-ordered sequences to :py:meth:`DataArray.chunk` & + :py:meth:`Dataset.chunk` is deprecated in favor of supplying a dictionary of + dimensions, or a single ``int`` or ``"auto"`` argument covering all + dimensions. Xarray favors using dimensions names rather than positions, and + this was one place in the API where dimension positions were used. + (:pull:`8341`) + By `Maximilian Roos `_. +- Following pandas, the frequency strings ``"A"``, ``"AS"``, ``"Q"``, ``"M"``, + ``"H"``, ``"T"``, ``"S"``, ``"L"``, ``"U"``, and ``"N"`` are deprecated in + favor of ``"Y"``, ``"YS"``, ``"QE"``, ``"ME"``, ``"h"``, ``"min"``, ``"s"``, + ``"ms"``, ``"us"``, and ``"ns"``, respectively. These strings are used, for + example, in :py:func:`date_range`, :py:func:`cftime_range`, + :py:meth:`DataArray.resample`, and :py:meth:`Dataset.resample` among others + (:issue:`8394`, :pull:`8415`). By `Spencer Clark + `_. +- Rename :py:meth:`Dataset.to_array` to :py:meth:`Dataset.to_dataarray` for + consistency with :py:meth:`DataArray.to_dataset` & + :py:func:`open_dataarray` functions. This is a "soft" deprecation — the + existing methods work and don't raise any warnings, given the relatively small + benefits of the change. + By `Maximilian Roos `_. +- Finally remove ``keep_attrs`` kwarg from :py:meth:`DataArray.resample` and + :py:meth:`Dataset.resample`. These were deprecated a long time ago. + By `Deepak Cherian `_. + +Bug fixes +~~~~~~~~~ + +- Port `bug fix from pandas `_ + to eliminate the adjustment of resample bin edges in the case that the + resampling frequency has units of days and is greater than one day + (e.g. ``"2D"``, ``"3D"`` etc.) and the ``closed`` argument is set to + ``"right"`` to xarray's implementation of resample for data indexed by a + :py:class:`CFTimeIndex` (:pull:`8393`). + By `Spencer Clark `_. +- Fix to once again support date offset strings as input to the loffset + parameter of resample and test this functionality (:pull:`8422`, :issue:`8399`). + By `Katelyn FitzGerald `_. +- Fix a bug where :py:meth:`DataArray.to_dataset` silently drops a variable + if a coordinate with the same name already exists (:pull:`8433`, :issue:`7823`). + By `András Gunyhó `_. +- Fix for :py:meth:`DataArray.to_zarr` & :py:meth:`Dataset.to_zarr` to close + the created zarr store when passing a path with `.zip` extension (:pull:`8425`). + By `Carl Andersson _`. + +Documentation +~~~~~~~~~~~~~ +- Small updates to documentation on distributed writes: See :ref:`io.zarr.appending` to Zarr. + By `Deepak Cherian `_. + +.. _whats-new.2023.10.1: + +v2023.10.1 (19 Oct, 2023) +------------------------- + +This release updates our minimum numpy version in ``pyproject.toml`` to 1.22, +consistent with our documentation below. + +.. _whats-new.2023.10.0: + +v2023.10.0 (19 Oct, 2023) +------------------------- + +This release brings performance enhancements to reading Zarr datasets, the ability to use `numbagg `_ for reductions, +an expansion in API for ``rolling_exp``, fixes two regressions with datetime decoding, +and many other bugfixes and improvements. Groupby reductions will also use ``numbagg`` if ``flox>=0.8.1`` and ``numbagg`` are both installed. + +Thanks to our 13 contributors: +Anderson Banihirwe, Bart Schilperoort, Deepak Cherian, Illviljan, Kai Mühlbauer, Mathias Hauser, Maximilian Roos, Michael Niklas, Pieter Eendebak, Simon Høxbro Hansen, Spencer Clark, Tom White, olimcc + +New Features +~~~~~~~~~~~~ +- Support high-performance reductions with `numbagg `_. + This is enabled by default if ``numbagg`` is installed. + By `Deepak Cherian `_. (:pull:`8316`) +- Add ``corr``, ``cov``, ``std`` & ``var`` to ``.rolling_exp``. + By `Maximilian Roos `_. (:pull:`8307`) +- :py:meth:`DataArray.where` & :py:meth:`Dataset.where` accept a callable for + the ``other`` parameter, passing the object as the only argument. Previously, + this was only valid for the ``cond`` parameter. (:issue:`8255`) + By `Maximilian Roos `_. +- ``.rolling_exp`` functions can now take a ``min_weight`` parameter, to only + output values when there are sufficient recent non-nan values. + ``numbagg>=0.3.1`` is required. (:pull:`8285`) + By `Maximilian Roos `_. +- :py:meth:`DataArray.sortby` & :py:meth:`Dataset.sortby` accept a callable for + the ``variables`` parameter, passing the object as the only argument. + By `Maximilian Roos `_. +- ``.rolling_exp`` functions can now operate on dask-backed arrays, assuming the + core dim has exactly one chunk. (:pull:`8284`). + By `Maximilian Roos `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + +- Made more arguments keyword-only (e.g. ``keep_attrs``, ``skipna``) for many :py:class:`xarray.DataArray` and + :py:class:`xarray.Dataset` methods (:pull:`6403`). By `Mathias Hauser `_. +- :py:meth:`Dataset.to_zarr` & :py:meth:`DataArray.to_zarr` require keyword + arguments after the initial 7 positional arguments. + By `Maximilian Roos `_. + + +Deprecations +~~~~~~~~~~~~ +- Rename :py:meth:`Dataset.reset_encoding` & :py:meth:`DataArray.reset_encoding` + to :py:meth:`Dataset.drop_encoding` & :py:meth:`DataArray.drop_encoding` for + consistency with other ``drop`` & ``reset`` methods — ``drop`` generally + removes something, while ``reset`` generally resets to some default or + standard value. (:pull:`8287`, :issue:`8259`) + By `Maximilian Roos `_. + +Bug fixes +~~~~~~~~~ + +- :py:meth:`DataArray.rename` & :py:meth:`Dataset.rename` would emit a warning + when the operation was a no-op. (:issue:`8266`) + By `Simon Hansen `_. +- Fixed a regression introduced in the previous release checking time-like units + when encoding/decoding masked data (:issue:`8269`, :pull:`8277`). + By `Kai Mühlbauer `_. + +- Fix datetime encoding precision loss regression introduced in the previous + release for datetimes encoded with units requiring floating point values, and + a reference date not equal to the first value of the datetime array + (:issue:`8271`, :pull:`8272`). By `Spencer Clark + `_. + +- Fix excess metadata requests when using a Zarr store. Prior to this, metadata + was re-read every time data was retrieved from the array, now metadata is retrieved only once + when they array is initialized. + (:issue:`8290`, :pull:`8297`). + By `Oliver McCormack `_. + +- Fix to_zarr ending in a ReadOnlyError when consolidated metadata was used and the + write_empty_chunks was provided. + (:issue:`8323`, :pull:`8326`) + By `Matthijs Amesz `_. + + +Documentation +~~~~~~~~~~~~~ + +- Added page on the interoperability of xarray objects. + (:pull:`7992`) By `Tom Nicholas `_. +- Added xarray-regrid to the list of xarray related projects (:pull:`8272`). + By `Bart Schilperoort `_. + + +Internal Changes +~~~~~~~~~~~~~~~~ + +- More improvements to support the Python `array API standard `_ + by using duck array ops in more places in the codebase. (:pull:`8267`) + By `Tom White `_. + + +.. _whats-new.2023.09.0: + +v2023.09.0 (Sep 26, 2023) +------------------------- + +This release continues work on the new :py:class:`xarray.Coordinates` object, allows to provide `preferred_chunks` when +reading from netcdf files, enables :py:func:`xarray.apply_ufunc` to handle missing core dimensions and fixes several bugs. + +Thanks to the 24 contributors to this release: Alexander Fischer, Amrest Chinkamol, Benoit Bovy, Darsh Ranjan, Deepak Cherian, +Gianfranco Costamagna, Gregorio L. Trevisan, Illviljan, Joe Hamman, JR, Justus Magin, Kai Mühlbauer, Kian-Meng Ang, Kyle Sunden, +Martin Raspaud, Mathias Hauser, Mattia Almansi, Maximilian Roos, András Gunyhó, Michael Niklas, Richard Kleijn, Riulinchen, +Tom Nicholas and Wiktor Kraśnicki. + +We welcome the following new contributors to Xarray!: Alexander Fischer, Amrest Chinkamol, Darsh Ranjan, Gianfranco Costamagna, Gregorio L. Trevisan, +Kian-Meng Ang, Riulinchen and Wiktor Kraśnicki. + +New Features +~~~~~~~~~~~~ + +- Added the :py:meth:`Coordinates.assign` method that can be used to combine + different collections of coordinates prior to assign them to a Dataset or + DataArray (:pull:`8102`) at once. + By `Benoît Bovy `_. +- Provide `preferred_chunks` for data read from netcdf files (:issue:`1440`, :pull:`7948`). + By `Martin Raspaud `_. +- Added `on_missing_core_dims` to :py:meth:`apply_ufunc` to allow for copying or + dropping a :py:class:`Dataset`'s variables with missing core dimensions (:pull:`8138`). + By `Maximilian Roos `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + +- The :py:class:`Coordinates` constructor now creates a (pandas) index by + default for each dimension coordinate. To keep the previous behavior (no index + created), pass an empty dictionary to ``indexes``. The constructor now also + extracts and add the indexes from another :py:class:`Coordinates` object + passed via ``coords`` (:pull:`8107`). + By `Benoît Bovy `_. +- Static typing of ``xlim`` and ``ylim`` arguments in plotting functions now must + be ``tuple[float, float]`` to align with matplotlib requirements. (:issue:`7802`, :pull:`8030`). + By `Michael Niklas `_. + +Deprecations +~~~~~~~~~~~~ + +- Deprecate passing a :py:class:`pandas.MultiIndex` object directly to the + :py:class:`Dataset` and :py:class:`DataArray` constructors as well as to + :py:meth:`Dataset.assign` and :py:meth:`Dataset.assign_coords`. + A new Xarray :py:class:`Coordinates` object has to be created first using + :py:meth:`Coordinates.from_pandas_multiindex` (:pull:`8094`). + By `Benoît Bovy `_. + +Bug fixes +~~~~~~~~~ + +- Improved static typing of reduction methods (:pull:`6746`). + By `Richard Kleijn `_. +- Fix bug where empty attrs would generate inconsistent tokens (:issue:`6970`, :pull:`8101`). + By `Mattia Almansi `_. +- Improved handling of multi-coordinate indexes when updating coordinates, including bug fixes + (and improved warnings for deprecated features) for pandas multi-indexes (:pull:`8094`). + By `Benoît Bovy `_. +- Fixed a bug in :py:func:`merge` with ``compat='minimal'`` where the coordinate + names were not updated properly internally (:issue:`7405`, :issue:`7588`, + :pull:`8104`). + By `Benoît Bovy `_. +- Fix bug where :py:class:`DataArray` instances on the right-hand side + of :py:meth:`DataArray.__setitem__` lose dimension names (:issue:`7030`, :pull:`8067`). + By `Darsh Ranjan `_. +- Return ``float64`` in presence of ``NaT`` in :py:class:`~core.accessor_dt.DatetimeAccessor` and + special case ``NaT`` handling in :py:meth:`~core.accessor_dt.DatetimeAccessor.isocalendar` + (:issue:`7928`, :pull:`8084`). + By `Kai Mühlbauer `_. +- Fix :py:meth:`~core.rolling.DatasetRolling.construct` with stride on Datasets without indexes. + (:issue:`7021`, :pull:`7578`). + By `Amrest Chinkamol `_ and `Michael Niklas `_. +- Calling plot with kwargs ``col``, ``row`` or ``hue`` no longer squeezes dimensions passed via these arguments + (:issue:`7552`, :pull:`8174`). + By `Wiktor Kraśnicki `_. +- Fixed a bug where casting from ``float`` to ``int64`` (undefined for ``NaN``) led to varying issues (:issue:`7817`, :issue:`7942`, :issue:`7790`, :issue:`6191`, :issue:`7096`, + :issue:`1064`, :pull:`7827`). + By `Kai Mühlbauer `_. +- Fixed a bug where inaccurate ``coordinates`` silently failed to decode variable (:issue:`1809`, :pull:`8195`). + By `Kai Mühlbauer `_ +- ``.rolling_exp`` functions no longer mistakenly lose non-dimensioned coords + (:issue:`6528`, :pull:`8114`). + By `Maximilian Roos `_. +- In the event that user-provided datetime64/timedelta64 units and integer dtype encoding parameters conflict with each other, override the units to preserve an integer dtype for most faithful serialization to disk (:issue:`1064`, :pull:`8201`). + By `Kai Mühlbauer `_. +- Static typing of dunder ops methods (like :py:meth:`DataArray.__eq__`) has been fixed. + Remaining issues are upstream problems (:issue:`7780`, :pull:`8204`). + By `Michael Niklas `_. +- Fix type annotation for ``center`` argument of plotting methods (like :py:meth:`xarray.plot.dataarray_plot.pcolormesh`) (:pull:`8261`). + By `Pieter Eendebak `_. + +Documentation +~~~~~~~~~~~~~ + +- Make documentation of :py:meth:`DataArray.where` clearer (:issue:`7767`, :pull:`7955`). + By `Riulinchen `_. + +Internal Changes +~~~~~~~~~~~~~~~~ + +- Many error messages related to invalid dimensions or coordinates now always show the list of valid dims/coords (:pull:`8079`). + By `András Gunyhó `_. +- Refactor of encoding and decoding times/timedeltas to preserve nanosecond resolution in arrays that contain missing values (:pull:`7827`). + By `Kai Mühlbauer `_. +- Transition ``.rolling_exp`` functions to use `.apply_ufunc` internally rather + than `.reduce`, as the start of a broader effort to move non-reducing + functions away from ```.reduce``, (:pull:`8114`). + By `Maximilian Roos `_. +- Test range of fill_value's in test_interpolate_pd_compat (:issue:`8146`, :pull:`8189`). + By `Kai Mühlbauer `_. + +.. _whats-new.2023.08.0: + +v2023.08.0 (Aug 18, 2023) +------------------------- + +This release brings changes to minimum dependencies, allows reading of datasets where a dimension name is +associated with a multidimensional variable (e.g. finite volume ocean model output), and introduces +a new :py:class:`xarray.Coordinates` object. + +Thanks to the 16 contributors to this release: Anderson Banihirwe, Articoking, Benoit Bovy, Deepak Cherian, Harshitha, Ian Carroll, +Joe Hamman, Justus Magin, Peter Hill, Rachel Wegener, Riley Kuttruff, Thomas Nicholas, Tom Nicholas, ilgast, quantsnus, vallirep + +Announcements +~~~~~~~~~~~~~ + +The :py:class:`xarray.Variable` class is being refactored out to a new project title 'namedarray'. +See the `design doc `_ for more +details. Reach out to us on this [discussion topic](https://github.com/pydata/xarray/discussions/8080) if you have any thoughts. + +New Features +~~~~~~~~~~~~ + +- :py:class:`Coordinates` can now be constructed independently of any Dataset or + DataArray (it is also returned by the :py:attr:`Dataset.coords` and + :py:attr:`DataArray.coords` properties). ``Coordinates`` objects are useful for + passing both coordinate variables and indexes to new Dataset / DataArray objects, + e.g., via their constructor or via :py:meth:`Dataset.assign_coords`. We may also + wrap coordinate variables in a ``Coordinates`` object in order to skip + the automatic creation of (pandas) indexes for dimension coordinates. + The :py:class:`Coordinates.from_pandas_multiindex` constructor may be used to + create coordinates directly from a :py:class:`pandas.MultiIndex` object (it is + preferred over passing it directly as coordinate data, which may be deprecated soon). + Like Dataset and DataArray objects, ``Coordinates`` objects may now be used in + :py:func:`align` and :py:func:`merge`. + (:issue:`6392`, :pull:`7368`). + By `Benoît Bovy `_. +- Visually group together coordinates with the same indexes in the index section of the text repr (:pull:`7225`). + By `Justus Magin `_. +- Allow creating Xarray objects where a multidimensional variable shares its name + with a dimension. Examples include output from finite volume models like FVCOM. + (:issue:`2233`, :pull:`7989`) + By `Deepak Cherian `_ and `Benoit Bovy `_. +- When outputting :py:class:`Dataset` objects as Zarr via :py:meth:`Dataset.to_zarr`, + user can now specify that chunks that will contain no valid data will not be written. + Originally, this could be done by specifying ``"write_empty_chunks": True`` in the + ``encoding`` parameter; however, this setting would not carry over when appending new + data to an existing dataset. (:issue:`8009`) Requires ``zarr>=2.11``. + + +Breaking changes +~~~~~~~~~~~~~~~~ + +- The minimum versions of some dependencies were changed (:pull:`8022`): + + ===================== ========= ======== + Package Old New + ===================== ========= ======== + boto3 1.20 1.24 + cftime 1.5 1.6 + dask-core 2022.1 2022.7 + distributed 2022.1 2022.7 + hfnetcdf 0.13 1.0 + iris 3.1 3.2 + lxml 4.7 4.9 + netcdf4 1.5.7 1.6.0 + numpy 1.21 1.22 + pint 0.18 0.19 + pydap 3.2 3.3 + rasterio 1.2 1.3 + scipy 1.7 1.8 + toolz 0.11 0.12 + typing_extensions 4.0 4.3 + zarr 2.10 2.12 + numbagg 0.1 0.2.1 + ===================== ========= ======== + +Documentation +~~~~~~~~~~~~~ + +- Added page on the internal design of xarray objects. + (:pull:`7991`) By `Tom Nicholas `_. +- Added examples to docstrings of :py:meth:`Dataset.assign_attrs`, :py:meth:`Dataset.broadcast_equals`, + :py:meth:`Dataset.equals`, :py:meth:`Dataset.identical`, :py:meth:`Dataset.expand_dims`,:py:meth:`Dataset.drop_vars` + (:issue:`6793`, :pull:`7937`) By `Harshitha `_. +- Add docstrings for the :py:class:`Index` base class and add some documentation on how to + create custom, Xarray-compatible indexes (:pull:`6975`) + By `Benoît Bovy `_. +- Added a page clarifying the role of Xarray core team members. + (:pull:`7999`) By `Tom Nicholas `_. +- Fixed broken links in "See also" section of :py:meth:`Dataset.count` (:issue:`8055`, :pull:`8057`) + By `Articoking `_. +- Extended the glossary by adding terms Aligning, Broadcasting, Merging, Concatenating, Combining, lazy, + labeled, serialization, indexing (:issue:`3355`, :pull:`7732`) + By `Harshitha `_. + +Internal Changes +~~~~~~~~~~~~~~~~ + +- :py:func:`as_variable` now consistently includes the variable name in any exceptions + raised. (:pull:`7995`). By `Peter Hill `_ +- :py:func:`encode_dataset_coordinates` now sorts coordinates automatically assigned to + `coordinates` attributes during serialization (:issue:`8026`, :pull:`8034`). + `By Ian Carroll `_. + +.. _whats-new.2023.07.0: + +v2023.07.0 (July 17, 2023) +-------------------------- + +This release brings improvements to the documentation on wrapping numpy-like arrays, improved docstrings, and bug fixes. + +Deprecations +~~~~~~~~~~~~ + +- `hue_style` is being deprecated for scatter plots. (:issue:`7907`, :pull:`7925`). + By `Jimmy Westling `_. + +Bug fixes +~~~~~~~~~ + +- Ensure no forward slashes in variable and dimension names for HDF5-based engines. + (:issue:`7943`, :pull:`7953`) By `Kai Mühlbauer `_. + +Documentation +~~~~~~~~~~~~~ + +- Added examples to docstrings of :py:meth:`Dataset.assign_attrs`, :py:meth:`Dataset.broadcast_equals`, + :py:meth:`Dataset.equals`, :py:meth:`Dataset.identical`, :py:meth:`Dataset.expand_dims`,:py:meth:`Dataset.drop_vars` + (:issue:`6793`, :pull:`7937`) By `Harshitha `_. +- Added page on wrapping chunked numpy-like arrays as alternatives to dask arrays. + (:pull:`7951`) By `Tom Nicholas `_. +- Expanded the page on wrapping numpy-like "duck" arrays. + (:pull:`7911`) By `Tom Nicholas `_. +- Added examples to docstrings of :py:meth:`Dataset.isel`, :py:meth:`Dataset.reduce`, :py:meth:`Dataset.argmin`, + :py:meth:`Dataset.argmax` (:issue:`6793`, :pull:`7881`) + By `Harshitha `_ . + +Internal Changes +~~~~~~~~~~~~~~~~ + +- Allow chunked non-dask arrays (i.e. Cubed arrays) in groupby operations. (:pull:`7941`) + By `Tom Nicholas `_. + + +.. _whats-new.2023.06.0: + +v2023.06.0 (June 21, 2023) +-------------------------- + +This release adds features to ``curvefit``, improves the performance of concatenation, and fixes various bugs. + +Thank to our 13 contributors to this release: +Anderson Banihirwe, Deepak Cherian, dependabot[bot], Illviljan, Juniper Tyree, Justus Magin, Martin Fleischmann, +Mattia Almansi, mgunyho, Rutger van Haasteren, Thomas Nicholas, Tom Nicholas, Tom White. + + +New Features +~~~~~~~~~~~~ + +- Added support for multidimensional initial guess and bounds in :py:meth:`DataArray.curvefit` (:issue:`7768`, :pull:`7821`). + By `András Gunyhó `_. +- Add an ``errors`` option to :py:meth:`Dataset.curve_fit` that allows + returning NaN for the parameters and covariances of failed fits, rather than + failing the whole series of fits (:issue:`6317`, :pull:`7891`). + By `Dominik Stańczak `_ and `András Gunyhó `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + + +Deprecations +~~~~~~~~~~~~ +- Deprecate the `cdms2 `_ conversion methods (:pull:`7876`) + By `Justus Magin `_. + +Performance +~~~~~~~~~~~ +- Improve concatenation performance (:issue:`7833`, :pull:`7824`). + By `Jimmy Westling `_. + +Bug fixes +~~~~~~~~~ +- Fix bug where weighted ``polyfit`` were changing the original object (:issue:`5644`, :pull:`7900`). + By `Mattia Almansi `_. +- Don't call ``CachingFileManager.__del__`` on interpreter shutdown (:issue:`7814`, :pull:`7880`). + By `Justus Magin `_. +- Preserve vlen dtype for empty string arrays (:issue:`7328`, :pull:`7862`). + By `Tom White `_ and `Kai Mühlbauer `_. +- Ensure dtype of reindex result matches dtype of the original DataArray (:issue:`7299`, :pull:`7917`) + By `Anderson Banihirwe `_. +- Fix bug where a zero-length zarr ``chunk_store`` was ignored as if it was ``None`` (:pull:`7923`) + By `Juniper Tyree `_. + +Documentation +~~~~~~~~~~~~~ + +Internal Changes +~~~~~~~~~~~~~~~~ + +- Minor improvements to support of the python `array api standard `_, + internally using the function ``xp.astype()`` instead of the method ``arr.astype()``, as the latter is not in the standard. + (:pull:`7847`) By `Tom Nicholas `_. +- Xarray now uploads nightly wheels to https://pypi.anaconda.org/scientific-python-nightly-wheels/simple/ (:issue:`7863`, :pull:`7865`). + By `Martin Fleischmann `_. +- Stop uploading development wheels to TestPyPI (:pull:`7889`) + By `Justus Magin `_. +- Added an exception catch for ``AttributeError`` along with ``ImportError`` when duck typing the dynamic imports in pycompat.py. This catches some name collisions between packages. (:issue:`7870`, :pull:`7874`) + +.. _whats-new.2023.05.0: + +v2023.05.0 (May 18, 2023) +------------------------- + +This release adds some new methods and operators, updates our deprecation policy for python versions, fixes some bugs with groupby, +and introduces experimental support for alternative chunked parallel array computation backends via a new plugin system! + +**Note:** If you are using a locally-installed development version of xarray then pulling the changes from this release may require you to re-install. +This avoids an error where xarray cannot detect dask via the new entrypoints system introduced in :pull:`7019`. See :issue:`7856` for details. + +Thanks to our 14 contributors: +Alan Brammer, crusaderky, David Stansby, dcherian, Deeksha, Deepak Cherian, Illviljan, James McCreight, +Joe Hamman, Justus Magin, Kyle Sunden, Max Hollmann, mgunyho, and Tom Nicholas + + +New Features +~~~~~~~~~~~~ +- Added new method :py:meth:`DataArray.to_dask_dataframe`, convert a dataarray into a dask dataframe (:issue:`7409`). + By `Deeksha `_. +- Add support for lshift and rshift binary operators (``<<``, ``>>``) on + :py:class:`xr.DataArray` of type :py:class:`int` (:issue:`7727` , :pull:`7741`). + By `Alan Brammer `_. +- Keyword argument `data='array'` to both :py:meth:`xarray.Dataset.to_dict` and + :py:meth:`xarray.DataArray.to_dict` will now return data as the underlying array type. + Python lists are returned for `data='list'` or `data=True`. Supplying `data=False` only returns the schema without data. + ``encoding=True`` returns the encoding dictionary for the underlying variable also. (:issue:`1599`, :pull:`7739`) . + By `James McCreight `_. + +Breaking changes +~~~~~~~~~~~~~~~~ +- adjust the deprecation policy for python to once again align with NEP-29 (:issue:`7765`, :pull:`7793`) + By `Justus Magin `_. + +Performance +~~~~~~~~~~~ +- Optimize ``.dt `` accessor performance with ``CFTimeIndex``. (:pull:`7796`) + By `Deepak Cherian `_. + +Bug fixes +~~~~~~~~~ +- Fix `as_compatible_data` for masked float arrays, now always creates a copy when mask is present (:issue:`2377`, :pull:`7788`). + By `Max Hollmann `_. +- Fix groupby binary ops when grouped array is subset relative to other. (:issue:`7797`). + By `Deepak Cherian `_. +- Fix groupby sum, prod for all-NaN groups with ``flox``. (:issue:`7808`). + By `Deepak Cherian `_. + +Internal Changes +~~~~~~~~~~~~~~~~ +- Experimental support for wrapping chunked array libraries other than dask. + A new ABC is defined - :py:class:`xr.core.parallelcompat.ChunkManagerEntrypoint` - which can be subclassed and then + registered by alternative chunked array implementations. (:issue:`6807`, :pull:`7019`) + By `Tom Nicholas `_. + + +.. _whats-new.2023.04.2: + +v2023.04.2 (April 20, 2023) +--------------------------- + +This is a patch release to fix a bug with binning (:issue:`7766`) + +Bug fixes +~~~~~~~~~ + +- Fix binning when ``labels`` is specified. (:issue:`7766`). + By `Deepak Cherian `_. + + +Documentation +~~~~~~~~~~~~~ +- Added examples to docstrings for :py:meth:`xarray.core.accessor_str.StringAccessor` methods. + (:pull:`7669`) . + By `Mary Gathoni `_. + + +.. _whats-new.2023.04.1: + +v2023.04.1 (April 18, 2023) +--------------------------- + +This is a patch release to fix a bug with binning (:issue:`7759`) + +Bug fixes +~~~~~~~~~ + +- Fix binning by unsorted arrays. (:issue:`7759`) + + +.. _whats-new.2023.04.0: + +v2023.04.0 (April 14, 2023) +--------------------------- + +This release includes support for pandas v2, allows refreshing of backend engines in a session, and removes deprecated backends +for ``rasterio`` and ``cfgrib``. + +Thanks to our 19 contributors: +Chinemere, Tom Coleman, Deepak Cherian, Harshitha, Illviljan, Jessica Scheick, Joe Hamman, Justus Magin, Kai Mühlbauer, Kwonil-Kim, Mary Gathoni, Michael Niklas, Pierre, Scott Henderson, Shreyal Gupta, Spencer Clark, mccloskey, nishtha981, veenstrajelmer + +We welcome the following new contributors to Xarray!: +Mary Gathoni, Harshitha, veenstrajelmer, Chinemere, nishtha981, Shreyal Gupta, Kwonil-Kim, mccloskey. + +New Features +~~~~~~~~~~~~ +- New methods to reset an objects encoding (:py:meth:`Dataset.reset_encoding`, :py:meth:`DataArray.reset_encoding`). + (:issue:`7686`, :pull:`7689`). + By `Joe Hamman `_. +- Allow refreshing backend engines with :py:meth:`xarray.backends.refresh_engines` (:issue:`7478`, :pull:`7523`). + By `Michael Niklas `_. +- Added ability to save ``DataArray`` objects directly to Zarr using :py:meth:`~xarray.DataArray.to_zarr`. + (:issue:`7692`, :pull:`7693`) . + By `Joe Hamman `_. + +Breaking changes +~~~~~~~~~~~~~~~~ +- Remove deprecated rasterio backend in favor of rioxarray (:pull:`7392`). + By `Scott Henderson `_. + +Deprecations +~~~~~~~~~~~~ + +Performance +~~~~~~~~~~~ +- Optimize alignment with ``join="exact", copy=False`` by avoiding copies. (:pull:`7736`) + By `Deepak Cherian `_. +- Avoid unnecessary copies of ``CFTimeIndex``. (:pull:`7735`) + By `Deepak Cherian `_. + +Bug fixes +~~~~~~~~~ + +- Fix :py:meth:`xr.polyval` with non-system standard integer coeffs (:pull:`7619`). + By `Shreyal Gupta `_ and `Michael Niklas `_. +- Improve error message when trying to open a file which you do not have permission to read (:issue:`6523`, :pull:`7629`). + By `Thomas Coleman `_. +- Proper plotting when passing :py:class:`~matplotlib.colors.BoundaryNorm` type argument in :py:meth:`DataArray.plot`. (:issue:`4061`, :issue:`7014`,:pull:`7553`) + By `Jelmer Veenstra `_. +- Ensure the formatting of time encoding reference dates outside the range of + nanosecond-precision datetimes remains the same under pandas version 2.0.0 + (:issue:`7420`, :pull:`7441`). + By `Justus Magin `_ and + `Spencer Clark `_. +- Various `dtype` related fixes needed to support `pandas>=2.0` (:pull:`7724`) + By `Justus Magin `_. +- Preserve boolean dtype within encoding (:issue:`7652`, :pull:`7720`). + By `Kai Mühlbauer `_ + +Documentation +~~~~~~~~~~~~~ + +- Update FAQ page on how do I open format X file as an xarray dataset? (:issue:`1285`, :pull:`7638`) using :py:func:`~xarray.open_dataset` + By `Harshitha `_ , `Tom Nicholas `_. + +Internal Changes +~~~~~~~~~~~~~~~~ + +- Don't assume that arrays read from disk will be Numpy arrays. This is a step toward + enabling reads from a Zarr store using the `Kvikio `_ + or `TensorStore `_ libraries. + (:pull:`6874`). By `Deepak Cherian `_. + +- Remove internal support for reading GRIB files through the ``cfgrib`` backend. ``cfgrib`` now uses the external + backend interface, so no existing code should break. + By `Deepak Cherian `_. +- Implement CF coding functions in ``VariableCoders`` (:pull:`7719`). + By `Kai Mühlbauer `_ + +- Added a config.yml file with messages for the welcome bot when a Github user creates their first ever issue or pull request or has their first PR merged. (:issue:`7685`, :pull:`7685`) + By `Nishtha P `_. + +- Ensure that only nanosecond-precision :py:class:`pd.Timestamp` objects + continue to be used internally under pandas version 2.0.0. This is mainly to + ease the transition to this latest version of pandas. It should be relaxed + when addressing :issue:`7493`. By `Spencer Clark + `_ (:issue:`7707`, :pull:`7731`). + +.. _whats-new.2023.03.0: + +v2023.03.0 (March 22, 2023) +--------------------------- + +This release brings many bug fixes, and some new features. The maximum pandas version is pinned to ``<2`` until we can support the new pandas datetime types. +Thanks to our 19 contributors: +Abel Aoun, Alex Goodman, Deepak Cherian, Illviljan, Jody Klymak, Joe Hamman, Justus Magin, Mary Gathoni, Mathias Hauser, Mattia Almansi, Mick, Oriol Abril-Pla, Patrick Hoefler, Paul Ockenfuß, Pierre, Shreyal Gupta, Spencer Clark, Tom Nicholas, Tom Vo + +New Features +~~~~~~~~~~~~ + +- Fix :py:meth:`xr.cov` and :py:meth:`xr.corr` now support complex valued arrays (:issue:`7340`, :pull:`7392`). + By `Michael Niklas `_. +- Allow indexing along unindexed dimensions with dask arrays + (:issue:`2511`, :issue:`4276`, :issue:`4663`, :pull:`5873`). + By `Abel Aoun `_ and `Deepak Cherian `_. +- Support dask arrays in ``first`` and ``last`` reductions. + By `Deepak Cherian `_. +- Improved performance in ``open_dataset`` for datasets with large object arrays (:issue:`7484`, :pull:`7494`). + By `Alex Goodman `_ and `Deepak Cherian `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + + +Deprecations +~~~~~~~~~~~~ +- Following pandas, the ``base`` and ``loffset`` parameters of + :py:meth:`xr.DataArray.resample` and :py:meth:`xr.Dataset.resample` have been + deprecated and will be removed in a future version of xarray. Using the + ``origin`` or ``offset`` parameters is recommended as a replacement for using + the ``base`` parameter and using time offset arithmetic is recommended as a + replacement for using the ``loffset`` parameter (:pull:`8459`). By `Spencer + Clark `_. + + +Bug fixes +~~~~~~~~~ + +- Improve error message when using in :py:meth:`Dataset.drop_vars` to state which variables can't be dropped. (:pull:`7518`) + By `Tom Nicholas `_. +- Require to explicitly defining optional dimensions such as hue + and markersize for scatter plots. (:issue:`7314`, :pull:`7277`). + By `Jimmy Westling `_. +- Fix matplotlib raising a UserWarning when plotting a scatter plot + with an unfilled marker (:issue:`7313`, :pull:`7318`). + By `Jimmy Westling `_. +- Fix issue with ``max_gap`` in ``interpolate_na``, when applied to + multidimensional arrays. (:issue:`7597`, :pull:`7598`). + By `Paul Ockenfuß `_. +- Fix :py:meth:`DataArray.plot.pcolormesh` which now works if one of the coordinates has str dtype (:issue:`6775`, :pull:`7612`). + By `Michael Niklas `_. + +Documentation +~~~~~~~~~~~~~ + +- Clarify language in contributor's guide (:issue:`7495`, :pull:`7595`) + By `Tom Nicholas `_. + +Internal Changes +~~~~~~~~~~~~~~~~ +- Pin pandas to ``<2``. By `Deepak Cherian `_. + +.. _whats-new.2023.02.0: + +v2023.02.0 (Feb 7, 2023) +------------------------ + +This release brings a major upgrade to :py:func:`xarray.concat`, many bug fixes, +and a bump in supported dependency versions. Thanks to our 11 contributors: +Aron Gergely, Deepak Cherian, Illviljan, James Bourbeau, Joe Hamman, +Justus Magin, Hauke Schulz, Kai Mühlbauer, Ken Mankoff, Spencer Clark, Tom Nicholas. + +Breaking changes +~~~~~~~~~~~~~~~~ + +- Support for ``python 3.8`` has been dropped and the minimum versions of some + dependencies were changed (:pull:`7461`): + + ===================== ========= ======== + Package Old New + ===================== ========= ======== + python 3.8 3.9 + numpy 1.20 1.21 + pandas 1.3 1.4 + dask 2021.11 2022.1 + distributed 2021.11 2022.1 + h5netcdf 0.11 0.13 + lxml 4.6 4.7 + numba 5.4 5.5 + ===================== ========= ======== + +Deprecations +~~~~~~~~~~~~ +- Following pandas, the `closed` parameters of :py:func:`cftime_range` and + :py:func:`date_range` are deprecated in favor of the `inclusive` parameters, + and will be removed in a future version of xarray (:issue:`6985`:, + :pull:`7373`). By `Spencer Clark `_. + +Bug fixes +~~~~~~~~~ +- :py:func:`xarray.concat` can now concatenate variables present in some datasets but + not others (:issue:`508`, :pull:`7400`). + By `Kai Mühlbauer `_ and `Scott Chamberlin `_. +- Handle ``keep_attrs`` option in binary operators of :py:meth:`Dataset` (:issue:`7390`, :pull:`7391`). + By `Aron Gergely `_. +- Improve error message when using dask in :py:func:`apply_ufunc` with ``output_sizes`` not supplied. (:pull:`7509`) + By `Tom Nicholas `_. +- :py:func:`xarray.Dataset.to_zarr` now drops variable encodings that have been added by xarray during reading + a dataset. (:issue:`7129`, :pull:`7500`). + By `Hauke Schulz `_. + +Documentation +~~~~~~~~~~~~~ +- Mention the `flox package `_ in GroupBy documentation and docstrings. + By `Deepak Cherian `_. + + +.. _whats-new.2023.01.0: + +v2023.01.0 (Jan 17, 2023) +------------------------- + +This release includes a number of bug fixes. Thanks to the 14 contributors to this release: +Aron Gergely, Benoit Bovy, Deepak Cherian, Ian Carroll, Illviljan, Joe Hamman, Justus Magin, Mark Harfouche, +Matthew Roeschke, Paige Martin, Pierre, Sam Levang, Tom White, stefank0. + +Breaking changes +~~~~~~~~~~~~~~~~ + +- :py:meth:`CFTimeIndex.get_loc` has removed the ``method`` and ``tolerance`` keyword arguments. + Use ``.get_indexer([key], method=..., tolerance=...)`` instead (:pull:`7361`). + By `Matthew Roeschke `_. + +Bug fixes +~~~~~~~~~ + +- Avoid in-memory broadcasting when converting to a dask dataframe + using ``.to_dask_dataframe.`` (:issue:`6811`, :pull:`7472`). + By `Jimmy Westling `_. +- Accessing the property ``.nbytes`` of a DataArray, or Variable no longer + accidentally triggers loading the variable into memory. +- Allow numpy-only objects in :py:func:`where` when ``keep_attrs=True`` (:issue:`7362`, :pull:`7364`). + By `Sam Levang `_. +- add a ``keep_attrs`` parameter to :py:meth:`Dataset.pad`, :py:meth:`DataArray.pad`, + and :py:meth:`Variable.pad` (:pull:`7267`). + By `Justus Magin `_. +- Fixed performance regression in alignment between indexed and non-indexed objects + of the same shape (:pull:`7382`). + By `Benoît Bovy `_. +- Preserve original dtype on accessing MultiIndex levels (:issue:`7250`, + :pull:`7393`). By `Ian Carroll `_. + +Internal Changes +~~~~~~~~~~~~~~~~ +- Add the pre-commit hook `absolufy-imports` to convert relative xarray imports to + absolute imports (:pull:`7204`, :pull:`7370`). + By `Jimmy Westling `_. + +.. _whats-new.2022.12.0: + +v2022.12.0 (2022 Dec 2) +----------------------- + +This release includes a number of bug fixes and experimental support for Zarr V3. +Thanks to the 16 contributors to this release: +Deepak Cherian, Francesco Zanetta, Gregory Lee, Illviljan, Joe Hamman, Justus Magin, Luke Conibear, Mark Harfouche, Mathias Hauser, +Mick, Mike Taves, Sam Levang, Spencer Clark, Tom Nicholas, Wei Ji, templiert + +New Features +~~~~~~~~~~~~ +- Enable using `offset` and `origin` arguments in :py:meth:`DataArray.resample` + and :py:meth:`Dataset.resample` (:issue:`7266`, :pull:`7284`). By `Spencer + Clark `_. +- Add experimental support for Zarr's in-progress V3 specification. (:pull:`6475`). + By `Gregory Lee `_ and `Joe Hamman `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + +- The minimum versions of some dependencies were changed (:pull:`7300`): + + ========================== ========= ======== + Package Old New + ========================== ========= ======== + boto 1.18 1.20 + cartopy 0.19 0.20 + distributed 2021.09 2021.11 + dask 2021.09 2021.11 + h5py 3.1 3.6 + hdf5 1.10 1.12 + matplotlib-base 3.4 3.5 + nc-time-axis 1.3 1.4 + netcdf4 1.5.3 1.5.7 + packaging 20.3 21.3 + pint 0.17 0.18 + pseudonetcdf 3.1 3.2 + typing_extensions 3.10 4.0 + ========================== ========= ======== + +Deprecations +~~~~~~~~~~~~ +- The PyNIO backend has been deprecated (:issue:`4491`, :pull:`7301`). + By `Joe Hamman `_. + +Bug fixes +~~~~~~~~~ +- Fix handling of coordinate attributes in :py:func:`where`. (:issue:`7220`, :pull:`7229`) + By `Sam Levang `_. +- Import ``nc_time_axis`` when needed (:issue:`7275`, :pull:`7276`). + By `Michael Niklas `_. +- Fix static typing of :py:meth:`xr.polyval` (:issue:`7312`, :pull:`7315`). + By `Michael Niklas `_. +- Fix multiple reads on fsspec S3 files by resetting file pointer to 0 when reading file streams (:issue:`6813`, :pull:`7304`). + By `David Hoese `_ and `Wei Ji Leong `_. +- Fix :py:meth:`Dataset.assign_coords` resetting all dimension coordinates to default (pandas) index (:issue:`7346`, :pull:`7347`). + By `Benoît Bovy `_. + +Documentation +~~~~~~~~~~~~~ + +- Add example of reading and writing individual groups to a single netCDF file to I/O docs page. (:pull:`7338`) + By `Tom Nicholas `_. + +Internal Changes +~~~~~~~~~~~~~~~~ + + +.. _whats-new.2022.11.0: + +v2022.11.0 (Nov 4, 2022) +------------------------ + +This release brings a number of bugfixes and documentation improvements. Both text and HTML +reprs now have a new "Indexes" section, which we expect will help with development of new +Index objects. This release also features more support for the Python Array API. + +Many thanks to the 16 contributors to this release: Daniel Goman, Deepak Cherian, Illviljan, Jessica Scheick, Justus Magin, Mark Harfouche, Maximilian Roos, Mick, Patrick Naylor, Pierre, Spencer Clark, Stephan Hoyer, Tom Nicholas, Tom White + +New Features +~~~~~~~~~~~~ + +- Add static typing to plot accessors (:issue:`6949`, :pull:`7052`). + By `Michael Niklas `_. +- Display the indexes in a new section of the text and HTML reprs + (:pull:`6795`, :pull:`7183`, :pull:`7185`) + By `Justus Magin `_ and `Benoît Bovy `_. +- Added methods :py:meth:`DataArrayGroupBy.cumprod` and :py:meth:`DatasetGroupBy.cumprod`. + (:pull:`5816`) + By `Patrick Naylor `_ + +Breaking changes +~~~~~~~~~~~~~~~~ + +- ``repr(ds)`` may not show the same result because it doesn't load small, + lazy data anymore. Use ``ds.head().load()`` when wanting to see just a sample + of the data. (:issue:`6722`, :pull:`7203`). + By `Jimmy Westling `_. +- Many arguments of plotmethods have been made keyword-only. +- ``xarray.plot.plot`` module renamed to ``xarray.plot.dataarray_plot`` to prevent + shadowing of the ``plot`` method. (:issue:`6949`, :pull:`7052`). + By `Michael Niklas `_. + +Deprecations +~~~~~~~~~~~~ + +- Positional arguments for all plot methods have been deprecated (:issue:`6949`, :pull:`7052`). + By `Michael Niklas `_. +- ``xarray.plot.FacetGrid.axes`` has been renamed to ``xarray.plot.FacetGrid.axs`` + because it's not clear if ``axes`` refers to single or multiple ``Axes`` instances. + This aligns with ``matplotlib.pyplot.subplots``. (:pull:`7194`) + By `Jimmy Westling `_. + +Bug fixes +~~~~~~~~~ + +- Explicitly opening a file multiple times (e.g., after modifying it on disk) + now reopens the file from scratch for h5netcdf and scipy netCDF backends, + rather than reusing a cached version (:issue:`4240`, :issue:`4862`). + By `Stephan Hoyer `_. +- Fixed bug where :py:meth:`Dataset.coarsen.construct` would demote non-dimension coordinates to variables. (:pull:`7233`) + By `Tom Nicholas `_. +- Raise a TypeError when trying to plot empty data (:issue:`7156`, :pull:`7228`). + By `Michael Niklas `_. + +Documentation +~~~~~~~~~~~~~ + +- Improves overall documentation around available backends, including adding docstrings for :py:func:`xarray.backends.list_engines` + Add :py:meth:`__str__` to surface the new :py:class:`BackendEntrypoint` ``description`` + and ``url`` attributes. (:issue:`6577`, :pull:`7000`) + By `Jessica Scheick `_. +- Created docstring examples for :py:meth:`DataArray.cumsum`, :py:meth:`DataArray.cumprod`, :py:meth:`Dataset.cumsum`, :py:meth:`Dataset.cumprod`, :py:meth:`DatasetGroupBy.cumsum`, :py:meth:`DataArrayGroupBy.cumsum`. (:issue:`5816`, :pull:`7152`) + By `Patrick Naylor `_ +- Add example of using :py:meth:`DataArray.coarsen.construct` to User Guide. (:pull:`7192`) + By `Tom Nicholas `_. +- Rename ``axes`` to ``axs`` in plotting to align with ``matplotlib.pyplot.subplots``. (:pull:`7194`) + By `Jimmy Westling `_. +- Add documentation of specific BackendEntrypoints (:pull:`7200`). + By `Michael Niklas `_. +- Add examples to docstring for :py:meth:`DataArray.drop_vars`, :py:meth:`DataArray.reindex_like`, :py:meth:`DataArray.interp_like`. (:issue:`6793`, :pull:`7123`) + By `Daniel Goman `_. + +Internal Changes +~~~~~~~~~~~~~~~~ +- Doctests fail on any warnings (:pull:`7166`) + By `Maximilian Roos `_. +- Improve import time by lazy loading ``dask.distributed`` (:pull: `7172`). +- Explicitly specify ``longdouble=False`` in :py:func:`cftime.date2num` when + encoding times to preserve existing behavior and prevent future errors when it + is eventually set to ``True`` by default in cftime (:pull:`7171`). By + `Spencer Clark `_. +- Improved import time by lazily importing backend modules, matplotlib, dask.array and flox. (:issue:`6726`, :pull:`7179`) + By `Michael Niklas `_. +- Emit a warning under the development version of pandas when we convert + non-nanosecond precision datetime or timedelta values to nanosecond precision. + This was required in the past, because pandas previously was not compatible + with non-nanosecond precision values. However pandas is currently working + towards removing this restriction. When things stabilize in pandas we will + likely consider relaxing this behavior in xarray as well (:issue:`7175`, + :pull:`7201`). By `Spencer Clark `_. + +.. _whats-new.2022.10.0: + +v2022.10.0 (Oct 14 2022) +------------------------ + +This release brings numerous bugfixes, a change in minimum supported versions, +and a new scatter plot method for DataArrays. + +Many thanks to 11 contributors to this release: Anderson Banihirwe, Benoit Bovy, +Dan Adriaansen, Illviljan, Justus Magin, Lukas Bindreiter, Mick, Patrick Naylor, +Spencer Clark, Thomas Nicholas + + +New Features +~~~~~~~~~~~~ + +- Add scatter plot for datarrays. Scatter plots now also supports 3d plots with + the z argument. (:pull:`6778`) + By `Jimmy Westling `_. +- Include the variable name in the error message when CF decoding fails to allow + for easier identification of problematic variables (:issue:`7145`, :pull:`7147`). + By `Spencer Clark `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + +- The minimum versions of some dependencies were changed: + + ========================== ========= ======== + Package Old New + ========================== ========= ======== + cftime 1.4 1.5 + distributed 2021.08 2021.09 + dask 2021.08 2021.09 + iris 2.4 3.1 + nc-time-axis 1.2 1.3 + numba 0.53 0.54 + numpy 1.19 1.20 + pandas 1.2 1.3 + packaging 20.0 21.0 + scipy 1.6 1.7 + sparse 0.12 0.13 + typing_extensions 3.7 3.10 + zarr 2.8 2.10 + ========================== ========= ======== + + +Bug fixes +~~~~~~~~~ + +- Remove nested function from :py:func:`open_mfdataset` to allow Dataset objects to be pickled. (:issue:`7109`, :pull:`7116`) + By `Daniel Adriaansen `_. +- Support for recursively defined Arrays. Fixes repr and deepcopy. (:issue:`7111`, :pull:`7112`) + By `Michael Niklas `_. +- Fixed :py:meth:`Dataset.transpose` to raise a more informative error. (:issue:`6502`, :pull:`7120`) + By `Patrick Naylor `_ +- Fix groupby on a multi-index level coordinate and fix + :py:meth:`DataArray.to_index` for multi-index levels (convert to single index). + (:issue:`6836`, :pull:`7105`) + By `Benoît Bovy `_. +- Support for open_dataset backends that return datasets containing multi-indexes (:issue:`7139`, :pull:`7150`) + By `Lukas Bindreiter `_. + + +.. _whats-new.2022.09.0: + +v2022.09.0 (September 30, 2022) +------------------------------- + +This release brings a large number of bugfixes and documentation improvements, as well as an external interface for +setting custom indexes! + +Many thanks to our 40 contributors: + +Anderson Banihirwe, Andrew Ronald Friedman, Bane Sullivan, Benoit Bovy, ColemanTom, Deepak Cherian, +Dimitri Papadopoulos Orfanos, Emma Marshall, Fabian Hofmann, Francesco Nattino, ghislainp, Graham Inggs, Hauke Schulz, +Illviljan, James Bourbeau, Jody Klymak, Julia Signell, Justus Magin, Keewis, Ken Mankoff, Luke Conibear, Mathias Hauser, +Max Jones, mgunyho, Michael Delgado, Mick, Mike Taves, Oliver Lopez, Patrick Naylor, Paul Hockett, Pierre Manchon, +Ray Bell, Riley Brady, Sam Levang, Spencer Clark, Stefaan Lippens, Tom Nicholas, Tom White, Travis A. O'Brien, +and Zachary Moon. + +New Features +~~~~~~~~~~~~ + +- Add :py:meth:`Dataset.set_xindex` and :py:meth:`Dataset.drop_indexes` and + their DataArray counterpart for setting and dropping pandas or custom indexes + given a set of arbitrary coordinates. (:pull:`6971`) + By `Benoît Bovy `_ and `Justus Magin `_. +- Enable taking the mean of dask-backed :py:class:`cftime.datetime` arrays + (:pull:`6556`, :pull:`6940`). + By `Deepak Cherian `_ and `Spencer Clark `_. + +Bug fixes +~~~~~~~~~ + +- Allow reading netcdf files where the 'units' attribute is a number. (:pull:`7085`) + By `Ghislain Picard `_. +- Allow decoding of 0 sized datetimes. (:issue:`1329`, :pull:`6882`) + By `Deepak Cherian `_. +- Make sure DataArray.name is always a string when used as label for plotting. (:issue:`6826`, :pull:`6832`) + By `Jimmy Westling `_. +- :py:attr:`DataArray.nbytes` now uses the ``nbytes`` property of the underlying array if available. (:pull:`6797`) + By `Max Jones `_. +- Rely on the array backend for string formatting. (:pull:`6823`). + By `Jimmy Westling `_. +- Fix incompatibility with numpy 1.20. (:issue:`6818`, :pull:`6821`) + By `Michael Niklas `_. +- Fix side effects on index coordinate metadata after aligning objects. (:issue:`6852`, :pull:`6857`) + By `Benoît Bovy `_. +- Make FacetGrid.set_titles send kwargs correctly using `handle.update(kwargs)`. (:issue:`6839`, :pull:`6843`) + By `Oliver Lopez `_. +- Fix bug where index variables would be changed inplace. (:issue:`6931`, :pull:`6938`) + By `Michael Niklas `_. +- Allow taking the mean over non-time dimensions of datasets containing + dask-backed cftime arrays. (:issue:`5897`, :pull:`6950`) + By `Spencer Clark `_. +- Harmonize returned multi-indexed indexes when applying ``concat`` along new dimension. (:issue:`6881`, :pull:`6889`) + By `Fabian Hofmann `_. +- Fix step plots with ``hue`` arg. (:pull:`6944`) + By `András Gunyhó `_. +- Avoid use of random numbers in `test_weighted.test_weighted_operations_nonequal_coords`. (:issue:`6504`, :pull:`6961`) + By `Luke Conibear `_. +- Fix multiple regression issues with :py:meth:`Dataset.set_index` and + :py:meth:`Dataset.reset_index`. (:pull:`6992`) + By `Benoît Bovy `_. +- Raise a ``UserWarning`` when renaming a coordinate or a dimension creates a + non-indexed dimension coordinate, and suggest the user creating an index + either with ``swap_dims`` or ``set_index``. (:issue:`6607`, :pull:`6999`) + By `Benoît Bovy `_. +- Use ``keep_attrs=True`` in grouping and resampling operations by default. (:issue:`7012`) + This means :py:attr:`Dataset.attrs` and :py:attr:`DataArray.attrs` are now preserved by default. + By `Deepak Cherian `_. +- ``Dataset.encoding['source']`` now exists when reading from a Path object. (:issue:`5888`, :pull:`6974`) + By `Thomas Coleman `_. +- Better dtype consistency for ``rolling.mean()``. (:issue:`7062`, :pull:`7063`) + By `Sam Levang `_. +- Allow writing NetCDF files including only dimensionless variables using the distributed or multiprocessing scheduler. (:issue:`7013`, :pull:`7040`) + By `Francesco Nattino `_. +- Fix deepcopy of attrs and encoding of DataArrays and Variables. (:issue:`2835`, :pull:`7089`) + By `Michael Niklas `_. +- Fix bug where subplot_kwargs were not working when plotting with figsize, size or aspect. (:issue:`7078`, :pull:`7080`) + By `Michael Niklas `_. + +Documentation +~~~~~~~~~~~~~ + +- Update merge docstrings. (:issue:`6935`, :pull:`7033`) + By `Zach Moon `_. +- Raise a more informative error when trying to open a non-existent zarr store. (:issue:`6484`, :pull:`7060`) + By `Sam Levang `_. +- Added examples to docstrings for :py:meth:`DataArray.expand_dims`, :py:meth:`DataArray.drop_duplicates`, :py:meth:`DataArray.reset_coords`, :py:meth:`DataArray.equals`, :py:meth:`DataArray.identical`, :py:meth:`DataArray.broadcast_equals`, :py:meth:`DataArray.bfill`, :py:meth:`DataArray.ffill`, :py:meth:`DataArray.fillna`, :py:meth:`DataArray.dropna`, :py:meth:`DataArray.drop_isel`, :py:meth:`DataArray.drop_sel`, :py:meth:`DataArray.head`, :py:meth:`DataArray.tail`. (:issue:`5816`, :pull:`7088`) + By `Patrick Naylor `_. +- Add missing docstrings to various array properties. (:pull:`7090`) + By `Tom Nicholas `_. + +Internal Changes +~~~~~~~~~~~~~~~~ +- Added test for DataArray attrs deepcopy recursion/nested attrs. (:issue:`2835`, :pull:`7086`) + By `Paul hockett `_. + +.. _whats-new.2022.06.0: + +v2022.06.0 (July 21, 2022) +-------------------------- + +This release brings a number of bug fixes and improvements, most notably a major internal +refactor of the indexing functionality, the use of `flox`_ in ``groupby`` operations, +and experimental support for the new Python `Array API standard `_. +It also stops testing support for the abandoned PyNIO. + +Much effort has been made to preserve backwards compatibility as part of the indexing refactor. +We are aware of one `unfixed issue `_. + +Please also see the `whats-new.2022.06.0rc0`_ for a full list of changes. + +Many thanks to our 18 contributors: +Bane Sullivan, Deepak Cherian, Dimitri Papadopoulos Orfanos, Emma Marshall, Hauke Schulz, Illviljan, +Julia Signell, Justus Magin, Keewis, Mathias Hauser, Michael Delgado, Mick, Pierre Manchon, Ray Bell, +Spencer Clark, Stefaan Lippens, Tom White, Travis A. O'Brien, + +New Features +~~~~~~~~~~~~ + +- Add :py:attr:`Dataset.dtypes`, :py:attr:`core.coordinates.DatasetCoordinates.dtypes`, + :py:attr:`core.coordinates.DataArrayCoordinates.dtypes` properties: Mapping from variable names to dtypes. + (:pull:`6706`) + By `Michael Niklas `_. +- Initial typing support for :py:meth:`groupby`, :py:meth:`rolling`, :py:meth:`rolling_exp`, + :py:meth:`coarsen`, :py:meth:`weighted`, :py:meth:`resample`, + (:pull:`6702`) + By `Michael Niklas `_. +- Experimental support for wrapping any array type that conforms to the python + `array api standard `_. (:pull:`6804`) + By `Tom White `_. +- Allow string formatting of scalar DataArrays. (:pull:`5981`) + By `fmaussion `_. + +Bug fixes +~~~~~~~~~ + +- :py:meth:`save_mfdataset` now passes ``**kwargs`` on to :py:meth:`Dataset.to_netcdf`, + allowing the ``encoding`` and ``unlimited_dims`` options with :py:meth:`save_mfdataset`. + (:issue:`6684`) + By `Travis A. O'Brien `_. +- Fix backend support of pydap versions <3.3.0 (:issue:`6648`, :pull:`6656`). + By `Hauke Schulz `_. +- :py:meth:`Dataset.where` with ``drop=True`` now behaves correctly with mixed dimensions. + (:issue:`6227`, :pull:`6690`) + By `Michael Niklas `_. +- Accommodate newly raised ``OutOfBoundsTimedelta`` error in the development version of + pandas when decoding times outside the range that can be represented with + nanosecond-precision values (:issue:`6716`, :pull:`6717`). + By `Spencer Clark `_. +- :py:meth:`open_dataset` with dask and ``~`` in the path now resolves the home directory + instead of raising an error. (:issue:`6707`, :pull:`6710`) + By `Michael Niklas `_. +- :py:meth:`DataArrayRolling.__iter__` with ``center=True`` now works correctly. + (:issue:`6739`, :pull:`6744`) + By `Michael Niklas `_. + +Internal Changes +~~~~~~~~~~~~~~~~ + +- ``xarray.core.groupby``, ``xarray.core.rolling``, + ``xarray.core.rolling_exp``, ``xarray.core.weighted`` + and ``xarray.core.resample`` modules are no longer imported by default. + (:pull:`6702`) + +.. _whats-new.2022.06.0rc0: + +v2022.06.0rc0 (9 June 2022) +--------------------------- + +This pre-release brings a number of bug fixes and improvements, most notably a major internal +refactor of the indexing functionality and the use of `flox`_ in ``groupby`` operations. It also stops +testing support for the abandoned PyNIO. + +Install it using + +:: + + mamba create -n python=3.10 xarray + python -m pip install --pre --upgrade --no-deps xarray + + +Many thanks to the 39 contributors: + +Abel Soares Siqueira, Alex Santana, Anderson Banihirwe, Benoit Bovy, Blair Bonnett, Brewster +Malevich, brynjarmorka, Charles Stern, Christian Jauvin, Deepak Cherian, Emma Marshall, Fabien +Maussion, Greg Behm, Guelate Seyo, Illviljan, Joe Hamman, Joseph K Aicher, Justus Magin, Kevin Paul, +Louis Stenger, Mathias Hauser, Mattia Almansi, Maximilian Roos, Michael Bauer, Michael Delgado, +Mick, ngam, Oleh Khoma, Oriol Abril-Pla, Philippe Blain, PLSeuJ, Sam Levang, Spencer Clark, Stan +West, Thomas Nicholas, Thomas Vogt, Tom White, Xianxiang Li + +Known Regressions +~~~~~~~~~~~~~~~~~ + +- `reset_coords(drop=True)` does not create indexes (:issue:`6607`) + +New Features +~~~~~~~~~~~~ + +- The `zarr` backend is now able to read NCZarr. + By `Mattia Almansi `_. +- Add a weighted ``quantile`` method to :py:class:`~core.weighted.DatasetWeighted` and + :py:class:`~core.weighted.DataArrayWeighted` (:pull:`6059`). + By `Christian Jauvin `_ and `David Huard `_. +- Add a ``create_index=True`` parameter to :py:meth:`Dataset.stack` and + :py:meth:`DataArray.stack` so that the creation of multi-indexes is optional + (:pull:`5692`). + By `Benoît Bovy `_. +- Multi-index levels are now accessible through their own, regular coordinates + instead of virtual coordinates (:pull:`5692`). + By `Benoît Bovy `_. +- Add a ``display_values_threshold`` option to control the total number of array + elements which trigger summarization rather than full repr in (numpy) array + detailed views of the html repr (:pull:`6400`). + By `Benoît Bovy `_. +- Allow passing chunks in ``kwargs`` form to :py:meth:`Dataset.chunk`, :py:meth:`DataArray.chunk`, and + :py:meth:`Variable.chunk`. (:pull:`6471`) + By `Tom Nicholas `_. +- Add :py:meth:`core.groupby.DatasetGroupBy.cumsum` and :py:meth:`core.groupby.DataArrayGroupBy.cumsum`. + By `Vladislav Skripniuk `_ and `Deepak Cherian `_. (:pull:`3147`, :pull:`6525`, :issue:`3141`) +- Expose `inline_array` kwarg from `dask.array.from_array` in :py:func:`open_dataset`, :py:meth:`Dataset.chunk`, + :py:meth:`DataArray.chunk`, and :py:meth:`Variable.chunk`. (:pull:`6471`) +- Expose the ``inline_array`` kwarg from :py:func:`dask.array.from_array` in :py:func:`open_dataset`, + :py:meth:`Dataset.chunk`, :py:meth:`DataArray.chunk`, and :py:meth:`Variable.chunk`. (:pull:`6471`) + By `Tom Nicholas `_. +- :py:func:`polyval` now supports :py:class:`Dataset` and :py:class:`DataArray` args of any shape, + is faster and requires less memory. (:pull:`6548`) + By `Michael Niklas `_. +- Improved overall typing. +- :py:meth:`Dataset.to_dict` and :py:meth:`DataArray.to_dict` may now optionally include encoding + attributes. (:pull:`6635`) + By `Joe Hamman `_. +- Upload development versions to `TestPyPI `_. + By `Justus Magin `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + +- PyNIO support is now untested. The minimum versions of some dependencies were changed: + + =============== ===== ==== + Package Old New + =============== ===== ==== + cftime 1.2 1.4 + dask 2.30 2021.4 + distributed 2.30 2021.4 + h5netcdf 0.8 0.11 + matplotlib-base 3.3 3.4 + numba 0.51 0.53 + numpy 1.18 1.19 + pandas 1.1 1.2 + pint 0.16 0.17 + rasterio 1.1 1.2 + scipy 1.5 1.6 + sparse 0.11 0.12 + zarr 2.5 2.8 + =============== ===== ==== + +- The Dataset and DataArray ``rename```` methods do not implicitly add or drop + indexes. (:pull:`5692`). + By `Benoît Bovy `_. +- Many arguments like ``keep_attrs``, ``axis``, and ``skipna`` are now keyword + only for all reduction operations like ``.mean``. + By `Deepak Cherian `_, `Jimmy Westling `_. +- Xarray's ufuncs have been removed, now that they can be replaced by numpy's ufuncs in all + supported versions of numpy. + By `Maximilian Roos `_. +- :py:meth:`xr.polyval` now uses the ``coord`` argument directly instead of its index coordinate. + (:pull:`6548`) + By `Michael Niklas `_. + +Bug fixes +~~~~~~~~~ + +- :py:meth:`Dataset.to_zarr` now allows to write all attribute types supported by `zarr-python`. + By `Mattia Almansi `_. +- Set ``skipna=None`` for all ``quantile`` methods (e.g. :py:meth:`Dataset.quantile`) and + ensure it skips missing values for float dtypes (consistent with other methods). This should + not change the behavior (:pull:`6303`). + By `Mathias Hauser `_. +- Many bugs fixed by the explicit indexes refactor, mainly related to multi-index (virtual) + coordinates. See the corresponding pull-request on GitHub for more details. (:pull:`5692`). + By `Benoît Bovy `_. +- Fixed "unhashable type" error trying to read NetCDF file with variable having its 'units' + attribute not ``str`` (e.g. ``numpy.ndarray``) (:issue:`6368`). + By `Oleh Khoma `_. +- Omit warning about specified dask chunks separating chunks on disk when the + underlying array is empty (e.g., because of an empty dimension) (:issue:`6401`). + By `Joseph K Aicher `_. +- Fixed the poor html repr performance on large multi-indexes (:pull:`6400`). + By `Benoît Bovy `_. +- Allow fancy indexing of duck dask arrays along multiple dimensions. (:pull:`6414`) + By `Justus Magin `_. +- In the API for backends, support dimensions that express their preferred chunk sizes + as a tuple of integers. (:issue:`6333`, :pull:`6334`) + By `Stan West `_. +- Fix bug in :py:func:`where` when passing non-xarray objects with ``keep_attrs=True``. (:issue:`6444`, :pull:`6461`) + By `Sam Levang `_. +- Allow passing both ``other`` and ``drop=True`` arguments to :py:meth:`DataArray.where` + and :py:meth:`Dataset.where` (:pull:`6466`, :pull:`6467`). + By `Michael Delgado `_. +- Ensure dtype encoding attributes are not added or modified on variables that contain datetime-like + values prior to being passed to :py:func:`xarray.conventions.decode_cf_variable` (:issue:`6453`, + :pull:`6489`). + By `Spencer Clark `_. +- Dark themes are now properly detected in Furo-themed Sphinx documents (:issue:`6500`, :pull:`6501`). + By `Kevin Paul `_. +- :py:meth:`Dataset.isel`, :py:meth:`DataArray.isel` with `drop=True` works as intended with scalar :py:class:`DataArray` indexers. + (:issue:`6554`, :pull:`6579`) + By `Michael Niklas `_. +- Fixed silent overflow issue when decoding times encoded with 32-bit and below + unsigned integer data types (:issue:`6589`, :pull:`6598`). + By `Spencer Clark `_. +- Fixed ``.chunks`` loading lazy data (:issue:`6538`). + By `Deepak Cherian `_. + +Documentation +~~~~~~~~~~~~~ + +- Revise the documentation for developers on specifying a backend's preferred chunk + sizes. In particular, correct the syntax and replace lists with tuples in the + examples. (:issue:`6333`, :pull:`6334`) + By `Stan West `_. +- Mention that :py:meth:`DataArray.rename` can rename coordinates. + (:issue:`5458`, :pull:`6665`) + By `Michael Niklas `_. +- Added examples to :py:meth:`Dataset.thin` and :py:meth:`DataArray.thin` + By `Emma Marshall `_. + +Performance +~~~~~~~~~~~ + +- GroupBy binary operations are now vectorized. + Previously this involved looping over all groups. (:issue:`5804`, :pull:`6160`) + By `Deepak Cherian `_. +- Substantially improved GroupBy operations using `flox `_. + This is auto-enabled when ``flox`` is installed. Use ``xr.set_options(use_flox=False)`` to use + the old algorithm. (:issue:`4473`, :issue:`4498`, :issue:`659`, :issue:`2237`, :pull:`271`). + By `Deepak Cherian `_, `Anderson Banihirwe `_, `Jimmy Westling `_. + +Internal Changes +~~~~~~~~~~~~~~~~ + +- Many internal changes due to the explicit indexes refactor. See the + corresponding pull-request on GitHub for more details. (:pull:`5692`). + By `Benoît Bovy `_. + +.. _whats-new.2022.03.0: + +v2022.03.0 (2 March 2022) +------------------------- + +This release brings a number of small improvements, as well as a move to `calendar versioning `_ (:issue:`6176`). + +Many thanks to the 16 contributors to the v2022.02.0 release! + +Aaron Spring, Alan D. Snow, Anderson Banihirwe, crusaderky, Illviljan, Joe Hamman, Jonas Gliß, +Lukas Pilz, Martin Bergemann, Mathias Hauser, Maximilian Roos, Romain Caneill, Stan West, Stijn Van Hoey, +Tobias Kölling, and Tom Nicholas. + + +New Features +~~~~~~~~~~~~ + +- Enabled multiplying tick offsets by floats. Allows ``float`` ``n`` in + :py:meth:`CFTimeIndex.shift` if ``shift_freq`` is between ``Day`` + and ``Microsecond``. (:issue:`6134`, :pull:`6135`). + By `Aaron Spring `_. +- Enable providing more keyword arguments to the `pydap` backend when reading + OpenDAP datasets (:issue:`6274`). + By `Jonas Gliß `. +- Allow :py:meth:`DataArray.drop_duplicates` to drop duplicates along multiple dimensions at once, + and add :py:meth:`Dataset.drop_duplicates`. (:pull:`6307`) + By `Tom Nicholas `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + +- Renamed the ``interpolation`` keyword of all ``quantile`` methods (e.g. :py:meth:`DataArray.quantile`) + to ``method`` for consistency with numpy v1.22.0 (:pull:`6108`). + By `Mathias Hauser `_. + +Deprecations +~~~~~~~~~~~~ + + +Bug fixes +~~~~~~~~~ + +- Variables which are chunked using dask in larger (but aligned) chunks than the target zarr chunk size + can now be stored using `to_zarr()` (:pull:`6258`) By `Tobias Kölling `_. +- Multi-file datasets containing encoded :py:class:`cftime.datetime` objects can be read in parallel again (:issue:`6226`, :pull:`6249`, :pull:`6305`). By `Martin Bergemann `_ and `Stan West `_. + +Documentation +~~~~~~~~~~~~~ + +- Delete files of datasets saved to disk while building the documentation and enable + building on Windows via `sphinx-build` (:pull:`6237`). + By `Stan West `_. + + +Internal Changes +~~~~~~~~~~~~~~~~ + + +.. _whats-new.0.21.1: + +v0.21.1 (31 January 2022) +------------------------- + +This is a bugfix release to resolve (:issue:`6216`, :pull:`6207`). + +Bug fixes +~~~~~~~~~ +- Add `packaging` as a dependency to Xarray (:issue:`6216`, :pull:`6207`). + By `Sebastian Weigand `_ and `Joe Hamman `_. + + +.. _whats-new.0.21.0: + +v0.21.0 (27 January 2022) +------------------------- + +Many thanks to the 20 contributors to the v0.21.0 release! + +Abel Aoun, Anderson Banihirwe, Ant Gib, Chris Roat, Cindy Chiao, +Deepak Cherian, Dominik Stańczak, Fabian Hofmann, Illviljan, Jody Klymak, Joseph +K Aicher, Mark Harfouche, Mathias Hauser, Matthew Roeschke, Maximilian Roos, +Michael Delgado, Pascal Bourgault, Pierre, Ray Bell, Romain Caneill, Tim Heap, +Tom Nicholas, Zeb Nicholls, joseph nowak, keewis. + + +New Features +~~~~~~~~~~~~ +- New top-level function :py:func:`cross`. (:issue:`3279`, :pull:`5365`). + By `Jimmy Westling `_. +- ``keep_attrs`` support for :py:func:`where` (:issue:`4141`, :issue:`4682`, :pull:`4687`). + By `Justus Magin `_. +- Enable the limit option for dask array in the following methods :py:meth:`DataArray.ffill`, :py:meth:`DataArray.bfill`, :py:meth:`Dataset.ffill` and :py:meth:`Dataset.bfill` (:issue:`6112`) + By `Joseph Nowak `_. + + +Breaking changes +~~~~~~~~~~~~~~~~ +- Rely on matplotlib's default datetime converters instead of pandas' (:issue:`6102`, :pull:`6109`). + By `Jimmy Westling `_. +- Improve repr readability when there are a large number of dimensions in datasets or dataarrays by + wrapping the text once the maximum display width has been exceeded. (:issue:`5546`, :pull:`5662`) + By `Jimmy Westling `_. + + +Deprecations +~~~~~~~~~~~~ +- Removed the lock kwarg from the zarr and pydap backends, completing the deprecation cycle started in :issue:`5256`. + By `Tom Nicholas `_. +- Support for ``python 3.7`` has been dropped. (:pull:`5892`) + By `Jimmy Westling `_. + + +Bug fixes +~~~~~~~~~ +- Preserve chunks when creating a :py:class:`DataArray` from another :py:class:`DataArray` + (:pull:`5984`). By `Fabian Hofmann `_. +- Properly support :py:meth:`DataArray.ffill`, :py:meth:`DataArray.bfill`, :py:meth:`Dataset.ffill` and :py:meth:`Dataset.bfill` along chunked dimensions (:issue:`6112`). + By `Joseph Nowak `_. + +- Subclasses of ``byte`` and ``str`` (e.g. ``np.str_`` and ``np.bytes_``) will now serialise to disk rather than raising a ``ValueError: unsupported dtype for netCDF4 variable: object`` as they did previously (:pull:`5264`). + By `Zeb Nicholls `_. + +- Fix applying function with non-xarray arguments using :py:func:`xr.map_blocks`. + By `Cindy Chiao `_. + +- No longer raise an error for an all-nan-but-one argument to + :py:meth:`DataArray.interpolate_na` when using `method='nearest'` (:issue:`5994`, :pull:`6144`). + By `Michael Delgado `_. +- `dt.season `_ can now handle NaN and NaT. (:pull:`5876`). + By `Pierre Loicq `_. +- Determination of zarr chunks handles empty lists for encoding chunks or variable chunks that occurs in certain circumstances (:pull:`5526`). By `Chris Roat `_. + +Internal Changes +~~~~~~~~~~~~~~~~ + +- Replace ``distutils.version`` with ``packaging.version`` (:issue:`6092`). + By `Mathias Hauser `_. + +- Removed internal checks for ``pd.Panel`` (:issue:`6145`). + By `Matthew Roeschke `_. + +- Add ``pyupgrade`` pre-commit hook (:pull:`6152`). + By `Maximilian Roos `_. + +.. _whats-new.0.20.2: + +v0.20.2 (9 December 2021) +------------------------- + +This is a bugfix release to resolve (:issue:`3391`, :issue:`5715`). It also +includes performance improvements in unstacking to a ``sparse`` array and a +number of documentation improvements. + +Many thanks to the 20 contributors: + +Aaron Spring, Alexandre Poux, Deepak Cherian, Enrico Minack, Fabien Maussion, +Giacomo Caria, Gijom, Guillaume Maze, Illviljan, Joe Hamman, Joseph Hardin, Kai +Mühlbauer, Matt Henderson, Maximilian Roos, Michael Delgado, Robert Gieseke, +Sebastian Weigand and Stephan Hoyer. + + +Breaking changes +~~~~~~~~~~~~~~~~ +- Use complex nan when interpolating complex values out of bounds by default (instead of real nan) (:pull:`6019`). + By `Alexandre Poux `_. + +Performance +~~~~~~~~~~~ + +- Significantly faster unstacking to a ``sparse`` array. :pull:`5577` + By `Deepak Cherian `_. + +Bug fixes +~~~~~~~~~ +- :py:func:`xr.map_blocks` and :py:func:`xr.corr` now work when dask is not installed (:issue:`3391`, :issue:`5715`, :pull:`5731`). + By `Gijom `_. +- Fix plot.line crash for data of shape ``(1, N)`` in _title_for_slice on format_item (:pull:`5948`). + By `Sebastian Weigand `_. +- Fix a regression in the removal of duplicate backend entrypoints (:issue:`5944`, :pull:`5959`) + By `Kai Mühlbauer `_. +- Fix an issue that datasets from being saved when time variables with units that ``cftime`` can parse but pandas can not were present (:pull:`6049`). + By `Tim Heap `_. + +Documentation +~~~~~~~~~~~~~ + +- Better examples in docstrings for groupby and resampling reductions (:pull:`5871`). + By `Deepak Cherian `_, + `Maximilian Roos `_, + `Jimmy Westling `_ . +- Add list-like possibility for tolerance parameter in the reindex functions. + By `Antoine Gibek `_, + +Internal Changes +~~~~~~~~~~~~~~~~ + +- Use ``importlib`` to replace functionality of ``pkg_resources`` in + backend plugins tests. (:pull:`5959`). + By `Kai Mühlbauer `_. + + +.. _whats-new.0.20.1: + +v0.20.1 (5 November 2021) +------------------------- + +This is a bugfix release to fix :issue:`5930`. + +Bug fixes +~~~~~~~~~ +- Fix a regression in the detection of the backend entrypoints (:issue:`5930`, :pull:`5931`) + By `Justus Magin `_. + +Documentation +~~~~~~~~~~~~~ + +- Significant improvements to :ref:`api`. By `Deepak Cherian `_. + +.. _whats-new.0.20.0: + +v0.20.0 (1 November 2021) +------------------------- + +This release brings improved support for pint arrays, methods for weighted standard deviation, variance, +and sum of squares, the option to disable the use of the bottleneck library, significantly improved performance of +unstack, as well as many bugfixes and internal changes. + +Many thanks to the 40 contributors to this release!: + +Aaron Spring, Akio Taniguchi, Alan D. Snow, arfy slowy, Benoit Bovy, Christian Jauvin, crusaderky, Deepak Cherian, +Giacomo Caria, Illviljan, James Bourbeau, Joe Hamman, Joseph K Aicher, Julien Herzen, Kai Mühlbauer, +keewis, lusewell, Martin K. Scherer, Mathias Hauser, Max Grover, Maxime Liquet, Maximilian Roos, Mike Taves, Nathan Lis, +pmav99, Pushkar Kopparla, Ray Bell, Rio McMahon, Scott Staniewicz, Spencer Clark, Stefan Bender, Taher Chegini, +Thomas Nicholas, Tomas Chor, Tom Augspurger, Victor Negîrneac, Zachary Blackwood, Zachary Moon, and Zeb Nicholls. + +New Features +~~~~~~~~~~~~ +- Add ``std``, ``var``, ``sum_of_squares`` to :py:class:`~core.weighted.DatasetWeighted` and :py:class:`~core.weighted.DataArrayWeighted`. + By `Christian Jauvin `_. +- Added a :py:func:`get_options` method to xarray's root namespace (:issue:`5698`, :pull:`5716`) + By `Pushkar Kopparla `_. +- Xarray now does a better job rendering variable names that are long LaTeX sequences when plotting (:issue:`5681`, :pull:`5682`). + By `Tomas Chor `_. +- Add an option (``"use_bottleneck"``) to disable the use of ``bottleneck`` using :py:func:`set_options` (:pull:`5560`) + By `Justus Magin `_. +- Added ``**kwargs`` argument to :py:meth:`open_rasterio` to access overviews (:issue:`3269`). + By `Pushkar Kopparla `_. +- Added ``storage_options`` argument to :py:meth:`to_zarr` (:issue:`5601`, :pull:`5615`). + By `Ray Bell `_, `Zachary Blackwood `_ and + `Nathan Lis `_. +- Added calendar utilities :py:func:`DataArray.convert_calendar`, :py:func:`DataArray.interp_calendar`, :py:func:`date_range`, :py:func:`date_range_like` and :py:attr:`DataArray.dt.calendar` (:issue:`5155`, :pull:`5233`). + By `Pascal Bourgault `_. +- Histogram plots are set with a title displaying the scalar coords if any, similarly to the other plots (:issue:`5791`, :pull:`5792`). + By `Maxime Liquet `_. +- Slice plots display the coords units in the same way as x/y/colorbar labels (:pull:`5847`). + By `Victor Negîrneac `_. +- Added a new :py:attr:`Dataset.chunksizes`, :py:attr:`DataArray.chunksizes`, and :py:attr:`Variable.chunksizes` + property, which will always return a mapping from dimension names to chunking pattern along that dimension, + regardless of whether the object is a Dataset, DataArray, or Variable. (:issue:`5846`, :pull:`5900`) + By `Tom Nicholas `_. + +Breaking changes +~~~~~~~~~~~~~~~~ +- The minimum versions of some dependencies were changed: + + =============== ====== ==== + Package Old New + =============== ====== ==== + cftime 1.1 1.2 + dask 2.15 2.30 + distributed 2.15 2.30 + lxml 4.5 4.6 + matplotlib-base 3.2 3.3 + numba 0.49 0.51 + numpy 1.17 1.18 + pandas 1.0 1.1 + pint 0.15 0.16 + scipy 1.4 1.5 + seaborn 0.10 0.11 + sparse 0.8 0.11 + toolz 0.10 0.11 + zarr 2.4 2.5 + =============== ====== ==== + +- The ``__repr__`` of a :py:class:`xarray.Dataset`'s ``coords`` and ``data_vars`` + ignore ``xarray.set_option(display_max_rows=...)`` and show the full output + when called directly as, e.g., ``ds.data_vars`` or ``print(ds.data_vars)`` + (:issue:`5545`, :pull:`5580`). + By `Stefan Bender `_. + +Deprecations +~~~~~~~~~~~~ + +- Deprecate :py:func:`open_rasterio` (:issue:`4697`, :pull:`5808`). + By `Alan Snow `_. +- Set the default argument for `roll_coords` to `False` for :py:meth:`DataArray.roll` + and :py:meth:`Dataset.roll`. (:pull:`5653`) + By `Tom Nicholas `_. +- :py:meth:`xarray.open_mfdataset` will now error instead of warn when a value for ``concat_dim`` is + passed alongside ``combine='by_coords'``. + By `Tom Nicholas `_. + +Bug fixes +~~~~~~~~~ + +- Fix ZeroDivisionError from saving dask array with empty dimension (:issue: `5741`). + By `Joseph K Aicher `_. +- Fixed performance bug where ``cftime`` import attempted within various core operations if ``cftime`` not + installed (:pull:`5640`). + By `Luke Sewell `_ +- Fixed bug when combining named DataArrays using :py:func:`combine_by_coords`. (:pull:`5834`). + By `Tom Nicholas `_. +- When a custom engine was used in :py:func:`~xarray.open_dataset` the engine + wasn't initialized properly, causing missing argument errors or inconsistent + method signatures. (:pull:`5684`) + By `Jimmy Westling `_. +- Numbers are properly formatted in a plot's title (:issue:`5788`, :pull:`5789`). + By `Maxime Liquet `_. +- Faceted plots will no longer raise a `pint.UnitStrippedWarning` when a `pint.Quantity` array is plotted, + and will correctly display the units of the data in the colorbar (if there is one) (:pull:`5886`). + By `Tom Nicholas `_. +- With backends, check for path-like objects rather than ``pathlib.Path`` + type, use ``os.fspath`` (:pull:`5879`). + By `Mike Taves `_. +- ``open_mfdataset()`` now accepts a single ``pathlib.Path`` object (:issue: `5881`). + By `Panos Mavrogiorgos `_. +- Improved performance of :py:meth:`Dataset.unstack` (:pull:`5906`). By `Tom Augspurger `_. + +Documentation +~~~~~~~~~~~~~ + +- Users are instructed to try ``use_cftime=True`` if a ``TypeError`` occurs when combining datasets and one of the types involved is a subclass of ``cftime.datetime`` (:pull:`5776`). + By `Zeb Nicholls `_. +- A clearer error is now raised if a user attempts to assign a Dataset to a single key of + another Dataset. (:pull:`5839`) + By `Tom Nicholas `_. + +Internal Changes +~~~~~~~~~~~~~~~~ + +- Explicit indexes refactor: avoid ``len(index)`` in ``map_blocks`` (:pull:`5670`). + By `Deepak Cherian `_. +- Explicit indexes refactor: decouple ``xarray.Index``` from ``xarray.Variable`` (:pull:`5636`). + By `Benoit Bovy `_. +- Fix ``Mapping`` argument typing to allow mypy to pass on ``str`` keys (:pull:`5690`). + By `Maximilian Roos `_. +- Annotate many of our tests, and fix some of the resulting typing errors. This will + also mean our typing annotations are tested as part of CI. (:pull:`5728`). + By `Maximilian Roos `_. +- Improve the performance of reprs for large datasets or dataarrays. (:pull:`5661`) + By `Jimmy Westling `_. +- Use isort's `float_to_top` config. (:pull:`5695`). + By `Maximilian Roos `_. +- Remove use of the deprecated ``kind`` argument in + :py:meth:`pandas.Index.get_slice_bound` inside :py:class:`xarray.CFTimeIndex` + tests (:pull:`5723`). By `Spencer Clark `_. +- Refactor `xarray.core.duck_array_ops` to no longer special-case dispatching to + dask versions of functions when acting on dask arrays, instead relying numpy + and dask's adherence to NEP-18 to dispatch automatically. (:pull:`5571`) + By `Tom Nicholas `_. +- Add an ASV benchmark CI and improve performance of the benchmarks (:pull:`5796`) + By `Jimmy Westling `_. +- Use ``importlib`` to replace functionality of ``pkg_resources`` such + as version setting and loading of resources. (:pull:`5845`). + By `Martin K. Scherer `_. + + +.. _whats-new.0.19.0: + +v0.19.0 (23 July 2021) +---------------------- + +This release brings improvements to plotting of categorical data, the ability to specify how attributes +are combined in xarray operations, a new high-level :py:func:`unify_chunks` function, as well as various +deprecations, bug fixes, and minor improvements. + + +Many thanks to the 29 contributors to this release!: + +Andrew Williams, Augustus, Aureliana Barghini, Benoit Bovy, crusaderky, Deepak Cherian, ellesmith88, +Elliott Sales de Andrade, Giacomo Caria, github-actions[bot], Illviljan, Joeperdefloep, joooeey, Julia Kent, +Julius Busecke, keewis, Mathias Hauser, Matthias Göbel, Mattia Almansi, Maximilian Roos, Peter Andreas Entschev, +Ray Bell, Sander, Santiago Soler, Sebastian, Spencer Clark, Stephan Hoyer, Thomas Hirtz, Thomas Nicholas. + +New Features +~~~~~~~~~~~~ +- Allow passing argument ``missing_dims`` to :py:meth:`Variable.transpose` and :py:meth:`Dataset.transpose` + (:issue:`5550`, :pull:`5586`) + By `Giacomo Caria `_. +- Allow passing a dictionary as coords to a :py:class:`DataArray` (:issue:`5527`, + reverts :pull:`1539`, which had deprecated this due to python's inconsistent ordering in earlier versions). + By `Sander van Rijn `_. +- Added :py:meth:`Dataset.coarsen.construct`, :py:meth:`DataArray.coarsen.construct` (:issue:`5454`, :pull:`5475`). + By `Deepak Cherian `_. +- Xarray now uses consolidated metadata by default when writing and reading Zarr + stores (:issue:`5251`). + By `Stephan Hoyer `_. +- New top-level function :py:func:`unify_chunks`. + By `Mattia Almansi `_. +- Allow assigning values to a subset of a dataset using positional or label-based + indexing (:issue:`3015`, :pull:`5362`). + By `Matthias Göbel `_. +- Attempting to reduce a weighted object over missing dimensions now raises an error (:pull:`5362`). + By `Mattia Almansi `_. +- Add ``.sum`` to :py:meth:`~xarray.DataArray.rolling_exp` and + :py:meth:`~xarray.Dataset.rolling_exp` for exponentially weighted rolling + sums. These require numbagg 0.2.1; + (:pull:`5178`). + By `Maximilian Roos `_. +- :py:func:`xarray.cov` and :py:func:`xarray.corr` now lazily check for missing + values if inputs are dask arrays (:issue:`4804`, :pull:`5284`). + By `Andrew Williams `_. +- Attempting to ``concat`` list of elements that are not all ``Dataset`` or all ``DataArray`` now raises an error (:issue:`5051`, :pull:`5425`). + By `Thomas Hirtz `_. +- allow passing a function to ``combine_attrs`` (:pull:`4896`). + By `Justus Magin `_. +- Allow plotting categorical data (:pull:`5464`). + By `Jimmy Westling `_. +- Allow removal of the coordinate attribute ``coordinates`` on variables by setting ``.attrs['coordinates']= None`` + (:issue:`5510`). + By `Elle Smith `_. +- Added :py:meth:`DataArray.to_numpy`, :py:meth:`DataArray.as_numpy`, and :py:meth:`Dataset.as_numpy`. (:pull:`5568`). + By `Tom Nicholas `_. +- Units in plot labels are now automatically inferred from wrapped :py:meth:`pint.Quantity` arrays. (:pull:`5561`). + By `Tom Nicholas `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + +- The default ``mode`` for :py:meth:`Dataset.to_zarr` when ``region`` is set + has changed to the new ``mode="r+"``, which only allows for overriding + pre-existing array values. This is a safer default than the prior ``mode="a"``, + and allows for higher performance writes (:pull:`5252`). + By `Stephan Hoyer `_. +- The main parameter to :py:func:`combine_by_coords` is renamed to `data_objects` instead + of `datasets` so anyone calling this method using a named parameter will need to update + the name accordingly (:issue:`3248`, :pull:`4696`). + By `Augustus Ijams `_. + +Deprecations +~~~~~~~~~~~~ + +- Removed the deprecated ``dim`` kwarg to :py:func:`DataArray.integrate` (:pull:`5630`) +- Removed the deprecated ``keep_attrs`` kwarg to :py:func:`DataArray.rolling` (:pull:`5630`) +- Removed the deprecated ``keep_attrs`` kwarg to :py:func:`DataArray.coarsen` (:pull:`5630`) +- Completed deprecation of passing an ``xarray.DataArray`` to :py:func:`Variable` - will now raise a ``TypeError`` (:pull:`5630`) + +Bug fixes +~~~~~~~~~ +- Fix a minor incompatibility between partial datetime string indexing with a + :py:class:`CFTimeIndex` and upcoming pandas version 1.3.0 (:issue:`5356`, + :pull:`5359`). + By `Spencer Clark `_. +- Fix 1-level multi-index incorrectly converted to single index (:issue:`5384`, + :pull:`5385`). + By `Benoit Bovy `_. +- Don't cast a duck array in a coordinate to :py:class:`numpy.ndarray` in + :py:meth:`DataArray.differentiate` (:pull:`5408`) + By `Justus Magin `_. +- Fix the ``repr`` of :py:class:`Variable` objects with ``display_expand_data=True`` + (:pull:`5406`) + By `Justus Magin `_. +- Plotting a pcolormesh with ``xscale="log"`` and/or ``yscale="log"`` works as + expected after improving the way the interval breaks are generated (:issue:`5333`). + By `Santiago Soler `_ +- :py:func:`combine_by_coords` can now handle combining a list of unnamed + ``DataArray`` as input (:issue:`3248`, :pull:`4696`). + By `Augustus Ijams `_. + + +Internal Changes +~~~~~~~~~~~~~~~~ +- Run CI on the first & last python versions supported only; currently 3.7 & 3.9. + (:pull:`5433`) + By `Maximilian Roos `_. +- Publish test results & timings on each PR. + (:pull:`5537`) + By `Maximilian Roos `_. +- Explicit indexes refactor: add a ``xarray.Index.query()`` method in which + one may eventually provide a custom implementation of label-based data + selection (not ready yet for public use). Also refactor the internal, + pandas-specific implementation into ``PandasIndex.query()`` and + ``PandasMultiIndex.query()`` (:pull:`5322`). + By `Benoit Bovy `_. + +.. _whats-new.0.18.2: + +v0.18.2 (19 May 2021) +--------------------- + +This release reverts a regression in xarray's unstacking of dask-backed arrays. + +.. _whats-new.0.18.1: + +v0.18.1 (18 May 2021) +--------------------- + +This release is intended as a small patch release to be compatible with the new +2021.5.0 ``dask.distributed`` release. It also includes a new +``drop_duplicates`` method, some documentation improvements, the beginnings of +our internal Index refactoring, and some bug fixes. + +Thank you to all 16 contributors! + +Anderson Banihirwe, Andrew, Benoit Bovy, Brewster Malevich, Giacomo Caria, +Illviljan, James Bourbeau, Keewis, Maximilian Roos, Ravin Kumar, Stephan Hoyer, +Thomas Nicholas, Tom Nicholas, Zachary Moon. + +New Features +~~~~~~~~~~~~ +- Implement :py:meth:`DataArray.drop_duplicates` + to remove duplicate dimension values (:pull:`5239`). + By `Andrew Huang `_. +- Allow passing ``combine_attrs`` strategy names to the ``keep_attrs`` parameter of + :py:func:`apply_ufunc` (:pull:`5041`) + By `Justus Magin `_. +- :py:meth:`Dataset.interp` now allows interpolation with non-numerical datatypes, + such as booleans, instead of dropping them. (:issue:`4761` :pull:`5008`). + By `Jimmy Westling `_. +- Raise more informative error when decoding time variables with invalid reference dates. + (:issue:`5199`, :pull:`5288`). By `Giacomo Caria `_. + + +Bug fixes +~~~~~~~~~ +- Opening netCDF files from a path that doesn't end in ``.nc`` without supplying + an explicit ``engine`` works again (:issue:`5295`), fixing a bug introduced in + 0.18.0. + By `Stephan Hoyer `_ + +Documentation +~~~~~~~~~~~~~ +- Clean up and enhance docstrings for the :py:class:`DataArray.plot` and ``Dataset.plot.*`` + families of methods (:pull:`5285`). + By `Zach Moon `_. + +- Explanation of deprecation cycles and how to implement them added to contributors + guide. (:pull:`5289`) + By `Tom Nicholas `_. + + +Internal Changes +~~~~~~~~~~~~~~~~ + +- Explicit indexes refactor: add an ``xarray.Index`` base class and + ``Dataset.xindexes`` / ``DataArray.xindexes`` properties. Also rename + ``PandasIndexAdapter`` to ``PandasIndex``, which now inherits from + ``xarray.Index`` (:pull:`5102`). + By `Benoit Bovy `_. +- Replace ``SortedKeysDict`` with python's ``dict``, given dicts are now ordered. + By `Maximilian Roos `_. +- Updated the release guide for developers. Now accounts for actions that are automated via github + actions. (:pull:`5274`). + By `Tom Nicholas `_. + +.. _whats-new.0.18.0: + +v0.18.0 (6 May 2021) +-------------------- + +This release brings a few important performance improvements, a wide range of +usability upgrades, lots of bug fixes, and some new features. These include +a plugin API to add backend engines, a new theme for the documentation, +curve fitting methods, and several new plotting functions. + +Many thanks to the 38 contributors to this release: Aaron Spring, Alessandro Amici, +Alex Marandon, Alistair Miles, Ana Paula Krelling, Anderson Banihirwe, Aureliana Barghini, +Baudouin Raoult, Benoit Bovy, Blair Bonnett, David Trémouilles, Deepak Cherian, +Gabriel Medeiros Abrahão, Giacomo Caria, Hauke Schulz, Illviljan, Mathias Hauser, Matthias Bussonnier, +Mattia Almansi, Maximilian Roos, Ray Bell, Richard Kleijn, Ryan Abernathey, Sam Levang, Spencer Clark, +Spencer Jones, Tammas Loughran, Tobias Kölling, Todd, Tom Nicholas, Tom White, Victor Negîrneac, +Xianxiang Li, Zeb Nicholls, crusaderky, dschwoerer, johnomotani, keewis + + +New Features +~~~~~~~~~~~~ + +- apply ``combine_attrs`` on data variables and coordinate variables when concatenating + and merging datasets and dataarrays (:pull:`4902`). + By `Justus Magin `_. +- Add :py:meth:`Dataset.to_pandas` (:pull:`5247`) + By `Giacomo Caria `_. +- Add :py:meth:`DataArray.plot.surface` which wraps matplotlib's `plot_surface` to make + surface plots (:issue:`2235` :issue:`5084` :pull:`5101`). + By `John Omotani `_. +- Allow passing multiple arrays to :py:meth:`Dataset.__setitem__` (:pull:`5216`). + By `Giacomo Caria `_. +- Add 'cumulative' option to :py:meth:`Dataset.integrate` and + :py:meth:`DataArray.integrate` so that result is a cumulative integral, like + :py:func:`scipy.integrate.cumulative_trapezoidal` (:pull:`5153`). + By `John Omotani `_. +- Add ``safe_chunks`` option to :py:meth:`Dataset.to_zarr` which allows overriding + checks made to ensure Dask and Zarr chunk compatibility (:issue:`5056`). + By `Ryan Abernathey `_ +- Add :py:meth:`Dataset.query` and :py:meth:`DataArray.query` which enable indexing + of datasets and data arrays by evaluating query expressions against the values of the + data variables (:pull:`4984`). + By `Alistair Miles `_. +- Allow passing ``combine_attrs`` to :py:meth:`Dataset.merge` (:pull:`4895`). + By `Justus Magin `_. +- Support for `dask.graph_manipulation + `_ (requires dask >=2021.3) + By `Guido Imperiale `_ +- Add :py:meth:`Dataset.plot.streamplot` for streamplot plots with :py:class:`Dataset` + variables (:pull:`5003`). + By `John Omotani `_. +- Many of the arguments for the :py:attr:`DataArray.str` methods now support + providing an array-like input. In this case, the array provided to the + arguments is broadcast against the original array and applied elementwise. +- :py:attr:`DataArray.str` now supports ``+``, ``*``, and ``%`` operators. These + behave the same as they do for :py:class:`str`, except that they follow + array broadcasting rules. +- A large number of new :py:attr:`DataArray.str` methods were implemented, + :py:meth:`DataArray.str.casefold`, :py:meth:`DataArray.str.cat`, + :py:meth:`DataArray.str.extract`, :py:meth:`DataArray.str.extractall`, + :py:meth:`DataArray.str.findall`, :py:meth:`DataArray.str.format`, + :py:meth:`DataArray.str.get_dummies`, :py:meth:`DataArray.str.islower`, + :py:meth:`DataArray.str.join`, :py:meth:`DataArray.str.normalize`, + :py:meth:`DataArray.str.partition`, :py:meth:`DataArray.str.rpartition`, + :py:meth:`DataArray.str.rsplit`, and :py:meth:`DataArray.str.split`. + A number of these methods allow for splitting or joining the strings in an + array. (:issue:`4622`) + By `Todd Jennings `_ +- Thanks to the new pluggable backend infrastructure external packages may now + use the ``xarray.backends`` entry point to register additional engines to be used in + :py:func:`open_dataset`, see the documentation in :ref:`add_a_backend` + (:issue:`4309`, :issue:`4803`, :pull:`4989`, :pull:`4810` and many others). + The backend refactor has been sponsored with the "Essential Open Source Software for Science" + grant from the `Chan Zuckerberg Initiative `_ and + developed by `B-Open `_. + By `Aureliana Barghini `_ and `Alessandro Amici `_. +- :py:attr:`~core.accessor_dt.DatetimeAccessor.date` added (:issue:`4983`, :pull:`4994`). + By `Hauke Schulz `_. +- Implement ``__getitem__`` for both :py:class:`~core.groupby.DatasetGroupBy` and + :py:class:`~core.groupby.DataArrayGroupBy`, inspired by pandas' + :py:meth:`~pandas.core.groupby.GroupBy.get_group`. + By `Deepak Cherian `_. +- Switch the tutorial functions to use `pooch `_ + (which is now a optional dependency) and add :py:func:`tutorial.open_rasterio` as a + way to open example rasterio files (:issue:`3986`, :pull:`4102`, :pull:`5074`). + By `Justus Magin `_. +- Add typing information to unary and binary arithmetic operators operating on + :py:class:`Dataset`, :py:class:`DataArray`, :py:class:`Variable`, + :py:class:`~core.groupby.DatasetGroupBy` or + :py:class:`~core.groupby.DataArrayGroupBy` (:pull:`4904`). + By `Richard Kleijn `_. +- Add a ``combine_attrs`` parameter to :py:func:`open_mfdataset` (:pull:`4971`). + By `Justus Magin `_. +- Enable passing arrays with a subset of dimensions to + :py:meth:`DataArray.clip` & :py:meth:`Dataset.clip`; these methods now use + :py:func:`xarray.apply_ufunc`; (:pull:`5184`). + By `Maximilian Roos `_. +- Disable the `cfgrib` backend if the `eccodes` library is not installed (:pull:`5083`). + By `Baudouin Raoult `_. +- Added :py:meth:`DataArray.curvefit` and :py:meth:`Dataset.curvefit` for general curve fitting applications. (:issue:`4300`, :pull:`4849`) + By `Sam Levang `_. +- Add options to control expand/collapse of sections in display of Dataset and + DataArray. The function :py:func:`set_options` now takes keyword arguments + ``display_expand_attrs``, ``display_expand_coords``, ``display_expand_data``, + ``display_expand_data_vars``, all of which can be one of ``True`` to always + expand, ``False`` to always collapse, or ``default`` to expand unless over a + pre-defined limit (:pull:`5126`). + By `Tom White `_. +- Significant speedups in :py:meth:`Dataset.interp` and :py:meth:`DataArray.interp`. + (:issue:`4739`, :pull:`4740`). + By `Deepak Cherian `_. +- Prevent passing `concat_dim` to :py:func:`xarray.open_mfdataset` when + `combine='by_coords'` is specified, which should never have been possible (as + :py:func:`xarray.combine_by_coords` has no `concat_dim` argument to pass to). + Also removes unneeded internal reordering of datasets in + :py:func:`xarray.open_mfdataset` when `combine='by_coords'` is specified. + Fixes (:issue:`5230`). + By `Tom Nicholas `_. +- Implement ``__setitem__`` for ``xarray.core.indexing.DaskIndexingAdapter`` if + dask version supports item assignment. (:issue:`5171`, :pull:`5174`) + By `Tammas Loughran `_. + +Breaking changes +~~~~~~~~~~~~~~~~ +- The minimum versions of some dependencies were changed: + + ============ ====== ==== + Package Old New + ============ ====== ==== + boto3 1.12 1.13 + cftime 1.0 1.1 + dask 2.11 2.15 + distributed 2.11 2.15 + matplotlib 3.1 3.2 + numba 0.48 0.49 + ============ ====== ==== + +- :py:func:`open_dataset` and :py:func:`open_dataarray` now accept only the first argument + as positional, all others need to be passed are keyword arguments. This is part of the + refactor to support external backends (:issue:`4309`, :pull:`4989`). + By `Alessandro Amici `_. +- Functions that are identities for 0d data return the unchanged data + if axis is empty. This ensures that Datasets where some variables do + not have the averaged dimensions are not accidentally changed + (:issue:`4885`, :pull:`5207`). + By `David Schwörer `_. +- :py:attr:`DataArray.coarsen` and :py:attr:`Dataset.coarsen` no longer support passing ``keep_attrs`` + via its constructor. Pass ``keep_attrs`` via the applied function, i.e. use + ``ds.coarsen(...).mean(keep_attrs=False)`` instead of ``ds.coarsen(..., keep_attrs=False).mean()``. + Further, coarsen now keeps attributes per default (:pull:`5227`). + By `Mathias Hauser `_. +- switch the default of the :py:func:`merge` ``combine_attrs`` parameter to + ``"override"``. This will keep the current behavior for merging the ``attrs`` of + variables but stop dropping the ``attrs`` of the main objects (:pull:`4902`). + By `Justus Magin `_. + +Deprecations +~~~~~~~~~~~~ + +- Warn when passing `concat_dim` to :py:func:`xarray.open_mfdataset` when + `combine='by_coords'` is specified, which should never have been possible (as + :py:func:`xarray.combine_by_coords` has no `concat_dim` argument to pass to). + Also removes unneeded internal reordering of datasets in + :py:func:`xarray.open_mfdataset` when `combine='by_coords'` is specified. + Fixes (:issue:`5230`), via (:pull:`5231`, :pull:`5255`). + By `Tom Nicholas `_. +- The `lock` keyword argument to :py:func:`open_dataset` and :py:func:`open_dataarray` is now + a backend specific option. It will give a warning if passed to a backend that doesn't support it + instead of being silently ignored. From the next version it will raise an error. + This is part of the refactor to support external backends (:issue:`5073`). + By `Tom Nicholas `_ and `Alessandro Amici `_. + + +Bug fixes +~~~~~~~~~ +- Properly support :py:meth:`DataArray.ffill`, :py:meth:`DataArray.bfill`, :py:meth:`Dataset.ffill`, :py:meth:`Dataset.bfill` along chunked dimensions. + (:issue:`2699`). + By `Deepak Cherian `_. +- Fix 2d plot failure for certain combinations of dimensions when `x` is 1d and `y` is + 2d (:issue:`5097`, :pull:`5099`). + By `John Omotani `_. +- Ensure standard calendar times encoded with large values (i.e. greater than + approximately 292 years), can be decoded correctly without silently overflowing + (:pull:`5050`). This was a regression in xarray 0.17.0. + By `Zeb Nicholls `_. +- Added support for `numpy.bool_` attributes in roundtrips using `h5netcdf` engine with `invalid_netcdf=True` [which casts `bool`s to `numpy.bool_`] (:issue:`4981`, :pull:`4986`). + By `Victor Negîrneac `_. +- Don't allow passing ``axis`` to :py:meth:`Dataset.reduce` methods (:issue:`3510`, :pull:`4940`). + By `Justus Magin `_. +- Decode values as signed if attribute `_Unsigned = "false"` (:issue:`4954`) + By `Tobias Kölling `_. +- Keep coords attributes when interpolating when the indexer is not a Variable. (:issue:`4239`, :issue:`4839` :pull:`5031`) + By `Jimmy Westling `_. +- Ensure standard calendar dates encoded with a calendar attribute with some or + all uppercase letters can be decoded or encoded to or from + ``np.datetime64[ns]`` dates with or without ``cftime`` installed + (:issue:`5093`, :pull:`5180`). + By `Spencer Clark `_. +- Warn on passing ``keep_attrs`` to ``resample`` and ``rolling_exp`` as they are ignored, pass ``keep_attrs`` + to the applied function instead (:pull:`5265`). + By `Mathias Hauser `_. + +Documentation +~~~~~~~~~~~~~ +- New section on :ref:`add_a_backend` in the "Internals" chapter aimed to backend developers + (:issue:`4803`, :pull:`4810`). + By `Aureliana Barghini `_. +- Add :py:meth:`Dataset.polyfit` and :py:meth:`DataArray.polyfit` under "See also" in + the docstrings of :py:meth:`Dataset.polyfit` and :py:meth:`DataArray.polyfit` + (:issue:`5016`, :pull:`5020`). + By `Aaron Spring `_. +- New sphinx theme & rearrangement of the docs (:pull:`4835`). + By `Anderson Banihirwe `_. + +Internal Changes +~~~~~~~~~~~~~~~~ +- Enable displaying mypy error codes and ignore only specific error codes using + ``# type: ignore[error-code]`` (:pull:`5096`). + By `Mathias Hauser `_. +- Replace uses of ``raises_regex`` with the more standard + ``pytest.raises(Exception, match="foo")``; + (:pull:`5188`), (:pull:`5191`). + By `Maximilian Roos `_. + +.. _whats-new.0.17.0: + +v0.17.0 (24 Feb 2021) +--------------------- + +This release brings a few important performance improvements, a wide range of +usability upgrades, lots of bug fixes, and some new features. These include +better ``cftime`` support, a new quiver plot, better ``unstack`` performance, +more efficient memory use in rolling operations, and some python packaging +improvements. We also have a few documentation improvements (and more planned!). + +Many thanks to the 36 contributors to this release: Alessandro Amici, Anderson +Banihirwe, Aureliana Barghini, Ayrton Bourn, Benjamin Bean, Blair Bonnett, Chun +Ho Chow, DWesl, Daniel Mesejo-León, Deepak Cherian, Eric Keenan, Illviljan, Jens +Hedegaard Nielsen, Jody Klymak, Julien Seguinot, Julius Busecke, Kai Mühlbauer, +Leif Denby, Martin Durant, Mathias Hauser, Maximilian Roos, Michael Mann, Ray +Bell, RichardScottOZ, Spencer Clark, Tim Gates, Tom Nicholas, Yunus Sevinchan, +alexamici, aurghs, crusaderky, dcherian, ghislainp, keewis, rhkleijn + +Breaking changes +~~~~~~~~~~~~~~~~ +- xarray no longer supports python 3.6 + + The minimum version policy was changed to also apply to projects with irregular + releases. As a result, the minimum versions of some dependencies have changed: + + ============ ====== ==== + Package Old New + ============ ====== ==== + Python 3.6 3.7 + setuptools 38.4 40.4 + numpy 1.15 1.17 + pandas 0.25 1.0 + dask 2.9 2.11 + distributed 2.9 2.11 + bottleneck 1.2 1.3 + h5netcdf 0.7 0.8 + iris 2.2 2.4 + netcdf4 1.4 1.5 + pseudonetcdf 3.0 3.1 + rasterio 1.0 1.1 + scipy 1.3 1.4 + seaborn 0.9 0.10 + zarr 2.3 2.4 + ============ ====== ==== + + (:issue:`4688`, :pull:`4720`, :pull:`4907`, :pull:`4942`) +- As a result of :pull:`4684` the default units encoding for + datetime-like values (``np.datetime64[ns]`` or ``cftime.datetime``) will now + always be set such that ``int64`` values can be used. In the past, no units + finer than "seconds" were chosen, which would sometimes mean that ``float64`` + values were required, which would lead to inaccurate I/O round-trips. +- Variables referred to in attributes like ``bounds`` and ``grid_mapping`` + can be set as coordinate variables. These attributes are moved to + :py:attr:`DataArray.encoding` from :py:attr:`DataArray.attrs`. This behaviour + is controlled by the ``decode_coords`` kwarg to :py:func:`open_dataset` and + :py:func:`open_mfdataset`. The full list of decoded attributes is in + :ref:`weather-climate` (:pull:`2844`, :issue:`3689`) +- As a result of :pull:`4911` the output from calling :py:meth:`DataArray.sum` + or :py:meth:`DataArray.prod` on an integer array with ``skipna=True`` and a + non-None value for ``min_count`` will now be a float array rather than an + integer array. + +Deprecations +~~~~~~~~~~~~ + +- ``dim`` argument to :py:meth:`DataArray.integrate` is being deprecated in + favour of a ``coord`` argument, for consistency with :py:meth:`Dataset.integrate`. + For now using ``dim`` issues a ``FutureWarning``. It will be removed in + version 0.19.0 (:pull:`3993`). + By `Tom Nicholas `_. +- Deprecated ``autoclose`` kwargs from :py:func:`open_dataset` are removed (:pull:`4725`). + By `Aureliana Barghini `_. +- the return value of :py:meth:`Dataset.update` is being deprecated to make it work more + like :py:meth:`dict.update`. It will be removed in version 0.19.0 (:pull:`4932`). + By `Justus Magin `_. + +New Features +~~~~~~~~~~~~ +- :py:meth:`~xarray.cftime_range` and :py:meth:`DataArray.resample` now support + millisecond (``"L"`` or ``"ms"``) and microsecond (``"U"`` or ``"us"``) frequencies + for ``cftime.datetime`` coordinates (:issue:`4097`, :pull:`4758`). + By `Spencer Clark `_. +- Significantly higher ``unstack`` performance on numpy-backed arrays which + contain missing values; 8x faster than previous versions in our benchmark, and + now 2x faster than pandas (:pull:`4746`). + By `Maximilian Roos `_. +- Add :py:meth:`Dataset.plot.quiver` for quiver plots with :py:class:`Dataset` variables. + By `Deepak Cherian `_. +- Add ``"drop_conflicts"`` to the strategies supported by the ``combine_attrs`` kwarg + (:issue:`4749`, :pull:`4827`). + By `Justus Magin `_. +- Allow installing from git archives (:pull:`4897`). + By `Justus Magin `_. +- :py:class:`~core.rolling.DataArrayCoarsen` and :py:class:`~core.rolling.DatasetCoarsen` + now implement a ``reduce`` method, enabling coarsening operations with custom + reduction functions (:issue:`3741`, :pull:`4939`). + By `Spencer Clark `_. +- Most rolling operations use significantly less memory. (:issue:`4325`). + By `Deepak Cherian `_. +- Add :py:meth:`Dataset.drop_isel` and :py:meth:`DataArray.drop_isel` + (:issue:`4658`, :pull:`4819`). + By `Daniel Mesejo `_. +- Xarray now leverages updates as of cftime version 1.4.1, which enable exact I/O + roundtripping of ``cftime.datetime`` objects (:pull:`4758`). + By `Spencer Clark `_. +- :py:func:`open_dataset` and :py:func:`open_mfdataset` now accept ``fsspec`` URLs + (including globs for the latter) for ``engine="zarr"``, and so allow reading from + many remote and other file systems (:pull:`4461`) + By `Martin Durant `_ +- :py:meth:`DataArray.swap_dims` & :py:meth:`Dataset.swap_dims` now accept dims + in the form of kwargs as well as a dict, like most similar methods. + By `Maximilian Roos `_. + +Bug fixes +~~~~~~~~~ +- Use specific type checks in ``xarray.core.variable.as_compatible_data`` instead of + blanket access to ``values`` attribute (:issue:`2097`) + By `Yunus Sevinchan `_. +- :py:meth:`DataArray.resample` and :py:meth:`Dataset.resample` do not trigger + computations anymore if :py:meth:`Dataset.weighted` or + :py:meth:`DataArray.weighted` are applied (:issue:`4625`, :pull:`4668`). By + `Julius Busecke `_. +- :py:func:`merge` with ``combine_attrs='override'`` makes a copy of the attrs + (:issue:`4627`). +- By default, when possible, xarray will now always use values of + type ``int64`` when encoding and decoding ``numpy.datetime64[ns]`` datetimes. This + ensures that maximum precision and accuracy are maintained in the round-tripping + process (:issue:`4045`, :pull:`4684`). It also enables encoding and decoding standard + calendar dates with time units of nanoseconds (:pull:`4400`). + By `Spencer Clark `_ and `Mark Harfouche + `_. +- :py:meth:`DataArray.astype`, :py:meth:`Dataset.astype` and :py:meth:`Variable.astype` support + the ``order`` and ``subok`` parameters again. This fixes a regression introduced in version 0.16.1 + (:issue:`4644`, :pull:`4683`). + By `Richard Kleijn `_ . +- Remove dictionary unpacking when using ``.loc`` to avoid collision with ``.sel`` parameters (:pull:`4695`). + By `Anderson Banihirwe `_. +- Fix the legend created by :py:meth:`Dataset.plot.scatter` (:issue:`4641`, :pull:`4723`). + By `Justus Magin `_. +- Fix a crash in orthogonal indexing on geographic coordinates with ``engine='cfgrib'`` + (:issue:`4733` :pull:`4737`). + By `Alessandro Amici `_. +- Coordinates with dtype ``str`` or ``bytes`` now retain their dtype on many operations, + e.g. ``reindex``, ``align``, ``concat``, ``assign``, previously they were cast to an object dtype + (:issue:`2658` and :issue:`4543`). + By `Mathias Hauser `_. +- Limit number of data rows when printing large datasets. (:issue:`4736`, :pull:`4750`). + By `Jimmy Westling `_. +- Add ``missing_dims`` parameter to transpose (:issue:`4647`, :pull:`4767`). + By `Daniel Mesejo `_. +- Resolve intervals before appending other metadata to labels when plotting (:issue:`4322`, :pull:`4794`). + By `Justus Magin `_. +- Fix regression when decoding a variable with a ``scale_factor`` and ``add_offset`` given + as a list of length one (:issue:`4631`). + By `Mathias Hauser `_. +- Expand user directory paths (e.g. ``~/``) in :py:func:`open_mfdataset` and + :py:meth:`Dataset.to_zarr` (:issue:`4783`, :pull:`4795`). + By `Julien Seguinot `_. +- Raise DeprecationWarning when trying to typecast a tuple containing a :py:class:`DataArray`. + User now prompted to first call `.data` on it (:issue:`4483`). + By `Chun Ho Chow `_. +- Ensure that :py:meth:`Dataset.interp` raises ``ValueError`` when interpolating + outside coordinate range and ``bounds_error=True`` (:issue:`4854`, + :pull:`4855`). + By `Leif Denby `_. +- Fix time encoding bug associated with using cftime versions greater than + 1.4.0 with xarray (:issue:`4870`, :pull:`4871`). + By `Spencer Clark `_. +- Stop :py:meth:`DataArray.sum` and :py:meth:`DataArray.prod` computing lazy + arrays when called with a ``min_count`` parameter (:issue:`4898`, :pull:`4911`). + By `Blair Bonnett `_. +- Fix bug preventing the ``min_count`` parameter to :py:meth:`DataArray.sum` and + :py:meth:`DataArray.prod` working correctly when calculating over all axes of + a float64 array (:issue:`4898`, :pull:`4911`). + By `Blair Bonnett `_. +- Fix decoding of vlen strings using h5py versions greater than 3.0.0 with h5netcdf backend (:issue:`4570`, :pull:`4893`). + By `Kai Mühlbauer `_. +- Allow converting :py:class:`Dataset` or :py:class:`DataArray` objects with a ``MultiIndex`` + and at least one other dimension to a ``pandas`` object (:issue:`3008`, :pull:`4442`). + By `ghislainp `_. + +Documentation +~~~~~~~~~~~~~ +- Add information about requirements for accessor classes (:issue:`2788`, :pull:`4657`). + By `Justus Magin `_. +- Start a list of external I/O integrating with ``xarray`` (:issue:`683`, :pull:`4566`). + By `Justus Magin `_. +- Add concat examples and improve combining documentation (:issue:`4620`, :pull:`4645`). + By `Ray Bell `_ and + `Justus Magin `_. +- explicitly mention that :py:meth:`Dataset.update` updates inplace (:issue:`2951`, :pull:`4932`). + By `Justus Magin `_. +- Added docs on vectorized indexing (:pull:`4711`). + By `Eric Keenan `_. + +Internal Changes +~~~~~~~~~~~~~~~~ +- Speed up of the continuous integration tests on azure. + + - Switched to mamba and use matplotlib-base for a faster installation of all dependencies (:pull:`4672`). + - Use ``pytest.mark.skip`` instead of ``pytest.mark.xfail`` for some tests that can currently not + succeed (:pull:`4685`). + - Run the tests in parallel using pytest-xdist (:pull:`4694`). + + By `Justus Magin `_ and `Mathias Hauser `_. +- Use ``pyproject.toml`` instead of the ``setup_requires`` option for + ``setuptools`` (:pull:`4897`). + By `Justus Magin `_. +- Replace all usages of ``assert x.identical(y)`` with ``assert_identical(x, y)`` + for clearer error messages (:pull:`4752`). + By `Maximilian Roos `_. +- Speed up attribute style access (e.g. ``ds.somevar`` instead of ``ds["somevar"]``) and + tab completion in IPython (:issue:`4741`, :pull:`4742`). + By `Richard Kleijn `_. +- Added the ``set_close`` method to ``Dataset`` and ``DataArray`` for backends + to specify how to voluntary release all resources. (:pull:`#4809`) + By `Alessandro Amici `_. +- Update type hints to work with numpy v1.20 (:pull:`4878`). + By `Mathias Hauser `_. +- Ensure warnings cannot be turned into exceptions in :py:func:`testing.assert_equal` and + the other ``assert_*`` functions (:pull:`4864`). + By `Mathias Hauser `_. +- Performance improvement when constructing DataArrays. Significantly speeds up + repr for Datasets with large number of variables. + By `Deepak Cherian `_. + +.. _whats-new.0.16.2: + +v0.16.2 (30 Nov 2020) +--------------------- + +This release brings the ability to write to limited regions of ``zarr`` files, +open zarr files with :py:func:`open_dataset` and :py:func:`open_mfdataset`, +increased support for propagating ``attrs`` using the ``keep_attrs`` flag, as +well as numerous bugfixes and documentation improvements. + +Many thanks to the 31 contributors who contributed to this release: Aaron +Spring, Akio Taniguchi, Aleksandar Jelenak, alexamici, Alexandre Poux, Anderson +Banihirwe, Andrew Pauling, Ashwin Vishnu, aurghs, Brian Ward, Caleb, crusaderky, +Dan Nowacki, darikg, David Brochart, David Huard, Deepak Cherian, Dion Häfner, +Gerardo Rivera, Gerrit Holl, Illviljan, inakleinbottle, Jacob Tomlinson, James +A. Bednar, jenssss, Joe Hamman, johnomotani, Joris Van den Bossche, Julia Kent, +Julius Busecke, Kai Mühlbauer, keewis, Keisuke Fujii, Kyle Cranmer, Luke +Volpatti, Mathias Hauser, Maximilian Roos, Michaël Defferrard, Michal +Baumgartner, Nick R. Papior, Pascal Bourgault, Peter Hausamann, PGijsbers, Ray +Bell, Romain Martinez, rpgoldman, Russell Manser, Sahid Velji, Samnan Rahee, +Sander, Spencer Clark, Stephan Hoyer, Thomas Zilio, Tobias Kölling, Tom +Augspurger, Wei Ji, Yash Saboo, Zeb Nicholls, + +Deprecations +~~~~~~~~~~~~ + +- :py:attr:`~core.accessor_dt.DatetimeAccessor.weekofyear` and :py:attr:`~core.accessor_dt.DatetimeAccessor.week` + have been deprecated. Use ``DataArray.dt.isocalendar().week`` + instead (:pull:`4534`). By `Mathias Hauser `_. + `Maximilian Roos `_, and `Spencer Clark `_. +- :py:attr:`DataArray.rolling` and :py:attr:`Dataset.rolling` no longer support passing ``keep_attrs`` + via its constructor. Pass ``keep_attrs`` via the applied function, i.e. use + ``ds.rolling(...).mean(keep_attrs=False)`` instead of ``ds.rolling(..., keep_attrs=False).mean()`` + Rolling operations now keep their attributes per default (:pull:`4510`). + By `Mathias Hauser `_. + +New Features +~~~~~~~~~~~~ + +- :py:func:`open_dataset` and :py:func:`open_mfdataset` + now works with ``engine="zarr"`` (:issue:`3668`, :pull:`4003`, :pull:`4187`). + By `Miguel Jimenez `_ and `Wei Ji Leong `_. +- Unary & binary operations follow the ``keep_attrs`` flag (:issue:`3490`, :issue:`4065`, :issue:`3433`, :issue:`3595`, :pull:`4195`). + By `Deepak Cherian `_. +- Added :py:meth:`~core.accessor_dt.DatetimeAccessor.isocalendar()` that returns a Dataset + with year, week, and weekday calculated according to the ISO 8601 calendar. Requires + pandas version 1.1.0 or greater (:pull:`4534`). By `Mathias Hauser `_, + `Maximilian Roos `_, and `Spencer Clark `_. +- :py:meth:`Dataset.to_zarr` now supports a ``region`` keyword for writing to + limited regions of existing Zarr stores (:pull:`4035`). + See :ref:`io.zarr.appending` for full details. + By `Stephan Hoyer `_. +- Added typehints in :py:func:`align` to reflect that the same type received in ``objects`` arg will be returned (:pull:`4522`). + By `Michal Baumgartner `_. +- :py:meth:`Dataset.weighted` and :py:meth:`DataArray.weighted` are now executing value checks lazily if weights are provided as dask arrays (:issue:`4541`, :pull:`4559`). + By `Julius Busecke `_. +- Added the ``keep_attrs`` keyword to ``rolling_exp.mean()``; it now keeps attributes + per default. By `Mathias Hauser `_ (:pull:`4592`). +- Added ``freq`` as property to :py:class:`CFTimeIndex` and into the + ``CFTimeIndex.repr``. (:issue:`2416`, :pull:`4597`) + By `Aaron Spring `_. + +Bug fixes +~~~~~~~~~ + +- Fix bug where reference times without padded years (e.g. ``since 1-1-1``) would lose their units when + being passed by ``encode_cf_datetime`` (:issue:`4422`, :pull:`4506`). Such units are ambiguous + about which digit represents the years (is it YMD or DMY?). Now, if such formatting is encountered, + it is assumed that the first digit is the years, they are padded appropriately (to e.g. ``since 0001-1-1``) + and a warning that this assumption is being made is issued. Previously, without ``cftime``, such times + would be silently parsed incorrectly (at least based on the CF conventions) e.g. "since 1-1-1" would + be parsed (via ``pandas`` and ``dateutil``) to ``since 2001-1-1``. + By `Zeb Nicholls `_. +- Fix :py:meth:`DataArray.plot.step`. By `Deepak Cherian `_. +- Fix bug where reading a scalar value from a NetCDF file opened with the ``h5netcdf`` backend would raise a ``ValueError`` when ``decode_cf=True`` (:issue:`4471`, :pull:`4485`). + By `Gerrit Holl `_. +- Fix bug where datetime64 times are silently changed to incorrect values if they are outside the valid date range for ns precision when provided in some other units (:issue:`4427`, :pull:`4454`). + By `Andrew Pauling `_ +- Fix silently overwriting the ``engine`` key when passing :py:func:`open_dataset` a file object + to an incompatible netCDF (:issue:`4457`). Now incompatible combinations of files and engines raise + an exception instead. By `Alessandro Amici `_. +- The ``min_count`` argument to :py:meth:`DataArray.sum()` and :py:meth:`DataArray.prod()` + is now ignored when not applicable, i.e. when ``skipna=False`` or when ``skipna=None`` + and the dtype does not have a missing value (:issue:`4352`). + By `Mathias Hauser `_. +- :py:func:`combine_by_coords` now raises an informative error when passing coordinates + with differing calendars (:issue:`4495`). By `Mathias Hauser `_. +- :py:attr:`DataArray.rolling` and :py:attr:`Dataset.rolling` now also keep the attributes and names of of (wrapped) + ``DataArray`` objects, previously only the global attributes were retained (:issue:`4497`, :pull:`4510`). + By `Mathias Hauser `_. +- Improve performance where reading small slices from huge dimensions was slower than necessary (:pull:`4560`). By `Dion Häfner `_. +- Fix bug where ``dask_gufunc_kwargs`` was silently changed in :py:func:`apply_ufunc` (:pull:`4576`). By `Kai Mühlbauer `_. + +Documentation +~~~~~~~~~~~~~ +- document the API not supported with duck arrays (:pull:`4530`). + By `Justus Magin `_. +- Mention the possibility to pass functions to :py:meth:`Dataset.where` or + :py:meth:`DataArray.where` in the parameter documentation (:issue:`4223`, :pull:`4613`). + By `Justus Magin `_. +- Update the docstring of :py:class:`DataArray` and :py:class:`Dataset`. + (:pull:`4532`); + By `Jimmy Westling `_. +- Raise a more informative error when :py:meth:`DataArray.to_dataframe` is + is called on a scalar, (:issue:`4228`); + By `Pieter Gijsbers `_. +- Fix grammar and typos in the :doc:`contributing` guide (:pull:`4545`). + By `Sahid Velji `_. +- Fix grammar and typos in the :doc:`user-guide/io` guide (:pull:`4553`). + By `Sahid Velji `_. +- Update link to NumPy docstring standard in the :doc:`contributing` guide (:pull:`4558`). + By `Sahid Velji `_. +- Add docstrings to ``isnull`` and ``notnull``, and fix the displayed signature + (:issue:`2760`, :pull:`4618`). + By `Justus Magin `_. + +Internal Changes +~~~~~~~~~~~~~~~~ + +- Optional dependencies can be installed along with xarray by specifying + extras as ``pip install "xarray[extra]"`` where ``extra`` can be one of ``io``, + ``accel``, ``parallel``, ``viz`` and ``complete``. See docs for updated + :ref:`installation instructions `. + (:issue:`2888`, :pull:`4480`). + By `Ashwin Vishnu `_, `Justus Magin + `_ and `Mathias Hauser + `_. +- Removed stray spaces that stem from black removing new lines (:pull:`4504`). + By `Mathias Hauser `_. +- Ensure tests are not skipped in the ``py38-all-but-dask`` test environment + (:issue:`4509`). By `Mathias Hauser `_. +- Ignore select numpy warnings around missing values, where xarray handles + the values appropriately, (:pull:`4536`); + By `Maximilian Roos `_. +- Replace the internal use of ``pd.Index.__or__`` and ``pd.Index.__and__`` with ``pd.Index.union`` + and ``pd.Index.intersection`` as they will stop working as set operations in the future + (:issue:`4565`). By `Mathias Hauser `_. +- Add GitHub action for running nightly tests against upstream dependencies (:pull:`4583`). + By `Anderson Banihirwe `_. +- Ensure all figures are closed properly in plot tests (:pull:`4600`). + By `Yash Saboo `_, `Nirupam K N + `_ and `Mathias Hauser + `_. + +.. _whats-new.0.16.1: + +v0.16.1 (2020-09-20) +--------------------- + +This patch release fixes an incompatibility with a recent pandas change, which +was causing an issue indexing with a ``datetime64``. It also includes +improvements to ``rolling``, ``to_dataframe``, ``cov`` & ``corr`` methods and +bug fixes. Our documentation has a number of improvements, including fixing all +doctests and confirming their accuracy on every commit. + +Many thanks to the 36 contributors who contributed to this release: + +Aaron Spring, Akio Taniguchi, Aleksandar Jelenak, Alexandre Poux, +Caleb, Dan Nowacki, Deepak Cherian, Gerardo Rivera, Jacob Tomlinson, James A. +Bednar, Joe Hamman, Julia Kent, Kai Mühlbauer, Keisuke Fujii, Mathias Hauser, +Maximilian Roos, Nick R. Papior, Pascal Bourgault, Peter Hausamann, Romain +Martinez, Russell Manser, Samnan Rahee, Sander, Spencer Clark, Stephan Hoyer, +Thomas Zilio, Tobias Kölling, Tom Augspurger, alexamici, crusaderky, darikg, +inakleinbottle, jenssss, johnomotani, keewis, and rpgoldman. + +Breaking changes +~~~~~~~~~~~~~~~~ + +- :py:meth:`DataArray.astype` and :py:meth:`Dataset.astype` now preserve attributes. Keep the + old behavior by passing `keep_attrs=False` (:issue:`2049`, :pull:`4314`). + By `Dan Nowacki `_ and `Gabriel Joel Mitchell `_. + +New Features +~~~~~~~~~~~~ + +- :py:meth:`~xarray.DataArray.rolling` and :py:meth:`~xarray.Dataset.rolling` + now accept more than 1 dimension. (:pull:`4219`) + By `Keisuke Fujii `_. +- :py:meth:`~xarray.DataArray.to_dataframe` and :py:meth:`~xarray.Dataset.to_dataframe` + now accept a ``dim_order`` parameter allowing to specify the resulting dataframe's + dimensions order (:issue:`4331`, :pull:`4333`). + By `Thomas Zilio `_. +- Support multiple outputs in :py:func:`xarray.apply_ufunc` when using + ``dask='parallelized'``. (:issue:`1815`, :pull:`4060`). + By `Kai Mühlbauer `_. +- ``min_count`` can be supplied to reductions such as ``.sum`` when specifying + multiple dimension to reduce over; (:pull:`4356`). + By `Maximilian Roos `_. +- :py:func:`xarray.cov` and :py:func:`xarray.corr` now handle missing values; (:pull:`4351`). + By `Maximilian Roos `_. +- Add support for parsing datetime strings formatted following the default + string representation of cftime objects, i.e. YYYY-MM-DD hh:mm:ss, in + partial datetime string indexing, as well as :py:meth:`~xarray.cftime_range` + (:issue:`4337`). By `Spencer Clark `_. +- Build ``CFTimeIndex.__repr__`` explicitly as :py:class:`pandas.Index`. Add ``calendar`` as a new + property for :py:class:`CFTimeIndex` and show ``calendar`` and ``length`` in + ``CFTimeIndex.__repr__`` (:issue:`2416`, :pull:`4092`) + By `Aaron Spring `_. +- Use a wrapped array's ``_repr_inline_`` method to construct the collapsed ``repr`` + of :py:class:`DataArray` and :py:class:`Dataset` objects and + document the new method in :doc:`internals/index`. (:pull:`4248`). + By `Justus Magin `_. +- Allow per-variable fill values in most functions. (:pull:`4237`). + By `Justus Magin `_. +- Expose ``use_cftime`` option in :py:func:`~xarray.open_zarr` (:issue:`2886`, :pull:`3229`) + By `Samnan Rahee `_ and `Anderson Banihirwe `_. + +Bug fixes +~~~~~~~~~ + +- Fix indexing with datetime64 scalars with pandas 1.1 (:issue:`4283`). + By `Stephan Hoyer `_ and + `Justus Magin `_. +- Variables which are chunked using dask only along some dimensions can be chunked while storing with zarr along previously + unchunked dimensions (:pull:`4312`) By `Tobias Kölling `_. +- Fixed a bug in backend caused by basic installation of Dask (:issue:`4164`, :pull:`4318`) + `Sam Morley `_. +- Fixed a few bugs with :py:meth:`Dataset.polyfit` when encountering deficient matrix ranks (:issue:`4190`, :pull:`4193`). By `Pascal Bourgault `_. +- Fixed inconsistencies between docstring and functionality for :py:meth:`DataArray.str.get` + and :py:meth:`DataArray.str.wrap` (:issue:`4334`). By `Mathias Hauser `_. +- Fixed overflow issue causing incorrect results in computing means of :py:class:`cftime.datetime` + arrays (:issue:`4341`). By `Spencer Clark `_. +- Fixed :py:meth:`Dataset.coarsen`, :py:meth:`DataArray.coarsen` dropping attributes on original object (:issue:`4120`, :pull:`4360`). By `Julia Kent `_. +- fix the signature of the plot methods. (:pull:`4359`) By `Justus Magin `_. +- Fix :py:func:`xarray.apply_ufunc` with ``vectorize=True`` and ``exclude_dims`` (:issue:`3890`). + By `Mathias Hauser `_. +- Fix `KeyError` when doing linear interpolation to an nd `DataArray` + that contains NaNs (:pull:`4233`). + By `Jens Svensmark `_ +- Fix incorrect legend labels for :py:meth:`Dataset.plot.scatter` (:issue:`4126`). + By `Peter Hausamann `_. +- Fix ``dask.optimize`` on ``DataArray`` producing an invalid Dask task graph (:issue:`3698`) + By `Tom Augspurger `_ +- Fix ``pip install .`` when no ``.git`` directory exists; namely when the xarray source + directory has been rsync'ed by PyCharm Professional for a remote deployment over SSH. + By `Guido Imperiale `_ +- Preserve dimension and coordinate order during :py:func:`xarray.concat` (:issue:`2811`, :issue:`4072`, :pull:`4419`). + By `Kai Mühlbauer `_. +- Avoid relying on :py:class:`set` objects for the ordering of the coordinates (:pull:`4409`) + By `Justus Magin `_. + +Documentation +~~~~~~~~~~~~~ + +- Update the docstring of :py:meth:`DataArray.copy` to remove incorrect mention of 'dataset' (:issue:`3606`) + By `Sander van Rijn `_. +- Removed skipna argument from :py:meth:`DataArray.count`, :py:meth:`DataArray.any`, :py:meth:`DataArray.all`. (:issue:`755`) + By `Sander van Rijn `_ +- Update the contributing guide to use merges instead of rebasing and state + that we squash-merge. (:pull:`4355`). By `Justus Magin `_. +- Make sure the examples from the docstrings actually work (:pull:`4408`). + By `Justus Magin `_. +- Updated Vectorized Indexing to a clearer example. + By `Maximilian Roos `_ + +Internal Changes +~~~~~~~~~~~~~~~~ + +- Fixed all doctests and enabled their running in CI. + By `Justus Magin `_. +- Relaxed the :ref:`mindeps_policy` to support: + + - all versions of setuptools released in the last 42 months (but no older than 38.4) + - all versions of dask and dask.distributed released in the last 12 months (but no + older than 2.9) + - all versions of other packages released in the last 12 months + + All are up from 6 months (:issue:`4295`) + `Guido Imperiale `_. +- Use :py:func:`dask.array.apply_gufunc ` instead of + :py:func:`dask.array.blockwise` in :py:func:`xarray.apply_ufunc` when using + ``dask='parallelized'``. (:pull:`4060`, :pull:`4391`, :pull:`4392`) + By `Kai Mühlbauer `_. +- Align ``mypy`` versions to ``0.782`` across ``requirements`` and + ``.pre-commit-config.yml`` files. (:pull:`4390`) + By `Maximilian Roos `_ +- Only load resource files when running inside a Jupyter Notebook + (:issue:`4294`) By `Guido Imperiale `_ +- Silenced most ``numpy`` warnings such as ``Mean of empty slice``. (:pull:`4369`) + By `Maximilian Roos `_ +- Enable type checking for :py:func:`concat` (:issue:`4238`) + By `Mathias Hauser `_. +- Updated plot functions for matplotlib version 3.3 and silenced warnings in the + plot tests (:pull:`4365`). By `Mathias Hauser `_. +- Versions in ``pre-commit.yaml`` are now pinned, to reduce the chances of + conflicting versions. (:pull:`4388`) + By `Maximilian Roos `_ + + + +.. _whats-new.0.16.0: + +v0.16.0 (2020-07-11) +--------------------- + +This release adds `xarray.cov` & `xarray.corr` for covariance & correlation +respectively; the `idxmax` & `idxmin` methods, the `polyfit` method & +`xarray.polyval` for fitting polynomials, as well as a number of documentation +improvements, other features, and bug fixes. Many thanks to all 44 contributors +who contributed to this release: + +Akio Taniguchi, Andrew Williams, Aurélien Ponte, Benoit Bovy, Dave Cole, David +Brochart, Deepak Cherian, Elliott Sales de Andrade, Etienne Combrisson, Hossein +Madadi, Huite, Joe Hamman, Kai Mühlbauer, Keisuke Fujii, Maik Riechert, Marek +Jacob, Mathias Hauser, Matthieu Ancellin, Maximilian Roos, Noah D Brenowitz, +Oriol Abril, Pascal Bourgault, Phillip Butcher, Prajjwal Nijhara, Ray Bell, Ryan +Abernathey, Ryan May, Spencer Clark, Spencer Hill, Srijan Saurav, Stephan Hoyer, +Taher Chegini, Todd, Tom Nicholas, Yohai Bar Sinai, Yunus Sevinchan, +arabidopsis, aurghs, clausmichele, dmey, johnomotani, keewis, raphael dussin, +risebell + +Breaking changes +~~~~~~~~~~~~~~~~ + +- Minimum supported versions for the following packages have changed: ``dask >=2.9``, + ``distributed>=2.9``. + By `Deepak Cherian `_ +- ``groupby`` operations will restore coord dimension order. Pass ``restore_coord_dims=False`` + to revert to previous behavior. +- :meth:`DataArray.transpose` will now transpose coordinates by default. + Pass ``transpose_coords=False`` to revert to previous behaviour. + By `Maximilian Roos `_ +- Alternate draw styles for :py:meth:`plot.step` must be passed using the + ``drawstyle`` (or ``ds``) keyword argument, instead of the ``linestyle`` (or + ``ls``) keyword argument, in line with the `upstream change in Matplotlib + `_. + (:pull:`3274`) + By `Elliott Sales de Andrade `_ +- The old ``auto_combine`` function has now been removed in + favour of the :py:func:`combine_by_coords` and + :py:func:`combine_nested` functions. This also means that + the default behaviour of :py:func:`open_mfdataset` has changed to use + ``combine='by_coords'`` as the default argument value. (:issue:`2616`, :pull:`3926`) + By `Tom Nicholas `_. +- The ``DataArray`` and ``Variable`` HTML reprs now expand the data section by + default (:issue:`4176`) + By `Stephan Hoyer `_. + +New Features +~~~~~~~~~~~~ +- :py:meth:`DataArray.argmin` and :py:meth:`DataArray.argmax` now support + sequences of 'dim' arguments, and if a sequence is passed return a dict + (which can be passed to :py:meth:`DataArray.isel` to get the value of the minimum) of + the indices for each dimension of the minimum or maximum of a DataArray. + (:pull:`3936`) + By `John Omotani `_, thanks to `Keisuke Fujii + `_ for work in :pull:`1469`. +- Added :py:func:`xarray.cov` and :py:func:`xarray.corr` (:issue:`3784`, :pull:`3550`, :pull:`4089`). + By `Andrew Williams `_ and `Robin Beer `_. +- Implement :py:meth:`DataArray.idxmax`, :py:meth:`DataArray.idxmin`, + :py:meth:`Dataset.idxmax`, :py:meth:`Dataset.idxmin`. (:issue:`60`, :pull:`3871`) + By `Todd Jennings `_ +- Added :py:meth:`DataArray.polyfit` and :py:func:`xarray.polyval` for fitting + polynomials. (:issue:`3349`, :pull:`3733`, :pull:`4099`) + By `Pascal Bourgault `_. +- Added :py:meth:`xarray.infer_freq` for extending frequency inferring to CFTime indexes and data (:pull:`4033`). + By `Pascal Bourgault `_. +- ``chunks='auto'`` is now supported in the ``chunks`` argument of + :py:meth:`Dataset.chunk`. (:issue:`4055`) + By `Andrew Williams `_ +- Control over attributes of result in :py:func:`merge`, :py:func:`concat`, + :py:func:`combine_by_coords` and :py:func:`combine_nested` using + combine_attrs keyword argument. (:issue:`3865`, :pull:`3877`) + By `John Omotani `_ +- `missing_dims` argument to :py:meth:`Dataset.isel`, + :py:meth:`DataArray.isel` and :py:meth:`Variable.isel` to allow replacing + the exception when a dimension passed to ``isel`` is not present with a + warning, or just ignore the dimension. (:issue:`3866`, :pull:`3923`) + By `John Omotani `_ +- Support dask handling for :py:meth:`DataArray.idxmax`, :py:meth:`DataArray.idxmin`, + :py:meth:`Dataset.idxmax`, :py:meth:`Dataset.idxmin`. (:pull:`3922`, :pull:`4135`) + By `Kai Mühlbauer `_ and `Pascal Bourgault `_. +- More support for unit aware arrays with pint (:pull:`3643`, :pull:`3975`, :pull:`4163`) + By `Justus Magin `_. +- Support overriding existing variables in ``to_zarr()`` with ``mode='a'`` even + without ``append_dim``, as long as dimension sizes do not change. + By `Stephan Hoyer `_. +- Allow plotting of boolean arrays. (:pull:`3766`) + By `Marek Jacob `_ +- Enable using MultiIndex levels as coordinates in 1D and 2D plots (:issue:`3927`). + By `Mathias Hauser `_. +- A ``days_in_month`` accessor for :py:class:`xarray.CFTimeIndex`, analogous to + the ``days_in_month`` accessor for a :py:class:`pandas.DatetimeIndex`, which + returns the days in the month each datetime in the index. Now days in month + weights for both standard and non-standard calendars can be obtained using + the :py:class:`~core.accessor_dt.DatetimeAccessor` (:pull:`3935`). This + feature requires cftime version 1.1.0 or greater. By + `Spencer Clark `_. +- For the netCDF3 backend, added dtype coercions for unsigned integer types. + (:issue:`4014`, :pull:`4018`) + By `Yunus Sevinchan `_ +- :py:meth:`map_blocks` now accepts a ``template`` kwarg. This allows use cases + where the result of a computation could not be inferred automatically. + By `Deepak Cherian `_ +- :py:meth:`map_blocks` can now handle dask-backed xarray objects in ``args``. (:pull:`3818`) + By `Deepak Cherian `_ +- Add keyword ``decode_timedelta`` to :py:func:`xarray.open_dataset`, + (:py:func:`xarray.open_dataarray`, :py:func:`xarray.open_dataarray`, + :py:func:`xarray.decode_cf`) that allows to disable/enable the decoding of timedeltas + independently of time decoding (:issue:`1621`) + `Aureliana Barghini `_ + +Enhancements +~~~~~~~~~~~~ +- Performance improvement of :py:meth:`DataArray.interp` and :py:func:`Dataset.interp` + We performs independent interpolation sequentially rather than interpolating in + one large multidimensional space. (:issue:`2223`) + By `Keisuke Fujii `_. +- :py:meth:`DataArray.interp` now support interpolations over chunked dimensions (:pull:`4155`). By `Alexandre Poux `_. +- Major performance improvement for :py:meth:`Dataset.from_dataframe` when the + dataframe has a MultiIndex (:pull:`4184`). + By `Stephan Hoyer `_. + - :py:meth:`DataArray.reset_index` and :py:meth:`Dataset.reset_index` now keep + coordinate attributes (:pull:`4103`). By `Oriol Abril `_. +- Axes kwargs such as ``facecolor`` can now be passed to :py:meth:`DataArray.plot` in ``subplot_kws``. + This works for both single axes plots and FacetGrid plots. + By `Raphael Dussin `_. +- Array items with long string reprs are now limited to a + reasonable width (:pull:`3900`) + By `Maximilian Roos `_ +- Large arrays whose numpy reprs would have greater than 40 lines are now + limited to a reasonable length. + (:pull:`3905`) + By `Maximilian Roos `_ + +Bug fixes +~~~~~~~~~ +- Fix errors combining attrs in :py:func:`open_mfdataset` (:issue:`4009`, :pull:`4173`) + By `John Omotani `_ +- If groupby receives a ``DataArray`` with name=None, assign a default name (:issue:`158`) + By `Phil Butcher `_. +- Support dark mode in VS code (:issue:`4024`) + By `Keisuke Fujii `_. +- Fix bug when converting multiindexed pandas objects to sparse xarray objects. (:issue:`4019`) + By `Deepak Cherian `_. +- ``ValueError`` is raised when ``fill_value`` is not a scalar in :py:meth:`full_like`. (:issue:`3977`) + By `Huite Bootsma `_. +- Fix wrong order in converting a ``pd.Series`` with a MultiIndex to ``DataArray``. + (:issue:`3951`, :issue:`4186`) + By `Keisuke Fujii `_ and `Stephan Hoyer `_. +- Fix renaming of coords when one or more stacked coords is not in + sorted order during stack+groupby+apply operations. (:issue:`3287`, + :pull:`3906`) By `Spencer Hill `_ +- Fix a regression where deleting a coordinate from a copied :py:class:`DataArray` + can affect the original :py:class:`DataArray`. (:issue:`3899`, :pull:`3871`) + By `Todd Jennings `_ +- Fix :py:class:`~xarray.plot.FacetGrid` plots with a single contour. (:issue:`3569`, :pull:`3915`). + By `Deepak Cherian `_ +- Use divergent colormap if ``levels`` spans 0. (:issue:`3524`) + By `Deepak Cherian `_ +- Fix :py:class:`~xarray.plot.FacetGrid` when ``vmin == vmax``. (:issue:`3734`) + By `Deepak Cherian `_ +- Fix plotting when ``levels`` is a scalar and ``norm`` is provided. (:issue:`3735`) + By `Deepak Cherian `_ +- Fix bug where plotting line plots with 2D coordinates depended on dimension + order. (:issue:`3933`) + By `Tom Nicholas `_. +- Fix ``RasterioDeprecationWarning`` when using a ``vrt`` in ``open_rasterio``. (:issue:`3964`) + By `Taher Chegini `_. +- Fix ``AttributeError`` on displaying a :py:class:`Variable` + in a notebook context. (:issue:`3972`, :pull:`3973`) + By `Ian Castleden `_. +- Fix bug causing :py:meth:`DataArray.interpolate_na` to always drop attributes, + and added `keep_attrs` argument. (:issue:`3968`) + By `Tom Nicholas `_. +- Fix bug in time parsing failing to fall back to cftime. This was causing time + variables with a time unit of `'msecs'` to fail to parse. (:pull:`3998`) + By `Ryan May `_. +- Fix weighted mean when passing boolean weights (:issue:`4074`). + By `Mathias Hauser `_. +- Fix html repr in untrusted notebooks: fallback to plain text repr. (:pull:`4053`) + By `Benoit Bovy `_. +- Fix :py:meth:`DataArray.to_unstacked_dataset` for single-dimension variables. (:issue:`4049`) + By `Deepak Cherian `_ +- Fix :py:func:`open_rasterio` for ``WarpedVRT`` with specified ``src_crs``. (:pull:`4104`) + By `Dave Cole `_. + +Documentation +~~~~~~~~~~~~~ +- update the docstring of :py:meth:`DataArray.assign_coords` : clarify how to + add a new coordinate to an existing dimension and illustrative example + (:issue:`3952`, :pull:`3958`) By + `Etienne Combrisson `_. +- update the docstring of :py:meth:`Dataset.diff` and + :py:meth:`DataArray.diff` so it does document the ``dim`` + parameter as required. (:issue:`1040`, :pull:`3909`) + By `Justus Magin `_. +- Updated :doc:`Calculating Seasonal Averages from Timeseries of Monthly Means + ` example notebook to take advantage of the new + ``days_in_month`` accessor for :py:class:`xarray.CFTimeIndex` + (:pull:`3935`). By `Spencer Clark `_. +- Updated the list of current core developers. (:issue:`3892`) + By `Tom Nicholas `_. +- Add example for multi-dimensional extrapolation and note different behavior + of ``kwargs`` in :py:meth:`Dataset.interp` and :py:meth:`DataArray.interp` + for 1-d and n-d interpolation (:pull:`3956`). + By `Matthias Riße `_. +- Apply ``black`` to all the code in the documentation (:pull:`4012`) + By `Justus Magin `_. +- Narrative documentation now describes :py:meth:`map_blocks`: :ref:`dask.automatic-parallelization`. + By `Deepak Cherian `_. +- Document ``.plot``, ``.dt``, ``.str`` accessors the way they are called. (:issue:`3625`, :pull:`3988`) + By `Justus Magin `_. +- Add documentation for the parameters and return values of :py:meth:`DataArray.sel`. + By `Justus Magin `_. + +Internal Changes +~~~~~~~~~~~~~~~~ +- Raise more informative error messages for chunk size conflicts when writing to zarr files. + By `Deepak Cherian `_. +- Run the ``isort`` pre-commit hook only on python source files + and update the ``flake8`` version. (:issue:`3750`, :pull:`3711`) + By `Justus Magin `_. +- Add `blackdoc `_ to the list of + checkers for development. (:pull:`4177`) + By `Justus Magin `_. +- Add a CI job that runs the tests with every optional dependency + except ``dask``. (:issue:`3794`, :pull:`3919`) + By `Justus Magin `_. +- Use ``async`` / ``await`` for the asynchronous distributed + tests. (:issue:`3987`, :pull:`3989`) + By `Justus Magin `_. +- Various internal code clean-ups (:pull:`4026`, :pull:`4038`). + By `Prajjwal Nijhara `_. + +.. _whats-new.0.15.1: + +v0.15.1 (23 Mar 2020) +--------------------- + +This release brings many new features such as :py:meth:`Dataset.weighted` methods for weighted array +reductions, a new jupyter repr by default, and the start of units integration with pint. There's also +the usual batch of usability improvements, documentation additions, and bug fixes. + +Breaking changes +~~~~~~~~~~~~~~~~ + +- Raise an error when assigning to the ``.values`` or ``.data`` attribute of + dimension coordinates i.e. ``IndexVariable`` objects. This has been broken since + v0.12.0. Please use :py:meth:`DataArray.assign_coords` or :py:meth:`Dataset.assign_coords` + instead. (:issue:`3470`, :pull:`3862`) + By `Deepak Cherian `_ + +New Features +~~~~~~~~~~~~ + +- Weighted array reductions are now supported via the new :py:meth:`DataArray.weighted` + and :py:meth:`Dataset.weighted` methods. See :ref:`comput.weighted`. (:issue:`422`, :pull:`2922`). + By `Mathias Hauser `_. +- The new jupyter notebook repr (``Dataset._repr_html_`` and + ``DataArray._repr_html_``) (introduced in 0.14.1) is now on by default. To + disable, use ``xarray.set_options(display_style="text")``. + By `Julia Signell `_. +- Added support for :py:class:`pandas.DatetimeIndex`-style rounding of + ``cftime.datetime`` objects directly via a :py:class:`CFTimeIndex` or via the + :py:class:`~core.accessor_dt.DatetimeAccessor`. + By `Spencer Clark `_ +- Support new h5netcdf backend keyword `phony_dims` (available from h5netcdf + v0.8.0 for :py:class:`~xarray.backends.H5NetCDFStore`. + By `Kai Mühlbauer `_. +- Add partial support for unit aware arrays with pint. (:pull:`3706`, :pull:`3611`) + By `Justus Magin `_. +- :py:meth:`Dataset.groupby` and :py:meth:`DataArray.groupby` now raise a + `TypeError` on multiple string arguments. Receiving multiple string arguments + often means a user is attempting to pass multiple dimensions as separate + arguments and should instead pass a single list of dimensions. + (:pull:`3802`) + By `Maximilian Roos `_ +- :py:func:`map_blocks` can now apply functions that add new unindexed dimensions. + By `Deepak Cherian `_ +- An ellipsis (``...``) is now supported in the ``dims`` argument of + :py:meth:`Dataset.stack` and :py:meth:`DataArray.stack`, meaning all + unlisted dimensions, similar to its meaning in :py:meth:`DataArray.transpose`. + (:pull:`3826`) + By `Maximilian Roos `_ +- :py:meth:`Dataset.where` and :py:meth:`DataArray.where` accept a lambda as a + first argument, which is then called on the input; replicating pandas' behavior. + By `Maximilian Roos `_. +- ``skipna`` is available in :py:meth:`Dataset.quantile`, :py:meth:`DataArray.quantile`, + :py:meth:`core.groupby.DatasetGroupBy.quantile`, :py:meth:`core.groupby.DataArrayGroupBy.quantile` + (:issue:`3843`, :pull:`3844`) + By `Aaron Spring `_. +- Add a diff summary for `testing.assert_allclose`. (:issue:`3617`, :pull:`3847`) + By `Justus Magin `_. + +Bug fixes +~~~~~~~~~ + +- Fix :py:meth:`Dataset.interp` when indexing array shares coordinates with the + indexed variable (:issue:`3252`). + By `David Huard `_. +- Fix recombination of groups in :py:meth:`Dataset.groupby` and + :py:meth:`DataArray.groupby` when performing an operation that changes the + size of the groups along the grouped dimension. By `Eric Jansen + `_. +- Fix use of multi-index with categorical values (:issue:`3674`). + By `Matthieu Ancellin `_. +- Fix alignment with ``join="override"`` when some dimensions are unindexed. (:issue:`3681`). + By `Deepak Cherian `_. +- Fix :py:meth:`Dataset.swap_dims` and :py:meth:`DataArray.swap_dims` producing + index with name reflecting the previous dimension name instead of the new one + (:issue:`3748`, :pull:`3752`). By `Joseph K Aicher + `_. +- Use ``dask_array_type`` instead of ``dask_array.Array`` for type + checking. (:issue:`3779`, :pull:`3787`) + By `Justus Magin `_. +- :py:func:`concat` can now handle coordinate variables only present in one of + the objects to be concatenated when ``coords="different"``. + By `Deepak Cherian `_. +- xarray now respects the over, under and bad colors if set on a provided colormap. + (:issue:`3590`, :pull:`3601`) + By `johnomotani `_. +- ``coarsen`` and ``rolling`` now respect ``xr.set_options(keep_attrs=True)`` + to preserve attributes. :py:meth:`Dataset.coarsen` accepts a keyword + argument ``keep_attrs`` to change this setting. (:issue:`3376`, + :pull:`3801`) By `Andrew Thomas `_. +- Delete associated indexes when deleting coordinate variables. (:issue:`3746`). + By `Deepak Cherian `_. +- Fix :py:meth:`Dataset.to_zarr` when using ``append_dim`` and ``group`` + simultaneously. (:issue:`3170`). By `Matthias Meyer `_. +- Fix html repr on :py:class:`Dataset` with non-string keys (:pull:`3807`). + By `Maximilian Roos `_. + +Documentation +~~~~~~~~~~~~~ + +- Fix documentation of :py:class:`DataArray` removing the deprecated mention + that when omitted, `dims` are inferred from a `coords`-dict. (:pull:`3821`) + By `Sander van Rijn `_. +- Improve the :py:func:`where` docstring. + By `Maximilian Roos `_ +- Update the installation instructions: only explicitly list recommended dependencies + (:issue:`3756`). + By `Mathias Hauser `_. + +Internal Changes +~~~~~~~~~~~~~~~~ + +- Remove the internal ``import_seaborn`` function which handled the deprecation of + the ``seaborn.apionly`` entry point (:issue:`3747`). + By `Mathias Hauser `_. +- Don't test pint integration in combination with datetime objects. (:issue:`3778`, :pull:`3788`) + By `Justus Magin `_. +- Change test_open_mfdataset_list_attr to only run with dask installed + (:issue:`3777`, :pull:`3780`). + By `Bruno Pagani `_. +- Preserve the ability to index with ``method="nearest"`` with a + :py:class:`CFTimeIndex` with pandas versions greater than 1.0.1 + (:issue:`3751`). By `Spencer Clark `_. +- Greater flexibility and improved test coverage of subtracting various types + of objects from a :py:class:`CFTimeIndex`. By `Spencer Clark + `_. +- Update Azure CI MacOS image, given pending removal. + By `Maximilian Roos `_ +- Remove xfails for scipy 1.0.1 for tests that append to netCDF files (:pull:`3805`). + By `Mathias Hauser `_. +- Remove conversion to ``pandas.Panel``, given its removal in pandas + in favor of xarray's objects. + By `Maximilian Roos `_ + +.. _whats-new.0.15.0: + + +v0.15.0 (30 Jan 2020) +--------------------- + +This release brings many improvements to xarray's documentation: our examples are now binderized notebooks (`click here `_) +and we have new example notebooks from our SciPy 2019 sprint (many thanks to our contributors!). + +This release also features many API improvements such as a new +:py:class:`~core.accessor_dt.TimedeltaAccessor` and support for :py:class:`CFTimeIndex` in +:py:meth:`~DataArray.interpolate_na`); as well as many bug fixes. + +Breaking changes +~~~~~~~~~~~~~~~~ +- Bumped minimum tested versions for dependencies: + + - numpy 1.15 + - pandas 0.25 + - dask 2.2 + - distributed 2.2 + - scipy 1.3 + +- Remove ``compat`` and ``encoding`` kwargs from ``DataArray``, which + have been deprecated since 0.12. (:pull:`3650`). + Instead, specify the ``encoding`` kwarg when writing to disk or set + the :py:attr:`DataArray.encoding` attribute directly. + By `Maximilian Roos `_. +- :py:func:`xarray.dot`, :py:meth:`DataArray.dot`, and the ``@`` operator now + use ``align="inner"`` (except when ``xarray.set_options(arithmetic_join="exact")``; + :issue:`3694`) by `Mathias Hauser `_. + +New Features +~~~~~~~~~~~~ +- Implement :py:meth:`DataArray.pad` and :py:meth:`Dataset.pad`. (:issue:`2605`, :pull:`3596`). + By `Mark Boer `_. +- :py:meth:`DataArray.sel` and :py:meth:`Dataset.sel` now support :py:class:`pandas.CategoricalIndex`. (:issue:`3669`) + By `Keisuke Fujii `_. +- Support using an existing, opened h5netcdf ``File`` with + :py:class:`~xarray.backends.H5NetCDFStore`. This permits creating an + :py:class:`~xarray.Dataset` from a h5netcdf ``File`` that has been opened + using other means (:issue:`3618`). + By `Kai Mühlbauer `_. +- Implement ``median`` and ``nanmedian`` for dask arrays. This works by rechunking + to a single chunk along all reduction axes. (:issue:`2999`). + By `Deepak Cherian `_. +- :py:func:`~xarray.concat` now preserves attributes from the first Variable. + (:issue:`2575`, :issue:`2060`, :issue:`1614`) + By `Deepak Cherian `_. +- :py:meth:`Dataset.quantile`, :py:meth:`DataArray.quantile` and ``GroupBy.quantile`` + now work with dask Variables. + By `Deepak Cherian `_. +- Added the ``count`` reduction method to both :py:class:`~core.rolling.DatasetCoarsen` + and :py:class:`~core.rolling.DataArrayCoarsen` objects. (:pull:`3500`) + By `Deepak Cherian `_ +- Add ``meta`` kwarg to :py:func:`~xarray.apply_ufunc`; + this is passed on to :py:func:`dask.array.blockwise`. (:pull:`3660`) + By `Deepak Cherian `_. +- Add ``attrs_file`` option in :py:func:`~xarray.open_mfdataset` to choose the + source file for global attributes in a multi-file dataset (:issue:`2382`, + :pull:`3498`). By `Julien Seguinot `_. +- :py:meth:`Dataset.swap_dims` and :py:meth:`DataArray.swap_dims` + now allow swapping to dimension names that don't exist yet. (:pull:`3636`) + By `Justus Magin `_. +- Extend :py:class:`~core.accessor_dt.DatetimeAccessor` properties + and support ``.dt`` accessor for timedeltas + via :py:class:`~core.accessor_dt.TimedeltaAccessor` (:pull:`3612`) + By `Anderson Banihirwe `_. +- Improvements to interpolating along time axes (:issue:`3641`, :pull:`3631`). + By `David Huard `_. + + - Support :py:class:`CFTimeIndex` in :py:meth:`DataArray.interpolate_na` + - define 1970-01-01 as the default offset for the interpolation index for both + :py:class:`pandas.DatetimeIndex` and :py:class:`CFTimeIndex`, + - use microseconds in the conversion from timedelta objects to floats to avoid + overflow errors. + +Bug fixes +~~~~~~~~~ +- Applying a user-defined function that adds new dimensions using :py:func:`apply_ufunc` + and ``vectorize=True`` now works with ``dask > 2.0``. (:issue:`3574`, :pull:`3660`). + By `Deepak Cherian `_. +- Fix :py:meth:`~xarray.combine_by_coords` to allow for combining incomplete + hypercubes of Datasets (:issue:`3648`). By `Ian Bolliger + `_. +- Fix :py:func:`~xarray.combine_by_coords` when combining cftime coordinates + which span long time intervals (:issue:`3535`). By `Spencer Clark + `_. +- Fix plotting with transposed 2D non-dimensional coordinates. (:issue:`3138`, :pull:`3441`) + By `Deepak Cherian `_. +- :py:meth:`plot.FacetGrid.set_titles` can now replace existing row titles of a + :py:class:`~xarray.plot.FacetGrid` plot. In addition :py:class:`~xarray.plot.FacetGrid` gained + two new attributes: :py:attr:`~xarray.plot.FacetGrid.col_labels` and + :py:attr:`~xarray.plot.FacetGrid.row_labels` contain :py:class:`matplotlib.text.Text` handles for both column and + row labels. These can be used to manually change the labels. + By `Deepak Cherian `_. +- Fix issue with Dask-backed datasets raising a ``KeyError`` on some computations involving :py:func:`map_blocks` (:pull:`3598`). + By `Tom Augspurger `_. +- Ensure :py:meth:`Dataset.quantile`, :py:meth:`DataArray.quantile` issue the correct error + when ``q`` is out of bounds (:issue:`3634`) by `Mathias Hauser `_. +- Fix regression in xarray 0.14.1 that prevented encoding times with certain + ``dtype``, ``_FillValue``, and ``missing_value`` encodings (:issue:`3624`). + By `Spencer Clark `_ +- Raise an error when trying to use :py:meth:`Dataset.rename_dims` to + rename to an existing name (:issue:`3438`, :pull:`3645`) + By `Justus Magin `_. +- :py:meth:`Dataset.rename`, :py:meth:`DataArray.rename` now check for conflicts with + MultiIndex level names. +- :py:meth:`Dataset.merge` no longer fails when passed a :py:class:`DataArray` instead of a :py:class:`Dataset`. + By `Tom Nicholas `_. +- Fix a regression in :py:meth:`Dataset.drop`: allow passing any + iterable when dropping variables (:issue:`3552`, :pull:`3693`) + By `Justus Magin `_. +- Fixed errors emitted by ``mypy --strict`` in modules that import xarray. + (:issue:`3695`) by `Guido Imperiale `_. +- Allow plotting of binned coordinates on the y axis in :py:meth:`plot.line` + and :py:meth:`plot.step` plots (:issue:`3571`, + :pull:`3685`) by `Julien Seguinot `_. +- setuptools is now marked as a dependency of xarray + (:pull:`3628`) by `Richard Höchenberger `_. + +Documentation +~~~~~~~~~~~~~ +- Switch doc examples to use `nbsphinx `_ and replace + ``sphinx_gallery`` scripts with Jupyter notebooks. (:pull:`3105`, :pull:`3106`, :pull:`3121`) + By `Ryan Abernathey `_. +- Added :doc:`example notebook ` demonstrating use of xarray with + Regional Ocean Modeling System (ROMS) ocean hydrodynamic model output. (:pull:`3116`) + By `Robert Hetland `_. +- Added :doc:`example notebook ` demonstrating the visualization of + ERA5 GRIB data. (:pull:`3199`) + By `Zach Bruick `_ and + `Stephan Siemen `_. +- Added examples for :py:meth:`DataArray.quantile`, :py:meth:`Dataset.quantile` and + ``GroupBy.quantile``. (:pull:`3576`) + By `Justus Magin `_. +- Add new :doc:`example notebook ` example notebook demonstrating + vectorization of a 1D function using :py:func:`apply_ufunc` , dask and numba. + By `Deepak Cherian `_. +- Added example for :py:func:`~xarray.map_blocks`. (:pull:`3667`) + By `Riley X. Brady `_. + +Internal Changes +~~~~~~~~~~~~~~~~ +- Make sure dask names change when rechunking by different chunk sizes. Conversely, make sure they + stay the same when rechunking by the same chunk size. (:issue:`3350`) + By `Deepak Cherian `_. +- 2x to 5x speed boost (on small arrays) for :py:meth:`Dataset.isel`, + :py:meth:`DataArray.isel`, and :py:meth:`DataArray.__getitem__` when indexing by int, + slice, list of int, scalar ndarray, or 1-dimensional ndarray. + (:pull:`3533`) by `Guido Imperiale `_. +- Removed internal method ``Dataset._from_vars_and_coord_names``, + which was dominated by ``Dataset._construct_direct``. (:pull:`3565`) + By `Maximilian Roos `_. +- Replaced versioneer with setuptools-scm. Moved contents of setup.py to setup.cfg. + Removed pytest-runner from setup.py, as per deprecation notice on the pytest-runner + project. (:pull:`3714`) by `Guido Imperiale `_. +- Use of isort is now enforced by CI. + (:pull:`3721`) by `Guido Imperiale `_ + + +.. _whats-new.0.14.1: + +v0.14.1 (19 Nov 2019) +--------------------- + +Breaking changes +~~~~~~~~~~~~~~~~ + +- Broken compatibility with ``cftime < 1.0.3`` . By `Deepak Cherian `_. + + .. warning:: + + cftime version 1.0.4 is broken + (`cftime/126 `_); + please use version 1.0.4.2 instead. + +- All leftover support for dates from non-standard calendars through ``netcdftime``, the + module included in versions of netCDF4 prior to 1.4 that eventually became the + `cftime `_ package, has been removed in favor of relying solely on + the standalone ``cftime`` package (:pull:`3450`). + By `Spencer Clark `_. + +New Features +~~~~~~~~~~~~ +- Added the ``sparse`` option to :py:meth:`~xarray.DataArray.unstack`, + :py:meth:`~xarray.Dataset.unstack`, :py:meth:`~xarray.DataArray.reindex`, + :py:meth:`~xarray.Dataset.reindex` (:issue:`3518`). + By `Keisuke Fujii `_. +- Added the ``fill_value`` option to :py:meth:`DataArray.unstack` and + :py:meth:`Dataset.unstack` (:issue:`3518`, :pull:`3541`). + By `Keisuke Fujii `_. +- Added the ``max_gap`` kwarg to :py:meth:`~xarray.DataArray.interpolate_na` and + :py:meth:`~xarray.Dataset.interpolate_na`. This controls the maximum size of the data + gap that will be filled by interpolation. By `Deepak Cherian `_. +- Added :py:meth:`Dataset.drop_sel` & :py:meth:`DataArray.drop_sel` for dropping labels. + :py:meth:`Dataset.drop_vars` & :py:meth:`DataArray.drop_vars` have been added for + dropping variables (including coordinates). The existing :py:meth:`Dataset.drop` & + :py:meth:`DataArray.drop` methods remain as a backward compatible + option for dropping either labels or variables, but using the more specific methods is encouraged. + (:pull:`3475`) + By `Maximilian Roos `_ +- Added :py:meth:`Dataset.map` & ``GroupBy.map`` & ``Resample.map`` for + mapping / applying a function over each item in the collection, reflecting the widely used + and least surprising name for this operation. + The existing ``apply`` methods remain for backward compatibility, though using the ``map`` + methods is encouraged. + (:pull:`3459`) + By `Maximilian Roos `_ +- :py:meth:`Dataset.transpose` and :py:meth:`DataArray.transpose` now support an ellipsis (``...``) + to represent all 'other' dimensions. For example, to move one dimension to the front, + use ``.transpose('x', ...)``. (:pull:`3421`) + By `Maximilian Roos `_ +- Changed ``xr.ALL_DIMS`` to equal python's ``Ellipsis`` (``...``), and changed internal usages to use + ``...`` directly. As before, you can use this to instruct a ``groupby`` operation + to reduce over all dimensions. While we have no plans to remove ``xr.ALL_DIMS``, we suggest + using ``...``. (:pull:`3418`) + By `Maximilian Roos `_ +- :py:func:`xarray.dot`, and :py:meth:`DataArray.dot` now support the + ``dims=...`` option to sum over the union of dimensions of all input arrays + (:issue:`3423`) by `Mathias Hauser `_. +- Added new ``Dataset._repr_html_`` and ``DataArray._repr_html_`` to improve + representation of objects in Jupyter. By default this feature is turned off + for now. Enable it with ``xarray.set_options(display_style="html")``. + (:pull:`3425`) by `Benoit Bovy `_ and + `Julia Signell `_. +- Implement `dask deterministic hashing + `_ + for xarray objects. Note that xarray objects with a dask.array backend already used + deterministic hashing in previous releases; this change implements it when whole + xarray objects are embedded in a dask graph, e.g. when :py:meth:`DataArray.map_blocks` is + invoked. (:issue:`3378`, :pull:`3446`, :pull:`3515`) + By `Deepak Cherian `_ and + `Guido Imperiale `_. +- Add the documented-but-missing :py:meth:`~core.groupby.DatasetGroupBy.quantile`. +- xarray now respects the ``DataArray.encoding["coordinates"]`` attribute when writing to disk. + See :ref:`io.coordinates` for more. (:issue:`3351`, :pull:`3487`) + By `Deepak Cherian `_. +- Add the documented-but-missing :py:meth:`~core.groupby.DatasetGroupBy.quantile`. + (:issue:`3525`, :pull:`3527`). By `Justus Magin `_. + +Bug fixes +~~~~~~~~~ +- Ensure an index of type ``CFTimeIndex`` is not converted to a ``DatetimeIndex`` when + calling :py:meth:`Dataset.rename`, :py:meth:`Dataset.rename_dims` and :py:meth:`Dataset.rename_vars`. + By `Mathias Hauser `_. (:issue:`3522`). +- Fix a bug in :py:meth:`DataArray.set_index` in case that an existing dimension becomes a level + variable of MultiIndex. (:pull:`3520`). By `Keisuke Fujii `_. +- Harmonize ``_FillValue``, ``missing_value`` during encoding and decoding steps. (:pull:`3502`) + By `Anderson Banihirwe `_. +- Fix regression introduced in v0.14.0 that would cause a crash if dask is installed + but cloudpickle isn't (:issue:`3401`) by `Rhys Doyle `_ +- Fix grouping over variables with NaNs. (:issue:`2383`, :pull:`3406`). + By `Deepak Cherian `_. +- Make alignment and concatenation significantly more efficient by using dask names to compare dask + objects prior to comparing values after computation. This change makes it more convenient to carry + around large non-dimensional coordinate variables backed by dask arrays. Existing workarounds involving + ``reset_coords(drop=True)`` should now be unnecessary in most cases. + (:issue:`3068`, :issue:`3311`, :issue:`3454`, :pull:`3453`). + By `Deepak Cherian `_. +- Add support for cftime>=1.0.4. By `Anderson Banihirwe `_. +- Rolling reduction operations no longer compute dask arrays by default. (:issue:`3161`). + In addition, the ``allow_lazy`` kwarg to ``reduce`` is deprecated. + By `Deepak Cherian `_. +- Fix ``GroupBy.reduce`` when reducing over multiple dimensions. + (:issue:`3402`). By `Deepak Cherian `_ +- Allow appending datetime and bool data variables to zarr stores. + (:issue:`3480`). By `Akihiro Matsukawa `_. +- Add support for numpy >=1.18 (); bugfix mean() on datetime64 arrays on dask backend + (:issue:`3409`, :pull:`3537`). By `Guido Imperiale `_. +- Add support for pandas >=0.26 (:issue:`3440`). + By `Deepak Cherian `_. +- Add support for pseudonetcdf >=3.1 (:pull:`3485`). + By `Barron Henderson `_. + +Documentation +~~~~~~~~~~~~~ +- Fix leap year condition in `monthly means example `_. + By `Mickaël Lalande `_. +- Fix the documentation of :py:meth:`DataArray.resample` and + :py:meth:`Dataset.resample`, explicitly stating that a + datetime-like dimension is required. (:pull:`3400`) + By `Justus Magin `_. +- Update the :ref:`terminology` page to address multidimensional coordinates. (:pull:`3410`) + By `Jon Thielen `_. +- Fix the documentation of :py:meth:`Dataset.integrate` and + :py:meth:`DataArray.integrate` and add an example to + :py:meth:`Dataset.integrate`. (:pull:`3469`) + By `Justus Magin `_. + +Internal Changes +~~~~~~~~~~~~~~~~ + +- Added integration tests against `pint `_. + (:pull:`3238`, :pull:`3447`, :pull:`3493`, :pull:`3508`) + by `Justus Magin `_. + + .. note:: + + At the moment of writing, these tests *as well as the ability to use pint in general* + require `a highly experimental version of pint + `_ (install with + ``pip install git+https://github.com/andrewgsavage/pint.git@refs/pull/6/head)``. + Even with it, interaction with non-numpy array libraries, e.g. dask or sparse, is broken. + +- Use Python 3.6 idioms throughout the codebase. (:pull:`3419`) + By `Maximilian Roos `_ + +- Run basic CI tests on Python 3.8. (:pull:`3477`) + By `Maximilian Roos `_ + +- Enable type checking on default sentinel values (:pull:`3472`) + By `Maximilian Roos `_ + +- Add ``Variable._replace`` for simpler replacing of a subset of attributes (:pull:`3472`) + By `Maximilian Roos `_ + +.. _whats-new.0.14.0: + +v0.14.0 (14 Oct 2019) +--------------------- + +Breaking changes +~~~~~~~~~~~~~~~~ +- This release introduces a rolling policy for minimum dependency versions: + :ref:`mindeps_policy`. + + Several minimum versions have been increased: + + ============ ================== ==== + Package Old New + ============ ================== ==== + Python 3.5.3 3.6 + numpy 1.12 1.14 + pandas 0.19.2 0.24 + dask 0.16 (tested: 2.4) 1.2 + bottleneck 1.1 (tested: 1.2) 1.2 + matplotlib 1.5 (tested: 3.1) 3.1 + ============ ================== ==== + + Obsolete patch versions (x.y.Z) are not tested anymore. + The oldest supported versions of all optional dependencies are now covered by + automated tests (before, only the very latest versions were tested). + + (:issue:`3222`, :issue:`3293`, :issue:`3340`, :issue:`3346`, :issue:`3358`). + By `Guido Imperiale `_. + +- Dropped the ``drop=False`` optional parameter from :py:meth:`Variable.isel`. + It was unused and doesn't make sense for a Variable. (:pull:`3375`). + By `Guido Imperiale `_. + +- Remove internal usage of :py:class:`collections.OrderedDict`. After dropping support for + Python <=3.5, most uses of ``OrderedDict`` in xarray were no longer necessary. We + have removed the internal use of the ``OrderedDict`` in favor of Python's builtin + ``dict`` object which is now ordered itself. This change will be most obvious when + interacting with the ``attrs`` property on Dataset and DataArray objects. + (:issue:`3380`, :pull:`3389`). By `Joe Hamman `_. + +New functions/methods +~~~~~~~~~~~~~~~~~~~~~ + +- Added :py:func:`~xarray.map_blocks`, modeled after :py:func:`dask.array.map_blocks`. + Also added :py:meth:`Dataset.unify_chunks`, :py:meth:`DataArray.unify_chunks` and + :py:meth:`testing.assert_chunks_equal`. (:pull:`3276`). + By `Deepak Cherian `_ and + `Guido Imperiale `_. + +Enhancements +~~~~~~~~~~~~ + +- ``core.groupby.GroupBy`` enhancements. By `Deepak Cherian `_. + + - Added a repr (:pull:`3344`). Example:: + + >>> da.groupby("time.season") + DataArrayGroupBy, grouped over 'season' + 4 groups with labels 'DJF', 'JJA', 'MAM', 'SON' + + - Added a ``GroupBy.dims`` property that mirrors the dimensions + of each group (:issue:`3344`). + +- Speed up :py:meth:`Dataset.isel` up to 33% and :py:meth:`DataArray.isel` up to 25% for small + arrays (:issue:`2799`, :pull:`3375`). By + `Guido Imperiale `_. + +Bug fixes +~~~~~~~~~ +- Reintroduce support for :mod:`weakref` (broken in v0.13.0). Support has been + reinstated for :py:class:`~xarray.DataArray` and :py:class:`~xarray.Dataset` objects only. + Internal xarray objects remain unaddressable by weakref in order to save memory + (:issue:`3317`). By `Guido Imperiale `_. +- Line plots with the ``x`` or ``y`` argument set to a 1D non-dimensional coord + now plot the correct data for 2D DataArrays + (:issue:`3334`). By `Tom Nicholas `_. +- Make :py:func:`~xarray.concat` more robust when merging variables present in some datasets but + not others (:issue:`508`). By `Deepak Cherian `_. +- The default behaviour of reducing across all dimensions for + :py:class:`~xarray.core.groupby.DataArrayGroupBy` objects has now been properly removed + as was done for :py:class:`~xarray.core.groupby.DatasetGroupBy` in 0.13.0 (:issue:`3337`). + Use ``xarray.ALL_DIMS`` if you need to replicate previous behaviour. + Also raise nicer error message when no groups are created (:issue:`1764`). + By `Deepak Cherian `_. +- Fix error in concatenating unlabeled dimensions (:pull:`3362`). + By `Deepak Cherian `_. +- Warn if the ``dim`` kwarg is passed to rolling operations. This is redundant since a dimension is + specified when the :py:class:`~core.rolling.DatasetRolling` or :py:class:`~core.rolling.DataArrayRolling` object is created. + (:pull:`3362`). By `Deepak Cherian `_. + +Documentation +~~~~~~~~~~~~~ + +- Created a glossary of important xarray terms (:issue:`2410`, :pull:`3352`). + By `Gregory Gundersen `_. +- Created a "How do I..." section (:ref:`howdoi`) for solutions to common questions. (:pull:`3357`). + By `Deepak Cherian `_. +- Add examples for :py:meth:`Dataset.swap_dims` and :py:meth:`DataArray.swap_dims` + (:pull:`3331`, :pull:`3331`). By `Justus Magin `_. +- Add examples for :py:meth:`align`, :py:meth:`merge`, :py:meth:`combine_by_coords`, + :py:meth:`full_like`, :py:meth:`zeros_like`, :py:meth:`ones_like`, :py:meth:`Dataset.pipe`, + :py:meth:`Dataset.assign`, :py:meth:`Dataset.reindex`, :py:meth:`Dataset.fillna` (:pull:`3328`). + By `Anderson Banihirwe `_. +- Fixed documentation to clean up an unwanted file created in ``ipython`` example + (:pull:`3353`). By `Gregory Gundersen `_. + +.. _whats-new.0.13.0: + +v0.13.0 (17 Sep 2019) +--------------------- + +This release includes many exciting changes: wrapping of +`NEP18 `_ compliant +numpy-like arrays; new :py:meth:`~Dataset.plot.scatter` plotting method that can scatter +two ``DataArrays`` in a ``Dataset`` against each other; support for converting pandas +DataFrames to xarray objects that wrap ``pydata/sparse``; and more! + +Breaking changes +~~~~~~~~~~~~~~~~ + +- This release increases the minimum required Python version from 3.5.0 to 3.5.3 + (:issue:`3089`). By `Guido Imperiale `_. +- The ``isel_points`` and ``sel_points`` methods are removed, having been deprecated + since v0.10.0. These are redundant with the ``isel`` / ``sel`` methods. + See :ref:`vectorized_indexing` for the details + By `Maximilian Roos `_ +- The ``inplace`` kwarg for public methods now raises an error, having been deprecated + since v0.11.0. + By `Maximilian Roos `_ +- :py:func:`~xarray.concat` now requires the ``dim`` argument. Its ``indexers``, ``mode`` + and ``concat_over`` kwargs have now been removed. + By `Deepak Cherian `_ +- Passing a list of colors in ``cmap`` will now raise an error, having been deprecated since + v0.6.1. +- Most xarray objects now define ``__slots__``. This reduces overall RAM usage by ~22% + (not counting the underlying numpy buffers); on CPython 3.7/x64, a trivial DataArray + has gone down from 1.9kB to 1.5kB. + + Caveats: + + - Pickle streams produced by older versions of xarray can't be loaded using this + release, and vice versa. + - Any user code that was accessing the ``__dict__`` attribute of + xarray objects will break. The best practice to attach custom metadata to xarray + objects is to use the ``attrs`` dictionary. + - Any user code that defines custom subclasses of xarray classes must now explicitly + define ``__slots__`` itself. Subclasses that don't add any attributes must state so + by defining ``__slots__ = ()`` right after the class header. + Omitting ``__slots__`` will now cause a ``FutureWarning`` to be logged, and will raise an + error in a later release. + + (:issue:`3250`) by `Guido Imperiale `_. +- The default dimension for :py:meth:`Dataset.groupby`, :py:meth:`Dataset.resample`, + :py:meth:`DataArray.groupby` and :py:meth:`DataArray.resample` reductions is now the + grouping or resampling dimension. +- :py:meth:`DataArray.to_dataset` requires ``name`` to be passed as a kwarg (previously ambiguous + positional arguments were deprecated) +- Reindexing with variables of a different dimension now raise an error (previously deprecated) +- ``xarray.broadcast_array`` is removed (previously deprecated in favor of + :py:func:`~xarray.broadcast`) +- ``Variable.expand_dims`` is removed (previously deprecated in favor of + :py:meth:`Variable.set_dims`) + +New functions/methods +~~~~~~~~~~~~~~~~~~~~~ + +- xarray can now wrap around any + `NEP18 `_ compliant + numpy-like library (important: read notes about ``NUMPY_EXPERIMENTAL_ARRAY_FUNCTION`` in + the above link). Added explicit test coverage for + `sparse `_. (:issue:`3117`, :issue:`3202`). + This requires `sparse>=0.8.0`. By `Nezar Abdennur `_ + and `Guido Imperiale `_. + +- :py:meth:`~Dataset.from_dataframe` and :py:meth:`~DataArray.from_series` now + support ``sparse=True`` for converting pandas objects into xarray objects + wrapping sparse arrays. This is particularly useful with sparsely populated + hierarchical indexes. (:issue:`3206`) + By `Stephan Hoyer `_. + +- The xarray package is now discoverable by mypy (although typing hints coverage is not + complete yet). mypy type checking is now enforced by CI. Libraries that depend on + xarray and use mypy can now remove from their setup.cfg the lines:: + + [mypy-xarray] + ignore_missing_imports = True + + (:issue:`2877`, :issue:`3088`, :issue:`3090`, :issue:`3112`, :issue:`3117`, + :issue:`3207`) + By `Guido Imperiale `_ + and `Maximilian Roos `_. + +- Added :py:meth:`DataArray.broadcast_like` and :py:meth:`Dataset.broadcast_like`. + By `Deepak Cherian `_ and `David Mertz + `_. + +- Dataset plotting API for visualizing dependencies between two DataArrays! + Currently only :py:meth:`Dataset.plot.scatter` is implemented. + By `Yohai Bar Sinai `_ and `Deepak Cherian `_ + +- Added :py:meth:`DataArray.head`, :py:meth:`DataArray.tail` and :py:meth:`DataArray.thin`; + as well as :py:meth:`Dataset.head`, :py:meth:`Dataset.tail` and :py:meth:`Dataset.thin` methods. + (:issue:`319`) By `Gerardo Rivera `_. + +Enhancements +~~~~~~~~~~~~ + +- Multiple enhancements to :py:func:`~xarray.concat` and :py:func:`~xarray.open_mfdataset`. + By `Deepak Cherian `_ + + - Added ``compat='override'``. When merging, this option picks the variable from the first dataset + and skips all comparisons. + + - Added ``join='override'``. When aligning, this only checks that index sizes are equal among objects + and skips checking indexes for equality. + + - :py:func:`~xarray.concat` and :py:func:`~xarray.open_mfdataset` now support the ``join`` kwarg. + It is passed down to :py:func:`~xarray.align`. + + - :py:func:`~xarray.concat` now calls :py:func:`~xarray.merge` on variables that are not concatenated + (i.e. variables without ``concat_dim`` when ``data_vars`` or ``coords`` are ``"minimal"``). + :py:func:`~xarray.concat` passes its new ``compat`` kwarg down to :py:func:`~xarray.merge`. + (:issue:`2064`) + + Users can avoid a common bottleneck when using :py:func:`~xarray.open_mfdataset` on a large number of + files with variables that are known to be aligned and some of which need not be concatenated. + Slow equality comparisons can now be avoided, for e.g.:: + + data = xr.open_mfdataset(files, concat_dim='time', data_vars='minimal', + coords='minimal', compat='override', join='override') + +- In :py:meth:`~xarray.Dataset.to_zarr`, passing ``mode`` is not mandatory if + ``append_dim`` is set, as it will automatically be set to ``'a'`` internally. + By `David Brochart `_. + +- Added the ability to initialize an empty or full DataArray + with a single value. (:issue:`277`) + By `Gerardo Rivera `_. + +- :py:func:`~xarray.Dataset.to_netcdf()` now supports the ``invalid_netcdf`` kwarg when used + with ``engine="h5netcdf"``. It is passed to ``h5netcdf.File``. + By `Ulrich Herter `_. + +- ``xarray.Dataset.drop`` now supports keyword arguments; dropping index + labels by using both ``dim`` and ``labels`` or using a + :py:class:`~core.coordinates.DataArrayCoordinates` object are deprecated (:issue:`2910`). + By `Gregory Gundersen `_. + +- Added examples of :py:meth:`Dataset.set_index` and + :py:meth:`DataArray.set_index`, as well are more specific error messages + when the user passes invalid arguments (:issue:`3176`). + By `Gregory Gundersen `_. + +- :py:meth:`Dataset.filter_by_attrs` now filters the coordinates as well as the variables. + By `Spencer Jones `_. + +Bug fixes +~~~~~~~~~ + +- Improve "missing dimensions" error message for :py:func:`~xarray.apply_ufunc` + (:issue:`2078`). + By `Rick Russotto `_. +- :py:meth:`~xarray.DataArray.assign_coords` now supports dictionary arguments + (:issue:`3231`). + By `Gregory Gundersen `_. +- Fix regression introduced in v0.12.2 where ``copy(deep=True)`` would convert + unicode indices to dtype=object (:issue:`3094`). + By `Guido Imperiale `_. +- Improved error handling and documentation for `.expand_dims()` + read-only view. +- Fix tests for big-endian systems (:issue:`3125`). + By `Graham Inggs `_. +- XFAIL several tests which are expected to fail on ARM systems + due to a ``datetime`` issue in NumPy (:issue:`2334`). + By `Graham Inggs `_. +- Fix KeyError that arises when using .sel method with float values + different from coords float type (:issue:`3137`). + By `Hasan Ahmad `_. +- Fixed bug in ``combine_by_coords()`` causing a `ValueError` if the input had + an unused dimension with coordinates which were not monotonic (:issue:`3150`). + By `Tom Nicholas `_. +- Fixed crash when applying ``distributed.Client.compute()`` to a DataArray + (:issue:`3171`). By `Guido Imperiale `_. +- Better error message when using groupby on an empty DataArray (:issue:`3037`). + By `Hasan Ahmad `_. +- Fix error that arises when using open_mfdataset on a series of netcdf files + having differing values for a variable attribute of type list. (:issue:`3034`) + By `Hasan Ahmad `_. +- Prevent :py:meth:`~xarray.DataArray.argmax` and :py:meth:`~xarray.DataArray.argmin` from calling + dask compute (:issue:`3237`). By `Ulrich Herter `_. +- Plots in 2 dimensions (pcolormesh, contour) now allow to specify levels as numpy + array (:issue:`3284`). By `Mathias Hauser `_. +- Fixed bug in :meth:`DataArray.quantile` failing to keep attributes when + `keep_attrs` was True (:issue:`3304`). By `David Huard `_. + +Documentation +~~~~~~~~~~~~~ + +- Created a `PR checklist `_ + as a quick reference for tasks before creating a new PR + or pushing new commits. + By `Gregory Gundersen `_. + +- Fixed documentation to clean up unwanted files created in ``ipython`` examples + (:issue:`3227`). + By `Gregory Gundersen `_. + +.. _whats-new.0.12.3: + +v0.12.3 (10 July 2019) +---------------------- + +New functions/methods +~~~~~~~~~~~~~~~~~~~~~ + +- New methods :py:meth:`Dataset.to_stacked_array` and + :py:meth:`DataArray.to_unstacked_dataset` for reshaping Datasets of variables + with different dimensions + (:issue:`1317`). + This is useful for feeding data from xarray into machine learning models, + as described in :ref:`reshape.stacking_different`. + By `Noah Brenowitz `_. + +Enhancements +~~~~~~~~~~~~ + +- Support for renaming ``Dataset`` variables and dimensions independently + with :py:meth:`~Dataset.rename_vars` and :py:meth:`~Dataset.rename_dims` + (:issue:`3026`). + By `Julia Kent `_. + +- Add ``scales``, ``offsets``, ``units`` and ``descriptions`` + attributes to :py:class:`~xarray.DataArray` returned by + :py:func:`~xarray.open_rasterio`. (:issue:`3013`) + By `Erle Carrara `_. + +Bug fixes +~~~~~~~~~ + +- Resolved deprecation warnings from newer versions of matplotlib and dask. +- Compatibility fixes for the upcoming pandas 0.25 and NumPy 1.17 releases. + By `Stephan Hoyer `_. +- Fix summaries for multiindex coordinates (:issue:`3079`). + By `Jonas Hörsch `_. +- Fix HDF5 error that could arise when reading multiple groups from a file at + once (:issue:`2954`). + By `Stephan Hoyer `_. + +.. _whats-new.0.12.2: + +v0.12.2 (29 June 2019) +---------------------- + +New functions/methods +~~~~~~~~~~~~~~~~~~~~~ + +- Two new functions, :py:func:`~xarray.combine_nested` and + :py:func:`~xarray.combine_by_coords`, allow for combining datasets along any + number of dimensions, instead of the one-dimensional list of datasets + supported by :py:func:`~xarray.concat`. + + The new ``combine_nested`` will accept the datasets as a nested + list-of-lists, and combine by applying a series of concat and merge + operations. The new ``combine_by_coords`` instead uses the dimension + coordinates of datasets to order them. + + :py:func:`~xarray.open_mfdataset` can use either ``combine_nested`` or + ``combine_by_coords`` to combine datasets along multiple dimensions, by + specifying the argument ``combine='nested'`` or ``combine='by_coords'``. + + The older function ``auto_combine`` has been deprecated, + because its functionality has been subsumed by the new functions. + To avoid FutureWarnings switch to using ``combine_nested`` or + ``combine_by_coords``, (or set the ``combine`` argument in + ``open_mfdataset``). (:issue:`2159`) + By `Tom Nicholas `_. + +- :py:meth:`~xarray.DataArray.rolling_exp` and + :py:meth:`~xarray.Dataset.rolling_exp` added, similar to pandas' + ``pd.DataFrame.ewm`` method. Calling ``.mean`` on the resulting object + will return an exponentially weighted moving average. + By `Maximilian Roos `_. + +- New :py:func:`DataArray.str ` for string + related manipulations, based on ``pandas.Series.str``. + By `0x0L `_. + +- Added ``strftime`` method to ``.dt`` accessor, making it simpler to hand a + datetime ``DataArray`` to other code expecting formatted dates and times. + (:issue:`2090`). :py:meth:`~xarray.CFTimeIndex.strftime` is also now + available on :py:class:`CFTimeIndex`. + By `Alan Brammer `_ and + `Ryan May `_. + +- ``GroupBy.quantile`` is now a method of ``GroupBy`` + objects (:issue:`3018`). + By `David Huard `_. + +- Argument and return types are added to most methods on ``DataArray`` and + ``Dataset``, allowing static type checking both within xarray and external + libraries. Type checking with `mypy `_ is enabled in + CI (though not required yet). + By `Guido Imperiale `_ + and `Maximilian Roos `_. + +Enhancements to existing functionality +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- Add ``keepdims`` argument for reduce operations (:issue:`2170`) + By `Scott Wales `_. +- Enable ``@`` operator for DataArray. This is equivalent to :py:meth:`DataArray.dot` + By `Maximilian Roos `_. +- Add ``fill_value`` argument for reindex, align, and merge operations + to enable custom fill values. (:issue:`2876`) + By `Zach Griffith `_. +- :py:meth:`DataArray.transpose` now accepts a keyword argument + ``transpose_coords`` which enables transposition of coordinates in the + same way as :py:meth:`Dataset.transpose`. :py:meth:`DataArray.groupby` + :py:meth:`DataArray.groupby_bins`, and :py:meth:`DataArray.resample` now + accept a keyword argument ``restore_coord_dims`` which keeps the order + of the dimensions of multi-dimensional coordinates intact (:issue:`1856`). + By `Peter Hausamann `_. +- Clean up Python 2 compatibility in code (:issue:`2950`) + By `Guido Imperiale `_. +- Better warning message when supplying invalid objects to ``xr.merge`` + (:issue:`2948`). By `Mathias Hauser `_. +- Add ``errors`` keyword argument to ``Dataset.drop`` and :py:meth:`Dataset.drop_dims` + that allows ignoring errors if a passed label or dimension is not in the dataset + (:issue:`2994`). + By `Andrew Ross `_. + +IO related enhancements +~~~~~~~~~~~~~~~~~~~~~~~ + +- Implement :py:func:`~xarray.load_dataset` and + :py:func:`~xarray.load_dataarray` as alternatives to + :py:func:`~xarray.open_dataset` and :py:func:`~xarray.open_dataarray` to + open, load into memory, and close files, returning the Dataset or DataArray. + These functions are helpful for avoiding file-lock errors when trying to + write to files opened using ``open_dataset()`` or ``open_dataarray()``. + (:issue:`2887`) + By `Dan Nowacki `_. +- It is now possible to extend existing :ref:`io.zarr` datasets, by using + ``mode='a'`` and the new ``append_dim`` argument in + :py:meth:`~xarray.Dataset.to_zarr`. + By `Jendrik Jördening `_, + `David Brochart `_, + `Ryan Abernathey `_ and + `Shikhar Goenka `_. +- ``xr.open_zarr`` now accepts manually specified chunks with the ``chunks=`` + parameter. ``auto_chunk=True`` is equivalent to ``chunks='auto'`` for + backwards compatibility. The ``overwrite_encoded_chunks`` parameter is + added to remove the original zarr chunk encoding. + By `Lily Wang `_. +- netCDF chunksizes are now only dropped when original_shape is different, + not when it isn't found. (:issue:`2207`) + By `Karel van de Plassche `_. +- Character arrays' character dimension name decoding and encoding handled by + ``var.encoding['char_dim_name']`` (:issue:`2895`) + By `James McCreight `_. +- open_rasterio() now supports rasterio.vrt.WarpedVRT with custom transform, + width and height (:issue:`2864`). + By `Julien Michel `_. + +Bug fixes +~~~~~~~~~ + +- Rolling operations on xarray objects containing dask arrays could silently + compute the incorrect result or use large amounts of memory (:issue:`2940`). + By `Stephan Hoyer `_. +- Don't set encoding attributes on bounds variables when writing to netCDF. + (:issue:`2921`) + By `Deepak Cherian `_. +- NetCDF4 output: variables with unlimited dimensions must be chunked (not + contiguous) on output. (:issue:`1849`) + By `James McCreight `_. +- indexing with an empty list creates an object with zero-length axis (:issue:`2882`) + By `Mayeul d'Avezac `_. +- Return correct count for scalar datetime64 arrays (:issue:`2770`) + By `Dan Nowacki `_. +- Fixed max, min exception when applied to a multiIndex (:issue:`2923`) + By `Ian Castleden `_ +- A deep copy deep-copies the coords (:issue:`1463`) + By `Martin Pletcher `_. +- Increased support for `missing_value` (:issue:`2871`) + By `Deepak Cherian `_. +- Removed usages of `pytest.config`, which is deprecated (:issue:`2988`) + By `Maximilian Roos `_. +- Fixed performance issues with cftime installed (:issue:`3000`) + By `0x0L `_. +- Replace incorrect usages of `message` in pytest assertions + with `match` (:issue:`3011`) + By `Maximilian Roos `_. +- Add explicit pytest markers, now required by pytest + (:issue:`3032`). + By `Maximilian Roos `_. +- Test suite fixes for newer versions of pytest (:issue:`3011`, :issue:`3032`). + By `Maximilian Roos `_ + and `Stephan Hoyer `_. + +.. _whats-new.0.12.1: + +v0.12.1 (4 April 2019) +---------------------- + +Enhancements +~~~~~~~~~~~~ + +- Allow ``expand_dims`` method to support inserting/broadcasting dimensions + with size > 1. (:issue:`2710`) + By `Martin Pletcher `_. + +Bug fixes +~~~~~~~~~ + +- Dataset.copy(deep=True) now creates a deep copy of the attrs (:issue:`2835`). + By `Andras Gefferth `_. +- Fix incorrect ``indexes`` resulting from various ``Dataset`` operations + (e.g., ``swap_dims``, ``isel``, ``reindex``, ``[]``) (:issue:`2842`, + :issue:`2856`). + By `Stephan Hoyer `_. + +.. _whats-new.0.12.0: + +v0.12.0 (15 March 2019) +----------------------- + +Highlights include: + +- Removed support for Python 2. This is the first version of xarray that is + Python 3 only! +- New :py:meth:`~xarray.DataArray.coarsen` and + :py:meth:`~xarray.DataArray.integrate` methods. See :ref:`compute.coarsen` + and :ref:`compute.using_coordinates` for details. +- Many improvements to cftime support. See below for details. + +Deprecations +~~~~~~~~~~~~ + +- The ``compat`` argument to ``Dataset`` and the ``encoding`` argument to + ``DataArray`` are deprecated and will be removed in a future release. + (:issue:`1188`) + By `Maximilian Roos `_. + +cftime related enhancements +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- Resampling of standard and non-standard calendars indexed by + :py:class:`~xarray.CFTimeIndex` is now possible. (:issue:`2191`). + By `Jwen Fai Low `_ and + `Spencer Clark `_. + +- Taking the mean of arrays of :py:class:`cftime.datetime` objects, and + by extension, use of :py:meth:`~xarray.DataArray.coarsen` with + :py:class:`cftime.datetime` coordinates is now possible. By `Spencer Clark + `_. + +- Internal plotting now supports ``cftime.datetime`` objects as time series. + (:issue:`2164`) + By `Julius Busecke `_ and + `Spencer Clark `_. + +- :py:meth:`~xarray.cftime_range` now supports QuarterBegin and QuarterEnd offsets (:issue:`2663`). + By `Jwen Fai Low `_ + +- :py:meth:`~xarray.open_dataset` now accepts a ``use_cftime`` argument, which + can be used to require that ``cftime.datetime`` objects are always used, or + never used when decoding dates encoded with a standard calendar. This can be + used to ensure consistent date types are returned when using + :py:meth:`~xarray.open_mfdataset` (:issue:`1263`) and/or to silence + serialization warnings raised if dates from a standard calendar are found to + be outside the :py:class:`pandas.Timestamp`-valid range (:issue:`2754`). By + `Spencer Clark `_. + +- :py:meth:`pandas.Series.dropna` is now supported for a + :py:class:`pandas.Series` indexed by a :py:class:`~xarray.CFTimeIndex` + (:issue:`2688`). By `Spencer Clark `_. + +Other enhancements +~~~~~~~~~~~~~~~~~~ + +- Added ability to open netcdf4/hdf5 file-like objects with ``open_dataset``. + Requires (h5netcdf>0.7 and h5py>2.9.0). (:issue:`2781`) + By `Scott Henderson `_ +- Add ``data=False`` option to ``to_dict()`` methods. (:issue:`2656`) + By `Ryan Abernathey `_ +- :py:meth:`DataArray.coarsen` and + :py:meth:`Dataset.coarsen` are newly added. + See :ref:`compute.coarsen` for details. + (:issue:`2525`) + By `Keisuke Fujii `_. +- Upsampling an array via interpolation with resample is now dask-compatible, + as long as the array is not chunked along the resampling dimension. + By `Spencer Clark `_. +- :py:func:`xarray.testing.assert_equal` and + :py:func:`xarray.testing.assert_identical` now provide a more detailed + report showing what exactly differs between the two objects (dimensions / + coordinates / variables / attributes) (:issue:`1507`). + By `Benoit Bovy `_. +- Add ``tolerance`` option to ``resample()`` methods ``bfill``, ``pad``, + ``nearest``. (:issue:`2695`) + By `Hauke Schulz `_. +- :py:meth:`DataArray.integrate` and + :py:meth:`Dataset.integrate` are newly added. + See :ref:`compute.using_coordinates` for the detail. + (:issue:`1332`) + By `Keisuke Fujii `_. +- Added :py:meth:`~xarray.Dataset.drop_dims` (:issue:`1949`). + By `Kevin Squire `_. + +Bug fixes +~~~~~~~~~ + +- Silenced warnings that appear when using pandas 0.24. + By `Stephan Hoyer `_ +- Interpolating via resample now internally specifies ``bounds_error=False`` + as an argument to ``scipy.interpolate.interp1d``, allowing for interpolation + from higher frequencies to lower frequencies. Datapoints outside the bounds + of the original time coordinate are now filled with NaN (:issue:`2197`). By + `Spencer Clark `_. +- Line plots with the ``x`` argument set to a non-dimensional coord now plot + the correct data for 1D DataArrays. + (:issue:`2725`). By `Tom Nicholas `_. +- Subtracting a scalar ``cftime.datetime`` object from a + :py:class:`CFTimeIndex` now results in a :py:class:`pandas.TimedeltaIndex` + instead of raising a ``TypeError`` (:issue:`2671`). By `Spencer Clark + `_. +- backend_kwargs are no longer ignored when using open_dataset with pynio engine + (:issue:'2380') + By `Jonathan Joyce `_. +- Fix ``open_rasterio`` creating a WKT CRS instead of PROJ.4 with + ``rasterio`` 1.0.14+ (:issue:`2715`). + By `David Hoese `_. +- Masking data arrays with :py:meth:`xarray.DataArray.where` now returns an + array with the name of the original masked array (:issue:`2748` and :issue:`2457`). + By `Yohai Bar-Sinai `_. +- Fixed error when trying to reduce a DataArray using a function which does not + require an axis argument. (:issue:`2768`) + By `Tom Nicholas `_. +- Concatenating a sequence of :py:class:`~xarray.DataArray` with varying names + sets the name of the output array to ``None``, instead of the name of the + first input array. If the names are the same it sets the name to that, + instead to the name of the first DataArray in the list as it did before. + (:issue:`2775`). By `Tom Nicholas `_. + +- Per the `CF conventions section on calendars + `_, + specifying ``'standard'`` as the calendar type in + :py:meth:`~xarray.cftime_range` now correctly refers to the ``'gregorian'`` + calendar instead of the ``'proleptic_gregorian'`` calendar (:issue:`2761`). + +.. _whats-new.0.11.3: + +v0.11.3 (26 January 2019) +------------------------- + +Bug fixes +~~~~~~~~~ + +- Saving files with times encoded with reference dates with timezones + (e.g. '2000-01-01T00:00:00-05:00') no longer raises an error + (:issue:`2649`). By `Spencer Clark `_. +- Fixed performance regression with ``open_mfdataset`` (:issue:`2662`). + By `Tom Nicholas `_. +- Fixed supplying an explicit dimension in the ``concat_dim`` argument to + to ``open_mfdataset`` (:issue:`2647`). + By `Ben Root `_. + +.. _whats-new.0.11.2: + +v0.11.2 (2 January 2019) +------------------------ + +Removes inadvertently introduced setup dependency on pytest-runner +(:issue:`2641`). Otherwise, this release is exactly equivalent to 0.11.1. + +.. warning:: + + This is the last xarray release that will support Python 2.7. Future releases + will be Python 3 only, but older versions of xarray will always be available + for Python 2.7 users. For the more details, see: + + - :issue:`Xarray Github issue discussing dropping Python 2 <1829>` + - `Python 3 Statement `__ + - `Tips on porting to Python 3 `__ + +.. _whats-new.0.11.1: + +v0.11.1 (29 December 2018) +-------------------------- + +This minor release includes a number of enhancements and bug fixes, and two +(slightly) breaking changes. + +Breaking changes +~~~~~~~~~~~~~~~~ + +- Minimum rasterio version increased from 0.36 to 1.0 (for ``open_rasterio``) +- Time bounds variables are now also decoded according to CF conventions + (:issue:`2565`). The previous behavior was to decode them only if they + had specific time attributes, now these attributes are copied + automatically from the corresponding time coordinate. This might + break downstream code that was relying on these variables to be + brake downstream code that was relying on these variables to be + not decoded. + By `Fabien Maussion `_. + +Enhancements +~~~~~~~~~~~~ + +- Ability to read and write consolidated metadata in zarr stores (:issue:`2558`). + By `Ryan Abernathey `_. +- :py:class:`CFTimeIndex` uses slicing for string indexing when possible (like + :py:class:`pandas.DatetimeIndex`), which avoids unnecessary copies. + By `Stephan Hoyer `_ +- Enable passing ``rasterio.io.DatasetReader`` or ``rasterio.vrt.WarpedVRT`` to + ``open_rasterio`` instead of file path string. Allows for in-memory + reprojection, see (:issue:`2588`). + By `Scott Henderson `_. +- Like :py:class:`pandas.DatetimeIndex`, :py:class:`CFTimeIndex` now supports + "dayofyear" and "dayofweek" accessors (:issue:`2597`). Note this requires a + version of cftime greater than 1.0.2. By `Spencer Clark + `_. +- The option ``'warn_for_unclosed_files'`` (False by default) has been added to + allow users to enable a warning when files opened by xarray are deallocated + but were not explicitly closed. This is mostly useful for debugging; we + recommend enabling it in your test suites if you use xarray for IO. + By `Stephan Hoyer `_ +- Support Dask ``HighLevelGraphs`` by `Matthew Rocklin `_. +- :py:meth:`DataArray.resample` and :py:meth:`Dataset.resample` now supports the + ``loffset`` kwarg just like pandas. + By `Deepak Cherian `_ +- Datasets are now guaranteed to have a ``'source'`` encoding, so the source + file name is always stored (:issue:`2550`). + By `Tom Nicholas `_. +- The ``apply`` methods for ``DatasetGroupBy``, ``DataArrayGroupBy``, + ``DatasetResample`` and ``DataArrayResample`` now support passing positional + arguments to the applied function as a tuple to the ``args`` argument. + By `Matti Eskelinen `_. +- 0d slices of ndarrays are now obtained directly through indexing, rather than + extracting and wrapping a scalar, avoiding unnecessary copying. By `Daniel + Wennberg `_. +- Added support for ``fill_value`` with + :py:meth:`~xarray.DataArray.shift` and :py:meth:`~xarray.Dataset.shift` + By `Maximilian Roos `_ + +Bug fixes +~~~~~~~~~ + +- Ensure files are automatically closed, if possible, when no longer referenced + by a Python variable (:issue:`2560`). + By `Stephan Hoyer `_ +- Fixed possible race conditions when reading/writing to disk in parallel + (:issue:`2595`). + By `Stephan Hoyer `_ +- Fix h5netcdf saving scalars with filters or chunks (:issue:`2563`). + By `Martin Raspaud `_. +- Fix parsing of ``_Unsigned`` attribute set by OPENDAP servers. (:issue:`2583`). + By `Deepak Cherian `_ +- Fix failure in time encoding when exporting to netCDF with versions of pandas + less than 0.21.1 (:issue:`2623`). By `Spencer Clark + `_. +- Fix MultiIndex selection to update label and level (:issue:`2619`). + By `Keisuke Fujii `_. + +.. _whats-new.0.11.0: + +v0.11.0 (7 November 2018) +------------------------- + +Breaking changes +~~~~~~~~~~~~~~~~ + +- Finished deprecations (changed behavior with this release): + + - ``Dataset.T`` has been removed as a shortcut for :py:meth:`Dataset.transpose`. + Call :py:meth:`Dataset.transpose` directly instead. + - Iterating over a ``Dataset`` now includes only data variables, not coordinates. + Similarly, calling ``len`` and ``bool`` on a ``Dataset`` now + includes only data variables. + - ``DataArray.__contains__`` (used by Python's ``in`` operator) now checks + array data, not coordinates. + - The old resample syntax from before xarray 0.10, e.g., + ``data.resample('1D', dim='time', how='mean')``, is no longer supported will + raise an error in most cases. You need to use the new resample syntax + instead, e.g., ``data.resample(time='1D').mean()`` or + ``data.resample({'time': '1D'}).mean()``. + + +- New deprecations (behavior will be changed in xarray 0.12): + + - Reduction of :py:meth:`DataArray.groupby` and :py:meth:`DataArray.resample` + without dimension argument will change in the next release. + Now we warn a FutureWarning. + By `Keisuke Fujii `_. + - The ``inplace`` kwarg of a number of `DataArray` and `Dataset` methods is being + deprecated and will be removed in the next release. + By `Deepak Cherian `_. + + +- Refactored storage backends: + + - Xarray's storage backends now automatically open and close files when + necessary, rather than requiring opening a file with ``autoclose=True``. A + global least-recently-used cache is used to store open files; the default + limit of 128 open files should suffice in most cases, but can be adjusted if + necessary with + ``xarray.set_options(file_cache_maxsize=...)``. The ``autoclose`` argument + to ``open_dataset`` and related functions has been deprecated and is now a + no-op. + + This change, along with an internal refactor of xarray's storage backends, + should significantly improve performance when reading and writing + netCDF files with Dask, especially when working with many files or using + Dask Distributed. By `Stephan Hoyer `_ + + +- Support for non-standard calendars used in climate science: + + - Xarray will now always use :py:class:`cftime.datetime` objects, rather + than by default trying to coerce them into ``np.datetime64[ns]`` objects. + A :py:class:`~xarray.CFTimeIndex` will be used for indexing along time + coordinates in these cases. + - A new method :py:meth:`~xarray.CFTimeIndex.to_datetimeindex` has been added + to aid in converting from a :py:class:`~xarray.CFTimeIndex` to a + :py:class:`pandas.DatetimeIndex` for the remaining use-cases where + using a :py:class:`~xarray.CFTimeIndex` is still a limitation (e.g. for + resample or plotting). + - Setting the ``enable_cftimeindex`` option is now a no-op and emits a + ``FutureWarning``. + +Enhancements +~~~~~~~~~~~~ + +- :py:meth:`xarray.DataArray.plot.line` can now accept multidimensional + coordinate variables as input. `hue` must be a dimension name in this case. + (:issue:`2407`) + By `Deepak Cherian `_. +- Added support for Python 3.7. (:issue:`2271`). + By `Joe Hamman `_. +- Added support for plotting data with `pandas.Interval` coordinates, such as those + created by :py:meth:`~xarray.DataArray.groupby_bins` + By `Maximilian Maahn `_. +- Added :py:meth:`~xarray.CFTimeIndex.shift` for shifting the values of a + CFTimeIndex by a specified frequency. (:issue:`2244`). + By `Spencer Clark `_. +- Added support for using ``cftime.datetime`` coordinates with + :py:meth:`~xarray.DataArray.differentiate`, + :py:meth:`~xarray.Dataset.differentiate`, + :py:meth:`~xarray.DataArray.interp`, and + :py:meth:`~xarray.Dataset.interp`. + By `Spencer Clark `_ +- There is now a global option to either always keep or always discard + dataset and dataarray attrs upon operations. The option is set with + ``xarray.set_options(keep_attrs=True)``, and the default is to use the old + behaviour. + By `Tom Nicholas `_. +- Added a new backend for the GRIB file format based on ECMWF *cfgrib* + python driver and *ecCodes* C-library. (:issue:`2475`) + By `Alessandro Amici `_, + sponsored by `ECMWF `_. +- Resample now supports a dictionary mapping from dimension to frequency as + its first argument, e.g., ``data.resample({'time': '1D'}).mean()``. This is + consistent with other xarray functions that accept either dictionaries or + keyword arguments. By `Stephan Hoyer `_. + +- The preferred way to access tutorial data is now to load it lazily with + :py:meth:`xarray.tutorial.open_dataset`. + :py:meth:`xarray.tutorial.load_dataset` calls `Dataset.load()` prior + to returning (and is now deprecated). This was changed in order to facilitate + using tutorial datasets with dask. + By `Joe Hamman `_. +- ``DataArray`` can now use ``xr.set_option(keep_attrs=True)`` and retain attributes in binary operations, + such as (``+, -, * ,/``). Default behaviour is unchanged (*Attributes will be dismissed*). By `Michael Blaschek `_ + +Bug fixes +~~~~~~~~~ + +- ``FacetGrid`` now properly uses the ``cbar_kwargs`` keyword argument. + (:issue:`1504`, :issue:`1717`) + By `Deepak Cherian `_. +- Addition and subtraction operators used with a CFTimeIndex now preserve the + index's type. (:issue:`2244`). + By `Spencer Clark `_. +- We now properly handle arrays of ``datetime.datetime`` and ``datetime.timedelta`` + provided as coordinates. (:issue:`2512`) + By `Deepak Cherian `_. +- ``xarray.DataArray.roll`` correctly handles multidimensional arrays. + (:issue:`2445`) + By `Keisuke Fujii `_. +- ``xarray.plot()`` now properly accepts a ``norm`` argument and does not override + the norm's ``vmin`` and ``vmax``. (:issue:`2381`) + By `Deepak Cherian `_. +- ``xarray.DataArray.std()`` now correctly accepts ``ddof`` keyword argument. + (:issue:`2240`) + By `Keisuke Fujii `_. +- Restore matplotlib's default of plotting dashed negative contours when + a single color is passed to ``DataArray.contour()`` e.g. ``colors='k'``. + By `Deepak Cherian `_. + + +- Fix a bug that caused some indexing operations on arrays opened with + ``open_rasterio`` to error (:issue:`2454`). + By `Stephan Hoyer `_. + +- Subtracting one CFTimeIndex from another now returns a + ``pandas.TimedeltaIndex``, analogous to the behavior for DatetimeIndexes + (:issue:`2484`). By `Spencer Clark `_. +- Adding a TimedeltaIndex to, or subtracting a TimedeltaIndex from a + CFTimeIndex is now allowed (:issue:`2484`). + By `Spencer Clark `_. +- Avoid use of Dask's deprecated ``get=`` parameter in tests + by `Matthew Rocklin `_. +- An ``OverflowError`` is now accurately raised and caught during the + encoding process if a reference date is used that is so distant that + the dates must be encoded using cftime rather than NumPy (:issue:`2272`). + By `Spencer Clark `_. + +- Chunked datasets can now roundtrip to Zarr storage continually + with `to_zarr` and ``open_zarr`` (:issue:`2300`). + By `Lily Wang `_. + +.. _whats-new.0.10.9: + +v0.10.9 (21 September 2018) +--------------------------- + +This minor release contains a number of backwards compatible enhancements. + +Announcements of note: + +- Xarray is now a NumFOCUS fiscally sponsored project! Read + `the announcement `_ + for more details. +- We have a new :doc:`roadmap` that outlines our future development plans. + +- ``Dataset.apply`` now properly documents the way `func` is called. + By `Matti Eskelinen `_. + +Enhancements +~~~~~~~~~~~~ + +- :py:meth:`~xarray.DataArray.differentiate` and + :py:meth:`~xarray.Dataset.differentiate` are newly added. + (:issue:`1332`) + By `Keisuke Fujii `_. + +- Default colormap for sequential and divergent data can now be set via + :py:func:`~xarray.set_options()` + (:issue:`2394`) + By `Julius Busecke `_. + +- min_count option is newly supported in :py:meth:`~xarray.DataArray.sum`, + :py:meth:`~xarray.DataArray.prod` and :py:meth:`~xarray.Dataset.sum`, and + :py:meth:`~xarray.Dataset.prod`. + (:issue:`2230`) + By `Keisuke Fujii `_. + +- :py:func:`~plot.plot()` now accepts the kwargs + ``xscale, yscale, xlim, ylim, xticks, yticks`` just like pandas. Also ``xincrease=False, yincrease=False`` now use matplotlib's axis inverting methods instead of setting limits. + By `Deepak Cherian `_. (:issue:`2224`) + +- DataArray coordinates and Dataset coordinates and data variables are + now displayed as `a b ... y z` rather than `a b c d ...`. + (:issue:`1186`) + By `Seth P `_. +- A new CFTimeIndex-enabled :py:func:`cftime_range` function for use in + generating dates from standard or non-standard calendars. By `Spencer Clark + `_. + +- When interpolating over a ``datetime64`` axis, you can now provide a datetime string instead of a ``datetime64`` object. E.g. ``da.interp(time='1991-02-01')`` + (:issue:`2284`) + By `Deepak Cherian `_. + +- A clear error message is now displayed if a ``set`` or ``dict`` is passed in place of an array + (:issue:`2331`) + By `Maximilian Roos `_. + +- Applying ``unstack`` to a large DataArray or Dataset is now much faster if the MultiIndex has not been modified after stacking the indices. + (:issue:`1560`) + By `Maximilian Maahn `_. + +- You can now control whether or not to offset the coordinates when using + the ``roll`` method and the current behavior, coordinates rolled by default, + raises a deprecation warning unless explicitly setting the keyword argument. + (:issue:`1875`) + By `Andrew Huang `_. + +- You can now call ``unstack`` without arguments to unstack every MultiIndex in a DataArray or Dataset. + By `Julia Signell `_. + +- Added the ability to pass a data kwarg to ``copy`` to create a new object with the + same metadata as the original object but using new values. + By `Julia Signell `_. + +Bug fixes +~~~~~~~~~ + +- ``xarray.plot.imshow()`` correctly uses the ``origin`` argument. + (:issue:`2379`) + By `Deepak Cherian `_. + +- Fixed ``DataArray.to_iris()`` failure while creating ``DimCoord`` by + falling back to creating ``AuxCoord``. Fixed dependency on ``var_name`` + attribute being set. + (:issue:`2201`) + By `Thomas Voigt `_. +- Fixed a bug in ``zarr`` backend which prevented use with datasets with + invalid chunk size encoding after reading from an existing store + (:issue:`2278`). + By `Joe Hamman `_. + +- Tests can be run in parallel with pytest-xdist + By `Tony Tung `_. + +- Follow up the renamings in dask; from dask.ghost to dask.overlap + By `Keisuke Fujii `_. + +- Now raises a ValueError when there is a conflict between dimension names and + level names of MultiIndex. (:issue:`2299`) + By `Keisuke Fujii `_. + +- Follow up the renamings in dask; from dask.ghost to dask.overlap + By `Keisuke Fujii `_. + +- Now :py:func:`~xarray.apply_ufunc` raises a ValueError when the size of + ``input_core_dims`` is inconsistent with the number of arguments. + (:issue:`2341`) + By `Keisuke Fujii `_. + +- Fixed ``Dataset.filter_by_attrs()`` behavior not matching ``netCDF4.Dataset.get_variables_by_attributes()``. + When more than one ``key=value`` is passed into ``Dataset.filter_by_attrs()`` it will now return a Dataset with variables which pass + all the filters. + (:issue:`2315`) + By `Andrew Barna `_. + +.. _whats-new.0.10.8: + +v0.10.8 (18 July 2018) +---------------------- + +Breaking changes +~~~~~~~~~~~~~~~~ + +- Xarray no longer supports python 3.4. Additionally, the minimum supported + versions of the following dependencies has been updated and/or clarified: + + - pandas: 0.18 -> 0.19 + - NumPy: 1.11 -> 1.12 + - Dask: 0.9 -> 0.16 + - Matplotlib: unspecified -> 1.5 + + (:issue:`2204`). By `Joe Hamman `_. + +Enhancements +~~~~~~~~~~~~ + +- :py:meth:`~xarray.DataArray.interp_like` and + :py:meth:`~xarray.Dataset.interp_like` methods are newly added. + (:issue:`2218`) + By `Keisuke Fujii `_. + +- Added support for curvilinear and unstructured generic grids + to :py:meth:`~xarray.DataArray.to_cdms2` and + :py:meth:`~xarray.DataArray.from_cdms2` (:issue:`2262`). + By `Stephane Raynaud `_. + +Bug fixes +~~~~~~~~~ + +- Fixed a bug in ``zarr`` backend which prevented use with datasets with + incomplete chunks in multiple dimensions (:issue:`2225`). + By `Joe Hamman `_. + +- Fixed a bug in :py:meth:`~Dataset.to_netcdf` which prevented writing + datasets when the arrays had different chunk sizes (:issue:`2254`). + By `Mike Neish `_. + +- Fixed masking during the conversion to cdms2 objects by + :py:meth:`~xarray.DataArray.to_cdms2` (:issue:`2262`). + By `Stephane Raynaud `_. + +- Fixed a bug in 2D plots which incorrectly raised an error when 2D coordinates + weren't monotonic (:issue:`2250`). + By `Fabien Maussion `_. + +- Fixed warning raised in :py:meth:`~Dataset.to_netcdf` due to deprecation of + `effective_get` in dask (:issue:`2238`). + By `Joe Hamman `_. + +.. _whats-new.0.10.7: + +v0.10.7 (7 June 2018) +--------------------- + +Enhancements +~~~~~~~~~~~~ + +- Plot labels now make use of metadata that follow CF conventions + (:issue:`2135`). + By `Deepak Cherian `_ and `Ryan Abernathey `_. + +- Line plots now support facetting with ``row`` and ``col`` arguments + (:issue:`2107`). + By `Yohai Bar Sinai `_. + +- :py:meth:`~xarray.DataArray.interp` and :py:meth:`~xarray.Dataset.interp` + methods are newly added. + See :ref:`interp` for the detail. + (:issue:`2079`) + By `Keisuke Fujii `_. + +Bug fixes +~~~~~~~~~ + +- Fixed a bug in ``rasterio`` backend which prevented use with ``distributed``. + The ``rasterio`` backend now returns pickleable objects (:issue:`2021`). + By `Joe Hamman `_. + +.. _whats-new.0.10.6: + +v0.10.6 (31 May 2018) +--------------------- + +The minor release includes a number of bug-fixes and backwards compatible +enhancements. + +Enhancements +~~~~~~~~~~~~ + +- New PseudoNetCDF backend for many Atmospheric data formats including + GEOS-Chem, CAMx, NOAA arlpacked bit and many others. See + ``io.PseudoNetCDF`` for more details. + By `Barron Henderson `_. + +- The :py:class:`Dataset` constructor now aligns :py:class:`DataArray` + arguments in ``data_vars`` to indexes set explicitly in ``coords``, + where previously an error would be raised. + (:issue:`674`) + By `Maximilian Roos `_. + +- :py:meth:`~DataArray.sel`, :py:meth:`~DataArray.isel` & :py:meth:`~DataArray.reindex`, + (and their :py:class:`Dataset` counterparts) now support supplying a ``dict`` + as a first argument, as an alternative to the existing approach + of supplying `kwargs`. This allows for more robust behavior + of dimension names which conflict with other keyword names, or are + not strings. + By `Maximilian Roos `_. + +- :py:meth:`~DataArray.rename` now supports supplying ``**kwargs``, as an + alternative to the existing approach of supplying a ``dict`` as the + first argument. + By `Maximilian Roos `_. + +- :py:meth:`~DataArray.cumsum` and :py:meth:`~DataArray.cumprod` now support + aggregation over multiple dimensions at the same time. This is the default + behavior when dimensions are not specified (previously this raised an error). + By `Stephan Hoyer `_ + +- :py:meth:`DataArray.dot` and :py:func:`dot` are partly supported with older + dask<0.17.4. (related to :issue:`2203`) + By `Keisuke Fujii `_. + +- Xarray now uses `Versioneer `__ + to manage its version strings. (:issue:`1300`). + By `Joe Hamman `_. + +Bug fixes +~~~~~~~~~ + +- Fixed a regression in 0.10.4, where explicitly specifying ``dtype='S1'`` or + ``dtype=str`` in ``encoding`` with ``to_netcdf()`` raised an error + (:issue:`2149`). + `Stephan Hoyer `_ + +- :py:func:`apply_ufunc` now directly validates output variables + (:issue:`1931`). + By `Stephan Hoyer `_. + +- Fixed a bug where ``to_netcdf(..., unlimited_dims='bar')`` yielded NetCDF + files with spurious 0-length dimensions (i.e. ``b``, ``a``, and ``r``) + (:issue:`2134`). + By `Joe Hamman `_. + +- Removed spurious warnings with ``Dataset.update(Dataset)`` (:issue:`2161`) + and ``array.equals(array)`` when ``array`` contains ``NaT`` (:issue:`2162`). + By `Stephan Hoyer `_. + +- Aggregations with :py:meth:`Dataset.reduce` (including ``mean``, ``sum``, + etc) no longer drop unrelated coordinates (:issue:`1470`). Also fixed a + bug where non-scalar data-variables that did not include the aggregation + dimension were improperly skipped. + By `Stephan Hoyer `_ + +- Fix :meth:`~DataArray.stack` with non-unique coordinates on pandas 0.23 + (:issue:`2160`). + By `Stephan Hoyer `_ + +- Selecting data indexed by a length-1 ``CFTimeIndex`` with a slice of strings + now behaves as it does when using a length-1 ``DatetimeIndex`` (i.e. it no + longer falsely returns an empty array when the slice includes the value in + the index) (:issue:`2165`). + By `Spencer Clark `_. + +- Fix ``DataArray.groupby().reduce()`` mutating coordinates on the input array + when grouping over dimension coordinates with duplicated entries + (:issue:`2153`). + By `Stephan Hoyer `_ + +- Fix ``Dataset.to_netcdf()`` cannot create group with ``engine="h5netcdf"`` + (:issue:`2177`). + By `Stephan Hoyer `_ + +.. _whats-new.0.10.4: + +v0.10.4 (16 May 2018) +---------------------- + +The minor release includes a number of bug-fixes and backwards compatible +enhancements. A highlight is ``CFTimeIndex``, which offers support for +non-standard calendars used in climate modeling. + +Documentation +~~~~~~~~~~~~~ + +- New FAQ entry, :ref:`ecosystem`. + By `Deepak Cherian `_. +- :ref:`assigning_values` now includes examples on how to select and assign + values to a :py:class:`~xarray.DataArray` with ``.loc``. + By `Chiara Lepore `_. + +Enhancements +~~~~~~~~~~~~ + +- Add an option for using a ``CFTimeIndex`` for indexing times with + non-standard calendars and/or outside the Timestamp-valid range; this index + enables a subset of the functionality of a standard + ``pandas.DatetimeIndex``. + See :ref:`CFTimeIndex` for full details. + (:issue:`789`, :issue:`1084`, :issue:`1252`) + By `Spencer Clark `_ with help from + `Stephan Hoyer `_. +- Allow for serialization of ``cftime.datetime`` objects (:issue:`789`, + :issue:`1084`, :issue:`2008`, :issue:`1252`) using the standalone ``cftime`` + library. + By `Spencer Clark `_. +- Support writing lists of strings as netCDF attributes (:issue:`2044`). + By `Dan Nowacki `_. +- :py:meth:`~xarray.Dataset.to_netcdf` with ``engine='h5netcdf'`` now accepts h5py + encoding settings ``compression`` and ``compression_opts``, along with the + NetCDF4-Python style settings ``gzip=True`` and ``complevel``. + This allows using any compression plugin installed in hdf5, e.g. LZF + (:issue:`1536`). By `Guido Imperiale `_. +- :py:meth:`~xarray.dot` on dask-backed data will now call :func:`dask.array.einsum`. + This greatly boosts speed and allows chunking on the core dims. + The function now requires dask >= 0.17.3 to work on dask-backed data + (:issue:`2074`). By `Guido Imperiale `_. +- ``plot.line()`` learned new kwargs: ``xincrease``, ``yincrease`` that change + the direction of the respective axes. + By `Deepak Cherian `_. + +- Added the ``parallel`` option to :py:func:`open_mfdataset`. This option uses + ``dask.delayed`` to parallelize the open and preprocessing steps within + ``open_mfdataset``. This is expected to provide performance improvements when + opening many files, particularly when used in conjunction with dask's + multiprocessing or distributed schedulers (:issue:`1981`). + By `Joe Hamman `_. + +- New ``compute`` option in :py:meth:`~xarray.Dataset.to_netcdf`, + :py:meth:`~xarray.Dataset.to_zarr`, and :py:func:`~xarray.save_mfdataset` to + allow for the lazy computation of netCDF and zarr stores. This feature is + currently only supported by the netCDF4 and zarr backends. (:issue:`1784`). + By `Joe Hamman `_. + + +Bug fixes +~~~~~~~~~ + +- ``ValueError`` is raised when coordinates with the wrong size are assigned to + a :py:class:`DataArray`. (:issue:`2112`) + By `Keisuke Fujii `_. +- Fixed a bug in :py:meth:`~xarray.DataArray.rolling` with bottleneck. Also, + fixed a bug in rolling an integer dask array. (:issue:`2113`) + By `Keisuke Fujii `_. +- Fixed a bug where `keep_attrs=True` flag was neglected if + :py:func:`apply_ufunc` was used with :py:class:`Variable`. (:issue:`2114`) + By `Keisuke Fujii `_. +- When assigning a :py:class:`DataArray` to :py:class:`Dataset`, any conflicted + non-dimensional coordinates of the DataArray are now dropped. + (:issue:`2068`) + By `Keisuke Fujii `_. +- Better error handling in ``open_mfdataset`` (:issue:`2077`). + By `Stephan Hoyer `_. +- ``plot.line()`` does not call ``autofmt_xdate()`` anymore. Instead it changes + the rotation and horizontal alignment of labels without removing the x-axes of + any other subplots in the figure (if any). + By `Deepak Cherian `_. +- Colorbar limits are now determined by excluding ±Infs too. + By `Deepak Cherian `_. + By `Joe Hamman `_. +- Fixed ``to_iris`` to maintain lazy dask array after conversion (:issue:`2046`). + By `Alex Hilson `_ and `Stephan Hoyer `_. + +.. _whats-new.0.10.3: + +v0.10.3 (13 April 2018) +------------------------ + +The minor release includes a number of bug-fixes and backwards compatible enhancements. + +Enhancements +~~~~~~~~~~~~ + +- :py:meth:`~xarray.DataArray.isin` and :py:meth:`~xarray.Dataset.isin` methods, + which test each value in the array for whether it is contained in the + supplied list, returning a bool array. See :ref:`selecting values with isin` + for full details. Similar to the ``np.isin`` function. + By `Maximilian Roos `_. +- Some speed improvement to construct :py:class:`~xarray.core.rolling.DataArrayRolling` + object (:issue:`1993`) + By `Keisuke Fujii `_. +- Handle variables with different values for ``missing_value`` and + ``_FillValue`` by masking values for both attributes; previously this + resulted in a ``ValueError``. (:issue:`2016`) + By `Ryan May `_. + +Bug fixes +~~~~~~~~~ + +- Fixed ``decode_cf`` function to operate lazily on dask arrays + (:issue:`1372`). By `Ryan Abernathey `_. +- Fixed labeled indexing with slice bounds given by xarray objects with + datetime64 or timedelta64 dtypes (:issue:`1240`). + By `Stephan Hoyer `_. +- Attempting to convert an xarray.Dataset into a numpy array now raises an + informative error message. + By `Stephan Hoyer `_. +- Fixed a bug in decode_cf_datetime where ``int32`` arrays weren't parsed + correctly (:issue:`2002`). + By `Fabien Maussion `_. +- When calling `xr.auto_combine()` or `xr.open_mfdataset()` with a `concat_dim`, + the resulting dataset will have that one-element dimension (it was + silently dropped, previously) (:issue:`1988`). + By `Ben Root `_. + +.. _whats-new.0.10.2: + +v0.10.2 (13 March 2018) +----------------------- + +The minor release includes a number of bug-fixes and enhancements, along with +one possibly **backwards incompatible change**. + +Backwards incompatible changes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- The addition of ``__array_ufunc__`` for xarray objects (see below) means that + NumPy `ufunc methods`_ (e.g., ``np.add.reduce``) that previously worked on + ``xarray.DataArray`` objects by converting them into NumPy arrays will now + raise ``NotImplementedError`` instead. In all cases, the work-around is + simple: convert your objects explicitly into NumPy arrays before calling the + ufunc (e.g., with ``.values``). + +.. _ufunc methods: https://numpy.org/doc/stable/reference/ufuncs.html#methods + +Enhancements +~~~~~~~~~~~~ + +- Added :py:func:`~xarray.dot`, equivalent to :py:func:`numpy.einsum`. + Also, :py:func:`~xarray.DataArray.dot` now supports ``dims`` option, + which specifies the dimensions to sum over. + (:issue:`1951`) + By `Keisuke Fujii `_. + +- Support for writing xarray datasets to netCDF files (netcdf4 backend only) + when using the `dask.distributed `_ + scheduler (:issue:`1464`). + By `Joe Hamman `_. + +- Support lazy vectorized-indexing. After this change, flexible indexing such + as orthogonal/vectorized indexing, becomes possible for all the backend + arrays. Also, lazy ``transpose`` is now also supported. (:issue:`1897`) + By `Keisuke Fujii `_. + +- Implemented NumPy's ``__array_ufunc__`` protocol for all xarray objects + (:issue:`1617`). This enables using NumPy ufuncs directly on + ``xarray.Dataset`` objects with recent versions of NumPy (v1.13 and newer): + + .. ipython:: python + + ds = xr.Dataset({"a": 1}) + np.sin(ds) + + This obliviates the need for the ``xarray.ufuncs`` module, which will be + deprecated in the future when xarray drops support for older versions of + NumPy. By `Stephan Hoyer `_. + +- Improve :py:func:`~xarray.DataArray.rolling` logic. + :py:func:`~xarray.core.rolling.DataArrayRolling` object now supports + :py:func:`~xarray.core.rolling.DataArrayRolling.construct` method that returns a view + of the DataArray / Dataset object with the rolling-window dimension added + to the last axis. This enables more flexible operation, such as strided + rolling, windowed rolling, ND-rolling, short-time FFT and convolution. + (:issue:`1831`, :issue:`1142`, :issue:`819`) + By `Keisuke Fujii `_. +- :py:func:`~plot.line()` learned to make plots with data on x-axis if so specified. (:issue:`575`) + By `Deepak Cherian `_. + +Bug fixes +~~~~~~~~~ + +- Raise an informative error message when using ``apply_ufunc`` with numpy + v1.11 (:issue:`1956`). + By `Stephan Hoyer `_. +- Fix the precision drop after indexing datetime64 arrays (:issue:`1932`). + By `Keisuke Fujii `_. +- Silenced irrelevant warnings issued by ``open_rasterio`` (:issue:`1964`). + By `Stephan Hoyer `_. +- Fix kwarg `colors` clashing with auto-inferred `cmap` (:issue:`1461`) + By `Deepak Cherian `_. +- Fix :py:func:`~xarray.plot.imshow` error when passed an RGB array with + size one in a spatial dimension. + By `Zac Hatfield-Dodds `_. + +.. _whats-new.0.10.1: + +v0.10.1 (25 February 2018) +-------------------------- + +The minor release includes a number of bug-fixes and backwards compatible enhancements. + +Documentation +~~~~~~~~~~~~~ + +- Added a new guide on :ref:`contributing` (:issue:`640`) + By `Joe Hamman `_. +- Added apply_ufunc example to :ref:`/examples/weather-data.ipynb#Toy-weather-data` (:issue:`1844`). + By `Liam Brannigan `_. +- New entry `Why don’t aggregations return Python scalars?` in the + :doc:`getting-started-guide/faq` (:issue:`1726`). + By `0x0L `_. + +Enhancements +~~~~~~~~~~~~ +**New functions and methods**: + +- Added :py:meth:`DataArray.to_iris` and + :py:meth:`DataArray.from_iris` for + converting data arrays to and from Iris_ Cubes with the same data and coordinates + (:issue:`621` and :issue:`37`). + By `Neil Parley `_ and `Duncan Watson-Parris `_. +- Experimental support for using `Zarr`_ as storage layer for xarray + (:issue:`1223`). + By `Ryan Abernathey `_ and + `Joe Hamman `_. +- New :py:meth:`~xarray.DataArray.rank` on arrays and datasets. Requires + bottleneck (:issue:`1731`). + By `0x0L `_. +- ``.dt`` accessor can now ceil, floor and round timestamps to specified frequency. + By `Deepak Cherian `_. + +**Plotting enhancements**: + +- :func:`xarray.plot.imshow` now handles RGB and RGBA images. + Saturation can be adjusted with ``vmin`` and ``vmax``, or with ``robust=True``. + By `Zac Hatfield-Dodds `_. +- :py:func:`~plot.contourf()` learned to contour 2D variables that have both a + 1D coordinate (e.g. time) and a 2D coordinate (e.g. depth as a function of + time) (:issue:`1737`). + By `Deepak Cherian `_. +- :py:func:`~plot.plot()` rotates x-axis ticks if x-axis is time. + By `Deepak Cherian `_. +- :py:func:`~plot.line()` can draw multiple lines if provided with a + 2D variable. + By `Deepak Cherian `_. + +**Other enhancements**: + +- Reduce methods such as :py:func:`DataArray.sum()` now handles object-type array. + + .. ipython:: python + + da = xr.DataArray(np.array([True, False, np.nan], dtype=object), dims="x") + da.sum() + + (:issue:`1866`) + By `Keisuke Fujii `_. +- Reduce methods such as :py:func:`DataArray.sum()` now accepts ``dtype`` + arguments. (:issue:`1838`) + By `Keisuke Fujii `_. +- Added nodatavals attribute to DataArray when using :py:func:`~xarray.open_rasterio`. (:issue:`1736`). + By `Alan Snow `_. +- Use ``pandas.Grouper`` class in xarray resample methods rather than the + deprecated ``pandas.TimeGrouper`` class (:issue:`1766`). + By `Joe Hamman `_. +- Experimental support for parsing ENVI metadata to coordinates and attributes + in :py:func:`xarray.open_rasterio`. + By `Matti Eskelinen `_. +- Reduce memory usage when decoding a variable with a scale_factor, by + converting 8-bit and 16-bit integers to float32 instead of float64 + (:pull:`1840`), and keeping float16 and float32 as float32 (:issue:`1842`). + Correspondingly, encoded variables may also be saved with a smaller dtype. + By `Zac Hatfield-Dodds `_. +- Speed of reindexing/alignment with dask array is orders of magnitude faster + when inserting missing values (:issue:`1847`). + By `Stephan Hoyer `_. +- Fix ``axis`` keyword ignored when applying ``np.squeeze`` to ``DataArray`` (:issue:`1487`). + By `Florian Pinault `_. +- ``netcdf4-python`` has moved the its time handling in the ``netcdftime`` module to + a standalone package (`netcdftime`_). As such, xarray now considers `netcdftime`_ + an optional dependency. One benefit of this change is that it allows for + encoding/decoding of datetimes with non-standard calendars without the + ``netcdf4-python`` dependency (:issue:`1084`). + By `Joe Hamman `_. + +.. _Zarr: http://zarr.readthedocs.io/ + +.. _Iris: http://scitools.org.uk/iris + +.. _netcdftime: https://unidata.github.io/netcdftime + +**New functions/methods** + +- New :py:meth:`~xarray.DataArray.rank` on arrays and datasets. Requires + bottleneck (:issue:`1731`). + By `0x0L `_. + +Bug fixes +~~~~~~~~~ +- Rolling aggregation with ``center=True`` option now gives the same result + with pandas including the last element (:issue:`1046`). + By `Keisuke Fujii `_. + +- Support indexing with a 0d-np.ndarray (:issue:`1921`). + By `Keisuke Fujii `_. +- Added warning in api.py of a netCDF4 bug that occurs when + the filepath has 88 characters (:issue:`1745`). + By `Liam Brannigan `_. +- Fixed encoding of multi-dimensional coordinates in + :py:meth:`~Dataset.to_netcdf` (:issue:`1763`). + By `Mike Neish `_. +- Fixed chunking with non-file-based rasterio datasets (:issue:`1816`) and + refactored rasterio test suite. + By `Ryan Abernathey `_ +- Bug fix in open_dataset(engine='pydap') (:issue:`1775`) + By `Keisuke Fujii `_. +- Bug fix in vectorized assignment (:issue:`1743`, :issue:`1744`). + Now item assignment to :py:meth:`~DataArray.__setitem__` checks +- Bug fix in vectorized assignment (:issue:`1743`, :issue:`1744`). + Now item assignment to :py:meth:`DataArray.__setitem__` checks + coordinates of target, destination and keys. If there are any conflict among + these coordinates, ``IndexError`` will be raised. + By `Keisuke Fujii `_. +- Properly point ``DataArray.__dask_scheduler__`` to + ``dask.threaded.get``. By `Matthew Rocklin `_. +- Bug fixes in :py:meth:`DataArray.plot.imshow`: all-NaN arrays and arrays + with size one in some dimension can now be plotted, which is good for + exploring satellite imagery (:issue:`1780`). + By `Zac Hatfield-Dodds `_. +- Fixed ``UnboundLocalError`` when opening netCDF file (:issue:`1781`). + By `Stephan Hoyer `_. +- The ``variables``, ``attrs``, and ``dimensions`` properties have been + deprecated as part of a bug fix addressing an issue where backends were + unintentionally loading the datastores data and attributes repeatedly during + writes (:issue:`1798`). + By `Joe Hamman `_. +- Compatibility fixes to plotting module for NumPy 1.14 and pandas 0.22 + (:issue:`1813`). + By `Joe Hamman `_. +- Bug fix in encoding coordinates with ``{'_FillValue': None}`` in netCDF + metadata (:issue:`1865`). + By `Chris Roth `_. +- Fix indexing with lists for arrays loaded from netCDF files with + ``engine='h5netcdf`` (:issue:`1864`). + By `Stephan Hoyer `_. +- Corrected a bug with incorrect coordinates for non-georeferenced geotiff + files (:issue:`1686`). Internally, we now use the rasterio coordinate + transform tool instead of doing the computations ourselves. A + ``parse_coordinates`` kwarg has been added to :py:func:`~open_rasterio` + (set to ``True`` per default). + By `Fabien Maussion `_. +- The colors of discrete colormaps are now the same regardless if `seaborn` + is installed or not (:issue:`1896`). + By `Fabien Maussion `_. +- Fixed dtype promotion rules in :py:func:`where` and :py:func:`concat` to + match pandas (:issue:`1847`). A combination of strings/numbers or + unicode/bytes now promote to object dtype, instead of strings or unicode. + By `Stephan Hoyer `_. +- Fixed bug where :py:meth:`~xarray.DataArray.isnull` was loading data + stored as dask arrays (:issue:`1937`). + By `Joe Hamman `_. + +.. _whats-new.0.10.0: + +v0.10.0 (20 November 2017) +-------------------------- + +This is a major release that includes bug fixes, new features and a few +backwards incompatible changes. Highlights include: + +- Indexing now supports broadcasting over dimensions, similar to NumPy's + vectorized indexing (but better!). +- :py:meth:`~DataArray.resample` has a new groupby-like API like pandas. +- :py:func:`~xarray.apply_ufunc` facilitates wrapping and parallelizing + functions written for NumPy arrays. +- Performance improvements, particularly for dask and :py:func:`open_mfdataset`. + +Breaking changes +~~~~~~~~~~~~~~~~ + +- xarray now supports a form of vectorized indexing with broadcasting, where + the result of indexing depends on dimensions of indexers, + e.g., ``array.sel(x=ind)`` with ``ind.dims == ('y',)``. Alignment between + coordinates on indexed and indexing objects is also now enforced. + Due to these changes, existing uses of xarray objects to index other xarray + objects will break in some cases. + + The new indexing API is much more powerful, supporting outer, diagonal and + vectorized indexing in a single interface. + The ``isel_points`` and ``sel_points`` methods are deprecated, since they are + now redundant with the ``isel`` / ``sel`` methods. + See :ref:`vectorized_indexing` for the details (:issue:`1444`, + :issue:`1436`). + By `Keisuke Fujii `_ and + `Stephan Hoyer `_. + +- A new resampling interface to match pandas' groupby-like API was added to + :py:meth:`Dataset.resample` and :py:meth:`DataArray.resample` + (:issue:`1272`). :ref:`Timeseries resampling ` is + fully supported for data with arbitrary dimensions as is both downsampling + and upsampling (including linear, quadratic, cubic, and spline interpolation). + + Old syntax: + + .. ipython:: + :verbatim: + + In [1]: ds.resample("24H", dim="time", how="max") + Out[1]: + + [...] + + New syntax: + + .. ipython:: + :verbatim: + + In [1]: ds.resample(time="24H").max() + Out[1]: + + [...] + + Note that both versions are currently supported, but using the old syntax will + produce a warning encouraging users to adopt the new syntax. + By `Daniel Rothenberg `_. + +- Calling ``repr()`` or printing xarray objects at the command line or in a + Jupyter Notebook will not longer automatically compute dask variables or + load data on arrays lazily loaded from disk (:issue:`1522`). + By `Guido Imperiale `_. + +- Supplying ``coords`` as a dictionary to the ``DataArray`` constructor without + also supplying an explicit ``dims`` argument is no longer supported. This + behavior was deprecated in version 0.9 but will now raise an error + (:issue:`727`). + +- Several existing features have been deprecated and will change to new + behavior in xarray v0.11. If you use any of them with xarray v0.10, you + should see a ``FutureWarning`` that describes how to update your code: + + - ``Dataset.T`` has been deprecated an alias for ``Dataset.transpose()`` + (:issue:`1232`). In the next major version of xarray, it will provide short- + cut lookup for variables or attributes with name ``'T'``. + - ``DataArray.__contains__`` (e.g., ``key in data_array``) currently checks + for membership in ``DataArray.coords``. In the next major version of + xarray, it will check membership in the array data found in + ``DataArray.values`` instead (:issue:`1267`). + - Direct iteration over and counting a ``Dataset`` (e.g., ``[k for k in ds]``, + ``ds.keys()``, ``ds.values()``, ``len(ds)`` and ``if ds``) currently + includes all variables, both data and coordinates. For improved usability + and consistency with pandas, in the next major version of xarray these will + change to only include data variables (:issue:`884`). Use ``ds.variables``, + ``ds.data_vars`` or ``ds.coords`` as alternatives. + +- Changes to minimum versions of dependencies: + + - Old numpy < 1.11 and pandas < 0.18 are no longer supported (:issue:`1512`). + By `Keisuke Fujii `_. + - The minimum supported version bottleneck has increased to 1.1 + (:issue:`1279`). + By `Joe Hamman `_. + +Enhancements +~~~~~~~~~~~~ + +**New functions/methods** + +- New helper function :py:func:`~xarray.apply_ufunc` for wrapping functions + written to work on NumPy arrays to support labels on xarray objects + (:issue:`770`). ``apply_ufunc`` also support automatic parallelization for + many functions with dask. See :ref:`comput.wrapping-custom` and + :ref:`dask.automatic-parallelization` for details. + By `Stephan Hoyer `_. + +- Added new method :py:meth:`Dataset.to_dask_dataframe`, convert a dataset into + a dask dataframe. + This allows lazy loading of data from a dataset containing dask arrays (:issue:`1462`). + By `James Munroe `_. + +- New function :py:func:`~xarray.where` for conditionally switching between + values in xarray objects, like :py:func:`numpy.where`: + + .. ipython:: + :verbatim: + + In [1]: import xarray as xr + + In [2]: arr = xr.DataArray([[1, 2, 3], [4, 5, 6]], dims=("x", "y")) + + In [3]: xr.where(arr % 2, "even", "odd") + Out[3]: + + array([['even', 'odd', 'even'], + ['odd', 'even', 'odd']], + dtype='`_. + +- Added :py:func:`~xarray.show_versions` function to aid in debugging + (:issue:`1485`). + By `Joe Hamman `_. + +**Performance improvements** + +- :py:func:`~xarray.concat` was computing variables that aren't in memory + (e.g. dask-based) multiple times; :py:func:`~xarray.open_mfdataset` + was loading them multiple times from disk. Now, both functions will instead + load them at most once and, if they do, store them in memory in the + concatenated array/dataset (:issue:`1521`). + By `Guido Imperiale `_. + +- Speed-up (x 100) of ``xarray.conventions.decode_cf_datetime``. + By `Christian Chwala `_. + +**IO related improvements** + +- Unicode strings (``str`` on Python 3) are now round-tripped successfully even + when written as character arrays (e.g., as netCDF3 files or when using + ``engine='scipy'``) (:issue:`1638`). This is controlled by the ``_Encoding`` + attribute convention, which is also understood directly by the netCDF4-Python + interface. See :ref:`io.string-encoding` for full details. + By `Stephan Hoyer `_. + +- Support for ``data_vars`` and ``coords`` keywords from + :py:func:`~xarray.concat` added to :py:func:`~xarray.open_mfdataset` + (:issue:`438`). Using these keyword arguments can significantly reduce + memory usage and increase speed. + By `Oleksandr Huziy `_. + +- Support for :py:class:`pathlib.Path` objects added to + :py:func:`~xarray.open_dataset`, :py:func:`~xarray.open_mfdataset`, + ``xarray.to_netcdf``, and :py:func:`~xarray.save_mfdataset` + (:issue:`799`): + + .. ipython:: + :verbatim: + + In [2]: from pathlib import Path # In Python 2, use pathlib2! + + In [3]: data_dir = Path("data/") + + In [4]: one_file = data_dir / "dta_for_month_01.nc" + + In [5]: xr.open_dataset(one_file) + Out[5]: + + [...] + + By `Willi Rath `_. + +- You can now explicitly disable any default ``_FillValue`` (``NaN`` for + floating point values) by passing the encoding ``{'_FillValue': None}`` + (:issue:`1598`). + By `Stephan Hoyer `_. + +- More attributes available in :py:attr:`~xarray.Dataset.attrs` dictionary when + raster files are opened with :py:func:`~xarray.open_rasterio`. + By `Greg Brener `_. + +- Support for NetCDF files using an ``_Unsigned`` attribute to indicate that a + a signed integer data type should be interpreted as unsigned bytes + (:issue:`1444`). + By `Eric Bruning `_. + +- Support using an existing, opened netCDF4 ``Dataset`` with + :py:class:`~xarray.backends.NetCDF4DataStore`. This permits creating an + :py:class:`~xarray.Dataset` from a netCDF4 ``Dataset`` that has been opened using + other means (:issue:`1459`). + By `Ryan May `_. + +- Changed :py:class:`~xarray.backends.PydapDataStore` to take a Pydap dataset. + This permits opening Opendap datasets that require authentication, by + instantiating a Pydap dataset with a session object. Also added + :py:meth:`xarray.backends.PydapDataStore.open` which takes a url and session + object (:issue:`1068`). + By `Philip Graae `_. + +- Support reading and writing unlimited dimensions with h5netcdf (:issue:`1636`). + By `Joe Hamman `_. + +**Other improvements** + +- Added ``_ipython_key_completions_`` to xarray objects, to enable + autocompletion for dictionary-like access in IPython, e.g., + ``ds['tem`` + tab -> ``ds['temperature']`` (:issue:`1628`). + By `Keisuke Fujii `_. + +- Support passing keyword arguments to ``load``, ``compute``, and ``persist`` + methods. Any keyword arguments supplied to these methods are passed on to + the corresponding dask function (:issue:`1523`). + By `Joe Hamman `_. + +- Encoding attributes are now preserved when xarray objects are concatenated. + The encoding is copied from the first object (:issue:`1297`). + By `Joe Hamman `_ and + `Gerrit Holl `_. + +- Support applying rolling window operations using bottleneck's moving window + functions on data stored as dask arrays (:issue:`1279`). + By `Joe Hamman `_. + +- Experimental support for the Dask collection interface (:issue:`1674`). + By `Matthew Rocklin `_. + +Bug fixes +~~~~~~~~~ + +- Suppress ``RuntimeWarning`` issued by ``numpy`` for "invalid value comparisons" + (e.g. ``NaN``). Xarray now behaves similarly to pandas in its treatment of + binary and unary operations on objects with NaNs (:issue:`1657`). + By `Joe Hamman `_. + +- Unsigned int support for reduce methods with ``skipna=True`` + (:issue:`1562`). + By `Keisuke Fujii `_. + +- Fixes to ensure xarray works properly with pandas 0.21: + + - Fix :py:meth:`~xarray.DataArray.isnull` method (:issue:`1549`). + - :py:meth:`~xarray.DataArray.to_series` and + :py:meth:`~xarray.Dataset.to_dataframe` should not return a ``pandas.MultiIndex`` + for 1D data (:issue:`1548`). + - Fix plotting with datetime64 axis labels (:issue:`1661`). + + By `Stephan Hoyer `_. + +- :py:func:`~xarray.open_rasterio` method now shifts the rasterio + coordinates so that they are centered in each pixel (:issue:`1468`). + By `Greg Brener `_. + +- :py:meth:`~xarray.Dataset.rename` method now doesn't throw errors + if some ``Variable`` is renamed to the same name as another ``Variable`` + as long as that other ``Variable`` is also renamed (:issue:`1477`). This + method now does throw when two ``Variables`` would end up with the same name + after the rename (since one of them would get overwritten in this case). + By `Prakhar Goel `_. + +- Fix :py:func:`xarray.testing.assert_allclose` to actually use ``atol`` and + ``rtol`` arguments when called on ``DataArray`` objects (:issue:`1488`). + By `Stephan Hoyer `_. + +- xarray ``quantile`` methods now properly raise a ``TypeError`` when applied to + objects with data stored as ``dask`` arrays (:issue:`1529`). + By `Joe Hamman `_. + +- Fix positional indexing to allow the use of unsigned integers (:issue:`1405`). + By `Joe Hamman `_ and + `Gerrit Holl `_. + +- Creating a :py:class:`Dataset` now raises ``MergeError`` if a coordinate + shares a name with a dimension but is comprised of arbitrary dimensions + (:issue:`1120`). + By `Joe Hamman `_. + +- :py:func:`~xarray.open_rasterio` method now skips rasterio's ``crs`` + attribute if its value is ``None`` (:issue:`1520`). + By `Leevi Annala `_. + +- Fix :py:func:`xarray.DataArray.to_netcdf` to return bytes when no path is + provided (:issue:`1410`). + By `Joe Hamman `_. + +- Fix :py:func:`xarray.save_mfdataset` to properly raise an informative error + when objects other than ``Dataset`` are provided (:issue:`1555`). + By `Joe Hamman `_. + +- :py:func:`xarray.Dataset.copy` would not preserve the encoding property + (:issue:`1586`). + By `Guido Imperiale `_. + +- :py:func:`xarray.concat` would eagerly load dask variables into memory if + the first argument was a numpy variable (:issue:`1588`). + By `Guido Imperiale `_. + +- Fix bug in :py:meth:`~xarray.Dataset.to_netcdf` when writing in append mode + (:issue:`1215`). + By `Joe Hamman `_. + +- Fix ``netCDF4`` backend to properly roundtrip the ``shuffle`` encoding option + (:issue:`1606`). + By `Joe Hamman `_. + +- Fix bug when using ``pytest`` class decorators to skipping certain unittests. + The previous behavior unintentionally causing additional tests to be skipped + (:issue:`1531`). By `Joe Hamman `_. + +- Fix pynio backend for upcoming release of pynio with Python 3 support + (:issue:`1611`). By `Ben Hillman `_. + +- Fix ``seaborn`` import warning for Seaborn versions 0.8 and newer when the + ``apionly`` module was deprecated. + (:issue:`1633`). By `Joe Hamman `_. + +- Fix COMPAT: MultiIndex checking is fragile + (:issue:`1833`). By `Florian Pinault `_. + +- Fix ``rasterio`` backend for Rasterio versions 1.0alpha10 and newer. + (:issue:`1641`). By `Chris Holden `_. + +Bug fixes after rc1 +~~~~~~~~~~~~~~~~~~~ + +- Suppress warning in IPython autocompletion, related to the deprecation + of ``.T`` attributes (:issue:`1675`). + By `Keisuke Fujii `_. + +- Fix a bug in lazily-indexing netCDF array. (:issue:`1688`) + By `Keisuke Fujii `_. + +- (Internal bug) MemoryCachedArray now supports the orthogonal indexing. + Also made some internal cleanups around array wrappers (:issue:`1429`). + By `Keisuke Fujii `_. + +- (Internal bug) MemoryCachedArray now always wraps ``np.ndarray`` by + ``NumpyIndexingAdapter``. (:issue:`1694`) + By `Keisuke Fujii `_. + +- Fix importing xarray when running Python with ``-OO`` (:issue:`1706`). + By `Stephan Hoyer `_. + +- Saving a netCDF file with a coordinates with a spaces in its names now raises + an appropriate warning (:issue:`1689`). + By `Stephan Hoyer `_. + +- Fix two bugs that were preventing dask arrays from being specified as + coordinates in the DataArray constructor (:issue:`1684`). + By `Joe Hamman `_. + +- Fixed ``apply_ufunc`` with ``dask='parallelized'`` for scalar arguments + (:issue:`1697`). + By `Stephan Hoyer `_. + +- Fix "Chunksize cannot exceed dimension size" error when writing netCDF4 files + loaded from disk (:issue:`1225`). + By `Stephan Hoyer `_. + +- Validate the shape of coordinates with names matching dimensions in the + DataArray constructor (:issue:`1709`). + By `Stephan Hoyer `_. + +- Raise ``NotImplementedError`` when attempting to save a MultiIndex to a + netCDF file (:issue:`1547`). + By `Stephan Hoyer `_. + +- Remove netCDF dependency from rasterio backend tests. + By `Matti Eskelinen `_ + +Bug fixes after rc2 +~~~~~~~~~~~~~~~~~~~ + +- Fixed unexpected behavior in ``Dataset.set_index()`` and + ``DataArray.set_index()`` introduced by pandas 0.21.0. Setting a new + index with a single variable resulted in 1-level + ``pandas.MultiIndex`` instead of a simple ``pandas.Index`` + (:issue:`1722`). By `Benoit Bovy `_. + +- Fixed unexpected memory loading of backend arrays after ``print``. + (:issue:`1720`). By `Keisuke Fujii `_. + +.. _whats-new.0.9.6: + +v0.9.6 (8 June 2017) +-------------------- + +This release includes a number of backwards compatible enhancements and bug +fixes. + +Enhancements +~~~~~~~~~~~~ + +- New :py:meth:`~xarray.Dataset.sortby` method to ``Dataset`` and ``DataArray`` + that enable sorting along dimensions (:issue:`967`). + See :ref:`the docs ` for examples. + By `Chun-Wei Yuan `_ and + `Kyle Heuton `_. + +- Add ``.dt`` accessor to DataArrays for computing datetime-like properties + for the values they contain, similar to ``pandas.Series`` (:issue:`358`). + By `Daniel Rothenberg `_. + +- Renamed internal dask arrays created by ``open_dataset`` to match new dask + conventions (:issue:`1343`). + By `Ryan Abernathey `_. + +- :py:meth:`~xarray.as_variable` is now part of the public API (:issue:`1303`). + By `Benoit Bovy `_. + +- :py:func:`~xarray.align` now supports ``join='exact'``, which raises + an error instead of aligning when indexes to be aligned are not equal. + By `Stephan Hoyer `_. + +- New function :py:func:`~xarray.open_rasterio` for opening raster files with + the `rasterio `_ library. + See :ref:`the docs ` for details. + By `Joe Hamman `_, + `Nic Wayand `_ and + `Fabien Maussion `_ + +Bug fixes +~~~~~~~~~ + +- Fix error from repeated indexing of datasets loaded from disk (:issue:`1374`). + By `Stephan Hoyer `_. + +- Fix a bug where ``.isel_points`` wrongly assigns unselected coordinate to + ``data_vars``. + By `Keisuke Fujii `_. + +- Tutorial datasets are now checked against a reference MD5 sum to confirm + successful download (:issue:`1392`). By `Matthew Gidden + `_. + +- ``DataArray.chunk()`` now accepts dask specific kwargs like + ``Dataset.chunk()`` does. By `Fabien Maussion `_. + +- Support for ``engine='pydap'`` with recent releases of Pydap (3.2.2+), + including on Python 3 (:issue:`1174`). + +Documentation +~~~~~~~~~~~~~ + +- A new `gallery `_ + allows to add interactive examples to the documentation. + By `Fabien Maussion `_. + +Testing +~~~~~~~ + +- Fix test suite failure caused by changes to ``pandas.cut`` function + (:issue:`1386`). + By `Ryan Abernathey `_. + +- Enhanced tests suite by use of ``@network`` decorator, which is + controlled via ``--run-network-tests`` command line argument + to ``py.test`` (:issue:`1393`). + By `Matthew Gidden `_. + +.. _whats-new.0.9.5: + +v0.9.5 (17 April, 2017) +----------------------- + +Remove an inadvertently introduced print statement. + +.. _whats-new.0.9.3: + +v0.9.3 (16 April, 2017) +----------------------- + +This minor release includes bug-fixes and backwards compatible enhancements. + +Enhancements +~~~~~~~~~~~~ + +- New :py:meth:`~xarray.DataArray.persist` method to Datasets and DataArrays to + enable persisting data in distributed memory when using Dask (:issue:`1344`). + By `Matthew Rocklin `_. + +- New :py:meth:`~xarray.DataArray.expand_dims` method for ``DataArray`` and + ``Dataset`` (:issue:`1326`). + By `Keisuke Fujii `_. + +Bug fixes +~~~~~~~~~ + +- Fix ``.where()`` with ``drop=True`` when arguments do not have indexes + (:issue:`1350`). This bug, introduced in v0.9, resulted in xarray producing + incorrect results in some cases. + By `Stephan Hoyer `_. + +- Fixed writing to file-like objects with :py:meth:`~xarray.Dataset.to_netcdf` + (:issue:`1320`). + `Stephan Hoyer `_. + +- Fixed explicitly setting ``engine='scipy'`` with ``to_netcdf`` when not + providing a path (:issue:`1321`). + `Stephan Hoyer `_. + +- Fixed open_dataarray does not pass properly its parameters to open_dataset + (:issue:`1359`). + `Stephan Hoyer `_. + +- Ensure test suite works when runs from an installed version of xarray + (:issue:`1336`). Use ``@pytest.mark.slow`` instead of a custom flag to mark + slow tests. + By `Stephan Hoyer `_ + +.. _whats-new.0.9.2: + +v0.9.2 (2 April 2017) +--------------------- + +The minor release includes bug-fixes and backwards compatible enhancements. + +Enhancements +~~~~~~~~~~~~ + +- ``rolling`` on Dataset is now supported (:issue:`859`). + +- ``.rolling()`` on Dataset is now supported (:issue:`859`). + By `Keisuke Fujii `_. + +- When bottleneck version 1.1 or later is installed, use bottleneck for rolling + ``var``, ``argmin``, ``argmax``, and ``rank`` computations. Also, rolling + median now accepts a ``min_periods`` argument (:issue:`1276`). + By `Joe Hamman `_. + +- When ``.plot()`` is called on a 2D DataArray and only one dimension is + specified with ``x=`` or ``y=``, the other dimension is now guessed + (:issue:`1291`). + By `Vincent Noel `_. + +- Added new method :py:meth:`~Dataset.assign_attrs` to ``DataArray`` and + ``Dataset``, a chained-method compatible implementation of the + ``dict.update`` method on attrs (:issue:`1281`). + By `Henry S. Harrison `_. + +- Added new ``autoclose=True`` argument to + :py:func:`~xarray.open_mfdataset` to explicitly close opened files when not in + use to prevent occurrence of an OS Error related to too many open files + (:issue:`1198`). + Note, the default is ``autoclose=False``, which is consistent with + previous xarray behavior. + By `Phillip J. Wolfram `_. + +- The ``repr()`` of ``Dataset`` and ``DataArray`` attributes uses a similar + format to coordinates and variables, with vertically aligned entries + truncated to fit on a single line (:issue:`1319`). Hopefully this will stop + people writing ``data.attrs = {}`` and discarding metadata in notebooks for + the sake of cleaner output. The full metadata is still available as + ``data.attrs``. + By `Zac Hatfield-Dodds `_. + +- Enhanced tests suite by use of ``@slow`` and ``@flaky`` decorators, which are + controlled via ``--run-flaky`` and ``--skip-slow`` command line arguments + to ``py.test`` (:issue:`1336`). + By `Stephan Hoyer `_ and + `Phillip J. Wolfram `_. + +- New aggregation on rolling objects :py:meth:`~core.rolling.DataArrayRolling.count` + which providing a rolling count of valid values (:issue:`1138`). + +Bug fixes +~~~~~~~~~ +- Rolling operations now keep preserve original dimension order (:issue:`1125`). + By `Keisuke Fujii `_. + +- Fixed ``sel`` with ``method='nearest'`` on Python 2.7 and 64-bit Windows + (:issue:`1140`). + `Stephan Hoyer `_. + +- Fixed ``where`` with ``drop='True'`` for empty masks (:issue:`1341`). + By `Stephan Hoyer `_ and + `Phillip J. Wolfram `_. + +.. _whats-new.0.9.1: + +v0.9.1 (30 January 2017) +------------------------ + +Renamed the "Unindexed dimensions" section in the ``Dataset`` and +``DataArray`` repr (added in v0.9.0) to "Dimensions without coordinates" +(:issue:`1199`). + +.. _whats-new.0.9.0: + +v0.9.0 (25 January 2017) +------------------------ + +This major release includes five months worth of enhancements and bug fixes from +24 contributors, including some significant changes that are not fully backwards +compatible. Highlights include: + +- Coordinates are now *optional* in the xarray data model, even for dimensions. +- Changes to caching, lazy loading and pickling to improve xarray's experience + for parallel computing. +- Improvements for accessing and manipulating ``pandas.MultiIndex`` levels. +- Many new methods and functions, including + :py:meth:`~DataArray.quantile`, + :py:meth:`~DataArray.cumsum`, + :py:meth:`~DataArray.cumprod` + :py:attr:`~DataArray.combine_first` + :py:meth:`~DataArray.set_index`, + :py:meth:`~DataArray.reset_index`, + :py:meth:`~DataArray.reorder_levels`, + :py:func:`~xarray.full_like`, + :py:func:`~xarray.zeros_like`, + :py:func:`~xarray.ones_like` + :py:func:`~xarray.open_dataarray`, + :py:meth:`~DataArray.compute`, + :py:meth:`Dataset.info`, + :py:func:`testing.assert_equal`, + :py:func:`testing.assert_identical`, and + :py:func:`testing.assert_allclose`. + +Breaking changes +~~~~~~~~~~~~~~~~ + +- Index coordinates for each dimensions are now optional, and no longer created + by default :issue:`1017`. You can identify such dimensions without coordinates + by their appearance in list of "Dimensions without coordinates" in the + ``Dataset`` or ``DataArray`` repr: + + .. ipython:: + :verbatim: + + In [1]: xr.Dataset({"foo": (("x", "y"), [[1, 2]])}) + Out[1]: + + Dimensions: (x: 1, y: 2) + Dimensions without coordinates: x, y + Data variables: + foo (x, y) int64 1 2 + + This has a number of implications: + + - :py:func:`~align` and :py:meth:`~Dataset.reindex` can now error, if + dimensions labels are missing and dimensions have different sizes. + - Because pandas does not support missing indexes, methods such as + ``to_dataframe``/``from_dataframe`` and ``stack``/``unstack`` no longer + roundtrip faithfully on all inputs. Use :py:meth:`~Dataset.reset_index` to + remove undesired indexes. + - ``Dataset.__delitem__`` and :py:meth:`~Dataset.drop` no longer delete/drop + variables that have dimensions matching a deleted/dropped variable. + - ``DataArray.coords.__delitem__`` is now allowed on variables matching + dimension names. + - ``.sel`` and ``.loc`` now handle indexing along a dimension without + coordinate labels by doing integer based indexing. See + :ref:`indexing.missing_coordinates` for an example. + - :py:attr:`~Dataset.indexes` is no longer guaranteed to include all + dimensions names as keys. The new method :py:meth:`~Dataset.get_index` has + been added to get an index for a dimension guaranteed, falling back to + produce a default ``RangeIndex`` if necessary. + +- The default behavior of ``merge`` is now ``compat='no_conflicts'``, so some + merges will now succeed in cases that previously raised + ``xarray.MergeError``. Set ``compat='broadcast_equals'`` to restore the + previous default. See :ref:`combining.no_conflicts` for more details. + +- Reading :py:attr:`~DataArray.values` no longer always caches values in a NumPy + array :issue:`1128`. Caching of ``.values`` on variables read from netCDF + files on disk is still the default when :py:func:`open_dataset` is called with + ``cache=True``. + By `Guido Imperiale `_ and + `Stephan Hoyer `_. +- Pickling a ``Dataset`` or ``DataArray`` linked to a file on disk no longer + caches its values into memory before pickling (:issue:`1128`). Instead, pickle + stores file paths and restores objects by reopening file references. This + enables preliminary, experimental use of xarray for opening files with + `dask.distributed `_. + By `Stephan Hoyer `_. +- Coordinates used to index a dimension are now loaded eagerly into + :py:class:`pandas.Index` objects, instead of loading the values lazily. + By `Guido Imperiale `_. +- Automatic levels for 2d plots are now guaranteed to land on ``vmin`` and + ``vmax`` when these kwargs are explicitly provided (:issue:`1191`). The + automated level selection logic also slightly changed. + By `Fabien Maussion `_. + +- ``DataArray.rename()`` behavior changed to strictly change the ``DataArray.name`` + if called with string argument, or strictly change coordinate names if called with + dict-like argument. + By `Markus Gonser `_. + +- By default ``to_netcdf()`` add a ``_FillValue = NaN`` attributes to float types. + By `Frederic Laliberte `_. + +- ``repr`` on ``DataArray`` objects uses an shortened display for NumPy array + data that is less likely to overflow onto multiple pages (:issue:`1207`). + By `Stephan Hoyer `_. + +- xarray no longer supports python 3.3, versions of dask prior to v0.9.0, + or versions of bottleneck prior to v1.0. + +Deprecations +~~~~~~~~~~~~ + +- Renamed the ``Coordinate`` class from xarray's low level API to + :py:class:`~xarray.IndexVariable`. ``Variable.to_variable`` and + ``Variable.to_coord`` have been renamed to + :py:meth:`~xarray.Variable.to_base_variable` and + :py:meth:`~xarray.Variable.to_index_variable`. +- Deprecated supplying ``coords`` as a dictionary to the ``DataArray`` + constructor without also supplying an explicit ``dims`` argument. The old + behavior encouraged relying on the iteration order of dictionaries, which is + a bad practice (:issue:`727`). +- Removed a number of methods deprecated since v0.7.0 or earlier: + ``load_data``, ``vars``, ``drop_vars``, ``dump``, ``dumps`` and the + ``variables`` keyword argument to ``Dataset``. +- Removed the dummy module that enabled ``import xray``. + +Enhancements +~~~~~~~~~~~~ + +- Added new method :py:meth:`~DataArray.combine_first` to ``DataArray`` and + ``Dataset``, based on the pandas method of the same name (see :ref:`combine`). + By `Chun-Wei Yuan `_. + +- Added the ability to change default automatic alignment (arithmetic_join="inner") + for binary operations via :py:func:`~xarray.set_options()` + (see :ref:`math automatic alignment`). + By `Chun-Wei Yuan `_. + +- Add checking of ``attr`` names and values when saving to netCDF, raising useful + error messages if they are invalid. (:issue:`911`). + By `Robin Wilson `_. +- Added ability to save ``DataArray`` objects directly to netCDF files using + :py:meth:`~xarray.DataArray.to_netcdf`, and to load directly from netCDF files + using :py:func:`~xarray.open_dataarray` (:issue:`915`). These remove the need + to convert a ``DataArray`` to a ``Dataset`` before saving as a netCDF file, + and deals with names to ensure a perfect 'roundtrip' capability. + By `Robin Wilson `_. +- Multi-index levels are now accessible as "virtual" coordinate variables, + e.g., ``ds['time']`` can pull out the ``'time'`` level of a multi-index + (see :ref:`coordinates`). ``sel`` also accepts providing multi-index levels + as keyword arguments, e.g., ``ds.sel(time='2000-01')`` + (see :ref:`multi-level indexing`). + By `Benoit Bovy `_. +- Added ``set_index``, ``reset_index`` and ``reorder_levels`` methods to + easily create and manipulate (multi-)indexes (see :ref:`reshape.set_index`). + By `Benoit Bovy `_. +- Added the ``compat`` option ``'no_conflicts'`` to ``merge``, allowing the + combination of xarray objects with disjoint (:issue:`742`) or + overlapping (:issue:`835`) coordinates as long as all present data agrees. + By `Johnnie Gray `_. See + :ref:`combining.no_conflicts` for more details. +- It is now possible to set ``concat_dim=None`` explicitly in + :py:func:`~xarray.open_mfdataset` to disable inferring a dimension along + which to concatenate. + By `Stephan Hoyer `_. +- Added methods :py:meth:`DataArray.compute`, :py:meth:`Dataset.compute`, and + :py:meth:`Variable.compute` as a non-mutating alternative to + :py:meth:`~DataArray.load`. + By `Guido Imperiale `_. +- Adds DataArray and Dataset methods :py:meth:`~xarray.DataArray.cumsum` and + :py:meth:`~xarray.DataArray.cumprod`. By `Phillip J. Wolfram + `_. + +- New properties :py:attr:`Dataset.sizes` and :py:attr:`DataArray.sizes` for + providing consistent access to dimension length on both ``Dataset`` and + ``DataArray`` (:issue:`921`). + By `Stephan Hoyer `_. +- New keyword argument ``drop=True`` for :py:meth:`~DataArray.sel`, + :py:meth:`~DataArray.isel` and :py:meth:`~DataArray.squeeze` for dropping + scalar coordinates that arise from indexing. + ``DataArray`` (:issue:`242`). + By `Stephan Hoyer `_. + +- New top-level functions :py:func:`~xarray.full_like`, + :py:func:`~xarray.zeros_like`, and :py:func:`~xarray.ones_like` + By `Guido Imperiale `_. +- Overriding a preexisting attribute with + :py:func:`~xarray.register_dataset_accessor` or + :py:func:`~xarray.register_dataarray_accessor` now issues a warning instead of + raising an error (:issue:`1082`). + By `Stephan Hoyer `_. +- Options for axes sharing between subplots are exposed to + :py:class:`~xarray.plot.FacetGrid` and :py:func:`~xarray.plot.plot`, so axes + sharing can be disabled for polar plots. + By `Bas Hoonhout `_. +- New utility functions :py:func:`~xarray.testing.assert_equal`, + :py:func:`~xarray.testing.assert_identical`, and + :py:func:`~xarray.testing.assert_allclose` for asserting relationships + between xarray objects, designed for use in a pytest test suite. +- ``figsize``, ``size`` and ``aspect`` plot arguments are now supported for all + plots (:issue:`897`). See :ref:`plotting.figsize` for more details. + By `Stephan Hoyer `_ and + `Fabien Maussion `_. +- New :py:meth:`~Dataset.info` method to summarize ``Dataset`` variables + and attributes. The method prints to a buffer (e.g. ``stdout``) with output + similar to what the command line utility ``ncdump -h`` produces (:issue:`1150`). + By `Joe Hamman `_. +- Added the ability write unlimited netCDF dimensions with the ``scipy`` and + ``netcdf4`` backends via the new ``xray.Dataset.encoding`` attribute + or via the ``unlimited_dims`` argument to ``xray.Dataset.to_netcdf``. + By `Joe Hamman `_. +- New :py:meth:`~DataArray.quantile` method to calculate quantiles from + DataArray objects (:issue:`1187`). + By `Joe Hamman `_. + + +Bug fixes +~~~~~~~~~ +- ``groupby_bins`` now restores empty bins by default (:issue:`1019`). + By `Ryan Abernathey `_. + +- Fix issues for dates outside the valid range of pandas timestamps + (:issue:`975`). By `Mathias Hauser `_. + +- Unstacking produced flipped array after stacking decreasing coordinate values + (:issue:`980`). + By `Stephan Hoyer `_. + +- Setting ``dtype`` via the ``encoding`` parameter of ``to_netcdf`` failed if + the encoded dtype was the same as the dtype of the original array + (:issue:`873`). + By `Stephan Hoyer `_. + +- Fix issues with variables where both attributes ``_FillValue`` and + ``missing_value`` are set to ``NaN`` (:issue:`997`). + By `Marco Zühlke `_. + +- ``.where()`` and ``.fillna()`` now preserve attributes (:issue:`1009`). + By `Fabien Maussion `_. + +- Applying :py:func:`broadcast()` to an xarray object based on the dask backend + won't accidentally convert the array from dask to numpy anymore (:issue:`978`). + By `Guido Imperiale `_. + +- ``Dataset.concat()`` now preserves variables order (:issue:`1027`). + By `Fabien Maussion `_. + +- Fixed an issue with pcolormesh (:issue:`781`). A new + ``infer_intervals`` keyword gives control on whether the cell intervals + should be computed or not. + By `Fabien Maussion `_. + +- Grouping over an dimension with non-unique values with ``groupby`` gives + correct groups. + By `Stephan Hoyer `_. + +- Fixed accessing coordinate variables with non-string names from ``.coords``. + By `Stephan Hoyer `_. + +- :py:meth:`~xarray.DataArray.rename` now simultaneously renames the array and + any coordinate with the same name, when supplied via a :py:class:`dict` + (:issue:`1116`). + By `Yves Delley `_. + +- Fixed sub-optimal performance in certain operations with object arrays (:issue:`1121`). + By `Yves Delley `_. + +- Fix ``.groupby(group)`` when ``group`` has datetime dtype (:issue:`1132`). + By `Jonas Sølvsteen `_. + +- Fixed a bug with facetgrid (the ``norm`` keyword was ignored, :issue:`1159`). + By `Fabien Maussion `_. + +- Resolved a concurrency bug that could cause Python to crash when + simultaneously reading and writing netCDF4 files with dask (:issue:`1172`). + By `Stephan Hoyer `_. + +- Fix to make ``.copy()`` actually copy dask arrays, which will be relevant for + future releases of dask in which dask arrays will be mutable (:issue:`1180`). + By `Stephan Hoyer `_. + +- Fix opening NetCDF files with multi-dimensional time variables + (:issue:`1229`). + By `Stephan Hoyer `_. + +Performance improvements +~~~~~~~~~~~~~~~~~~~~~~~~ + +- ``xarray.Dataset.isel_points`` and ``xarray.Dataset.sel_points`` now + use vectorised indexing in numpy and dask (:issue:`1161`), which can + result in several orders of magnitude speedup. + By `Jonathan Chambers `_. + +.. _whats-new.0.8.2: + +v0.8.2 (18 August 2016) +----------------------- + +This release includes a number of bug fixes and minor enhancements. + +Breaking changes +~~~~~~~~~~~~~~~~ + +- :py:func:`~xarray.broadcast` and :py:func:`~xarray.concat` now auto-align + inputs, using ``join=outer``. Previously, these functions raised + ``ValueError`` for non-aligned inputs. + By `Guido Imperiale `_. + +Enhancements +~~~~~~~~~~~~ + +- New documentation on :ref:`panel transition`. By + `Maximilian Roos `_. +- New ``Dataset`` and ``DataArray`` methods :py:meth:`~xarray.Dataset.to_dict` + and :py:meth:`~xarray.Dataset.from_dict` to allow easy conversion between + dictionaries and xarray objects (:issue:`432`). See + :ref:`dictionary IO` for more details. + By `Julia Signell `_. +- Added ``exclude`` and ``indexes`` optional parameters to :py:func:`~xarray.align`, + and ``exclude`` optional parameter to :py:func:`~xarray.broadcast`. + By `Guido Imperiale `_. +- Better error message when assigning variables without dimensions + (:issue:`971`). By `Stephan Hoyer `_. +- Better error message when reindex/align fails due to duplicate index values + (:issue:`956`). By `Stephan Hoyer `_. + +Bug fixes +~~~~~~~~~ + +- Ensure xarray works with h5netcdf v0.3.0 for arrays with ``dtype=str`` + (:issue:`953`). By `Stephan Hoyer `_. +- ``Dataset.__dir__()`` (i.e. the method python calls to get autocomplete + options) failed if one of the dataset's keys was not a string (:issue:`852`). + By `Maximilian Roos `_. +- ``Dataset`` constructor can now take arbitrary objects as values + (:issue:`647`). By `Maximilian Roos `_. +- Clarified ``copy`` argument for :py:meth:`~xarray.DataArray.reindex` and + :py:func:`~xarray.align`, which now consistently always return new xarray + objects (:issue:`927`). +- Fix ``open_mfdataset`` with ``engine='pynio'`` (:issue:`936`). + By `Stephan Hoyer `_. +- ``groupby_bins`` sorted bin labels as strings (:issue:`952`). + By `Stephan Hoyer `_. +- Fix bug introduced by v0.8.0 that broke assignment to datasets when both the + left and right side have the same non-unique index values (:issue:`956`). + +.. _whats-new.0.8.1: + +v0.8.1 (5 August 2016) +---------------------- + +Bug fixes +~~~~~~~~~ + +- Fix bug in v0.8.0 that broke assignment to Datasets with non-unique + indexes (:issue:`943`). By `Stephan Hoyer `_. + +.. _whats-new.0.8.0: + +v0.8.0 (2 August 2016) +---------------------- + +This release includes four months of new features and bug fixes, including +several breaking changes. + +.. _v0.8.0.breaking: + +Breaking changes +~~~~~~~~~~~~~~~~ + +- Dropped support for Python 2.6 (:issue:`855`). +- Indexing on multi-index now drop levels, which is consistent with pandas. + It also changes the name of the dimension / coordinate when the multi-index is + reduced to a single index (:issue:`802`). +- Contour plots no longer add a colorbar per default (:issue:`866`). Filled + contour plots are unchanged. +- ``DataArray.values`` and ``.data`` now always returns an NumPy array-like + object, even for 0-dimensional arrays with object dtype (:issue:`867`). + Previously, ``.values`` returned native Python objects in such cases. To + convert the values of scalar arrays to Python objects, use the ``.item()`` + method. + +Enhancements +~~~~~~~~~~~~ + +- Groupby operations now support grouping over multidimensional variables. A new + method called :py:meth:`~xarray.Dataset.groupby_bins` has also been added to + allow users to specify bins for grouping. The new features are described in + :ref:`groupby.multidim` and :ref:`/examples/multidimensional-coords.ipynb`. + By `Ryan Abernathey `_. + +- DataArray and Dataset method :py:meth:`where` now supports a ``drop=True`` + option that clips coordinate elements that are fully masked. By + `Phillip J. Wolfram `_. + +- New top level :py:func:`merge` function allows for combining variables from + any number of ``Dataset`` and/or ``DataArray`` variables. See :ref:`merge` + for more details. By `Stephan Hoyer `_. + +- :py:meth:`DataArray.resample` and :py:meth:`Dataset.resample` now support the + ``keep_attrs=False`` option that determines whether variable and dataset + attributes are retained in the resampled object. By + `Jeremy McGibbon `_. + +- Better multi-index support in :py:meth:`DataArray.sel`, + :py:meth:`DataArray.loc`, :py:meth:`Dataset.sel` and + :py:meth:`Dataset.loc`, which now behave more closely to pandas and + which also accept dictionaries for indexing based on given level names + and labels (see :ref:`multi-level indexing`). + By `Benoit Bovy `_. + +- New (experimental) decorators :py:func:`~xarray.register_dataset_accessor` and + :py:func:`~xarray.register_dataarray_accessor` for registering custom xarray + extensions without subclassing. They are described in the new documentation + page on :ref:`internals`. By `Stephan Hoyer `_. + +- Round trip boolean datatypes. Previously, writing boolean datatypes to netCDF + formats would raise an error since netCDF does not have a `bool` datatype. + This feature reads/writes a `dtype` attribute to boolean variables in netCDF + files. By `Joe Hamman `_. + +- 2D plotting methods now have two new keywords (`cbar_ax` and `cbar_kwargs`), + allowing more control on the colorbar (:issue:`872`). + By `Fabien Maussion `_. + +- New Dataset method :py:meth:`Dataset.filter_by_attrs`, akin to + ``netCDF4.Dataset.get_variables_by_attributes``, to easily filter + data variables using its attributes. + `Filipe Fernandes `_. + +Bug fixes +~~~~~~~~~ + +- Attributes were being retained by default for some resampling + operations when they should not. With the ``keep_attrs=False`` option, they + will no longer be retained by default. This may be backwards-incompatible + with some scripts, but the attributes may be kept by adding the + ``keep_attrs=True`` option. By + `Jeremy McGibbon `_. + +- Concatenating xarray objects along an axis with a MultiIndex or PeriodIndex + preserves the nature of the index (:issue:`875`). By + `Stephan Hoyer `_. + +- Fixed bug in arithmetic operations on DataArray objects whose dimensions + are numpy structured arrays or recarrays :issue:`861`, :issue:`837`. By + `Maciek Swat `_. + +- ``decode_cf_timedelta`` now accepts arrays with ``ndim`` >1 (:issue:`842`). + This fixes issue :issue:`665`. + `Filipe Fernandes `_. + +- Fix a bug where `xarray.ufuncs` that take two arguments would incorrectly + use to numpy functions instead of dask.array functions (:issue:`876`). By + `Stephan Hoyer `_. + +- Support for pickling functions from ``xarray.ufuncs`` (:issue:`901`). By + `Stephan Hoyer `_. + +- ``Variable.copy(deep=True)`` no longer converts MultiIndex into a base Index + (:issue:`769`). By `Benoit Bovy `_. + +- Fixes for groupby on dimensions with a multi-index (:issue:`867`). By + `Stephan Hoyer `_. + +- Fix printing datasets with unicode attributes on Python 2 (:issue:`892`). By + `Stephan Hoyer `_. + +- Fixed incorrect test for dask version (:issue:`891`). By + `Stephan Hoyer `_. + +- Fixed `dim` argument for `isel_points`/`sel_points` when a `pandas.Index` is + passed. By `Stephan Hoyer `_. + +- :py:func:`~xarray.plot.contour` now plots the correct number of contours + (:issue:`866`). By `Fabien Maussion `_. + +.. _whats-new.0.7.2: + +v0.7.2 (13 March 2016) +---------------------- + +This release includes two new, entirely backwards compatible features and +several bug fixes. + +Enhancements +~~~~~~~~~~~~ + +- New DataArray method :py:meth:`DataArray.dot` for calculating the dot + product of two DataArrays along shared dimensions. By + `Dean Pospisil `_. + +- Rolling window operations on DataArray objects are now supported via a new + :py:meth:`DataArray.rolling` method. For example: + + .. ipython:: + :verbatim: + + In [1]: import xarray as xr + ...: import numpy as np + + In [2]: arr = xr.DataArray(np.arange(0, 7.5, 0.5).reshape(3, 5), dims=("x", "y")) + + In [3]: arr + Out[3]: + + array([[ 0. , 0.5, 1. , 1.5, 2. ], + [ 2.5, 3. , 3.5, 4. , 4.5], + [ 5. , 5.5, 6. , 6.5, 7. ]]) + Coordinates: + * x (x) int64 0 1 2 + * y (y) int64 0 1 2 3 4 + + In [4]: arr.rolling(y=3, min_periods=2).mean() + Out[4]: + + array([[ nan, 0.25, 0.5 , 1. , 1.5 ], + [ nan, 2.75, 3. , 3.5 , 4. ], + [ nan, 5.25, 5.5 , 6. , 6.5 ]]) + Coordinates: + * x (x) int64 0 1 2 + * y (y) int64 0 1 2 3 4 + + See :ref:`comput.rolling` for more details. By + `Joe Hamman `_. + +Bug fixes +~~~~~~~~~ + +- Fixed an issue where plots using pcolormesh and Cartopy axes were being distorted + by the inference of the axis interval breaks. This change chooses not to modify + the coordinate variables when the axes have the attribute ``projection``, allowing + Cartopy to handle the extent of pcolormesh plots (:issue:`781`). By + `Joe Hamman `_. + +- 2D plots now better handle additional coordinates which are not ``DataArray`` + dimensions (:issue:`788`). By `Fabien Maussion `_. + + +.. _whats-new.0.7.1: + +v0.7.1 (16 February 2016) +------------------------- + +This is a bug fix release that includes two small, backwards compatible enhancements. +We recommend that all users upgrade. + +Enhancements +~~~~~~~~~~~~ + +- Numerical operations now return empty objects on no overlapping labels rather + than raising ``ValueError`` (:issue:`739`). +- :py:class:`~pandas.Series` is now supported as valid input to the ``Dataset`` + constructor (:issue:`740`). + +Bug fixes +~~~~~~~~~ + +- Restore checks for shape consistency between data and coordinates in the + DataArray constructor (:issue:`758`). +- Single dimension variables no longer transpose as part of a broader + ``.transpose``. This behavior was causing ``pandas.PeriodIndex`` dimensions + to lose their type (:issue:`749`) +- :py:class:`~xarray.Dataset` labels remain as their native type on ``.to_dataset``. + Previously they were coerced to strings (:issue:`745`) +- Fixed a bug where replacing a ``DataArray`` index coordinate would improperly + align the coordinate (:issue:`725`). +- ``DataArray.reindex_like`` now maintains the dtype of complex numbers when + reindexing leads to NaN values (:issue:`738`). +- ``Dataset.rename`` and ``DataArray.rename`` support the old and new names + being the same (:issue:`724`). +- Fix :py:meth:`~xarray.Dataset.from_dataframe` for DataFrames with Categorical + column and a MultiIndex index (:issue:`737`). +- Fixes to ensure xarray works properly after the upcoming pandas v0.18 and + NumPy v1.11 releases. + +Acknowledgments +~~~~~~~~~~~~~~~ + +The following individuals contributed to this release: + +- Edward Richards +- Maximilian Roos +- Rafael Guedes +- Spencer Hill +- Stephan Hoyer + +.. _whats-new.0.7.0: + +v0.7.0 (21 January 2016) +------------------------ + +This major release includes redesign of :py:class:`~xarray.DataArray` +internals, as well as new methods for reshaping, rolling and shifting +data. It includes preliminary support for :py:class:`pandas.MultiIndex`, +as well as a number of other features and bug fixes, several of which +offer improved compatibility with pandas. + +New name +~~~~~~~~ + +The project formerly known as "xray" is now "xarray", pronounced "x-array"! +This avoids a namespace conflict with the entire field of x-ray science. Renaming +our project seemed like the right thing to do, especially because some +scientists who work with actual x-rays are interested in using this project in +their work. Thanks for your understanding and patience in this transition. You +can now find our documentation and code repository at new URLs: + +- https://docs.xarray.dev +- https://github.com/pydata/xarray/ + +To ease the transition, we have simultaneously released v0.7.0 of both +``xray`` and ``xarray`` on the Python Package Index. These packages are +identical. For now, ``import xray`` still works, except it issues a +deprecation warning. This will be the last xray release. Going forward, we +recommend switching your import statements to ``import xarray as xr``. + +.. _v0.7.0.breaking: + +Breaking changes +~~~~~~~~~~~~~~~~ + +- The internal data model used by ``xray.DataArray`` has been + rewritten to fix several outstanding issues (:issue:`367`, :issue:`634`, + `this stackoverflow report`_). Internally, ``DataArray`` is now implemented + in terms of ``._variable`` and ``._coords`` attributes instead of holding + variables in a ``Dataset`` object. + + This refactor ensures that if a DataArray has the + same name as one of its coordinates, the array and the coordinate no longer + share the same data. + + In practice, this means that creating a DataArray with the same ``name`` as + one of its dimensions no longer automatically uses that array to label the + corresponding coordinate. You will now need to provide coordinate labels + explicitly. Here's the old behavior: + + .. ipython:: + :verbatim: + + In [2]: xray.DataArray([4, 5, 6], dims="x", name="x") + Out[2]: + + array([4, 5, 6]) + Coordinates: + * x (x) int64 4 5 6 + + and the new behavior (compare the values of the ``x`` coordinate): + + .. ipython:: + :verbatim: + + In [2]: xray.DataArray([4, 5, 6], dims="x", name="x") + Out[2]: + + array([4, 5, 6]) + Coordinates: + * x (x) int64 0 1 2 + +- It is no longer possible to convert a DataArray to a Dataset with + ``xray.DataArray.to_dataset`` if it is unnamed. This will now + raise ``ValueError``. If the array is unnamed, you need to supply the + ``name`` argument. + +.. _this stackoverflow report: http://stackoverflow.com/questions/33158558/python-xray-extract-first-and-last-time-value-within-each-month-of-a-timeseries + +Enhancements +~~~~~~~~~~~~ + +- Basic support for :py:class:`~pandas.MultiIndex` coordinates on xray objects, including + indexing, :py:meth:`~DataArray.stack` and :py:meth:`~DataArray.unstack`: + + .. ipython:: + :verbatim: + + In [7]: df = pd.DataFrame({"foo": range(3), "x": ["a", "b", "b"], "y": [0, 0, 1]}) + + In [8]: s = df.set_index(["x", "y"])["foo"] + + In [12]: arr = xray.DataArray(s, dims="z") + + In [13]: arr + Out[13]: + + array([0, 1, 2]) + Coordinates: + * z (z) object ('a', 0) ('b', 0) ('b', 1) + + In [19]: arr.indexes["z"] + Out[19]: + MultiIndex(levels=[[u'a', u'b'], [0, 1]], + labels=[[0, 1, 1], [0, 0, 1]], + names=[u'x', u'y']) + + In [14]: arr.unstack("z") + Out[14]: + + array([[ 0., nan], + [ 1., 2.]]) + Coordinates: + * x (x) object 'a' 'b' + * y (y) int64 0 1 + + In [26]: arr.unstack("z").stack(z=("x", "y")) + Out[26]: + + array([ 0., nan, 1., 2.]) + Coordinates: + * z (z) object ('a', 0) ('a', 1) ('b', 0) ('b', 1) + + See :ref:`reshape.stack` for more details. + + .. warning:: + + xray's MultiIndex support is still experimental, and we have a long to- + do list of desired additions (:issue:`719`), including better display of + multi-index levels when printing a ``Dataset``, and support for saving + datasets with a MultiIndex to a netCDF file. User contributions in this + area would be greatly appreciated. + +- Support for reading GRIB, HDF4 and other file formats via PyNIO_. +- Better error message when a variable is supplied with the same name as + one of its dimensions. +- Plotting: more control on colormap parameters (:issue:`642`). ``vmin`` and + ``vmax`` will not be silently ignored anymore. Setting ``center=False`` + prevents automatic selection of a divergent colormap. +- New ``xray.Dataset.shift`` and ``xray.Dataset.roll`` methods + for shifting/rotating datasets or arrays along a dimension: + + .. ipython:: python + :okwarning: + + array = xray.DataArray([5, 6, 7, 8], dims="x") + array.shift(x=2) + array.roll(x=2) + + Notice that ``shift`` moves data independently of coordinates, but ``roll`` + moves both data and coordinates. +- Assigning a ``pandas`` object directly as a ``Dataset`` variable is now permitted. Its + index names correspond to the ``dims`` of the ``Dataset``, and its data is aligned. +- Passing a :py:class:`pandas.DataFrame` or ``pandas.Panel`` to a Dataset constructor + is now permitted. +- New function ``xray.broadcast`` for explicitly broadcasting + ``DataArray`` and ``Dataset`` objects against each other. For example: + + .. ipython:: python + + a = xray.DataArray([1, 2, 3], dims="x") + b = xray.DataArray([5, 6], dims="y") + a + b + a2, b2 = xray.broadcast(a, b) + a2 + b2 + +.. _PyNIO: https://www.pyngl.ucar.edu/Nio.shtml + +Bug fixes +~~~~~~~~~ + +- Fixes for several issues found on ``DataArray`` objects with the same name + as one of their coordinates (see :ref:`v0.7.0.breaking` for more details). +- ``DataArray.to_masked_array`` always returns masked array with mask being an + array (not a scalar value) (:issue:`684`) +- Allows for (imperfect) repr of Coords when underlying index is PeriodIndex (:issue:`645`). +- Fixes for several issues found on ``DataArray`` objects with the same name + as one of their coordinates (see :ref:`v0.7.0.breaking` for more details). +- Attempting to assign a ``Dataset`` or ``DataArray`` variable/attribute using + attribute-style syntax (e.g., ``ds.foo = 42``) now raises an error rather + than silently failing (:issue:`656`, :issue:`714`). +- You can now pass pandas objects with non-numpy dtypes (e.g., ``categorical`` + or ``datetime64`` with a timezone) into xray without an error + (:issue:`716`). + +Acknowledgments +~~~~~~~~~~~~~~~ + +The following individuals contributed to this release: + +- Antony Lee +- Fabien Maussion +- Joe Hamman +- Maximilian Roos +- Stephan Hoyer +- Takeshi Kanmae +- femtotrader + +v0.6.1 (21 October 2015) +------------------------ + +This release contains a number of bug and compatibility fixes, as well +as enhancements to plotting, indexing and writing files to disk. + +Note that the minimum required version of dask for use with xray is now +version 0.6. + +API Changes +~~~~~~~~~~~ + +- The handling of colormaps and discrete color lists for 2D plots in + ``xray.DataArray.plot`` was changed to provide more compatibility + with matplotlib's ``contour`` and ``contourf`` functions (:issue:`538`). + Now discrete lists of colors should be specified using ``colors`` keyword, + rather than ``cmap``. + +Enhancements +~~~~~~~~~~~~ + +- Faceted plotting through ``xray.plot.FacetGrid`` and the + ``xray.plot.plot`` method. See :ref:`plotting.faceting` for more details + and examples. +- ``xray.Dataset.sel`` and ``xray.Dataset.reindex`` now support + the ``tolerance`` argument for controlling nearest-neighbor selection + (:issue:`629`): + + .. ipython:: + :verbatim: + + In [5]: array = xray.DataArray([1, 2, 3], dims="x") + + In [6]: array.reindex(x=[0.9, 1.5], method="nearest", tolerance=0.2) + Out[6]: + + array([ 2., nan]) + Coordinates: + * x (x) float64 0.9 1.5 + + This feature requires pandas v0.17 or newer. +- New ``encoding`` argument in ``xray.Dataset.to_netcdf`` for writing + netCDF files with compression, as described in the new documentation + section on :ref:`io.netcdf.writing_encoded`. +- Add ``xray.Dataset.real`` and ``xray.Dataset.imag`` + attributes to Dataset and DataArray (:issue:`553`). +- More informative error message with ``xray.Dataset.from_dataframe`` + if the frame has duplicate columns. +- xray now uses deterministic names for dask arrays it creates or opens from + disk. This allows xray users to take advantage of dask's nascent support for + caching intermediate computation results. See :issue:`555` for an example. + +Bug fixes +~~~~~~~~~ + +- Forwards compatibility with the latest pandas release (v0.17.0). We were + using some internal pandas routines for datetime conversion, which + unfortunately have now changed upstream (:issue:`569`). +- Aggregation functions now correctly skip ``NaN`` for data for ``complex128`` + dtype (:issue:`554`). +- Fixed indexing 0d arrays with unicode dtype (:issue:`568`). +- ``xray.DataArray.name`` and Dataset keys must be a string or None to + be written to netCDF (:issue:`533`). +- ``xray.DataArray.where`` now uses dask instead of numpy if either the + array or ``other`` is a dask array. Previously, if ``other`` was a numpy array + the method was evaluated eagerly. +- Global attributes are now handled more consistently when loading remote + datasets using ``engine='pydap'`` (:issue:`574`). +- It is now possible to assign to the ``.data`` attribute of DataArray objects. +- ``coordinates`` attribute is now kept in the encoding dictionary after + decoding (:issue:`610`). +- Compatibility with numpy 1.10 (:issue:`617`). + +Acknowledgments +~~~~~~~~~~~~~~~ + +The following individuals contributed to this release: + +- Ryan Abernathey +- Pete Cable +- Clark Fitzgerald +- Joe Hamman +- Stephan Hoyer +- Scott Sinclair + +v0.6.0 (21 August 2015) +----------------------- + +This release includes numerous bug fixes and enhancements. Highlights +include the introduction of a plotting module and the new Dataset and DataArray +methods ``xray.Dataset.isel_points``, ``xray.Dataset.sel_points``, +``xray.Dataset.where`` and ``xray.Dataset.diff``. There are no +breaking changes from v0.5.2. + +Enhancements +~~~~~~~~~~~~ + +- Plotting methods have been implemented on DataArray objects + ``xray.DataArray.plot`` through integration with matplotlib + (:issue:`185`). For an introduction, see :ref:`plotting`. +- Variables in netCDF files with multiple missing values are now decoded as NaN + after issuing a warning if open_dataset is called with mask_and_scale=True. +- We clarified our rules for when the result from an xray operation is a copy + vs. a view (see :ref:`copies_vs_views` for more details). +- Dataset variables are now written to netCDF files in order of appearance + when using the netcdf4 backend (:issue:`479`). + +- Added ``xray.Dataset.isel_points`` and ``xray.Dataset.sel_points`` + to support pointwise indexing of Datasets and DataArrays (:issue:`475`). + + .. ipython:: + :verbatim: + + In [1]: da = xray.DataArray( + ...: np.arange(56).reshape((7, 8)), + ...: coords={"x": list("abcdefg"), "y": 10 * np.arange(8)}, + ...: dims=["x", "y"], + ...: ) + + In [2]: da + Out[2]: + + array([[ 0, 1, 2, 3, 4, 5, 6, 7], + [ 8, 9, 10, 11, 12, 13, 14, 15], + [16, 17, 18, 19, 20, 21, 22, 23], + [24, 25, 26, 27, 28, 29, 30, 31], + [32, 33, 34, 35, 36, 37, 38, 39], + [40, 41, 42, 43, 44, 45, 46, 47], + [48, 49, 50, 51, 52, 53, 54, 55]]) + Coordinates: + * y (y) int64 0 10 20 30 40 50 60 70 + * x (x) |S1 'a' 'b' 'c' 'd' 'e' 'f' 'g' + + # we can index by position along each dimension + In [3]: da.isel_points(x=[0, 1, 6], y=[0, 1, 0], dim="points") + Out[3]: + + array([ 0, 9, 48]) + Coordinates: + y (points) int64 0 10 0 + x (points) |S1 'a' 'b' 'g' + * points (points) int64 0 1 2 + + # or equivalently by label + In [9]: da.sel_points(x=["a", "b", "g"], y=[0, 10, 0], dim="points") + Out[9]: + + array([ 0, 9, 48]) + Coordinates: + y (points) int64 0 10 0 + x (points) |S1 'a' 'b' 'g' + * points (points) int64 0 1 2 + +- New ``xray.Dataset.where`` method for masking xray objects according + to some criteria. This works particularly well with multi-dimensional data: + + .. ipython:: python + + ds = xray.Dataset(coords={"x": range(100), "y": range(100)}) + ds["distance"] = np.sqrt(ds.x**2 + ds.y**2) + + @savefig where_example.png width=4in height=4in + ds.distance.where(ds.distance < 100).plot() + +- Added new methods ``xray.DataArray.diff`` and ``xray.Dataset.diff`` + for finite difference calculations along a given axis. + +- New ``xray.DataArray.to_masked_array`` convenience method for + returning a numpy.ma.MaskedArray. + + .. ipython:: python + + da = xray.DataArray(np.random.random_sample(size=(5, 4))) + da.where(da < 0.5) + da.where(da < 0.5).to_masked_array(copy=True) + +- Added new flag "drop_variables" to ``xray.open_dataset`` for + excluding variables from being parsed. This may be useful to drop + variables with problems or inconsistent values. + +Bug fixes +~~~~~~~~~ + +- Fixed aggregation functions (e.g., sum and mean) on big-endian arrays when + bottleneck is installed (:issue:`489`). +- Dataset aggregation functions dropped variables with unsigned integer dtype + (:issue:`505`). +- ``.any()`` and ``.all()`` were not lazy when used on xray objects containing + dask arrays. +- Fixed an error when attempting to saving datetime64 variables to netCDF + files when the first element is ``NaT`` (:issue:`528`). +- Fix pickle on DataArray objects (:issue:`515`). +- Fixed unnecessary coercion of float64 to float32 when using netcdf3 and + netcdf4_classic formats (:issue:`526`). + +v0.5.2 (16 July 2015) +--------------------- + +This release contains bug fixes, several additional options for opening and +saving netCDF files, and a backwards incompatible rewrite of the advanced +options for ``xray.concat``. + +Backwards incompatible changes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- The optional arguments ``concat_over`` and ``mode`` in ``xray.concat`` have + been removed and replaced by ``data_vars`` and ``coords``. The new arguments are both + more easily understood and more robustly implemented, and allowed us to fix a bug + where ``concat`` accidentally loaded data into memory. If you set values for + these optional arguments manually, you will need to update your code. The default + behavior should be unchanged. + +Enhancements +~~~~~~~~~~~~ + +- ``xray.open_mfdataset`` now supports a ``preprocess`` argument for + preprocessing datasets prior to concatenaton. This is useful if datasets + cannot be otherwise merged automatically, e.g., if the original datasets + have conflicting index coordinates (:issue:`443`). +- ``xray.open_dataset`` and ``xray.open_mfdataset`` now use a + global thread lock by default for reading from netCDF files with dask. This + avoids possible segmentation faults for reading from netCDF4 files when HDF5 + is not configured properly for concurrent access (:issue:`444`). +- Added support for serializing arrays of complex numbers with `engine='h5netcdf'`. +- The new ``xray.save_mfdataset`` function allows for saving multiple + datasets to disk simultaneously. This is useful when processing large datasets + with dask.array. For example, to save a dataset too big to fit into memory + to one file per year, we could write: + + .. ipython:: + :verbatim: + + In [1]: years, datasets = zip(*ds.groupby("time.year")) + + In [2]: paths = ["%s.nc" % y for y in years] + + In [3]: xray.save_mfdataset(datasets, paths) + +Bug fixes +~~~~~~~~~ + +- Fixed ``min``, ``max``, ``argmin`` and ``argmax`` for arrays with string or + unicode types (:issue:`453`). +- ``xray.open_dataset`` and ``xray.open_mfdataset`` support + supplying chunks as a single integer. +- Fixed a bug in serializing scalar datetime variable to netCDF. +- Fixed a bug that could occur in serialization of 0-dimensional integer arrays. +- Fixed a bug where concatenating DataArrays was not always lazy (:issue:`464`). +- When reading datasets with h5netcdf, bytes attributes are decoded to strings. + This allows conventions decoding to work properly on Python 3 (:issue:`451`). + +v0.5.1 (15 June 2015) +--------------------- + +This minor release fixes a few bugs and an inconsistency with pandas. It also +adds the ``pipe`` method, copied from pandas. + +Enhancements +~~~~~~~~~~~~ + +- Added ``xray.Dataset.pipe``, replicating the `new pandas method`_ in version + 0.16.2. See :ref:`transforming datasets` for more details. +- ``xray.Dataset.assign`` and ``xray.Dataset.assign_coords`` + now assign new variables in sorted (alphabetical) order, mirroring the + behavior in pandas. Previously, the order was arbitrary. + +.. _new pandas method: http://pandas.pydata.org/pandas-docs/version/0.16.2/whatsnew.html#pipe + +Bug fixes +~~~~~~~~~ + +- ``xray.concat`` fails in an edge case involving identical coordinate variables (:issue:`425`) +- We now decode variables loaded from netCDF3 files with the scipy engine using native + endianness (:issue:`416`). This resolves an issue when aggregating these arrays with + bottleneck installed. + +v0.5 (1 June 2015) +------------------ + +Highlights +~~~~~~~~~~ + +The headline feature in this release is experimental support for out-of-core +computing (data that doesn't fit into memory) with :doc:`user-guide/dask`. This includes a new +top-level function ``xray.open_mfdataset`` that makes it easy to open +a collection of netCDF (using dask) as a single ``xray.Dataset`` object. For +more on dask, read the `blog post introducing xray + dask`_ and the new +documentation section :doc:`user-guide/dask`. + +.. _blog post introducing xray + dask: https://www.anaconda.com/blog/developer-blog/xray-dask-out-core-labeled-arrays-python/ + +Dask makes it possible to harness parallelism and manipulate gigantic datasets +with xray. It is currently an optional dependency, but it may become required +in the future. + +Backwards incompatible changes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- The logic used for choosing which variables are concatenated with + ``xray.concat`` has changed. Previously, by default any variables + which were equal across a dimension were not concatenated. This lead to some + surprising behavior, where the behavior of groupby and concat operations + could depend on runtime values (:issue:`268`). For example: + + .. ipython:: + :verbatim: + + In [1]: ds = xray.Dataset({"x": 0}) + + In [2]: xray.concat([ds, ds], dim="y") + Out[2]: + + Dimensions: () + Coordinates: + *empty* + Data variables: + x int64 0 + + Now, the default always concatenates data variables: + + .. ipython:: python + :suppress: + + ds = xray.Dataset({"x": 0}) + + .. ipython:: python + + xray.concat([ds, ds], dim="y") + + To obtain the old behavior, supply the argument ``concat_over=[]``. + +Enhancements +~~~~~~~~~~~~ + +- New ``xray.Dataset.to_dataarray`` and enhanced + ``xray.DataArray.to_dataset`` methods make it easy to switch back + and forth between arrays and datasets: + + .. ipython:: python + + ds = xray.Dataset( + {"a": 1, "b": ("x", [1, 2, 3])}, + coords={"c": 42}, + attrs={"Conventions": "None"}, + ) + ds.to_dataarray() + ds.to_dataarray().to_dataset(dim="variable") + +- New ``xray.Dataset.fillna`` method to fill missing values, modeled + off the pandas method of the same name: + + .. ipython:: python + + array = xray.DataArray([np.nan, 1, np.nan, 3], dims="x") + array.fillna(0) + + ``fillna`` works on both ``Dataset`` and ``DataArray`` objects, and uses + index based alignment and broadcasting like standard binary operations. It + also can be applied by group, as illustrated in + :ref:`/examples/weather-data.ipynb#Fill-missing-values-with-climatology`. +- New ``xray.Dataset.assign`` and ``xray.Dataset.assign_coords`` + methods patterned off the new :py:meth:`DataFrame.assign ` + method in pandas: + + .. ipython:: python + + ds = xray.Dataset({"y": ("x", [1, 2, 3])}) + ds.assign(z=lambda ds: ds.y**2) + ds.assign_coords(z=("x", ["a", "b", "c"])) + + These methods return a new Dataset (or DataArray) with updated data or + coordinate variables. +- ``xray.Dataset.sel`` now supports the ``method`` parameter, which works + like the parameter of the same name on ``xray.Dataset.reindex``. It + provides a simple interface for doing nearest-neighbor interpolation: + + .. use verbatim because I can't seem to install pandas 0.16.1 on RTD :( + + .. ipython:: + :verbatim: + + In [12]: ds.sel(x=1.1, method="nearest") + Out[12]: + + Dimensions: () + Coordinates: + x int64 1 + Data variables: + y int64 2 + + In [13]: ds.sel(x=[1.1, 2.1], method="pad") + Out[13]: + + Dimensions: (x: 2) + Coordinates: + * x (x) int64 1 2 + Data variables: + y (x) int64 2 3 + + See :ref:`nearest neighbor lookups` for more details. +- You can now control the underlying backend used for accessing remote + datasets (via OPeNDAP) by specifying ``engine='netcdf4'`` or + ``engine='pydap'``. +- xray now provides experimental support for reading and writing netCDF4 files directly + via `h5py`_ with the `h5netcdf`_ package, avoiding the netCDF4-Python package. You + will need to install h5netcdf and specify ``engine='h5netcdf'`` to try this + feature. +- Accessing data from remote datasets now has retrying logic (with exponential + backoff) that should make it robust to occasional bad responses from DAP + servers. +- You can control the width of the Dataset repr with ``xray.set_options``. + It can be used either as a context manager, in which case the default is restored + outside the context: + + .. ipython:: python + + ds = xray.Dataset({"x": np.arange(1000)}) + with xray.set_options(display_width=40): + print(ds) + + Or to set a global option: + + .. ipython:: + :verbatim: + + In [1]: xray.set_options(display_width=80) + + The default value for the ``display_width`` option is 80. + +.. _h5py: http://www.h5py.org/ +.. _h5netcdf: https://github.com/shoyer/h5netcdf + +Deprecations +~~~~~~~~~~~~ + +- The method ``load_data()`` has been renamed to the more succinct + ``xray.Dataset.load``. + +v0.4.1 (18 March 2015) +---------------------- + +The release contains bug fixes and several new features. All changes should be +fully backwards compatible. + +Enhancements +~~~~~~~~~~~~ + +- New documentation sections on :ref:`time-series` and + :ref:`combining multiple files`. +- ``xray.Dataset.resample`` lets you resample a dataset or data array to + a new temporal resolution. The syntax is the `same as pandas`_, except you + need to supply the time dimension explicitly: + + .. ipython:: python + :verbatim: + + time = pd.date_range("2000-01-01", freq="6H", periods=10) + array = xray.DataArray(np.arange(10), [("time", time)]) + array.resample("1D", dim="time") + + You can specify how to do the resampling with the ``how`` argument and other + options such as ``closed`` and ``label`` let you control labeling: + + .. ipython:: python + :verbatim: + + array.resample("1D", dim="time", how="sum", label="right") + + If the desired temporal resolution is higher than the original data + (upsampling), xray will insert missing values: + + .. ipython:: python + :verbatim: + + array.resample("3H", "time") + +- ``first`` and ``last`` methods on groupby objects let you take the first or + last examples from each group along the grouped axis: + + .. ipython:: python + :verbatim: + + array.groupby("time.day").first() + + These methods combine well with ``resample``: + + .. ipython:: python + :verbatim: + + array.resample("1D", dim="time", how="first") + + +- ``xray.Dataset.swap_dims`` allows for easily swapping one dimension + out for another: + + .. ipython:: python + + ds = xray.Dataset({"x": range(3), "y": ("x", list("abc"))}) + ds + ds.swap_dims({"x": "y"}) + + This was possible in earlier versions of xray, but required some contortions. +- ``xray.open_dataset`` and ``xray.Dataset.to_netcdf`` now + accept an ``engine`` argument to explicitly select which underlying library + (netcdf4 or scipy) is used for reading/writing a netCDF file. + +.. _same as pandas: http://pandas.pydata.org/pandas-docs/stable/timeseries.html#up-and-downsampling + +Bug fixes +~~~~~~~~~ + +- Fixed a bug where data netCDF variables read from disk with + ``engine='scipy'`` could still be associated with the file on disk, even + after closing the file (:issue:`341`). This manifested itself in warnings + about mmapped arrays and segmentation faults (if the data was accessed). +- Silenced spurious warnings about all-NaN slices when using nan-aware + aggregation methods (:issue:`344`). +- Dataset aggregations with ``keep_attrs=True`` now preserve attributes on + data variables, not just the dataset itself. +- Tests for xray now pass when run on Windows (:issue:`360`). +- Fixed a regression in v0.4 where saving to netCDF could fail with the error + ``ValueError: could not automatically determine time units``. + +v0.4 (2 March, 2015) +-------------------- + +This is one of the biggest releases yet for xray: it includes some major +changes that may break existing code, along with the usual collection of minor +enhancements and bug fixes. On the plus side, this release includes all +hitherto planned breaking changes, so the upgrade path for xray should be +smoother going forward. + +Breaking changes +~~~~~~~~~~~~~~~~ + +- We now automatically align index labels in arithmetic, dataset construction, + merging and updating. This means the need for manually invoking methods like + ``xray.align`` and ``xray.Dataset.reindex_like`` should be + vastly reduced. + + :ref:`For arithmetic`, we align + based on the **intersection** of labels: + + .. ipython:: python + + lhs = xray.DataArray([1, 2, 3], [("x", [0, 1, 2])]) + rhs = xray.DataArray([2, 3, 4], [("x", [1, 2, 3])]) + lhs + rhs + + :ref:`For dataset construction and merging`, we align based on the + **union** of labels: + + .. ipython:: python + + xray.Dataset({"foo": lhs, "bar": rhs}) + + :ref:`For update and __setitem__`, we align based on the **original** + object: + + .. ipython:: python + + lhs.coords["rhs"] = rhs + lhs + +- Aggregations like ``mean`` or ``median`` now skip missing values by default: + + .. ipython:: python + + xray.DataArray([1, 2, np.nan, 3]).mean() + + You can turn this behavior off by supplying the keyword argument + ``skipna=False``. + + These operations are lightning fast thanks to integration with bottleneck_, + which is a new optional dependency for xray (numpy is used if bottleneck is + not installed). +- Scalar coordinates no longer conflict with constant arrays with the same + value (e.g., in arithmetic, merging datasets and concat), even if they have + different shape (:issue:`243`). For example, the coordinate ``c`` here + persists through arithmetic, even though it has different shapes on each + DataArray: + + .. ipython:: python + + a = xray.DataArray([1, 2], coords={"c": 0}, dims="x") + b = xray.DataArray([1, 2], coords={"c": ("x", [0, 0])}, dims="x") + (a + b).coords + + This functionality can be controlled through the ``compat`` option, which + has also been added to the ``xray.Dataset`` constructor. +- Datetime shortcuts such as ``'time.month'`` now return a ``DataArray`` with + the name ``'month'``, not ``'time.month'`` (:issue:`345`). This makes it + easier to index the resulting arrays when they are used with ``groupby``: + + .. ipython:: python + + time = xray.DataArray( + pd.date_range("2000-01-01", periods=365), dims="time", name="time" + ) + counts = time.groupby("time.month").count() + counts.sel(month=2) + + Previously, you would need to use something like + ``counts.sel(**{'time.month': 2}})``, which is much more awkward. +- The ``season`` datetime shortcut now returns an array of string labels + such `'DJF'`: + + .. code-block:: ipython + + In[92]: ds = xray.Dataset({"t": pd.date_range("2000-01-01", periods=12, freq="M")}) + + In[93]: ds["t.season"] + Out[93]: + + array(['DJF', 'DJF', 'MAM', ..., 'SON', 'SON', 'DJF'], dtype='`_. +- Use functions that return generic ndarrays with DataArray.groupby.apply and + Dataset.apply (:issue:`327` and :issue:`329`). Thanks Jeff Gerard! +- Consolidated the functionality of ``dumps`` (writing a dataset to a netCDF3 + bytestring) into ``xray.Dataset.to_netcdf`` (:issue:`333`). +- ``xray.Dataset.to_netcdf`` now supports writing to groups in netCDF4 + files (:issue:`333`). It also finally has a full docstring -- you should read + it! +- ``xray.open_dataset`` and ``xray.Dataset.to_netcdf`` now + work on netCDF3 files when netcdf4-python is not installed as long as scipy + is available (:issue:`333`). +- The new ``xray.Dataset.drop`` and ``xray.DataArray.drop`` methods + makes it easy to drop explicitly listed variables or index labels: + + .. ipython:: python + :okwarning: + + # drop variables + ds = xray.Dataset({"x": 0, "y": 1}) + ds.drop("x") + + # drop index labels + arr = xray.DataArray([1, 2, 3], coords=[("x", list("abc"))]) + arr.drop(["a", "c"], dim="x") + +- ``xray.Dataset.broadcast_equals`` has been added to correspond to + the new ``compat`` option. +- Long attributes are now truncated at 500 characters when printing a dataset + (:issue:`338`). This should make things more convenient for working with + datasets interactively. +- Added a new documentation example, :ref:`/examples/monthly-means.ipynb`. Thanks Joe + Hamman! + +Bug fixes +~~~~~~~~~ + +- Several bug fixes related to decoding time units from netCDF files + (:issue:`316`, :issue:`330`). Thanks Stefan Pfenninger! +- xray no longer requires ``decode_coords=False`` when reading datasets with + unparsable coordinate attributes (:issue:`308`). +- Fixed ``DataArray.loc`` indexing with ``...`` (:issue:`318`). +- Fixed an edge case that resulting in an error when reindexing + multi-dimensional variables (:issue:`315`). +- Slicing with negative step sizes (:issue:`312`). +- Invalid conversion of string arrays to numeric dtype (:issue:`305`). +- Fixed``repr()`` on dataset objects with non-standard dates (:issue:`347`). + +Deprecations +~~~~~~~~~~~~ + +- ``dump`` and ``dumps`` have been deprecated in favor of + ``xray.Dataset.to_netcdf``. +- ``drop_vars`` has been deprecated in favor of ``xray.Dataset.drop``. + +Future plans +~~~~~~~~~~~~ + +The biggest feature I'm excited about working toward in the immediate future +is supporting out-of-core operations in xray using Dask_, a part of the Blaze_ +project. For a preview of using Dask with weather data, read +`this blog post`_ by Matthew Rocklin. See :issue:`328` for more details. + +.. _Dask: https://dask.org +.. _Blaze: https://blaze.pydata.org +.. _this blog post: https://matthewrocklin.com/blog/work/2015/02/13/Towards-OOC-Slicing-and-Stacking + +v0.3.2 (23 December, 2014) +-------------------------- + +This release focused on bug-fixes, speedups and resolving some niggling +inconsistencies. + +There are a few cases where the behavior of xray differs from the previous +version. However, I expect that in almost all cases your code will continue to +run unmodified. + +.. warning:: + + xray now requires pandas v0.15.0 or later. This was necessary for + supporting TimedeltaIndex without too many painful hacks. + +Backwards incompatible changes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- Arrays of :py:class:`datetime.datetime` objects are now automatically cast to + ``datetime64[ns]`` arrays when stored in an xray object, using machinery + borrowed from pandas: + + .. ipython:: python + + from datetime import datetime + + xray.Dataset({"t": [datetime(2000, 1, 1)]}) + +- xray now has support (including serialization to netCDF) for + :py:class:`~pandas.TimedeltaIndex`. :py:class:`datetime.timedelta` objects + are thus accordingly cast to ``timedelta64[ns]`` objects when appropriate. +- Masked arrays are now properly coerced to use ``NaN`` as a sentinel value + (:issue:`259`). + +Enhancements +~~~~~~~~~~~~ + +- Due to popular demand, we have added experimental attribute style access as + a shortcut for dataset variables, coordinates and attributes: + + .. ipython:: python + + ds = xray.Dataset({"tmin": ([], 25, {"units": "celsius"})}) + ds.tmin.units + + Tab-completion for these variables should work in editors such as IPython. + However, setting variables or attributes in this fashion is not yet + supported because there are some unresolved ambiguities (:issue:`300`). +- You can now use a dictionary for indexing with labeled dimensions. This + provides a safe way to do assignment with labeled dimensions: + + .. ipython:: python + + array = xray.DataArray(np.zeros(5), dims=["x"]) + array[dict(x=slice(3))] = 1 + array + +- Non-index coordinates can now be faithfully written to and restored from + netCDF files. This is done according to CF conventions when possible by + using the ``coordinates`` attribute on a data variable. When not possible, + xray defines a global ``coordinates`` attribute. +- Preliminary support for converting ``xray.DataArray`` objects to and from + CDAT_ ``cdms2`` variables. +- We sped up any operation that involves creating a new Dataset or DataArray + (e.g., indexing, aggregation, arithmetic) by a factor of 30 to 50%. The full + speed up requires cyordereddict_ to be installed. + +.. _CDAT: http://uvcdat.llnl.gov/ +.. _cyordereddict: https://github.com/shoyer/cyordereddict + +Bug fixes +~~~~~~~~~ + +- Fix for ``to_dataframe()`` with 0d string/object coordinates (:issue:`287`) +- Fix for ``to_netcdf`` with 0d string variable (:issue:`284`) +- Fix writing datetime64 arrays to netcdf if NaT is present (:issue:`270`) +- Fix align silently upcasts data arrays when NaNs are inserted (:issue:`264`) + +Future plans +~~~~~~~~~~~~ + +- I am contemplating switching to the terms "coordinate variables" and "data + variables" instead of the (currently used) "coordinates" and "variables", + following their use in `CF Conventions`_ (:issue:`293`). This would mostly + have implications for the documentation, but I would also change the + ``Dataset`` attribute ``vars`` to ``data``. +- I no longer certain that automatic label alignment for arithmetic would be a + good idea for xray -- it is a feature from pandas that I have not missed + (:issue:`186`). +- The main API breakage that I *do* anticipate in the next release is finally + making all aggregation operations skip missing values by default + (:issue:`130`). I'm pretty sick of writing ``ds.reduce(np.nanmean, 'time')``. +- The next version of xray (0.4) will remove deprecated features and aliases + whose use currently raises a warning. + +If you have opinions about any of these anticipated changes, I would love to +hear them -- please add a note to any of the referenced GitHub issues. + +.. _CF Conventions: http://cfconventions.org/Data/cf-conventions/cf-conventions-1.6/build/cf-conventions.html + +v0.3.1 (22 October, 2014) +------------------------- + +This is mostly a bug-fix release to make xray compatible with the latest +release of pandas (v0.15). + +We added several features to better support working with missing values and +exporting xray objects to pandas. We also reorganized the internal API for +serializing and deserializing datasets, but this change should be almost +entirely transparent to users. + +Other than breaking the experimental DataStore API, there should be no +backwards incompatible changes. + +New features +~~~~~~~~~~~~ + +- Added ``xray.Dataset.count`` and ``xray.Dataset.dropna`` + methods, copied from pandas, for working with missing values (:issue:`247`, + :issue:`58`). +- Added ``xray.DataArray.to_pandas`` for + converting a data array into the pandas object with the same dimensionality + (1D to Series, 2D to DataFrame, etc.) (:issue:`255`). +- Support for reading gzipped netCDF3 files (:issue:`239`). +- Reduced memory usage when writing netCDF files (:issue:`251`). +- 'missing_value' is now supported as an alias for the '_FillValue' attribute + on netCDF variables (:issue:`245`). +- Trivial indexes, equivalent to ``range(n)`` where ``n`` is the length of the + dimension, are no longer written to disk (:issue:`245`). + +Bug fixes +~~~~~~~~~ + +- Compatibility fixes for pandas v0.15 (:issue:`262`). +- Fixes for display and indexing of ``NaT`` (not-a-time) (:issue:`238`, + :issue:`240`) +- Fix slicing by label was an argument is a data array (:issue:`250`). +- Test data is now shipped with the source distribution (:issue:`253`). +- Ensure order does not matter when doing arithmetic with scalar data arrays + (:issue:`254`). +- Order of dimensions preserved with ``DataArray.to_dataframe`` (:issue:`260`). + +v0.3 (21 September 2014) +------------------------ + +New features +~~~~~~~~~~~~ + +- **Revamped coordinates**: "coordinates" now refer to all arrays that are not + used to index a dimension. Coordinates are intended to allow for keeping track + of arrays of metadata that describe the grid on which the points in "variable" + arrays lie. They are preserved (when unambiguous) even though mathematical + operations. +- **Dataset math** ``xray.Dataset`` objects now support all arithmetic + operations directly. Dataset-array operations map across all dataset + variables; dataset-dataset operations act on each pair of variables with the + same name. +- **GroupBy math**: This provides a convenient shortcut for normalizing by the + average value of a group. +- The dataset ``__repr__`` method has been entirely overhauled; dataset + objects now show their values when printed. +- You can now index a dataset with a list of variables to return a new dataset: + ``ds[['foo', 'bar']]``. + +Backwards incompatible changes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- ``Dataset.__eq__`` and ``Dataset.__ne__`` are now element-wise operations + instead of comparing all values to obtain a single boolean. Use the method + ``xray.Dataset.equals`` instead. + +Deprecations +~~~~~~~~~~~~ + +- ``Dataset.noncoords`` is deprecated: use ``Dataset.vars`` instead. +- ``Dataset.select_vars`` deprecated: index a ``Dataset`` with a list of + variable names instead. +- ``DataArray.select_vars`` and ``DataArray.drop_vars`` deprecated: use + ``xray.DataArray.reset_coords`` instead. + +v0.2 (14 August 2014) +--------------------- + +This is major release that includes some new features and quite a few bug +fixes. Here are the highlights: + +- There is now a direct constructor for ``DataArray`` objects, which makes it + possible to create a DataArray without using a Dataset. This is highlighted + in the refreshed ``tutorial``. +- You can perform aggregation operations like ``mean`` directly on + ``xray.Dataset`` objects, thanks to Joe Hamman. These aggregation + methods also worked on grouped datasets. +- xray now works on Python 2.6, thanks to Anna Kuznetsova. +- A number of methods and attributes were given more sensible (usually shorter) + names: ``labeled`` -> ``sel``, ``indexed`` -> ``isel``, ``select`` -> + ``select_vars``, ``unselect`` -> ``drop_vars``, ``dimensions`` -> ``dims``, + ``coordinates`` -> ``coords``, ``attributes`` -> ``attrs``. +- New ``xray.Dataset.load_data`` and ``xray.Dataset.close`` + methods for datasets facilitate lower level of control of data loaded from + disk. + +v0.1.1 (20 May 2014) +-------------------- + +xray 0.1.1 is a bug-fix release that includes changes that should be almost +entirely backwards compatible with v0.1: + +- Python 3 support (:issue:`53`) +- Required numpy version relaxed to 1.7 (:issue:`129`) +- Return numpy.datetime64 arrays for non-standard calendars (:issue:`126`) +- Support for opening datasets associated with NetCDF4 groups (:issue:`127`) +- Bug-fixes for concatenating datetime arrays (:issue:`134`) + +Special thanks to new contributors Thomas Kluyver, Joe Hamman and Alistair +Miles. + +v0.1 (2 May 2014) +----------------- + +Initial release. diff --git a/test/fixtures/whole_applications/xarray/licenses/ANYTREE_LICENSE b/test/fixtures/whole_applications/xarray/licenses/ANYTREE_LICENSE new file mode 100644 index 0000000..8dada3e --- /dev/null +++ b/test/fixtures/whole_applications/xarray/licenses/ANYTREE_LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright {yyyy} {name of copyright owner} + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/test/fixtures/whole_applications/xarray/licenses/DASK_LICENSE b/test/fixtures/whole_applications/xarray/licenses/DASK_LICENSE new file mode 100644 index 0000000..e98784c --- /dev/null +++ b/test/fixtures/whole_applications/xarray/licenses/DASK_LICENSE @@ -0,0 +1,28 @@ +Copyright (c) 2014-2018, Anaconda, Inc. and contributors +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +Redistributions of source code must retain the above copyright notice, +this list of conditions and the following disclaimer. + +Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation +and/or other materials provided with the distribution. + +Neither the name of Anaconda nor the names of any contributors may be used to +endorse or promote products derived from this software without specific prior +written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +THE POSSIBILITY OF SUCH DAMAGE. diff --git a/test/fixtures/whole_applications/xarray/licenses/ICOMOON_LICENSE b/test/fixtures/whole_applications/xarray/licenses/ICOMOON_LICENSE new file mode 100644 index 0000000..4ea99c2 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/licenses/ICOMOON_LICENSE @@ -0,0 +1,395 @@ +Attribution 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution 4.0 International Public License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution 4.0 International Public License ("Public License"). To the +extent this Public License may be interpreted as a contract, You are +granted the Licensed Rights in consideration of Your acceptance of +these terms and conditions, and the Licensor grants You such rights in +consideration of benefits the Licensor receives from making the +Licensed Material available under these terms and conditions. + + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + + d. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + e. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + f. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + g. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + h. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + i. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + j. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + k. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part; and + + b. produce, reproduce, and Share Adapted Material. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties. + + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + 4. If You Share Adapted Material You produce, the Adapter's + License You apply must not prevent recipients of the Adapted + Material from complying with this Public License. + + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material; and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the “Licensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. diff --git a/test/fixtures/whole_applications/xarray/licenses/NUMPY_LICENSE b/test/fixtures/whole_applications/xarray/licenses/NUMPY_LICENSE new file mode 100644 index 0000000..7e972cf --- /dev/null +++ b/test/fixtures/whole_applications/xarray/licenses/NUMPY_LICENSE @@ -0,0 +1,30 @@ +Copyright (c) 2005-2011, NumPy Developers. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + + * Neither the name of the NumPy Developers nor the names of any + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/test/fixtures/whole_applications/xarray/licenses/PANDAS_LICENSE b/test/fixtures/whole_applications/xarray/licenses/PANDAS_LICENSE new file mode 100644 index 0000000..8026eb4 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/licenses/PANDAS_LICENSE @@ -0,0 +1,36 @@ +pandas license +============== + +Copyright (c) 2011-2012, Lambda Foundry, Inc. and PyData Development Team +All rights reserved. + +Copyright (c) 2008-2011 AQR Capital Management, LLC +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + + * Neither the name of the copyright holder nor the names of any + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDER AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/test/fixtures/whole_applications/xarray/licenses/PYTHON_LICENSE b/test/fixtures/whole_applications/xarray/licenses/PYTHON_LICENSE new file mode 100644 index 0000000..88251f5 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/licenses/PYTHON_LICENSE @@ -0,0 +1,254 @@ +A. HISTORY OF THE SOFTWARE +========================== + +Python was created in the early 1990s by Guido van Rossum at Stichting +Mathematisch Centrum (CWI, see http://www.cwi.nl) in the Netherlands +as a successor of a language called ABC. Guido remains Python's +principal author, although it includes many contributions from others. + +In 1995, Guido continued his work on Python at the Corporation for +National Research Initiatives (CNRI, see http://www.cnri.reston.va.us) +in Reston, Virginia where he released several versions of the +software. + +In May 2000, Guido and the Python core development team moved to +BeOpen.com to form the BeOpen PythonLabs team. In October of the same +year, the PythonLabs team moved to Digital Creations (now Zope +Corporation, see http://www.zope.com). In 2001, the Python Software +Foundation (PSF, see http://www.python.org/psf/) was formed, a +non-profit organization created specifically to own Python-related +Intellectual Property. Zope Corporation is a sponsoring member of +the PSF. + +All Python releases are Open Source (see http://www.opensource.org for +the Open Source Definition). Historically, most, but not all, Python +releases have also been GPL-compatible; the table below summarizes +the various releases. + + Release Derived Year Owner GPL- + from compatible? (1) + + 0.9.0 thru 1.2 1991-1995 CWI yes + 1.3 thru 1.5.2 1.2 1995-1999 CNRI yes + 1.6 1.5.2 2000 CNRI no + 2.0 1.6 2000 BeOpen.com no + 1.6.1 1.6 2001 CNRI yes (2) + 2.1 2.0+1.6.1 2001 PSF no + 2.0.1 2.0+1.6.1 2001 PSF yes + 2.1.1 2.1+2.0.1 2001 PSF yes + 2.1.2 2.1.1 2002 PSF yes + 2.1.3 2.1.2 2002 PSF yes + 2.2 and above 2.1.1 2001-now PSF yes + +Footnotes: + +(1) GPL-compatible doesn't mean that we're distributing Python under + the GPL. All Python licenses, unlike the GPL, let you distribute + a modified version without making your changes open source. The + GPL-compatible licenses make it possible to combine Python with + other software that is released under the GPL; the others don't. + +(2) According to Richard Stallman, 1.6.1 is not GPL-compatible, + because its license has a choice of law clause. According to + CNRI, however, Stallman's lawyer has told CNRI's lawyer that 1.6.1 + is "not incompatible" with the GPL. + +Thanks to the many outside volunteers who have worked under Guido's +direction to make these releases possible. + + +B. TERMS AND CONDITIONS FOR ACCESSING OR OTHERWISE USING PYTHON +=============================================================== + +PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2 +-------------------------------------------- + +1. This LICENSE AGREEMENT is between the Python Software Foundation +("PSF"), and the Individual or Organization ("Licensee") accessing and +otherwise using this software ("Python") in source or binary form and +its associated documentation. + +2. Subject to the terms and conditions of this License Agreement, PSF hereby +grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, +analyze, test, perform and/or display publicly, prepare derivative works, +distribute, and otherwise use Python alone or in any derivative version, +provided, however, that PSF's License Agreement and PSF's notice of copyright, +i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, +2011, 2012, 2013, 2014, 2015 Python Software Foundation; All Rights Reserved" +are retained in Python alone or in any derivative version prepared by Licensee. + +3. In the event Licensee prepares a derivative work that is based on +or incorporates Python or any part thereof, and wants to make +the derivative work available to others as provided herein, then +Licensee hereby agrees to include in any such work a brief summary of +the changes made to Python. + +4. PSF is making Python available to Licensee on an "AS IS" +basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND +DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT +INFRINGE ANY THIRD PARTY RIGHTS. + +5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS +A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, +OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. + +6. This License Agreement will automatically terminate upon a material +breach of its terms and conditions. + +7. Nothing in this License Agreement shall be deemed to create any +relationship of agency, partnership, or joint venture between PSF and +Licensee. This License Agreement does not grant permission to use PSF +trademarks or trade name in a trademark sense to endorse or promote +products or services of Licensee, or any third party. + +8. By copying, installing or otherwise using Python, Licensee +agrees to be bound by the terms and conditions of this License +Agreement. + + +BEOPEN.COM LICENSE AGREEMENT FOR PYTHON 2.0 +------------------------------------------- + +BEOPEN PYTHON OPEN SOURCE LICENSE AGREEMENT VERSION 1 + +1. This LICENSE AGREEMENT is between BeOpen.com ("BeOpen"), having an +office at 160 Saratoga Avenue, Santa Clara, CA 95051, and the +Individual or Organization ("Licensee") accessing and otherwise using +this software in source or binary form and its associated +documentation ("the Software"). + +2. Subject to the terms and conditions of this BeOpen Python License +Agreement, BeOpen hereby grants Licensee a non-exclusive, +royalty-free, world-wide license to reproduce, analyze, test, perform +and/or display publicly, prepare derivative works, distribute, and +otherwise use the Software alone or in any derivative version, +provided, however, that the BeOpen Python License is retained in the +Software, alone or in any derivative version prepared by Licensee. + +3. BeOpen is making the Software available to Licensee on an "AS IS" +basis. BEOPEN MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, BEOPEN MAKES NO AND +DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF THE SOFTWARE WILL NOT +INFRINGE ANY THIRD PARTY RIGHTS. + +4. BEOPEN SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF THE +SOFTWARE FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS +AS A RESULT OF USING, MODIFYING OR DISTRIBUTING THE SOFTWARE, OR ANY +DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. + +5. This License Agreement will automatically terminate upon a material +breach of its terms and conditions. + +6. This License Agreement shall be governed by and interpreted in all +respects by the law of the State of California, excluding conflict of +law provisions. Nothing in this License Agreement shall be deemed to +create any relationship of agency, partnership, or joint venture +between BeOpen and Licensee. This License Agreement does not grant +permission to use BeOpen trademarks or trade names in a trademark +sense to endorse or promote products or services of Licensee, or any +third party. As an exception, the "BeOpen Python" logos available at +http://www.pythonlabs.com/logos.html may be used according to the +permissions granted on that web page. + +7. By copying, installing or otherwise using the software, Licensee +agrees to be bound by the terms and conditions of this License +Agreement. + + +CNRI LICENSE AGREEMENT FOR PYTHON 1.6.1 +--------------------------------------- + +1. This LICENSE AGREEMENT is between the Corporation for National +Research Initiatives, having an office at 1895 Preston White Drive, +Reston, VA 20191 ("CNRI"), and the Individual or Organization +("Licensee") accessing and otherwise using Python 1.6.1 software in +source or binary form and its associated documentation. + +2. Subject to the terms and conditions of this License Agreement, CNRI +hereby grants Licensee a nonexclusive, royalty-free, world-wide +license to reproduce, analyze, test, perform and/or display publicly, +prepare derivative works, distribute, and otherwise use Python 1.6.1 +alone or in any derivative version, provided, however, that CNRI's +License Agreement and CNRI's notice of copyright, i.e., "Copyright (c) +1995-2001 Corporation for National Research Initiatives; All Rights +Reserved" are retained in Python 1.6.1 alone or in any derivative +version prepared by Licensee. Alternately, in lieu of CNRI's License +Agreement, Licensee may substitute the following text (omitting the +quotes): "Python 1.6.1 is made available subject to the terms and +conditions in CNRI's License Agreement. This Agreement together with +Python 1.6.1 may be located on the Internet using the following +unique, persistent identifier (known as a handle): 1895.22/1013. This +Agreement may also be obtained from a proxy server on the Internet +using the following URL: http://hdl.handle.net/1895.22/1013". + +3. In the event Licensee prepares a derivative work that is based on +or incorporates Python 1.6.1 or any part thereof, and wants to make +the derivative work available to others as provided herein, then +Licensee hereby agrees to include in any such work a brief summary of +the changes made to Python 1.6.1. + +4. CNRI is making Python 1.6.1 available to Licensee on an "AS IS" +basis. CNRI MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, CNRI MAKES NO AND +DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON 1.6.1 WILL NOT +INFRINGE ANY THIRD PARTY RIGHTS. + +5. CNRI SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +1.6.1 FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS +A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON 1.6.1, +OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. + +6. This License Agreement will automatically terminate upon a material +breach of its terms and conditions. + +7. This License Agreement shall be governed by the federal +intellectual property law of the United States, including without +limitation the federal copyright law, and, to the extent such +U.S. federal law does not apply, by the law of the Commonwealth of +Virginia, excluding Virginia's conflict of law provisions. +Notwithstanding the foregoing, with regard to derivative works based +on Python 1.6.1 that incorporate non-separable material that was +previously distributed under the GNU General Public License (GPL), the +law of the Commonwealth of Virginia shall govern this License +Agreement only as to issues arising under or with respect to +Paragraphs 4, 5, and 7 of this License Agreement. Nothing in this +License Agreement shall be deemed to create any relationship of +agency, partnership, or joint venture between CNRI and Licensee. This +License Agreement does not grant permission to use CNRI trademarks or +trade name in a trademark sense to endorse or promote products or +services of Licensee, or any third party. + +8. By clicking on the "ACCEPT" button where indicated, or by copying, +installing or otherwise using Python 1.6.1, Licensee agrees to be +bound by the terms and conditions of this License Agreement. + + ACCEPT + + +CWI LICENSE AGREEMENT FOR PYTHON 0.9.0 THROUGH 1.2 +-------------------------------------------------- + +Copyright (c) 1991 - 1995, Stichting Mathematisch Centrum Amsterdam, +The Netherlands. All rights reserved. + +Permission to use, copy, modify, and distribute this software and its +documentation for any purpose and without fee is hereby granted, +provided that the above copyright notice appear in all copies and that +both that copyright notice and this permission notice appear in +supporting documentation, and that the name of Stichting Mathematisch +Centrum or CWI not be used in advertising or publicity pertaining to +distribution of the software without specific, written prior +permission. + +STICHTING MATHEMATISCH CENTRUM DISCLAIMS ALL WARRANTIES WITH REGARD TO +THIS SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND +FITNESS, IN NO EVENT SHALL STICHTING MATHEMATISCH CENTRUM BE LIABLE +FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. diff --git a/test/fixtures/whole_applications/xarray/licenses/SCIKIT_LEARN_LICENSE b/test/fixtures/whole_applications/xarray/licenses/SCIKIT_LEARN_LICENSE new file mode 100644 index 0000000..63bc7ee --- /dev/null +++ b/test/fixtures/whole_applications/xarray/licenses/SCIKIT_LEARN_LICENSE @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2007-2021 The scikit-learn developers. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE diff --git a/test/fixtures/whole_applications/xarray/licenses/SEABORN_LICENSE b/test/fixtures/whole_applications/xarray/licenses/SEABORN_LICENSE new file mode 100644 index 0000000..c6b4209 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/licenses/SEABORN_LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2012-2013, Michael L. Waskom +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the {organization} nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/test/fixtures/whole_applications/xarray/properties/README.md b/test/fixtures/whole_applications/xarray/properties/README.md new file mode 100644 index 0000000..86c1d41 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/properties/README.md @@ -0,0 +1,22 @@ +# Property-based tests using Hypothesis + +This directory contains property-based tests using a library +called [Hypothesis](https://github.com/HypothesisWorks/hypothesis-python). + +The property tests for xarray are a work in progress - more are always welcome. +They are stored in a separate directory because they tend to run more examples +and thus take longer, and so that local development can run a test suite +without needing to `pip install hypothesis`. + +## Hang on, "property-based" tests? + +Instead of making assertions about operations on a particular piece of +data, you use Hypothesis to describe a *kind* of data, then make assertions +that should hold for *any* example of this kind. + +For example: "given a 2d ndarray of dtype uint8 `arr`, +`xr.DataArray(arr).plot.imshow()` never raises an exception". + +Hypothesis will then try many random examples, and report a minimised +failing input for each error it finds. +[See the docs for more info.](https://hypothesis.readthedocs.io/en/master/) diff --git a/test/fixtures/whole_applications/xarray/properties/conftest.py b/test/fixtures/whole_applications/xarray/properties/conftest.py new file mode 100644 index 0000000..30e6381 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/properties/conftest.py @@ -0,0 +1,29 @@ +import pytest + + +def pytest_addoption(parser): + parser.addoption( + "--run-slow-hypothesis", + action="store_true", + default=False, + help="run slow hypothesis tests", + ) + + +def pytest_collection_modifyitems(config, items): + if config.getoption("--run-slow-hypothesis"): + return + skip_slow_hyp = pytest.mark.skip(reason="need --run-slow-hypothesis option to run") + for item in items: + if "slow_hypothesis" in item.keywords: + item.add_marker(skip_slow_hyp) + + +try: + from hypothesis import settings +except ImportError: + pass +else: + # Run for a while - arrays are a bigger search space than usual + settings.register_profile("ci", deadline=None, print_blob=True) + settings.load_profile("ci") diff --git a/test/fixtures/whole_applications/xarray/properties/test_encode_decode.py b/test/fixtures/whole_applications/xarray/properties/test_encode_decode.py new file mode 100644 index 0000000..60e1bbe --- /dev/null +++ b/test/fixtures/whole_applications/xarray/properties/test_encode_decode.py @@ -0,0 +1,52 @@ +""" +Property-based tests for encoding/decoding methods. + +These ones pass, just as you'd hope! + +""" + +import pytest + +pytest.importorskip("hypothesis") +# isort: split + +import hypothesis.extra.numpy as npst +import hypothesis.strategies as st +from hypothesis import given + +import xarray as xr + +an_array = npst.arrays( + dtype=st.one_of( + npst.unsigned_integer_dtypes(), npst.integer_dtypes(), npst.floating_dtypes() + ), + shape=npst.array_shapes(max_side=3), # max_side specified for performance +) + + +@pytest.mark.slow +@given(st.data(), an_array) +def test_CFMask_coder_roundtrip(data, arr) -> None: + names = data.draw( + st.lists(st.text(), min_size=arr.ndim, max_size=arr.ndim, unique=True).map( + tuple + ) + ) + original = xr.Variable(names, arr) + coder = xr.coding.variables.CFMaskCoder() + roundtripped = coder.decode(coder.encode(original)) + xr.testing.assert_identical(original, roundtripped) + + +@pytest.mark.slow +@given(st.data(), an_array) +def test_CFScaleOffset_coder_roundtrip(data, arr) -> None: + names = data.draw( + st.lists(st.text(), min_size=arr.ndim, max_size=arr.ndim, unique=True).map( + tuple + ) + ) + original = xr.Variable(names, arr) + coder = xr.coding.variables.CFScaleOffsetCoder() + roundtripped = coder.decode(coder.encode(original)) + xr.testing.assert_identical(original, roundtripped) diff --git a/test/fixtures/whole_applications/xarray/properties/test_index_manipulation.py b/test/fixtures/whole_applications/xarray/properties/test_index_manipulation.py new file mode 100644 index 0000000..77b7fcb --- /dev/null +++ b/test/fixtures/whole_applications/xarray/properties/test_index_manipulation.py @@ -0,0 +1,273 @@ +import itertools + +import numpy as np +import pytest + +import xarray as xr +from xarray import Dataset +from xarray.testing import _assert_internal_invariants + +pytest.importorskip("hypothesis") +pytestmark = pytest.mark.slow_hypothesis + +import hypothesis.extra.numpy as npst +import hypothesis.strategies as st +from hypothesis import note, settings +from hypothesis.stateful import ( + RuleBasedStateMachine, + initialize, + invariant, + precondition, + rule, +) + +import xarray.testing.strategies as xrst + + +@st.composite +def unique(draw, strategy): + # https://stackoverflow.com/questions/73737073/create-hypothesis-strategy-that-returns-unique-values + seen = draw(st.shared(st.builds(set), key="key-for-unique-elems")) + return draw( + strategy.filter(lambda x: x not in seen).map(lambda x: seen.add(x) or x) + ) + + +# Share to ensure we get unique names on each draw, +# so we don't try to add two variables with the same name +# or stack to a dimension with a name that already exists in the Dataset. +UNIQUE_NAME = unique(strategy=xrst.names()) +DIM_NAME = xrst.dimension_names(name_strategy=UNIQUE_NAME, min_dims=1, max_dims=1) +index_variables = st.builds( + xr.Variable, + data=npst.arrays( + dtype=xrst.pandas_index_dtypes(), + shape=npst.array_shapes(min_dims=1, max_dims=1), + elements=dict(allow_nan=False, allow_infinity=False, allow_subnormal=False), + unique=True, + ), + dims=DIM_NAME, + attrs=xrst.attrs(), +) + + +def add_dim_coord_and_data_var(ds, var): + (name,) = var.dims + # dim coord + ds[name] = var + # non-dim coord of same size; this allows renaming + ds[name + "_"] = var + + +class DatasetStateMachine(RuleBasedStateMachine): + # Can't use bundles because we'd need pre-conditions on consumes(bundle) + # indexed_dims = Bundle("indexed_dims") + # multi_indexed_dims = Bundle("multi_indexed_dims") + + def __init__(self): + super().__init__() + self.dataset = Dataset() + self.check_default_indexes = True + + # We track these separately as lists so we can guarantee order of iteration over them. + # Order of iteration over Dataset.dims is not guaranteed + self.indexed_dims = [] + self.multi_indexed_dims = [] + + @initialize(var=index_variables) + def init_ds(self, var): + """Initialize the Dataset so that at least one rule will always fire.""" + (name,) = var.dims + add_dim_coord_and_data_var(self.dataset, var) + + self.indexed_dims.append(name) + + # TODO: stacking with a timedelta64 index and unstacking converts it to object + @rule(var=index_variables) + def add_dim_coord(self, var): + (name,) = var.dims + note(f"adding dimension coordinate {name}") + add_dim_coord_and_data_var(self.dataset, var) + + self.indexed_dims.append(name) + + @rule(var=index_variables) + def assign_coords(self, var): + (name,) = var.dims + note(f"assign_coords: {name}") + self.dataset = self.dataset.assign_coords({name: var}) + + self.indexed_dims.append(name) + + @property + def has_indexed_dims(self) -> bool: + return bool(self.indexed_dims + self.multi_indexed_dims) + + @rule(data=st.data()) + @precondition(lambda self: self.has_indexed_dims) + def reset_index(self, data): + dim = data.draw(st.sampled_from(self.indexed_dims + self.multi_indexed_dims)) + self.check_default_indexes = False + note(f"> resetting {dim}") + self.dataset = self.dataset.reset_index(dim) + + if dim in self.indexed_dims: + del self.indexed_dims[self.indexed_dims.index(dim)] + elif dim in self.multi_indexed_dims: + del self.multi_indexed_dims[self.multi_indexed_dims.index(dim)] + + @rule(newname=UNIQUE_NAME, data=st.data(), create_index=st.booleans()) + @precondition(lambda self: bool(self.indexed_dims)) + def stack(self, newname, data, create_index): + oldnames = data.draw( + st.lists( + st.sampled_from(self.indexed_dims), + min_size=1, + max_size=3 if create_index else None, + unique=True, + ) + ) + note(f"> stacking {oldnames} as {newname}") + self.dataset = self.dataset.stack( + {newname: oldnames}, create_index=create_index + ) + + if create_index: + self.multi_indexed_dims += [newname] + + # if create_index is False, then we just drop these + for dim in oldnames: + del self.indexed_dims[self.indexed_dims.index(dim)] + + @rule(data=st.data()) + @precondition(lambda self: bool(self.multi_indexed_dims)) + def unstack(self, data): + # TODO: add None + dim = data.draw(st.sampled_from(self.multi_indexed_dims)) + note(f"> unstacking {dim}") + if dim is not None: + pd_index = self.dataset.xindexes[dim].index + self.dataset = self.dataset.unstack(dim) + + del self.multi_indexed_dims[self.multi_indexed_dims.index(dim)] + + if dim is not None: + self.indexed_dims.extend(pd_index.names) + else: + # TODO: fix this + pass + + @rule(newname=UNIQUE_NAME, data=st.data()) + @precondition(lambda self: bool(self.dataset.variables)) + def rename_vars(self, newname, data): + dim = data.draw(st.sampled_from(sorted(self.dataset.variables))) + # benbovy: "skip the default indexes invariant test when the name of an + # existing dimension coordinate is passed as input kwarg or dict key + # to .rename_vars()." + self.check_default_indexes = False + note(f"> renaming {dim} to {newname}") + self.dataset = self.dataset.rename_vars({dim: newname}) + + if dim in self.indexed_dims: + del self.indexed_dims[self.indexed_dims.index(dim)] + elif dim in self.multi_indexed_dims: + del self.multi_indexed_dims[self.multi_indexed_dims.index(dim)] + + @precondition(lambda self: bool(self.dataset.dims)) + @rule(data=st.data()) + def drop_dims(self, data): + dims = data.draw( + st.lists( + st.sampled_from(sorted(tuple(self.dataset.dims))), + min_size=1, + unique=True, + ) + ) + note(f"> drop_dims: {dims}") + self.dataset = self.dataset.drop_dims(dims) + + for dim in dims: + if dim in self.indexed_dims: + del self.indexed_dims[self.indexed_dims.index(dim)] + elif dim in self.multi_indexed_dims: + del self.multi_indexed_dims[self.multi_indexed_dims.index(dim)] + + @precondition(lambda self: bool(self.indexed_dims)) + @rule(data=st.data()) + def drop_indexes(self, data): + self.check_default_indexes = False + + dims = data.draw( + st.lists(st.sampled_from(self.indexed_dims), min_size=1, unique=True) + ) + note(f"> drop_indexes: {dims}") + self.dataset = self.dataset.drop_indexes(dims) + + for dim in dims: + if dim in self.indexed_dims: + del self.indexed_dims[self.indexed_dims.index(dim)] + elif dim in self.multi_indexed_dims: + del self.multi_indexed_dims[self.multi_indexed_dims.index(dim)] + + @property + def swappable_dims(self): + ds = self.dataset + options = [] + for dim in self.indexed_dims: + choices = [ + name + for name, var in ds._variables.items() + if var.dims == (dim,) + # TODO: Avoid swapping a dimension to itself + and name != dim + ] + options.extend( + (a, b) for a, b in itertools.zip_longest((dim,), choices, fillvalue=dim) + ) + return options + + @rule(data=st.data()) + # TODO: swap_dims is basically all broken if a multiindex is present + # TODO: Avoid swapping from Index to a MultiIndex level + # TODO: Avoid swapping from MultiIndex to a level of the same MultiIndex + # TODO: Avoid swapping when a MultiIndex is present + @precondition(lambda self: not bool(self.multi_indexed_dims)) + @precondition(lambda self: bool(self.swappable_dims)) + def swap_dims(self, data): + ds = self.dataset + options = self.swappable_dims + dim, to = data.draw(st.sampled_from(options)) + note( + f"> swapping {dim} to {to}, found swappable dims: {options}, all_dims: {tuple(self.dataset.dims)}" + ) + self.dataset = ds.swap_dims({dim: to}) + + del self.indexed_dims[self.indexed_dims.index(dim)] + self.indexed_dims += [to] + + @invariant() + def assert_invariants(self): + # note(f"> ===\n\n {self.dataset!r} \n===\n\n") + _assert_internal_invariants(self.dataset, self.check_default_indexes) + + +DatasetStateMachine.TestCase.settings = settings(max_examples=300, deadline=None) +DatasetTest = DatasetStateMachine.TestCase + + +@pytest.mark.skip(reason="failure detected by hypothesis") +def test_unstack_object(): + import xarray as xr + + ds = xr.Dataset() + ds["0"] = np.array(["", "\x000"], dtype=object) + ds.stack({"1": ["0"]}).unstack() + + +@pytest.mark.skip(reason="failure detected by hypothesis") +def test_unstack_timedelta_index(): + import xarray as xr + + ds = xr.Dataset() + ds["0"] = np.array([0, 1, 2, 3], dtype="timedelta64[ns]") + ds.stack({"1": ["0"]}).unstack() diff --git a/test/fixtures/whole_applications/xarray/properties/test_pandas_roundtrip.py b/test/fixtures/whole_applications/xarray/properties/test_pandas_roundtrip.py new file mode 100644 index 0000000..0249aa5 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/properties/test_pandas_roundtrip.py @@ -0,0 +1,127 @@ +""" +Property-based tests for roundtripping between xarray and pandas objects. +""" + +from functools import partial + +import numpy as np +import pandas as pd +import pytest + +import xarray as xr +from xarray.tests import has_pandas_3 + +pytest.importorskip("hypothesis") +import hypothesis.extra.numpy as npst # isort:skip +import hypothesis.extra.pandas as pdst # isort:skip +import hypothesis.strategies as st # isort:skip +from hypothesis import given # isort:skip + +numeric_dtypes = st.one_of( + npst.unsigned_integer_dtypes(endianness="="), + npst.integer_dtypes(endianness="="), + npst.floating_dtypes(endianness="="), +) + +numeric_series = numeric_dtypes.flatmap(lambda dt: pdst.series(dtype=dt)) + +an_array = npst.arrays( + dtype=numeric_dtypes, + shape=npst.array_shapes(max_dims=2), # can only convert 1D/2D to pandas +) + + +datetime_with_tz_strategy = st.datetimes(timezones=st.timezones()) +dataframe_strategy = pdst.data_frames( + [ + pdst.column("datetime_col", elements=datetime_with_tz_strategy), + pdst.column("other_col", elements=st.integers()), + ], + index=pdst.range_indexes(min_size=1, max_size=10), +) + + +@st.composite +def datasets_1d_vars(draw) -> xr.Dataset: + """Generate datasets with only 1D variables + + Suitable for converting to pandas dataframes. + """ + # Generate an index for the dataset + idx = draw(pdst.indexes(dtype="u8", min_size=0, max_size=100)) + + # Generate 1-3 variables, 1D with the same length as the index + vars_strategy = st.dictionaries( + keys=st.text(), + values=npst.arrays(dtype=numeric_dtypes, shape=len(idx)).map( + partial(xr.Variable, ("rows",)) + ), + min_size=1, + max_size=3, + ) + return xr.Dataset(draw(vars_strategy), coords={"rows": idx}) + + +@given(st.data(), an_array) +def test_roundtrip_dataarray(data, arr) -> None: + names = data.draw( + st.lists(st.text(), min_size=arr.ndim, max_size=arr.ndim, unique=True).map( + tuple + ) + ) + coords = {name: np.arange(n) for (name, n) in zip(names, arr.shape)} + original = xr.DataArray(arr, dims=names, coords=coords) + roundtripped = xr.DataArray(original.to_pandas()) + xr.testing.assert_identical(original, roundtripped) + + +@given(datasets_1d_vars()) +def test_roundtrip_dataset(dataset) -> None: + df = dataset.to_dataframe() + assert isinstance(df, pd.DataFrame) + roundtripped = xr.Dataset(df) + xr.testing.assert_identical(dataset, roundtripped) + + +@given(numeric_series, st.text()) +def test_roundtrip_pandas_series(ser, ix_name) -> None: + # Need to name the index, otherwise Xarray calls it 'dim_0'. + ser.index.name = ix_name + arr = xr.DataArray(ser) + roundtripped = arr.to_pandas() + pd.testing.assert_series_equal(ser, roundtripped) + xr.testing.assert_identical(arr, roundtripped.to_xarray()) + + +# Dataframes with columns of all the same dtype - for roundtrip to DataArray +numeric_homogeneous_dataframe = numeric_dtypes.flatmap( + lambda dt: pdst.data_frames(columns=pdst.columns(["a", "b", "c"], dtype=dt)) +) + + +@pytest.mark.xfail +@given(numeric_homogeneous_dataframe) +def test_roundtrip_pandas_dataframe(df) -> None: + # Need to name the indexes, otherwise Xarray names them 'dim_0', 'dim_1'. + df.index.name = "rows" + df.columns.name = "cols" + arr = xr.DataArray(df) + roundtripped = arr.to_pandas() + pd.testing.assert_frame_equal(df, roundtripped) + xr.testing.assert_identical(arr, roundtripped.to_xarray()) + + +@pytest.mark.skipif( + has_pandas_3, + reason="fails to roundtrip on pandas 3 (see https://github.com/pydata/xarray/issues/9098)", +) +@given(df=dataframe_strategy) +def test_roundtrip_pandas_dataframe_datetime(df) -> None: + # Need to name the indexes, otherwise Xarray names them 'dim_0', 'dim_1'. + df.index.name = "rows" + df.columns.name = "cols" + dataset = xr.Dataset.from_dataframe(df) + roundtripped = dataset.to_dataframe() + roundtripped.columns.name = "cols" # why? + pd.testing.assert_frame_equal(df, roundtripped) + xr.testing.assert_identical(dataset, roundtripped.to_xarray()) diff --git a/test/fixtures/whole_applications/xarray/pyproject.toml b/test/fixtures/whole_applications/xarray/pyproject.toml new file mode 100644 index 0000000..db64d7a --- /dev/null +++ b/test/fixtures/whole_applications/xarray/pyproject.toml @@ -0,0 +1,354 @@ +[project] +authors = [ + {name = "xarray Developers", email = "xarray@googlegroups.com"}, +] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Intended Audience :: Science/Research", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering", +] +description = "N-D labeled arrays and datasets in Python" +dynamic = ["version"] +license = {text = "Apache-2.0"} +name = "xarray" +readme = "README.md" +requires-python = ">=3.9" + +dependencies = [ + "numpy>=1.23", + "packaging>=23.1", + "pandas>=2.0", +] + +[project.optional-dependencies] +accel = ["scipy", "bottleneck", "numbagg", "flox", "opt_einsum"] +complete = ["xarray[accel,io,parallel,viz,dev]"] +dev = [ + "hypothesis", + "mypy", + "pre-commit", + "pytest", + "pytest-cov", + "pytest-env", + "pytest-xdist", + "pytest-timeout", + "ruff", + "xarray[complete]", +] +io = ["netCDF4", "h5netcdf", "scipy", 'pydap; python_version<"3.10"', "zarr", "fsspec", "cftime", "pooch"] +parallel = ["dask[complete]"] +viz = ["matplotlib", "seaborn", "nc-time-axis"] + +[project.urls] +Documentation = "https://docs.xarray.dev" +SciPy2015-talk = "https://www.youtube.com/watch?v=X0pAhJgySxk" +homepage = "https://xarray.dev/" +issue-tracker = "https://github.com/pydata/xarray/issues" +source-code = "https://github.com/pydata/xarray" + +[project.entry-points."xarray.chunkmanagers"] +dask = "xarray.namedarray.daskmanager:DaskManager" + +[build-system] +build-backend = "setuptools.build_meta" +requires = [ + "setuptools>=42", + "setuptools-scm>=7", +] + +[tool.setuptools] +packages = ["xarray"] + +[tool.setuptools_scm] +fallback_version = "9999" + +[tool.coverage.run] +omit = [ + "*/xarray/tests/*", + "*/xarray/core/dask_array_compat.py", + "*/xarray/core/npcompat.py", + "*/xarray/core/pdcompat.py", + "*/xarray/core/pycompat.py", + "*/xarray/core/types.py", +] +source = ["xarray"] + +[tool.coverage.report] +exclude_lines = ["pragma: no cover", "if TYPE_CHECKING"] + +[tool.mypy] +enable_error_code = "redundant-self" +exclude = [ + 'xarray/util/generate_.*\.py', + 'xarray/datatree_/doc/.*\.py', +] +files = "xarray" +show_error_codes = true +show_error_context = true +warn_redundant_casts = true +warn_unused_configs = true +warn_unused_ignores = true + +# Much of the numerical computing stack doesn't have type annotations yet. +[[tool.mypy.overrides]] +ignore_missing_imports = true +module = [ + "affine.*", + "bottleneck.*", + "cartopy.*", + "cf_units.*", + "cfgrib.*", + "cftime.*", + "cloudpickle.*", + "cubed.*", + "cupy.*", + "dask.types.*", + "fsspec.*", + "h5netcdf.*", + "h5py.*", + "iris.*", + "matplotlib.*", + "mpl_toolkits.*", + "nc_time_axis.*", + "numbagg.*", + "netCDF4.*", + "netcdftime.*", + "opt_einsum.*", + "pandas.*", + "pint.*", + "pooch.*", + "pyarrow.*", + "pydap.*", + "pytest.*", + "scipy.*", + "seaborn.*", + "setuptools", + "sparse.*", + "toolz.*", + "zarr.*", + "numpy.exceptions.*", # remove once support for `numpy<2.0` has been dropped + "array_api_strict.*", +] + +# Gradually we want to add more modules to this list, ratcheting up our total +# coverage. Once a module is here, functions are checked by mypy regardless of +# whether they have type annotations. It would be especially useful to have test +# files listed here, because without them being checked, we don't have a great +# way of testing our annotations. +[[tool.mypy.overrides]] +check_untyped_defs = true +module = [ + "xarray.core.accessor_dt", + "xarray.core.accessor_str", + "xarray.core.alignment", + "xarray.core.computation", + "xarray.core.rolling_exp", + "xarray.indexes.*", + "xarray.tests.*", +] +# This then excludes some modules from the above list. (So ideally we remove +# from here in time...) +[[tool.mypy.overrides]] +check_untyped_defs = false +module = [ + "xarray.tests.test_coarsen", + "xarray.tests.test_coding_times", + "xarray.tests.test_combine", + "xarray.tests.test_computation", + "xarray.tests.test_concat", + "xarray.tests.test_coordinates", + "xarray.tests.test_dask", + "xarray.tests.test_dataarray", + "xarray.tests.test_duck_array_ops", + "xarray.tests.test_indexing", + "xarray.tests.test_merge", + "xarray.tests.test_missing", + "xarray.tests.test_parallelcompat", + "xarray.tests.test_sparse", + "xarray.tests.test_ufuncs", + "xarray.tests.test_units", + "xarray.tests.test_utils", + "xarray.tests.test_variable", + "xarray.tests.test_weighted", +] + +# Use strict = true whenever namedarray has become standalone. In the meantime +# don't forget to add all new files related to namedarray here: +# ref: https://mypy.readthedocs.io/en/stable/existing_code.html#introduce-stricter-options +[[tool.mypy.overrides]] +# Start off with these +warn_unused_ignores = true + +# Getting these passing should be easy +strict_concatenate = true +strict_equality = true + +# Strongly recommend enabling this one as soon as you can +check_untyped_defs = true + +# These shouldn't be too much additional work, but may be tricky to +# get passing if you use a lot of untyped libraries +disallow_any_generics = true +disallow_subclassing_any = true +disallow_untyped_decorators = true + +# These next few are various gradations of forcing use of type annotations +disallow_incomplete_defs = true +disallow_untyped_calls = true +disallow_untyped_defs = true + +# This one isn't too hard to get passing, but return on investment is lower +no_implicit_reexport = true + +# This one can be tricky to get passing if you use a lot of untyped libraries +warn_return_any = true + +module = ["xarray.namedarray.*", "xarray.tests.test_namedarray"] + +[tool.pyright] +# include = ["src"] +# exclude = ["**/node_modules", +# "**/__pycache__", +# "src/experimental", +# "src/typestubs" +# ] +# ignore = ["src/oldstuff"] +defineConstant = {DEBUG = true} +# stubPath = "src/stubs" +# venv = "env367" + +# Enabling this means that developers who have disabled the warning locally — +# because not all dependencies are installable — are overridden +# reportMissingImports = true +reportMissingTypeStubs = false + +# pythonVersion = "3.6" +# pythonPlatform = "Linux" + +# executionEnvironments = [ +# { root = "src/web", pythonVersion = "3.5", pythonPlatform = "Windows", extraPaths = [ "src/service_libs" ] }, +# { root = "src/sdk", pythonVersion = "3.0", extraPaths = [ "src/backend" ] }, +# { root = "src/tests", extraPaths = ["src/tests/e2e", "src/sdk" ]}, +# { root = "src" } +# ] + +[tool.ruff] +builtins = ["ellipsis"] +extend-exclude = [ + "doc", + "_typed_ops.pyi", +] +target-version = "py39" + +[tool.ruff.lint] +# E402: module level import not at top of file +# E501: line too long - let black worry about that +# E731: do not assign a lambda expression, use a def +extend-safe-fixes = [ + "TID252", # absolute imports +] +ignore = [ + "E402", + "E501", + "E731", +] +select = [ + "F", # Pyflakes + "E", # Pycodestyle + "W", + "TID", # flake8-tidy-imports (absolute imports) + "I", # isort + "UP", # Pyupgrade +] + +[tool.ruff.lint.per-file-ignores] +# don't enforce absolute imports +"asv_bench/**" = ["TID252"] + +[tool.ruff.lint.isort] +known-first-party = ["xarray"] + +[tool.ruff.lint.flake8-tidy-imports] +# Disallow all relative imports. +ban-relative-imports = "all" + +[tool.pytest.ini_options] +addopts = ["--strict-config", "--strict-markers"] + +# We want to forbid warnings from within xarray in our tests — instead we should +# fix our own code, or mark the test itself as expecting a warning. So this: +# - Converts any warning from xarray into an error +# - Allows some warnings ("default") which the test suite currently raises, +# since it wasn't practical to fix them all before merging this config. The +# warnings are reported in CI (since it uses `default`, not `ignore`). +# +# Over time, we can remove these rules allowing warnings. A valued contribution +# is removing a line, seeing what breaks, and then fixing the library code or +# tests so that it doesn't raise warnings. +# +# There are some instance where we'll want to add to these rules: +# - While we only raise errors on warnings from within xarray, a dependency can +# raise a warning with a stacklevel such that it's interpreted to be raised +# from xarray and this will mistakenly convert it to an error. If that +# happens, please feel free to add a rule switching it to `default` here, and +# disabling the error. +# - If these settings get in the way of making progress, it's also acceptable to +# temporarily add additional `default` rules. +# - But we should only add `ignore` rules if we're confident that we'll never +# need to address a warning. + +filterwarnings = [ + "error:::xarray.*", + "default:No index created:UserWarning:xarray.core.dataset", + "default::UserWarning:xarray.tests.test_coding_times", + "default::UserWarning:xarray.tests.test_computation", + "default::UserWarning:xarray.tests.test_dataset", + "default:`ancestors` has been deprecated:DeprecationWarning:xarray.core.treenode", + "default:`iter_lineage` has been deprecated:DeprecationWarning:xarray.core.treenode", + "default:`lineage` has been deprecated:DeprecationWarning:xarray.core.treenode", + "default:coords should be an ndarray:DeprecationWarning:xarray.tests.test_variable", + "default:deallocating CachingFileManager:RuntimeWarning:xarray.backends.*", + "default:deallocating CachingFileManager:RuntimeWarning:xarray.backends.netCDF4_", + "default:deallocating CachingFileManager:RuntimeWarning:xarray.core.indexing", + "default:Failed to decode variable.*NumPy will stop allowing conversion of out-of-bound Python integers to integer arrays:DeprecationWarning", + "default:The `interpolation` argument to quantile was renamed to `method`:FutureWarning:xarray.*", + "default:invalid value encountered in cast:RuntimeWarning:xarray.core.duck_array_ops", + "default:invalid value encountered in cast:RuntimeWarning:xarray.conventions", + "default:invalid value encountered in cast:RuntimeWarning:xarray.tests.test_units", + "default:invalid value encountered in cast:RuntimeWarning:xarray.tests.test_array_api", + "default:NumPy will stop allowing conversion of:DeprecationWarning", + "default:shape should be provided:DeprecationWarning:xarray.tests.test_variable", + "default:the `pandas.MultiIndex` object:FutureWarning:xarray.tests.test_variable", + "default:Using a non-tuple sequence for multidimensional indexing is deprecated:FutureWarning", + "default:Duplicate dimension names present:UserWarning:xarray.namedarray.core", + "default:::xarray.tests.test_strategies", + # TODO: remove once we know how to deal with a changed signature in protocols + "ignore:__array__ implementation doesn't accept a copy keyword, so passing copy=False failed.", +] + +log_cli_level = "INFO" +markers = [ + "flaky: flaky tests", + "network: tests requiring a network connection", + "slow: slow tests", + "slow_hypothesis: slow hypothesis tests", +] +minversion = "7" +python_files = "test_*.py" +testpaths = ["xarray/tests", "properties"] + +[tool.aliases] +test = "pytest" + +[tool.repo-review] +ignore = [ + "PP308", # This option creates a large amount of log lines. +] diff --git a/test/fixtures/whole_applications/xarray/setup.py b/test/fixtures/whole_applications/xarray/setup.py new file mode 100755 index 0000000..6934351 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/setup.py @@ -0,0 +1,4 @@ +#!/usr/bin/env python +from setuptools import setup + +setup(use_scm_version={"fallback_version": "9999"}) diff --git a/test/fixtures/whole_applications/xarray/xarray/__init__.py b/test/fixtures/whole_applications/xarray/xarray/__init__.py new file mode 100644 index 0000000..0c0d599 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/__init__.py @@ -0,0 +1,114 @@ +from importlib.metadata import version as _version + +from xarray import testing, tutorial +from xarray.backends.api import ( + load_dataarray, + load_dataset, + open_dataarray, + open_dataset, + open_mfdataset, + save_mfdataset, +) +from xarray.backends.zarr import open_zarr +from xarray.coding.cftime_offsets import cftime_range, date_range, date_range_like +from xarray.coding.cftimeindex import CFTimeIndex +from xarray.coding.frequencies import infer_freq +from xarray.conventions import SerializationWarning, decode_cf +from xarray.core.alignment import align, broadcast +from xarray.core.combine import combine_by_coords, combine_nested +from xarray.core.common import ALL_DIMS, full_like, ones_like, zeros_like +from xarray.core.computation import ( + apply_ufunc, + corr, + cov, + cross, + dot, + polyval, + unify_chunks, + where, +) +from xarray.core.concat import concat +from xarray.core.coordinates import Coordinates +from xarray.core.dataarray import DataArray +from xarray.core.dataset import Dataset +from xarray.core.extensions import ( + register_dataarray_accessor, + register_dataset_accessor, +) +from xarray.core.indexes import Index +from xarray.core.indexing import IndexSelResult +from xarray.core.merge import Context, MergeError, merge +from xarray.core.options import get_options, set_options +from xarray.core.parallel import map_blocks +from xarray.core.variable import IndexVariable, Variable, as_variable +from xarray.namedarray.core import NamedArray +from xarray.util.print_versions import show_versions + +try: + __version__ = _version("xarray") +except Exception: + # Local copy or not installed with setuptools. + # Disable minimum version checks on downstream libraries. + __version__ = "9999" + +# A hardcoded __all__ variable is necessary to appease +# `mypy --strict` running in projects that import xarray. +__all__ = ( + # Sub-packages + "testing", + "tutorial", + # Top-level functions + "align", + "apply_ufunc", + "as_variable", + "broadcast", + "cftime_range", + "combine_by_coords", + "combine_nested", + "concat", + "date_range", + "date_range_like", + "decode_cf", + "dot", + "cov", + "corr", + "cross", + "full_like", + "get_options", + "infer_freq", + "load_dataarray", + "load_dataset", + "map_blocks", + "merge", + "ones_like", + "open_dataarray", + "open_dataset", + "open_mfdataset", + "open_zarr", + "polyval", + "register_dataarray_accessor", + "register_dataset_accessor", + "save_mfdataset", + "set_options", + "show_versions", + "unify_chunks", + "where", + "zeros_like", + # Classes + "CFTimeIndex", + "Context", + "Coordinates", + "DataArray", + "Dataset", + "Index", + "IndexSelResult", + "IndexVariable", + "Variable", + "NamedArray", + # Exceptions + "MergeError", + "SerializationWarning", + # Constants + "__version__", + "ALL_DIMS", +) diff --git a/test/fixtures/whole_applications/xarray/xarray/backends/__init__.py b/test/fixtures/whole_applications/xarray/xarray/backends/__init__.py new file mode 100644 index 0000000..550b9e2 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/backends/__init__.py @@ -0,0 +1,43 @@ +"""Backend objects for saving and loading data + +DataStores provide a uniform interface for saving and loading data in different +formats. They should not be used directly, but rather through Dataset objects. +""" + +from xarray.backends.common import AbstractDataStore, BackendArray, BackendEntrypoint +from xarray.backends.file_manager import ( + CachingFileManager, + DummyFileManager, + FileManager, +) +from xarray.backends.h5netcdf_ import H5netcdfBackendEntrypoint, H5NetCDFStore +from xarray.backends.memory import InMemoryDataStore +from xarray.backends.netCDF4_ import NetCDF4BackendEntrypoint, NetCDF4DataStore +from xarray.backends.plugins import list_engines, refresh_engines +from xarray.backends.pydap_ import PydapBackendEntrypoint, PydapDataStore +from xarray.backends.scipy_ import ScipyBackendEntrypoint, ScipyDataStore +from xarray.backends.store import StoreBackendEntrypoint +from xarray.backends.zarr import ZarrBackendEntrypoint, ZarrStore + +__all__ = [ + "AbstractDataStore", + "BackendArray", + "BackendEntrypoint", + "FileManager", + "CachingFileManager", + "DummyFileManager", + "InMemoryDataStore", + "NetCDF4DataStore", + "PydapDataStore", + "ScipyDataStore", + "H5NetCDFStore", + "ZarrStore", + "H5netcdfBackendEntrypoint", + "NetCDF4BackendEntrypoint", + "PydapBackendEntrypoint", + "ScipyBackendEntrypoint", + "StoreBackendEntrypoint", + "ZarrBackendEntrypoint", + "list_engines", + "refresh_engines", +] diff --git a/test/fixtures/whole_applications/xarray/xarray/backends/api.py b/test/fixtures/whole_applications/xarray/xarray/backends/api.py new file mode 100644 index 0000000..4b7f105 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/backends/api.py @@ -0,0 +1,1709 @@ +from __future__ import annotations + +import os +from collections.abc import Hashable, Iterable, Mapping, MutableMapping, Sequence +from functools import partial +from io import BytesIO +from numbers import Number +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Final, + Literal, + Union, + cast, + overload, +) + +import numpy as np + +from xarray import backends, conventions +from xarray.backends import plugins +from xarray.backends.common import ( + AbstractDataStore, + ArrayWriter, + _find_absolute_paths, + _normalize_path, +) +from xarray.backends.locks import _get_scheduler +from xarray.core import indexing +from xarray.core.combine import ( + _infer_concat_order_from_positions, + _nested_combine, + combine_by_coords, +) +from xarray.core.dataarray import DataArray +from xarray.core.dataset import Dataset, _get_chunk, _maybe_chunk +from xarray.core.indexes import Index +from xarray.core.types import NetcdfWriteModes, ZarrWriteModes +from xarray.core.utils import is_remote_uri +from xarray.namedarray.daskmanager import DaskManager +from xarray.namedarray.parallelcompat import guess_chunkmanager + +if TYPE_CHECKING: + try: + from dask.delayed import Delayed + except ImportError: + Delayed = None # type: ignore + from io import BufferedIOBase + + from xarray.backends.common import BackendEntrypoint + from xarray.core.types import ( + CombineAttrsOptions, + CompatOptions, + JoinOptions, + NestedSequence, + T_Chunks, + ) + + T_NetcdfEngine = Literal["netcdf4", "scipy", "h5netcdf"] + T_Engine = Union[ + T_NetcdfEngine, + Literal["pydap", "zarr"], + type[BackendEntrypoint], + str, # no nice typing support for custom backends + None, + ] + T_NetcdfTypes = Literal[ + "NETCDF4", "NETCDF4_CLASSIC", "NETCDF3_64BIT", "NETCDF3_CLASSIC" + ] + from xarray.core.datatree import DataTree + +DATAARRAY_NAME = "__xarray_dataarray_name__" +DATAARRAY_VARIABLE = "__xarray_dataarray_variable__" + +ENGINES = { + "netcdf4": backends.NetCDF4DataStore.open, + "scipy": backends.ScipyDataStore, + "pydap": backends.PydapDataStore.open, + "h5netcdf": backends.H5NetCDFStore.open, + "zarr": backends.ZarrStore.open_group, +} + + +def _get_default_engine_remote_uri() -> Literal["netcdf4", "pydap"]: + engine: Literal["netcdf4", "pydap"] + try: + import netCDF4 # noqa: F401 + + engine = "netcdf4" + except ImportError: # pragma: no cover + try: + import pydap # noqa: F401 + + engine = "pydap" + except ImportError: + raise ValueError( + "netCDF4 or pydap is required for accessing " + "remote datasets via OPeNDAP" + ) + return engine + + +def _get_default_engine_gz() -> Literal["scipy"]: + try: + import scipy # noqa: F401 + + engine: Final = "scipy" + except ImportError: # pragma: no cover + raise ValueError("scipy is required for accessing .gz files") + return engine + + +def _get_default_engine_netcdf() -> Literal["netcdf4", "scipy"]: + engine: Literal["netcdf4", "scipy"] + try: + import netCDF4 # noqa: F401 + + engine = "netcdf4" + except ImportError: # pragma: no cover + try: + import scipy.io.netcdf # noqa: F401 + + engine = "scipy" + except ImportError: + raise ValueError( + "cannot read or write netCDF files without " + "netCDF4-python or scipy installed" + ) + return engine + + +def _get_default_engine(path: str, allow_remote: bool = False) -> T_NetcdfEngine: + if allow_remote and is_remote_uri(path): + return _get_default_engine_remote_uri() # type: ignore[return-value] + elif path.endswith(".gz"): + return _get_default_engine_gz() + else: + return _get_default_engine_netcdf() + + +def _validate_dataset_names(dataset: Dataset) -> None: + """DataArray.name and Dataset keys must be a string or None""" + + def check_name(name: Hashable): + if isinstance(name, str): + if not name: + raise ValueError( + f"Invalid name {name!r} for DataArray or Dataset key: " + "string must be length 1 or greater for " + "serialization to netCDF files" + ) + elif name is not None: + raise TypeError( + f"Invalid name {name!r} for DataArray or Dataset key: " + "must be either a string or None for serialization to netCDF " + "files" + ) + + for k in dataset.variables: + check_name(k) + + +def _validate_attrs(dataset, invalid_netcdf=False): + """`attrs` must have a string key and a value which is either: a number, + a string, an ndarray, a list/tuple of numbers/strings, or a numpy.bool_. + + Notes + ----- + A numpy.bool_ is only allowed when using the h5netcdf engine with + `invalid_netcdf=True`. + """ + + valid_types = (str, Number, np.ndarray, np.number, list, tuple) + if invalid_netcdf: + valid_types += (np.bool_,) + + def check_attr(name, value, valid_types): + if isinstance(name, str): + if not name: + raise ValueError( + f"Invalid name for attr {name!r}: string must be " + "length 1 or greater for serialization to " + "netCDF files" + ) + else: + raise TypeError( + f"Invalid name for attr: {name!r} must be a string for " + "serialization to netCDF files" + ) + + if not isinstance(value, valid_types): + raise TypeError( + f"Invalid value for attr {name!r}: {value!r}. For serialization to " + "netCDF files, its value must be of one of the following types: " + f"{', '.join([vtype.__name__ for vtype in valid_types])}" + ) + + # Check attrs on the dataset itself + for k, v in dataset.attrs.items(): + check_attr(k, v, valid_types) + + # Check attrs on each variable within the dataset + for variable in dataset.variables.values(): + for k, v in variable.attrs.items(): + check_attr(k, v, valid_types) + + +def _resolve_decoders_kwargs(decode_cf, open_backend_dataset_parameters, **decoders): + for d in list(decoders): + if decode_cf is False and d in open_backend_dataset_parameters: + decoders[d] = False + if decoders[d] is None: + decoders.pop(d) + return decoders + + +def _get_mtime(filename_or_obj): + # if passed an actual file path, augment the token with + # the file modification time + mtime = None + + try: + path = os.fspath(filename_or_obj) + except TypeError: + path = None + + if path and not is_remote_uri(path): + mtime = os.path.getmtime(os.path.expanduser(filename_or_obj)) + + return mtime + + +def _protect_dataset_variables_inplace(dataset, cache): + for name, variable in dataset.variables.items(): + if name not in dataset._indexes: + # no need to protect IndexVariable objects + data = indexing.CopyOnWriteArray(variable._data) + if cache: + data = indexing.MemoryCachedArray(data) + variable.data = data + + +def _finalize_store(write, store): + """Finalize this store by explicitly syncing and closing""" + del write # ensure writing is done first + store.close() + + +def _multi_file_closer(closers): + for closer in closers: + closer() + + +def load_dataset(filename_or_obj, **kwargs) -> Dataset: + """Open, load into memory, and close a Dataset from a file or file-like + object. + + This is a thin wrapper around :py:meth:`~xarray.open_dataset`. It differs + from `open_dataset` in that it loads the Dataset into memory, closes the + file, and returns the Dataset. In contrast, `open_dataset` keeps the file + handle open and lazy loads its contents. All parameters are passed directly + to `open_dataset`. See that documentation for further details. + + Returns + ------- + dataset : Dataset + The newly created Dataset. + + See Also + -------- + open_dataset + """ + if "cache" in kwargs: + raise TypeError("cache has no effect in this context") + + with open_dataset(filename_or_obj, **kwargs) as ds: + return ds.load() + + +def load_dataarray(filename_or_obj, **kwargs): + """Open, load into memory, and close a DataArray from a file or file-like + object containing a single data variable. + + This is a thin wrapper around :py:meth:`~xarray.open_dataarray`. It differs + from `open_dataarray` in that it loads the Dataset into memory, closes the + file, and returns the Dataset. In contrast, `open_dataarray` keeps the file + handle open and lazy loads its contents. All parameters are passed directly + to `open_dataarray`. See that documentation for further details. + + Returns + ------- + datarray : DataArray + The newly created DataArray. + + See Also + -------- + open_dataarray + """ + if "cache" in kwargs: + raise TypeError("cache has no effect in this context") + + with open_dataarray(filename_or_obj, **kwargs) as da: + return da.load() + + +def _chunk_ds( + backend_ds, + filename_or_obj, + engine, + chunks, + overwrite_encoded_chunks, + inline_array, + chunked_array_type, + from_array_kwargs, + **extra_tokens, +): + chunkmanager = guess_chunkmanager(chunked_array_type) + + # TODO refactor to move this dask-specific logic inside the DaskManager class + if isinstance(chunkmanager, DaskManager): + from dask.base import tokenize + + mtime = _get_mtime(filename_or_obj) + token = tokenize(filename_or_obj, mtime, engine, chunks, **extra_tokens) + name_prefix = "open_dataset-" + else: + # not used + token = (None,) + name_prefix = None + + variables = {} + for name, var in backend_ds.variables.items(): + var_chunks = _get_chunk(var, chunks, chunkmanager) + variables[name] = _maybe_chunk( + name, + var, + var_chunks, + overwrite_encoded_chunks=overwrite_encoded_chunks, + name_prefix=name_prefix, + token=token, + inline_array=inline_array, + chunked_array_type=chunkmanager, + from_array_kwargs=from_array_kwargs.copy(), + ) + return backend_ds._replace(variables) + + +def _dataset_from_backend_dataset( + backend_ds, + filename_or_obj, + engine, + chunks, + cache, + overwrite_encoded_chunks, + inline_array, + chunked_array_type, + from_array_kwargs, + **extra_tokens, +): + if not isinstance(chunks, (int, dict)) and chunks not in {None, "auto"}: + raise ValueError( + f"chunks must be an int, dict, 'auto', or None. Instead found {chunks}." + ) + + _protect_dataset_variables_inplace(backend_ds, cache) + if chunks is None: + ds = backend_ds + else: + ds = _chunk_ds( + backend_ds, + filename_or_obj, + engine, + chunks, + overwrite_encoded_chunks, + inline_array, + chunked_array_type, + from_array_kwargs, + **extra_tokens, + ) + + ds.set_close(backend_ds._close) + + # Ensure source filename always stored in dataset object + if "source" not in ds.encoding and isinstance(filename_or_obj, (str, os.PathLike)): + ds.encoding["source"] = _normalize_path(filename_or_obj) + + return ds + + +def open_dataset( + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + *, + engine: T_Engine = None, + chunks: T_Chunks = None, + cache: bool | None = None, + decode_cf: bool | None = None, + mask_and_scale: bool | None = None, + decode_times: bool | None = None, + decode_timedelta: bool | None = None, + use_cftime: bool | None = None, + concat_characters: bool | None = None, + decode_coords: Literal["coordinates", "all"] | bool | None = None, + drop_variables: str | Iterable[str] | None = None, + inline_array: bool = False, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, + backend_kwargs: dict[str, Any] | None = None, + **kwargs, +) -> Dataset: + """Open and decode a dataset from a file or file-like object. + + Parameters + ---------- + filename_or_obj : str, Path, file-like or DataStore + Strings and Path objects are interpreted as a path to a netCDF file + or an OpenDAP URL and opened with python-netCDF4, unless the filename + ends with .gz, in which case the file is gunzipped and opened with + scipy.io.netcdf (only netCDF3 supported). Byte-strings or file-like + objects are opened by scipy.io.netcdf (netCDF3) or h5py (netCDF4/HDF). + engine : {"netcdf4", "scipy", "pydap", "h5netcdf", "zarr", None}\ + , installed backend \ + or subclass of xarray.backends.BackendEntrypoint, optional + Engine to use when reading files. If not provided, the default engine + is chosen based on available dependencies, with a preference for + "netcdf4". A custom backend class (a subclass of ``BackendEntrypoint``) + can also be used. + chunks : int, dict, 'auto' or None, optional + If chunks is provided, it is used to load the new dataset into dask + arrays. ``chunks=-1`` loads the dataset with dask using a single + chunk for all arrays. ``chunks={}`` loads the dataset with dask using + engine preferred chunks if exposed by the backend, otherwise with + a single chunk for all arrays. In order to reproduce the default behavior + of ``xr.open_zarr(...)`` use ``xr.open_dataset(..., engine='zarr', chunks={})``. + ``chunks='auto'`` will use dask ``auto`` chunking taking into account the + engine preferred chunks. See dask chunking for more details. + cache : bool, optional + If True, cache data loaded from the underlying datastore in memory as + NumPy arrays when accessed to avoid reading from the underlying data- + store multiple times. Defaults to True unless you specify the `chunks` + argument to use dask, in which case it defaults to False. Does not + change the behavior of coordinates corresponding to dimensions, which + always load their data from disk into a ``pandas.Index``. + decode_cf : bool, optional + Whether to decode these variables, assuming they were saved according + to CF conventions. + mask_and_scale : bool, optional + If True, replace array values equal to `_FillValue` with NA and scale + values according to the formula `original_values * scale_factor + + add_offset`, where `_FillValue`, `scale_factor` and `add_offset` are + taken from variable attributes (if they exist). If the `_FillValue` or + `missing_value` attribute contains multiple values a warning will be + issued and all array values matching one of the multiple values will + be replaced by NA. This keyword may not be supported by all the backends. + decode_times : bool, optional + If True, decode times encoded in the standard NetCDF datetime format + into datetime objects. Otherwise, leave them encoded as numbers. + This keyword may not be supported by all the backends. + decode_timedelta : bool, optional + If True, decode variables and coordinates with time units in + {"days", "hours", "minutes", "seconds", "milliseconds", "microseconds"} + into timedelta objects. If False, leave them encoded as numbers. + If None (default), assume the same value of decode_time. + This keyword may not be supported by all the backends. + use_cftime: bool, optional + Only relevant if encoded dates come from a standard calendar + (e.g. "gregorian", "proleptic_gregorian", "standard", or not + specified). If None (default), attempt to decode times to + ``np.datetime64[ns]`` objects; if this is not possible, decode times to + ``cftime.datetime`` objects. If True, always decode times to + ``cftime.datetime`` objects, regardless of whether or not they can be + represented using ``np.datetime64[ns]`` objects. If False, always + decode times to ``np.datetime64[ns]`` objects; if this is not possible + raise an error. This keyword may not be supported by all the backends. + concat_characters : bool, optional + If True, concatenate along the last dimension of character arrays to + form string arrays. Dimensions will only be concatenated over (and + removed) if they have no corresponding variable and if they are only + used as the last dimension of character arrays. + This keyword may not be supported by all the backends. + decode_coords : bool or {"coordinates", "all"}, optional + Controls which variables are set as coordinate variables: + + - "coordinates" or True: Set variables referred to in the + ``'coordinates'`` attribute of the datasets or individual variables + as coordinate variables. + - "all": Set variables referred to in ``'grid_mapping'``, ``'bounds'`` and + other attributes as coordinate variables. + + Only existing variables can be set as coordinates. Missing variables + will be silently ignored. + drop_variables: str or iterable of str, optional + A variable or list of variables to exclude from being parsed from the + dataset. This may be useful to drop variables with problems or + inconsistent values. + inline_array: bool, default: False + How to include the array in the dask task graph. + By default(``inline_array=False``) the array is included in a task by + itself, and each chunk refers to that task by its key. With + ``inline_array=True``, Dask will instead inline the array directly + in the values of the task graph. See :py:func:`dask.array.from_array`. + chunked_array_type: str, optional + Which chunked array type to coerce this datasets' arrays to. + Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEnetryPoint` system. + Experimental API that should not be relied upon. + from_array_kwargs: dict + Additional keyword arguments passed on to the `ChunkManagerEntrypoint.from_array` method used to create + chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg. + For example if :py:func:`dask.array.Array` objects are used for chunking, additional kwargs will be passed + to :py:func:`dask.array.from_array`. Experimental API that should not be relied upon. + backend_kwargs: dict + Additional keyword arguments passed on to the engine open function, + equivalent to `**kwargs`. + **kwargs: dict + Additional keyword arguments passed on to the engine open function. + For example: + + - 'group': path to the netCDF4 group in the given file to open given as + a str,supported by "netcdf4", "h5netcdf", "zarr". + - 'lock': resource lock to use when reading data from disk. Only + relevant when using dask or another form of parallelism. By default, + appropriate locks are chosen to safely read and write files with the + currently active dask scheduler. Supported by "netcdf4", "h5netcdf", + "scipy". + + See engine open function for kwargs accepted by each specific engine. + + Returns + ------- + dataset : Dataset + The newly created dataset. + + Notes + ----- + ``open_dataset`` opens the file with read-only access. When you modify + values of a Dataset, even one linked to files on disk, only the in-memory + copy you are manipulating in xarray is modified: the original file on disk + is never touched. + + See Also + -------- + open_mfdataset + """ + + if cache is None: + cache = chunks is None + + if backend_kwargs is not None: + kwargs.update(backend_kwargs) + + if engine is None: + engine = plugins.guess_engine(filename_or_obj) + + if from_array_kwargs is None: + from_array_kwargs = {} + + backend = plugins.get_backend(engine) + + decoders = _resolve_decoders_kwargs( + decode_cf, + open_backend_dataset_parameters=backend.open_dataset_parameters, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + decode_timedelta=decode_timedelta, + concat_characters=concat_characters, + use_cftime=use_cftime, + decode_coords=decode_coords, + ) + + overwrite_encoded_chunks = kwargs.pop("overwrite_encoded_chunks", None) + backend_ds = backend.open_dataset( + filename_or_obj, + drop_variables=drop_variables, + **decoders, + **kwargs, + ) + ds = _dataset_from_backend_dataset( + backend_ds, + filename_or_obj, + engine, + chunks, + cache, + overwrite_encoded_chunks, + inline_array, + chunked_array_type, + from_array_kwargs, + drop_variables=drop_variables, + **decoders, + **kwargs, + ) + return ds + + +def open_dataarray( + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + *, + engine: T_Engine | None = None, + chunks: T_Chunks | None = None, + cache: bool | None = None, + decode_cf: bool | None = None, + mask_and_scale: bool | None = None, + decode_times: bool | None = None, + decode_timedelta: bool | None = None, + use_cftime: bool | None = None, + concat_characters: bool | None = None, + decode_coords: Literal["coordinates", "all"] | bool | None = None, + drop_variables: str | Iterable[str] | None = None, + inline_array: bool = False, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, + backend_kwargs: dict[str, Any] | None = None, + **kwargs, +) -> DataArray: + """Open an DataArray from a file or file-like object containing a single + data variable. + + This is designed to read netCDF files with only one data variable. If + multiple variables are present then a ValueError is raised. + + Parameters + ---------- + filename_or_obj : str, Path, file-like or DataStore + Strings and Path objects are interpreted as a path to a netCDF file + or an OpenDAP URL and opened with python-netCDF4, unless the filename + ends with .gz, in which case the file is gunzipped and opened with + scipy.io.netcdf (only netCDF3 supported). Byte-strings or file-like + objects are opened by scipy.io.netcdf (netCDF3) or h5py (netCDF4/HDF). + engine : {"netcdf4", "scipy", "pydap", "h5netcdf", "zarr", None}\ + , installed backend \ + or subclass of xarray.backends.BackendEntrypoint, optional + Engine to use when reading files. If not provided, the default engine + is chosen based on available dependencies, with a preference for + "netcdf4". + chunks : int, dict, 'auto' or None, optional + If chunks is provided, it is used to load the new dataset into dask + arrays. ``chunks=-1`` loads the dataset with dask using a single + chunk for all arrays. `chunks={}`` loads the dataset with dask using + engine preferred chunks if exposed by the backend, otherwise with + a single chunk for all arrays. + ``chunks='auto'`` will use dask ``auto`` chunking taking into account the + engine preferred chunks. See dask chunking for more details. + cache : bool, optional + If True, cache data loaded from the underlying datastore in memory as + NumPy arrays when accessed to avoid reading from the underlying data- + store multiple times. Defaults to True unless you specify the `chunks` + argument to use dask, in which case it defaults to False. Does not + change the behavior of coordinates corresponding to dimensions, which + always load their data from disk into a ``pandas.Index``. + decode_cf : bool, optional + Whether to decode these variables, assuming they were saved according + to CF conventions. + mask_and_scale : bool, optional + If True, replace array values equal to `_FillValue` with NA and scale + values according to the formula `original_values * scale_factor + + add_offset`, where `_FillValue`, `scale_factor` and `add_offset` are + taken from variable attributes (if they exist). If the `_FillValue` or + `missing_value` attribute contains multiple values a warning will be + issued and all array values matching one of the multiple values will + be replaced by NA. This keyword may not be supported by all the backends. + decode_times : bool, optional + If True, decode times encoded in the standard NetCDF datetime format + into datetime objects. Otherwise, leave them encoded as numbers. + This keyword may not be supported by all the backends. + decode_timedelta : bool, optional + If True, decode variables and coordinates with time units in + {"days", "hours", "minutes", "seconds", "milliseconds", "microseconds"} + into timedelta objects. If False, leave them encoded as numbers. + If None (default), assume the same value of decode_time. + This keyword may not be supported by all the backends. + use_cftime: bool, optional + Only relevant if encoded dates come from a standard calendar + (e.g. "gregorian", "proleptic_gregorian", "standard", or not + specified). If None (default), attempt to decode times to + ``np.datetime64[ns]`` objects; if this is not possible, decode times to + ``cftime.datetime`` objects. If True, always decode times to + ``cftime.datetime`` objects, regardless of whether or not they can be + represented using ``np.datetime64[ns]`` objects. If False, always + decode times to ``np.datetime64[ns]`` objects; if this is not possible + raise an error. This keyword may not be supported by all the backends. + concat_characters : bool, optional + If True, concatenate along the last dimension of character arrays to + form string arrays. Dimensions will only be concatenated over (and + removed) if they have no corresponding variable and if they are only + used as the last dimension of character arrays. + This keyword may not be supported by all the backends. + decode_coords : bool or {"coordinates", "all"}, optional + Controls which variables are set as coordinate variables: + + - "coordinates" or True: Set variables referred to in the + ``'coordinates'`` attribute of the datasets or individual variables + as coordinate variables. + - "all": Set variables referred to in ``'grid_mapping'``, ``'bounds'`` and + other attributes as coordinate variables. + + Only existing variables can be set as coordinates. Missing variables + will be silently ignored. + drop_variables: str or iterable of str, optional + A variable or list of variables to exclude from being parsed from the + dataset. This may be useful to drop variables with problems or + inconsistent values. + inline_array: bool, default: False + How to include the array in the dask task graph. + By default(``inline_array=False``) the array is included in a task by + itself, and each chunk refers to that task by its key. With + ``inline_array=True``, Dask will instead inline the array directly + in the values of the task graph. See :py:func:`dask.array.from_array`. + chunked_array_type: str, optional + Which chunked array type to coerce the underlying data array to. + Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEnetryPoint` system. + Experimental API that should not be relied upon. + from_array_kwargs: dict + Additional keyword arguments passed on to the `ChunkManagerEntrypoint.from_array` method used to create + chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg. + For example if :py:func:`dask.array.Array` objects are used for chunking, additional kwargs will be passed + to :py:func:`dask.array.from_array`. Experimental API that should not be relied upon. + backend_kwargs: dict + Additional keyword arguments passed on to the engine open function, + equivalent to `**kwargs`. + **kwargs: dict + Additional keyword arguments passed on to the engine open function. + For example: + + - 'group': path to the netCDF4 group in the given file to open given as + a str,supported by "netcdf4", "h5netcdf", "zarr". + - 'lock': resource lock to use when reading data from disk. Only + relevant when using dask or another form of parallelism. By default, + appropriate locks are chosen to safely read and write files with the + currently active dask scheduler. Supported by "netcdf4", "h5netcdf", + "scipy". + + See engine open function for kwargs accepted by each specific engine. + + Notes + ----- + This is designed to be fully compatible with `DataArray.to_netcdf`. Saving + using `DataArray.to_netcdf` and then loading with this function will + produce an identical result. + + All parameters are passed directly to `xarray.open_dataset`. See that + documentation for further details. + + See also + -------- + open_dataset + """ + + dataset = open_dataset( + filename_or_obj, + decode_cf=decode_cf, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + engine=engine, + chunks=chunks, + cache=cache, + drop_variables=drop_variables, + inline_array=inline_array, + chunked_array_type=chunked_array_type, + from_array_kwargs=from_array_kwargs, + backend_kwargs=backend_kwargs, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + **kwargs, + ) + + if len(dataset.data_vars) != 1: + raise ValueError( + "Given file dataset contains more than one data " + "variable. Please read with xarray.open_dataset and " + "then select the variable you want." + ) + else: + (data_array,) = dataset.data_vars.values() + + data_array.set_close(dataset._close) + + # Reset names if they were changed during saving + # to ensure that we can 'roundtrip' perfectly + if DATAARRAY_NAME in dataset.attrs: + data_array.name = dataset.attrs[DATAARRAY_NAME] + del dataset.attrs[DATAARRAY_NAME] + + if data_array.name == DATAARRAY_VARIABLE: + data_array.name = None + + return data_array + + +def open_datatree( + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + engine: T_Engine = None, + **kwargs, +) -> DataTree: + """ + Open and decode a DataTree from a file or file-like object, creating one tree node for each group in the file. + + Parameters + ---------- + filename_or_obj : str, Path, file-like, or DataStore + Strings and Path objects are interpreted as a path to a netCDF file or Zarr store. + engine : str, optional + Xarray backend engine to use. Valid options include `{"netcdf4", "h5netcdf", "zarr"}`. + **kwargs : dict + Additional keyword arguments passed to :py:func:`~xarray.open_dataset` for each group. + Returns + ------- + xarray.DataTree + """ + if engine is None: + engine = plugins.guess_engine(filename_or_obj) + + backend = plugins.get_backend(engine) + + return backend.open_datatree(filename_or_obj, **kwargs) + + +def open_mfdataset( + paths: str | NestedSequence[str | os.PathLike], + chunks: T_Chunks | None = None, + concat_dim: ( + str + | DataArray + | Index + | Sequence[str] + | Sequence[DataArray] + | Sequence[Index] + | None + ) = None, + compat: CompatOptions = "no_conflicts", + preprocess: Callable[[Dataset], Dataset] | None = None, + engine: T_Engine | None = None, + data_vars: Literal["all", "minimal", "different"] | list[str] = "all", + coords="different", + combine: Literal["by_coords", "nested"] = "by_coords", + parallel: bool = False, + join: JoinOptions = "outer", + attrs_file: str | os.PathLike | None = None, + combine_attrs: CombineAttrsOptions = "override", + **kwargs, +) -> Dataset: + """Open multiple files as a single dataset. + + If combine='by_coords' then the function ``combine_by_coords`` is used to combine + the datasets into one before returning the result, and if combine='nested' then + ``combine_nested`` is used. The filepaths must be structured according to which + combining function is used, the details of which are given in the documentation for + ``combine_by_coords`` and ``combine_nested``. By default ``combine='by_coords'`` + will be used. Requires dask to be installed. See documentation for + details on dask [1]_. Global attributes from the ``attrs_file`` are used + for the combined dataset. + + Parameters + ---------- + paths : str or nested sequence of paths + Either a string glob in the form ``"path/to/my/files/*.nc"`` or an explicit list of + files to open. Paths can be given as strings or as pathlib Paths. If + concatenation along more than one dimension is desired, then ``paths`` must be a + nested list-of-lists (see ``combine_nested`` for details). (A string glob will + be expanded to a 1-dimensional list.) + chunks : int, dict, 'auto' or None, optional + Dictionary with keys given by dimension names and values given by chunk sizes. + In general, these should divide the dimensions of each dataset. If int, chunk + each dimension by ``chunks``. By default, chunks will be chosen to load entire + input files into memory at once. This has a major impact on performance: please + see the full documentation for more details [2]_. + concat_dim : str, DataArray, Index or a Sequence of these or None, optional + Dimensions to concatenate files along. You only need to provide this argument + if ``combine='nested'``, and if any of the dimensions along which you want to + concatenate is not a dimension in the original datasets, e.g., if you want to + stack a collection of 2D arrays along a third dimension. Set + ``concat_dim=[..., None, ...]`` explicitly to disable concatenation along a + particular dimension. Default is None, which for a 1D list of filepaths is + equivalent to opening the files separately and then merging them with + ``xarray.merge``. + combine : {"by_coords", "nested"}, optional + Whether ``xarray.combine_by_coords`` or ``xarray.combine_nested`` is used to + combine all the data. Default is to use ``xarray.combine_by_coords``. + compat : {"identical", "equals", "broadcast_equals", \ + "no_conflicts", "override"}, default: "no_conflicts" + String indicating how to compare variables of the same name for + potential conflicts when merging: + + * "broadcast_equals": all values must be equal when variables are + broadcast against each other to ensure common dimensions. + * "equals": all values and dimensions must be the same. + * "identical": all values, dimensions and attributes must be the + same. + * "no_conflicts": only values which are not null in both datasets + must be equal. The returned dataset then contains the combination + of all non-null values. + * "override": skip comparing and pick variable from first dataset + + preprocess : callable, optional + If provided, call this function on each dataset prior to concatenation. + You can find the file-name from which each dataset was loaded in + ``ds.encoding["source"]``. + engine : {"netcdf4", "scipy", "pydap", "h5netcdf", "zarr", None}\ + , installed backend \ + or subclass of xarray.backends.BackendEntrypoint, optional + Engine to use when reading files. If not provided, the default engine + is chosen based on available dependencies, with a preference for + "netcdf4". + data_vars : {"minimal", "different", "all"} or list of str, default: "all" + These data variables will be concatenated together: + * "minimal": Only data variables in which the dimension already + appears are included. + * "different": Data variables which are not equal (ignoring + attributes) across all datasets are also concatenated (as well as + all for which dimension already appears). Beware: this option may + load the data payload of data variables into memory if they are not + already loaded. + * "all": All data variables will be concatenated. + * list of str: The listed data variables will be concatenated, in + addition to the "minimal" data variables. + coords : {"minimal", "different", "all"} or list of str, optional + These coordinate variables will be concatenated together: + * "minimal": Only coordinates in which the dimension already appears + are included. + * "different": Coordinates which are not equal (ignoring attributes) + across all datasets are also concatenated (as well as all for which + dimension already appears). Beware: this option may load the data + payload of coordinate variables into memory if they are not already + loaded. + * "all": All coordinate variables will be concatenated, except + those corresponding to other dimensions. + * list of str: The listed coordinate variables will be concatenated, + in addition the "minimal" coordinates. + parallel : bool, default: False + If True, the open and preprocess steps of this function will be + performed in parallel using ``dask.delayed``. Default is False. + join : {"outer", "inner", "left", "right", "exact", "override"}, default: "outer" + String indicating how to combine differing indexes + (excluding concat_dim) in objects + + - "outer": use the union of object indexes + - "inner": use the intersection of object indexes + - "left": use indexes from the first object with each dimension + - "right": use indexes from the last object with each dimension + - "exact": instead of aligning, raise `ValueError` when indexes to be + aligned are not equal + - "override": if indexes are of same size, rewrite indexes to be + those of the first object with that dimension. Indexes for the same + dimension must have the same size in all objects. + attrs_file : str or path-like, optional + Path of the file used to read global attributes from. + By default global attributes are read from the first file provided, + with wildcard matches sorted by filename. + combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \ + "override"} or callable, default: "override" + A callable or a string indicating how to combine attrs of the objects being + merged: + + - "drop": empty attrs on returned Dataset. + - "identical": all attrs must be the same on every object. + - "no_conflicts": attrs from all objects are combined, any that have + the same name must also have the same value. + - "drop_conflicts": attrs from all objects are combined, any that have + the same name but different values are dropped. + - "override": skip comparing and copy attrs from the first dataset to + the result. + + If a callable, it must expect a sequence of ``attrs`` dicts and a context object + as its only parameters. + **kwargs : optional + Additional arguments passed on to :py:func:`xarray.open_dataset`. For an + overview of some of the possible options, see the documentation of + :py:func:`xarray.open_dataset` + + Returns + ------- + xarray.Dataset + + Notes + ----- + ``open_mfdataset`` opens files with read-only access. When you modify values + of a Dataset, even one linked to files on disk, only the in-memory copy you + are manipulating in xarray is modified: the original file on disk is never + touched. + + See Also + -------- + combine_by_coords + combine_nested + open_dataset + + Examples + -------- + A user might want to pass additional arguments into ``preprocess`` when + applying some operation to many individual files that are being opened. One route + to do this is through the use of ``functools.partial``. + + >>> from functools import partial + >>> def _preprocess(x, lon_bnds, lat_bnds): + ... return x.sel(lon=slice(*lon_bnds), lat=slice(*lat_bnds)) + ... + >>> lon_bnds, lat_bnds = (-110, -105), (40, 45) + >>> partial_func = partial(_preprocess, lon_bnds=lon_bnds, lat_bnds=lat_bnds) + >>> ds = xr.open_mfdataset( + ... "file_*.nc", concat_dim="time", preprocess=partial_func + ... ) # doctest: +SKIP + + It is also possible to use any argument to ``open_dataset`` together + with ``open_mfdataset``, such as for example ``drop_variables``: + + >>> ds = xr.open_mfdataset( + ... "file.nc", drop_variables=["varname_1", "varname_2"] # any list of vars + ... ) # doctest: +SKIP + + References + ---------- + + .. [1] https://docs.xarray.dev/en/stable/dask.html + .. [2] https://docs.xarray.dev/en/stable/dask.html#chunking-and-performance + """ + paths = _find_absolute_paths(paths, engine=engine, **kwargs) + + if not paths: + raise OSError("no files to open") + + if combine == "nested": + if isinstance(concat_dim, (str, DataArray)) or concat_dim is None: + concat_dim = [concat_dim] # type: ignore[assignment] + + # This creates a flat list which is easier to iterate over, whilst + # encoding the originally-supplied structure as "ids". + # The "ids" are not used at all if combine='by_coords`. + combined_ids_paths = _infer_concat_order_from_positions(paths) + ids, paths = ( + list(combined_ids_paths.keys()), + list(combined_ids_paths.values()), + ) + elif combine == "by_coords" and concat_dim is not None: + raise ValueError( + "When combine='by_coords', passing a value for `concat_dim` has no " + "effect. To manually combine along a specific dimension you should " + "instead specify combine='nested' along with a value for `concat_dim`.", + ) + + open_kwargs = dict(engine=engine, chunks=chunks or {}, **kwargs) + + if parallel: + import dask + + # wrap the open_dataset, getattr, and preprocess with delayed + open_ = dask.delayed(open_dataset) + getattr_ = dask.delayed(getattr) + if preprocess is not None: + preprocess = dask.delayed(preprocess) + else: + open_ = open_dataset + getattr_ = getattr + + datasets = [open_(p, **open_kwargs) for p in paths] + closers = [getattr_(ds, "_close") for ds in datasets] + if preprocess is not None: + datasets = [preprocess(ds) for ds in datasets] + + if parallel: + # calling compute here will return the datasets/file_objs lists, + # the underlying datasets will still be stored as dask arrays + datasets, closers = dask.compute(datasets, closers) + + # Combine all datasets, closing them in case of a ValueError + try: + if combine == "nested": + # Combined nested list by successive concat and merge operations + # along each dimension, using structure given by "ids" + combined = _nested_combine( + datasets, + concat_dims=concat_dim, + compat=compat, + data_vars=data_vars, + coords=coords, + ids=ids, + join=join, + combine_attrs=combine_attrs, + ) + elif combine == "by_coords": + # Redo ordering from coordinates, ignoring how they were ordered + # previously + combined = combine_by_coords( + datasets, + compat=compat, + data_vars=data_vars, + coords=coords, + join=join, + combine_attrs=combine_attrs, + ) + else: + raise ValueError( + f"{combine} is an invalid option for the keyword argument" + " ``combine``" + ) + except ValueError: + for ds in datasets: + ds.close() + raise + + combined.set_close(partial(_multi_file_closer, closers)) + + # read global attributes from the attrs_file or from the first dataset + if attrs_file is not None: + if isinstance(attrs_file, os.PathLike): + attrs_file = cast(str, os.fspath(attrs_file)) + combined.attrs = datasets[paths.index(attrs_file)].attrs + + return combined + + +WRITEABLE_STORES: dict[T_NetcdfEngine, Callable] = { + "netcdf4": backends.NetCDF4DataStore.open, + "scipy": backends.ScipyDataStore, + "h5netcdf": backends.H5NetCDFStore.open, +} + + +# multifile=True returns writer and datastore +@overload +def to_netcdf( + dataset: Dataset, + path_or_file: str | os.PathLike | None = None, + mode: NetcdfWriteModes = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + compute: bool = True, + *, + multifile: Literal[True], + invalid_netcdf: bool = False, +) -> tuple[ArrayWriter, AbstractDataStore]: ... + + +# path=None writes to bytes +@overload +def to_netcdf( + dataset: Dataset, + path_or_file: None = None, + mode: NetcdfWriteModes = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + compute: bool = True, + multifile: Literal[False] = False, + invalid_netcdf: bool = False, +) -> bytes: ... + + +# compute=False returns dask.Delayed +@overload +def to_netcdf( + dataset: Dataset, + path_or_file: str | os.PathLike, + mode: NetcdfWriteModes = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + *, + compute: Literal[False], + multifile: Literal[False] = False, + invalid_netcdf: bool = False, +) -> Delayed: ... + + +# default return None +@overload +def to_netcdf( + dataset: Dataset, + path_or_file: str | os.PathLike, + mode: NetcdfWriteModes = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + compute: Literal[True] = True, + multifile: Literal[False] = False, + invalid_netcdf: bool = False, +) -> None: ... + + +# if compute cannot be evaluated at type check time +# we may get back either Delayed or None +@overload +def to_netcdf( + dataset: Dataset, + path_or_file: str | os.PathLike, + mode: NetcdfWriteModes = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + compute: bool = False, + multifile: Literal[False] = False, + invalid_netcdf: bool = False, +) -> Delayed | None: ... + + +# if multifile cannot be evaluated at type check time +# we may get back either writer and datastore or Delayed or None +@overload +def to_netcdf( + dataset: Dataset, + path_or_file: str | os.PathLike, + mode: NetcdfWriteModes = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + compute: bool = False, + multifile: bool = False, + invalid_netcdf: bool = False, +) -> tuple[ArrayWriter, AbstractDataStore] | Delayed | None: ... + + +# Any +@overload +def to_netcdf( + dataset: Dataset, + path_or_file: str | os.PathLike | None, + mode: NetcdfWriteModes = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + compute: bool = False, + multifile: bool = False, + invalid_netcdf: bool = False, +) -> tuple[ArrayWriter, AbstractDataStore] | bytes | Delayed | None: ... + + +def to_netcdf( + dataset: Dataset, + path_or_file: str | os.PathLike | None = None, + mode: NetcdfWriteModes = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + compute: bool = True, + multifile: bool = False, + invalid_netcdf: bool = False, +) -> tuple[ArrayWriter, AbstractDataStore] | bytes | Delayed | None: + """This function creates an appropriate datastore for writing a dataset to + disk as a netCDF file + + See `Dataset.to_netcdf` for full API docs. + + The ``multifile`` argument is only for the private use of save_mfdataset. + """ + if isinstance(path_or_file, os.PathLike): + path_or_file = os.fspath(path_or_file) + + if encoding is None: + encoding = {} + + if path_or_file is None: + if engine is None: + engine = "scipy" + elif engine != "scipy": + raise ValueError( + "invalid engine for creating bytes with " + f"to_netcdf: {engine!r}. Only the default engine " + "or engine='scipy' is supported" + ) + if not compute: + raise NotImplementedError( + "to_netcdf() with compute=False is not yet implemented when " + "returning bytes" + ) + elif isinstance(path_or_file, str): + if engine is None: + engine = _get_default_engine(path_or_file) + path_or_file = _normalize_path(path_or_file) + else: # file-like object + engine = "scipy" + + # validate Dataset keys, DataArray names, and attr keys/values + _validate_dataset_names(dataset) + _validate_attrs(dataset, invalid_netcdf=invalid_netcdf and engine == "h5netcdf") + + try: + store_open = WRITEABLE_STORES[engine] + except KeyError: + raise ValueError(f"unrecognized engine for to_netcdf: {engine!r}") + + if format is not None: + format = format.upper() # type: ignore[assignment] + + # handle scheduler specific logic + scheduler = _get_scheduler() + have_chunks = any(v.chunks is not None for v in dataset.variables.values()) + + autoclose = have_chunks and scheduler in ["distributed", "multiprocessing"] + if autoclose and engine == "scipy": + raise NotImplementedError( + f"Writing netCDF files with the {engine} backend " + f"is not currently supported with dask's {scheduler} scheduler" + ) + + target = path_or_file if path_or_file is not None else BytesIO() + kwargs = dict(autoclose=True) if autoclose else {} + if invalid_netcdf: + if engine == "h5netcdf": + kwargs["invalid_netcdf"] = invalid_netcdf + else: + raise ValueError( + f"unrecognized option 'invalid_netcdf' for engine {engine}" + ) + store = store_open(target, mode, format, group, **kwargs) + + if unlimited_dims is None: + unlimited_dims = dataset.encoding.get("unlimited_dims", None) + if unlimited_dims is not None: + if isinstance(unlimited_dims, str) or not isinstance(unlimited_dims, Iterable): + unlimited_dims = [unlimited_dims] + else: + unlimited_dims = list(unlimited_dims) + + writer = ArrayWriter() + + # TODO: figure out how to refactor this logic (here and in save_mfdataset) + # to avoid this mess of conditionals + try: + # TODO: allow this work (setting up the file for writing array data) + # to be parallelized with dask + dump_to_store( + dataset, store, writer, encoding=encoding, unlimited_dims=unlimited_dims + ) + if autoclose: + store.close() + + if multifile: + return writer, store + + writes = writer.sync(compute=compute) + + if isinstance(target, BytesIO): + store.sync() + return target.getvalue() + finally: + if not multifile and compute: + store.close() + + if not compute: + import dask + + return dask.delayed(_finalize_store)(writes, store) + return None + + +def dump_to_store( + dataset, store, writer=None, encoder=None, encoding=None, unlimited_dims=None +): + """Store dataset contents to a backends.*DataStore object.""" + if writer is None: + writer = ArrayWriter() + + if encoding is None: + encoding = {} + + variables, attrs = conventions.encode_dataset_coordinates(dataset) + + check_encoding = set() + for k, enc in encoding.items(): + # no need to shallow copy the variable again; that already happened + # in encode_dataset_coordinates + variables[k].encoding = enc + check_encoding.add(k) + + if encoder: + variables, attrs = encoder(variables, attrs) + + store.store(variables, attrs, check_encoding, writer, unlimited_dims=unlimited_dims) + + +def save_mfdataset( + datasets, + paths, + mode="w", + format=None, + groups=None, + engine=None, + compute=True, + **kwargs, +): + """Write multiple datasets to disk as netCDF files simultaneously. + + This function is intended for use with datasets consisting of dask.array + objects, in which case it can write the multiple datasets to disk + simultaneously using a shared thread pool. + + When not using dask, it is no different than calling ``to_netcdf`` + repeatedly. + + Parameters + ---------- + datasets : list of Dataset + List of datasets to save. + paths : list of str or list of path-like objects + List of paths to which to save each corresponding dataset. + mode : {"w", "a"}, optional + Write ("w") or append ("a") mode. If mode="w", any existing file at + these locations will be overwritten. + format : {"NETCDF4", "NETCDF4_CLASSIC", "NETCDF3_64BIT", \ + "NETCDF3_CLASSIC"}, optional + File format for the resulting netCDF file: + + * NETCDF4: Data is stored in an HDF5 file, using netCDF4 API + features. + * NETCDF4_CLASSIC: Data is stored in an HDF5 file, using only + netCDF 3 compatible API features. + * NETCDF3_64BIT: 64-bit offset version of the netCDF 3 file format, + which fully supports 2+ GB files, but is only compatible with + clients linked against netCDF version 3.6.0 or later. + * NETCDF3_CLASSIC: The classic netCDF 3 file format. It does not + handle 2+ GB files very well. + + All formats are supported by the netCDF4-python library. + scipy.io.netcdf only supports the last two formats. + + The default format is NETCDF4 if you are saving a file to disk and + have the netCDF4-python library available. Otherwise, xarray falls + back to using scipy to write netCDF files and defaults to the + NETCDF3_64BIT format (scipy does not support netCDF4). + groups : list of str, optional + Paths to the netCDF4 group in each corresponding file to which to save + datasets (only works for format="NETCDF4"). The groups will be created + if necessary. + engine : {"netcdf4", "scipy", "h5netcdf"}, optional + Engine to use when writing netCDF files. If not provided, the + default engine is chosen based on available dependencies, with a + preference for "netcdf4" if writing to a file on disk. + See `Dataset.to_netcdf` for additional information. + compute : bool + If true compute immediately, otherwise return a + ``dask.delayed.Delayed`` object that can be computed later. + **kwargs : dict, optional + Additional arguments are passed along to ``to_netcdf``. + + Examples + -------- + Save a dataset into one netCDF per year of data: + + >>> ds = xr.Dataset( + ... {"a": ("time", np.linspace(0, 1, 48))}, + ... coords={"time": pd.date_range("2010-01-01", freq="ME", periods=48)}, + ... ) + >>> ds + Size: 768B + Dimensions: (time: 48) + Coordinates: + * time (time) datetime64[ns] 384B 2010-01-31 2010-02-28 ... 2013-12-31 + Data variables: + a (time) float64 384B 0.0 0.02128 0.04255 ... 0.9574 0.9787 1.0 + >>> years, datasets = zip(*ds.groupby("time.year")) + >>> paths = [f"{y}.nc" for y in years] + >>> xr.save_mfdataset(datasets, paths) + """ + if mode == "w" and len(set(paths)) < len(paths): + raise ValueError( + "cannot use mode='w' when writing multiple datasets to the same path" + ) + + for obj in datasets: + if not isinstance(obj, Dataset): + raise TypeError( + "save_mfdataset only supports writing Dataset " + f"objects, received type {type(obj)}" + ) + + if groups is None: + groups = [None] * len(datasets) + + if len({len(datasets), len(paths), len(groups)}) > 1: + raise ValueError( + "must supply lists of the same length for the " + "datasets, paths and groups arguments to " + "save_mfdataset" + ) + + writers, stores = zip( + *[ + to_netcdf( + ds, + path, + mode, + format, + group, + engine, + compute=compute, + multifile=True, + **kwargs, + ) + for ds, path, group in zip(datasets, paths, groups) + ] + ) + + try: + writes = [w.sync(compute=compute) for w in writers] + finally: + if compute: + for store in stores: + store.close() + + if not compute: + import dask + + return dask.delayed( + [dask.delayed(_finalize_store)(w, s) for w, s in zip(writes, stores)] + ) + + +# compute=True returns ZarrStore +@overload +def to_zarr( + dataset: Dataset, + store: MutableMapping | str | os.PathLike[str] | None = None, + chunk_store: MutableMapping | str | os.PathLike | None = None, + mode: ZarrWriteModes | None = None, + synchronizer=None, + group: str | None = None, + encoding: Mapping | None = None, + *, + compute: Literal[True] = True, + consolidated: bool | None = None, + append_dim: Hashable | None = None, + region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None, + safe_chunks: bool = True, + storage_options: dict[str, str] | None = None, + zarr_version: int | None = None, + write_empty_chunks: bool | None = None, + chunkmanager_store_kwargs: dict[str, Any] | None = None, +) -> backends.ZarrStore: ... + + +# compute=False returns dask.Delayed +@overload +def to_zarr( + dataset: Dataset, + store: MutableMapping | str | os.PathLike[str] | None = None, + chunk_store: MutableMapping | str | os.PathLike | None = None, + mode: ZarrWriteModes | None = None, + synchronizer=None, + group: str | None = None, + encoding: Mapping | None = None, + *, + compute: Literal[False], + consolidated: bool | None = None, + append_dim: Hashable | None = None, + region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None, + safe_chunks: bool = True, + storage_options: dict[str, str] | None = None, + zarr_version: int | None = None, + write_empty_chunks: bool | None = None, + chunkmanager_store_kwargs: dict[str, Any] | None = None, +) -> Delayed: ... + + +def to_zarr( + dataset: Dataset, + store: MutableMapping | str | os.PathLike[str] | None = None, + chunk_store: MutableMapping | str | os.PathLike | None = None, + mode: ZarrWriteModes | None = None, + synchronizer=None, + group: str | None = None, + encoding: Mapping | None = None, + *, + compute: bool = True, + consolidated: bool | None = None, + append_dim: Hashable | None = None, + region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None, + safe_chunks: bool = True, + storage_options: dict[str, str] | None = None, + zarr_version: int | None = None, + write_empty_chunks: bool | None = None, + chunkmanager_store_kwargs: dict[str, Any] | None = None, +) -> backends.ZarrStore | Delayed: + """This function creates an appropriate datastore for writing a dataset to + a zarr ztore + + See `Dataset.to_zarr` for full API docs. + """ + + # Load empty arrays to avoid bug saving zero length dimensions (Issue #5741) + for v in dataset.variables.values(): + if v.size == 0: + v.load() + + # expand str and path-like arguments + store = _normalize_path(store) + chunk_store = _normalize_path(chunk_store) + + if storage_options is None: + mapper = store + chunk_mapper = chunk_store + else: + from fsspec import get_mapper + + if not isinstance(store, str): + raise ValueError( + f"store must be a string to use storage_options. Got {type(store)}" + ) + mapper = get_mapper(store, **storage_options) + if chunk_store is not None: + chunk_mapper = get_mapper(chunk_store, **storage_options) + else: + chunk_mapper = chunk_store + + if encoding is None: + encoding = {} + + if mode is None: + if append_dim is not None: + mode = "a" + elif region is not None: + mode = "r+" + else: + mode = "w-" + + if mode not in ["a", "a-"] and append_dim is not None: + raise ValueError("cannot set append_dim unless mode='a' or mode=None") + + if mode not in ["a", "a-", "r+"] and region is not None: + raise ValueError( + "cannot set region unless mode='a', mode='a-', mode='r+' or mode=None" + ) + + if mode not in ["w", "w-", "a", "a-", "r+"]: + raise ValueError( + "The only supported options for mode are 'w', " + f"'w-', 'a', 'a-', and 'r+', but mode={mode!r}" + ) + + # validate Dataset keys, DataArray names + _validate_dataset_names(dataset) + + if zarr_version is None: + # default to 2 if store doesn't specify it's version (e.g. a path) + zarr_version = int(getattr(store, "_store_version", 2)) + + if consolidated is None and zarr_version > 2: + consolidated = False + + if mode == "r+": + already_consolidated = consolidated + consolidate_on_close = False + else: + already_consolidated = False + consolidate_on_close = consolidated or consolidated is None + zstore = backends.ZarrStore.open_group( + store=mapper, + mode=mode, + synchronizer=synchronizer, + group=group, + consolidated=already_consolidated, + consolidate_on_close=consolidate_on_close, + chunk_store=chunk_mapper, + append_dim=append_dim, + write_region=region, + safe_chunks=safe_chunks, + stacklevel=4, # for Dataset.to_zarr() + zarr_version=zarr_version, + write_empty=write_empty_chunks, + ) + + if region is not None: + zstore._validate_and_autodetect_region(dataset) + # can't modify indexes with region writes + dataset = dataset.drop_vars(dataset.indexes) + if append_dim is not None and append_dim in region: + raise ValueError( + f"cannot list the same dimension in both ``append_dim`` and " + f"``region`` with to_zarr(), got {append_dim} in both" + ) + + if encoding and mode in ["a", "a-", "r+"]: + existing_var_names = set(zstore.zarr_group.array_keys()) + for var_name in existing_var_names: + if var_name in encoding: + raise ValueError( + f"variable {var_name!r} already exists, but encoding was provided" + ) + + writer = ArrayWriter() + # TODO: figure out how to properly handle unlimited_dims + dump_to_store(dataset, zstore, writer, encoding=encoding) + writes = writer.sync( + compute=compute, chunkmanager_store_kwargs=chunkmanager_store_kwargs + ) + + if compute: + _finalize_store(writes, zstore) + else: + import dask + + return dask.delayed(_finalize_store)(writes, zstore) + + return zstore diff --git a/test/fixtures/whole_applications/xarray/xarray/backends/common.py b/test/fixtures/whole_applications/xarray/xarray/backends/common.py new file mode 100644 index 0000000..e9bfdd9 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/backends/common.py @@ -0,0 +1,540 @@ +from __future__ import annotations + +import logging +import os +import time +import traceback +from collections.abc import Iterable +from glob import glob +from typing import TYPE_CHECKING, Any, ClassVar + +import numpy as np + +from xarray.conventions import cf_encoder +from xarray.core import indexing +from xarray.core.utils import FrozenDict, NdimSizeLenMixin, is_remote_uri +from xarray.namedarray.parallelcompat import get_chunked_array_type +from xarray.namedarray.pycompat import is_chunked_array + +if TYPE_CHECKING: + from io import BufferedIOBase + + from xarray.core.dataset import Dataset + from xarray.core.datatree import DataTree + from xarray.core.types import NestedSequence + +# Create a logger object, but don't add any handlers. Leave that to user code. +logger = logging.getLogger(__name__) + + +NONE_VAR_NAME = "__values__" + + +def _normalize_path(path): + """ + Normalize pathlikes to string. + + Parameters + ---------- + path : + Path to file. + + Examples + -------- + >>> from pathlib import Path + + >>> directory = Path(xr.backends.common.__file__).parent + >>> paths_path = Path(directory).joinpath("comm*n.py") + >>> paths_str = xr.backends.common._normalize_path(paths_path) + >>> print([type(p) for p in (paths_str,)]) + [] + """ + if isinstance(path, os.PathLike): + path = os.fspath(path) + + if isinstance(path, str) and not is_remote_uri(path): + path = os.path.abspath(os.path.expanduser(path)) + + return path + + +def _find_absolute_paths( + paths: str | os.PathLike | NestedSequence[str | os.PathLike], **kwargs +) -> list[str]: + """ + Find absolute paths from the pattern. + + Parameters + ---------- + paths : + Path(s) to file(s). Can include wildcards like * . + **kwargs : + Extra kwargs. Mainly for fsspec. + + Examples + -------- + >>> from pathlib import Path + + >>> directory = Path(xr.backends.common.__file__).parent + >>> paths = str(Path(directory).joinpath("comm*n.py")) # Find common with wildcard + >>> paths = xr.backends.common._find_absolute_paths(paths) + >>> [Path(p).name for p in paths] + ['common.py'] + """ + if isinstance(paths, str): + if is_remote_uri(paths) and kwargs.get("engine", None) == "zarr": + try: + from fsspec.core import get_fs_token_paths + except ImportError as e: + raise ImportError( + "The use of remote URLs for opening zarr requires the package fsspec" + ) from e + + fs, _, _ = get_fs_token_paths( + paths, + mode="rb", + storage_options=kwargs.get("backend_kwargs", {}).get( + "storage_options", {} + ), + expand=False, + ) + tmp_paths = fs.glob(fs._strip_protocol(paths)) # finds directories + paths = [fs.get_mapper(path) for path in tmp_paths] + elif is_remote_uri(paths): + raise ValueError( + "cannot do wild-card matching for paths that are remote URLs " + f"unless engine='zarr' is specified. Got paths: {paths}. " + "Instead, supply paths as an explicit list of strings." + ) + else: + paths = sorted(glob(_normalize_path(paths))) + elif isinstance(paths, os.PathLike): + paths = [os.fspath(paths)] + else: + paths = [os.fspath(p) if isinstance(p, os.PathLike) else p for p in paths] + + return paths + + +def _encode_variable_name(name): + if name is None: + name = NONE_VAR_NAME + return name + + +def _decode_variable_name(name): + if name == NONE_VAR_NAME: + name = None + return name + + +def _iter_nc_groups(root, parent="/"): + from xarray.core.treenode import NodePath + + parent = NodePath(parent) + for path, group in root.groups.items(): + gpath = parent / path + yield str(gpath) + yield from _iter_nc_groups(group, parent=gpath) + + +def find_root_and_group(ds): + """Find the root and group name of a netCDF4/h5netcdf dataset.""" + hierarchy = () + while ds.parent is not None: + hierarchy = (ds.name.split("/")[-1],) + hierarchy + ds = ds.parent + group = "/" + "/".join(hierarchy) + return ds, group + + +def robust_getitem(array, key, catch=Exception, max_retries=6, initial_delay=500): + """ + Robustly index an array, using retry logic with exponential backoff if any + of the errors ``catch`` are raised. The initial_delay is measured in ms. + + With the default settings, the maximum delay will be in the range of 32-64 + seconds. + """ + assert max_retries >= 0 + for n in range(max_retries + 1): + try: + return array[key] + except catch: + if n == max_retries: + raise + base_delay = initial_delay * 2**n + next_delay = base_delay + np.random.randint(base_delay) + msg = ( + f"getitem failed, waiting {next_delay} ms before trying again " + f"({max_retries - n} tries remaining). Full traceback: {traceback.format_exc()}" + ) + logger.debug(msg) + time.sleep(1e-3 * next_delay) + + +class BackendArray(NdimSizeLenMixin, indexing.ExplicitlyIndexed): + __slots__ = () + + def get_duck_array(self, dtype: np.typing.DTypeLike = None): + key = indexing.BasicIndexer((slice(None),) * self.ndim) + return self[key] # type: ignore [index] + + +class AbstractDataStore: + __slots__ = () + + def get_dimensions(self): # pragma: no cover + raise NotImplementedError() + + def get_attrs(self): # pragma: no cover + raise NotImplementedError() + + def get_variables(self): # pragma: no cover + raise NotImplementedError() + + def get_encoding(self): + return {} + + def load(self): + """ + This loads the variables and attributes simultaneously. + A centralized loading function makes it easier to create + data stores that do automatic encoding/decoding. + + For example:: + + class SuffixAppendingDataStore(AbstractDataStore): + + def load(self): + variables, attributes = AbstractDataStore.load(self) + variables = {'%s_suffix' % k: v + for k, v in variables.items()} + attributes = {'%s_suffix' % k: v + for k, v in attributes.items()} + return variables, attributes + + This function will be called anytime variables or attributes + are requested, so care should be taken to make sure its fast. + """ + variables = FrozenDict( + (_decode_variable_name(k), v) for k, v in self.get_variables().items() + ) + attributes = FrozenDict(self.get_attrs()) + return variables, attributes + + def close(self): + pass + + def __enter__(self): + return self + + def __exit__(self, exception_type, exception_value, traceback): + self.close() + + +class ArrayWriter: + __slots__ = ("sources", "targets", "regions", "lock") + + def __init__(self, lock=None): + self.sources = [] + self.targets = [] + self.regions = [] + self.lock = lock + + def add(self, source, target, region=None): + if is_chunked_array(source): + self.sources.append(source) + self.targets.append(target) + self.regions.append(region) + else: + if region: + target[region] = source + else: + target[...] = source + + def sync(self, compute=True, chunkmanager_store_kwargs=None): + if self.sources: + chunkmanager = get_chunked_array_type(*self.sources) + + # TODO: consider wrapping targets with dask.delayed, if this makes + # for any discernible difference in performance, e.g., + # targets = [dask.delayed(t) for t in self.targets] + + if chunkmanager_store_kwargs is None: + chunkmanager_store_kwargs = {} + + delayed_store = chunkmanager.store( + self.sources, + self.targets, + lock=self.lock, + compute=compute, + flush=True, + regions=self.regions, + **chunkmanager_store_kwargs, + ) + self.sources = [] + self.targets = [] + self.regions = [] + return delayed_store + + +class AbstractWritableDataStore(AbstractDataStore): + __slots__ = () + + def encode(self, variables, attributes): + """ + Encode the variables and attributes in this store + + Parameters + ---------- + variables : dict-like + Dictionary of key/value (variable name / xr.Variable) pairs + attributes : dict-like + Dictionary of key/value (attribute name / attribute) pairs + + Returns + ------- + variables : dict-like + attributes : dict-like + + """ + variables = {k: self.encode_variable(v) for k, v in variables.items()} + attributes = {k: self.encode_attribute(v) for k, v in attributes.items()} + return variables, attributes + + def encode_variable(self, v): + """encode one variable""" + return v + + def encode_attribute(self, a): + """encode one attribute""" + return a + + def set_dimension(self, dim, length): # pragma: no cover + raise NotImplementedError() + + def set_attribute(self, k, v): # pragma: no cover + raise NotImplementedError() + + def set_variable(self, k, v): # pragma: no cover + raise NotImplementedError() + + def store_dataset(self, dataset): + """ + in stores, variables are all variables AND coordinates + in xarray.Dataset variables are variables NOT coordinates, + so here we pass the whole dataset in instead of doing + dataset.variables + """ + self.store(dataset, dataset.attrs) + + def store( + self, + variables, + attributes, + check_encoding_set=frozenset(), + writer=None, + unlimited_dims=None, + ): + """ + Top level method for putting data on this store, this method: + - encodes variables/attributes + - sets dimensions + - sets variables + + Parameters + ---------- + variables : dict-like + Dictionary of key/value (variable name / xr.Variable) pairs + attributes : dict-like + Dictionary of key/value (attribute name / attribute) pairs + check_encoding_set : list-like + List of variables that should be checked for invalid encoding + values + writer : ArrayWriter + unlimited_dims : list-like + List of dimension names that should be treated as unlimited + dimensions. + """ + if writer is None: + writer = ArrayWriter() + + variables, attributes = self.encode(variables, attributes) + + self.set_attributes(attributes) + self.set_dimensions(variables, unlimited_dims=unlimited_dims) + self.set_variables( + variables, check_encoding_set, writer, unlimited_dims=unlimited_dims + ) + + def set_attributes(self, attributes): + """ + This provides a centralized method to set the dataset attributes on the + data store. + + Parameters + ---------- + attributes : dict-like + Dictionary of key/value (attribute name / attribute) pairs + """ + for k, v in attributes.items(): + self.set_attribute(k, v) + + def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=None): + """ + This provides a centralized method to set the variables on the data + store. + + Parameters + ---------- + variables : dict-like + Dictionary of key/value (variable name / xr.Variable) pairs + check_encoding_set : list-like + List of variables that should be checked for invalid encoding + values + writer : ArrayWriter + unlimited_dims : list-like + List of dimension names that should be treated as unlimited + dimensions. + """ + + for vn, v in variables.items(): + name = _encode_variable_name(vn) + check = vn in check_encoding_set + target, source = self.prepare_variable( + name, v, check, unlimited_dims=unlimited_dims + ) + + writer.add(source, target) + + def set_dimensions(self, variables, unlimited_dims=None): + """ + This provides a centralized method to set the dimensions on the data + store. + + Parameters + ---------- + variables : dict-like + Dictionary of key/value (variable name / xr.Variable) pairs + unlimited_dims : list-like + List of dimension names that should be treated as unlimited + dimensions. + """ + if unlimited_dims is None: + unlimited_dims = set() + + existing_dims = self.get_dimensions() + + dims = {} + for v in unlimited_dims: # put unlimited_dims first + dims[v] = None + for v in variables.values(): + dims.update(dict(zip(v.dims, v.shape))) + + for dim, length in dims.items(): + if dim in existing_dims and length != existing_dims[dim]: + raise ValueError( + "Unable to update size for existing dimension" + f"{dim!r} ({length} != {existing_dims[dim]})" + ) + elif dim not in existing_dims: + is_unlimited = dim in unlimited_dims + self.set_dimension(dim, length, is_unlimited) + + +class WritableCFDataStore(AbstractWritableDataStore): + __slots__ = () + + def encode(self, variables, attributes): + # All NetCDF files get CF encoded by default, without this attempting + # to write times, for example, would fail. + variables, attributes = cf_encoder(variables, attributes) + variables = {k: self.encode_variable(v) for k, v in variables.items()} + attributes = {k: self.encode_attribute(v) for k, v in attributes.items()} + return variables, attributes + + +class BackendEntrypoint: + """ + ``BackendEntrypoint`` is a class container and it is the main interface + for the backend plugins, see :ref:`RST backend_entrypoint`. + It shall implement: + + - ``open_dataset`` method: it shall implement reading from file, variables + decoding and it returns an instance of :py:class:`~xarray.Dataset`. + It shall take in input at least ``filename_or_obj`` argument and + ``drop_variables`` keyword argument. + For more details see :ref:`RST open_dataset`. + - ``guess_can_open`` method: it shall return ``True`` if the backend is able to open + ``filename_or_obj``, ``False`` otherwise. The implementation of this + method is not mandatory. + - ``open_datatree`` method: it shall implement reading from file, variables + decoding and it returns an instance of :py:class:`~datatree.DataTree`. + It shall take in input at least ``filename_or_obj`` argument. The + implementation of this method is not mandatory. For more details see + . + + Attributes + ---------- + + open_dataset_parameters : tuple, default: None + A list of ``open_dataset`` method parameters. + The setting of this attribute is not mandatory. + description : str, default: "" + A short string describing the engine. + The setting of this attribute is not mandatory. + url : str, default: "" + A string with the URL to the backend's documentation. + The setting of this attribute is not mandatory. + """ + + open_dataset_parameters: ClassVar[tuple | None] = None + description: ClassVar[str] = "" + url: ClassVar[str] = "" + + def __repr__(self) -> str: + txt = f"<{type(self).__name__}>" + if self.description: + txt += f"\n {self.description}" + if self.url: + txt += f"\n Learn more at {self.url}" + return txt + + def open_dataset( + self, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + *, + drop_variables: str | Iterable[str] | None = None, + **kwargs: Any, + ) -> Dataset: + """ + Backend open_dataset method used by Xarray in :py:func:`~xarray.open_dataset`. + """ + + raise NotImplementedError() + + def guess_can_open( + self, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + ) -> bool: + """ + Backend open_dataset method used by Xarray in :py:func:`~xarray.open_dataset`. + """ + + return False + + def open_datatree( + self, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + **kwargs: Any, + ) -> DataTree: + """ + Backend open_datatree method used by Xarray in :py:func:`~xarray.open_datatree`. + """ + + raise NotImplementedError() + + +# mapping of engine name to (module name, BackendEntrypoint Class) +BACKEND_ENTRYPOINTS: dict[str, tuple[str | None, type[BackendEntrypoint]]] = {} diff --git a/test/fixtures/whole_applications/xarray/xarray/backends/file_manager.py b/test/fixtures/whole_applications/xarray/xarray/backends/file_manager.py new file mode 100644 index 0000000..df901f9 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/backends/file_manager.py @@ -0,0 +1,356 @@ +from __future__ import annotations + +import atexit +import contextlib +import io +import threading +import uuid +import warnings +from collections.abc import Hashable +from typing import Any + +from xarray.backends.locks import acquire +from xarray.backends.lru_cache import LRUCache +from xarray.core import utils +from xarray.core.options import OPTIONS + +# Global cache for storing open files. +FILE_CACHE: LRUCache[Any, io.IOBase] = LRUCache( + maxsize=OPTIONS["file_cache_maxsize"], on_evict=lambda k, v: v.close() +) +assert FILE_CACHE.maxsize, "file cache must be at least size one" + +REF_COUNTS: dict[Any, int] = {} + +_DEFAULT_MODE = utils.ReprObject("") + + +class FileManager: + """Manager for acquiring and closing a file object. + + Use FileManager subclasses (CachingFileManager in particular) on backend + storage classes to automatically handle issues related to keeping track of + many open files and transferring them between multiple processes. + """ + + def acquire(self, needs_lock=True): + """Acquire the file object from this manager.""" + raise NotImplementedError() + + def acquire_context(self, needs_lock=True): + """Context manager for acquiring a file. Yields a file object. + + The context manager unwinds any actions taken as part of acquisition + (i.e., removes it from any cache) if an exception is raised from the + context. It *does not* automatically close the file. + """ + raise NotImplementedError() + + def close(self, needs_lock=True): + """Close the file object associated with this manager, if needed.""" + raise NotImplementedError() + + +class CachingFileManager(FileManager): + """Wrapper for automatically opening and closing file objects. + + Unlike files, CachingFileManager objects can be safely pickled and passed + between processes. They should be explicitly closed to release resources, + but a per-process least-recently-used cache for open files ensures that you + can safely create arbitrarily large numbers of FileManager objects. + + Don't directly close files acquired from a FileManager. Instead, call + FileManager.close(), which ensures that closed files are removed from the + cache as well. + + Example usage: + + manager = FileManager(open, 'example.txt', mode='w') + f = manager.acquire() + f.write(...) + manager.close() # ensures file is closed + + Note that as long as previous files are still cached, acquiring a file + multiple times from the same FileManager is essentially free: + + f1 = manager.acquire() + f2 = manager.acquire() + assert f1 is f2 + + """ + + def __init__( + self, + opener, + *args, + mode=_DEFAULT_MODE, + kwargs=None, + lock=None, + cache=None, + manager_id: Hashable | None = None, + ref_counts=None, + ): + """Initialize a CachingFileManager. + + The cache, manager_id and ref_counts arguments exist solely to + facilitate dependency injection, and should only be set for tests. + + Parameters + ---------- + opener : callable + Function that when called like ``opener(*args, **kwargs)`` returns + an open file object. The file object must implement a ``close()`` + method. + *args + Positional arguments for opener. A ``mode`` argument should be + provided as a keyword argument (see below). All arguments must be + hashable. + mode : optional + If provided, passed as a keyword argument to ``opener`` along with + ``**kwargs``. ``mode='w' `` has special treatment: after the first + call it is replaced by ``mode='a'`` in all subsequent function to + avoid overriding the newly created file. + kwargs : dict, optional + Keyword arguments for opener, excluding ``mode``. All values must + be hashable. + lock : duck-compatible threading.Lock, optional + Lock to use when modifying the cache inside acquire() and close(). + By default, uses a new threading.Lock() object. If set, this object + should be pickleable. + cache : MutableMapping, optional + Mapping to use as a cache for open files. By default, uses xarray's + global LRU file cache. Because ``cache`` typically points to a + global variable and contains non-picklable file objects, an + unpickled FileManager objects will be restored with the default + cache. + manager_id : hashable, optional + Identifier for this CachingFileManager. + ref_counts : dict, optional + Optional dict to use for keeping track the number of references to + the same file. + """ + self._opener = opener + self._args = args + self._mode = mode + self._kwargs = {} if kwargs is None else dict(kwargs) + + self._use_default_lock = lock is None or lock is False + self._lock = threading.Lock() if self._use_default_lock else lock + + # cache[self._key] stores the file associated with this object. + if cache is None: + cache = FILE_CACHE + self._cache = cache + if manager_id is None: + # Each call to CachingFileManager should separately open files. + manager_id = str(uuid.uuid4()) + self._manager_id = manager_id + self._key = self._make_key() + + # ref_counts[self._key] stores the number of CachingFileManager objects + # in memory referencing this same file. We use this to know if we can + # close a file when the manager is deallocated. + if ref_counts is None: + ref_counts = REF_COUNTS + self._ref_counter = _RefCounter(ref_counts) + self._ref_counter.increment(self._key) + + def _make_key(self): + """Make a key for caching files in the LRU cache.""" + value = ( + self._opener, + self._args, + "a" if self._mode == "w" else self._mode, + tuple(sorted(self._kwargs.items())), + self._manager_id, + ) + return _HashedSequence(value) + + @contextlib.contextmanager + def _optional_lock(self, needs_lock): + """Context manager for optionally acquiring a lock.""" + if needs_lock: + with self._lock: + yield + else: + yield + + def acquire(self, needs_lock=True): + """Acquire a file object from the manager. + + A new file is only opened if it has expired from the + least-recently-used cache. + + This method uses a lock, which ensures that it is thread-safe. You can + safely acquire a file in multiple threads at the same time, as long as + the underlying file object is thread-safe. + + Returns + ------- + file-like + An open file object, as returned by ``opener(*args, **kwargs)``. + """ + file, _ = self._acquire_with_cache_info(needs_lock) + return file + + @contextlib.contextmanager + def acquire_context(self, needs_lock=True): + """Context manager for acquiring a file.""" + file, cached = self._acquire_with_cache_info(needs_lock) + try: + yield file + except Exception: + if not cached: + self.close(needs_lock) + raise + + def _acquire_with_cache_info(self, needs_lock=True): + """Acquire a file, returning the file and whether it was cached.""" + with self._optional_lock(needs_lock): + try: + file = self._cache[self._key] + except KeyError: + kwargs = self._kwargs + if self._mode is not _DEFAULT_MODE: + kwargs = kwargs.copy() + kwargs["mode"] = self._mode + file = self._opener(*self._args, **kwargs) + if self._mode == "w": + # ensure file doesn't get overridden when opened again + self._mode = "a" + self._cache[self._key] = file + return file, False + else: + return file, True + + def close(self, needs_lock=True): + """Explicitly close any associated file object (if necessary).""" + # TODO: remove needs_lock if/when we have a reentrant lock in + # dask.distributed: https://github.com/dask/dask/issues/3832 + with self._optional_lock(needs_lock): + default = None + file = self._cache.pop(self._key, default) + if file is not None: + file.close() + + def __del__(self) -> None: + # If we're the only CachingFileManger referencing a unclosed file, + # remove it from the cache upon garbage collection. + # + # We keep track of our own reference count because we don't want to + # close files if another identical file manager needs it. This can + # happen if a CachingFileManager is pickled and unpickled without + # closing the original file. + ref_count = self._ref_counter.decrement(self._key) + + if not ref_count and self._key in self._cache: + if acquire(self._lock, blocking=False): + # Only close files if we can do so immediately. + try: + self.close(needs_lock=False) + finally: + self._lock.release() + + if OPTIONS["warn_for_unclosed_files"]: + warnings.warn( + f"deallocating {self}, but file is not already closed. " + "This may indicate a bug.", + RuntimeWarning, + stacklevel=2, + ) + + def __getstate__(self): + """State for pickling.""" + # cache is intentionally omitted: we don't want to try to serialize + # these global objects. + lock = None if self._use_default_lock else self._lock + return ( + self._opener, + self._args, + self._mode, + self._kwargs, + lock, + self._manager_id, + ) + + def __setstate__(self, state) -> None: + """Restore from a pickle.""" + opener, args, mode, kwargs, lock, manager_id = state + self.__init__( # type: ignore + opener, *args, mode=mode, kwargs=kwargs, lock=lock, manager_id=manager_id + ) + + def __repr__(self) -> str: + args_string = ", ".join(map(repr, self._args)) + if self._mode is not _DEFAULT_MODE: + args_string += f", mode={self._mode!r}" + return ( + f"{type(self).__name__}({self._opener!r}, {args_string}, " + f"kwargs={self._kwargs}, manager_id={self._manager_id!r})" + ) + + +@atexit.register +def _remove_del_method(): + # We don't need to close unclosed files at program exit, and may not be able + # to, because Python is cleaning up imports / globals. + del CachingFileManager.__del__ + + +class _RefCounter: + """Class for keeping track of reference counts.""" + + def __init__(self, counts): + self._counts = counts + self._lock = threading.Lock() + + def increment(self, name): + with self._lock: + count = self._counts[name] = self._counts.get(name, 0) + 1 + return count + + def decrement(self, name): + with self._lock: + count = self._counts[name] - 1 + if count: + self._counts[name] = count + else: + del self._counts[name] + return count + + +class _HashedSequence(list): + """Speedup repeated look-ups by caching hash values. + + Based on what Python uses internally in functools.lru_cache. + + Python doesn't perform this optimization automatically: + https://bugs.python.org/issue1462796 + """ + + def __init__(self, tuple_value): + self[:] = tuple_value + self.hashvalue = hash(tuple_value) + + def __hash__(self): + return self.hashvalue + + +class DummyFileManager(FileManager): + """FileManager that simply wraps an open file in the FileManager interface.""" + + def __init__(self, value): + self._value = value + + def acquire(self, needs_lock=True): + del needs_lock # ignored + return self._value + + @contextlib.contextmanager + def acquire_context(self, needs_lock=True): + del needs_lock + yield self._value + + def close(self, needs_lock=True): + del needs_lock # ignored + self._value.close() diff --git a/test/fixtures/whole_applications/xarray/xarray/backends/h5netcdf_.py b/test/fixtures/whole_applications/xarray/xarray/backends/h5netcdf_.py new file mode 100644 index 0000000..cd6bde4 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/backends/h5netcdf_.py @@ -0,0 +1,487 @@ +from __future__ import annotations + +import functools +import io +import os +from collections.abc import Callable, Iterable +from typing import TYPE_CHECKING, Any + +from xarray.backends.common import ( + BACKEND_ENTRYPOINTS, + BackendEntrypoint, + WritableCFDataStore, + _normalize_path, + find_root_and_group, +) +from xarray.backends.file_manager import CachingFileManager, DummyFileManager +from xarray.backends.locks import HDF5_LOCK, combine_locks, ensure_lock, get_write_lock +from xarray.backends.netCDF4_ import ( + BaseNetCDF4Array, + _encode_nc4_variable, + _ensure_no_forward_slash_in_name, + _extract_nc4_variable_encoding, + _get_datatype, + _nc4_require_group, +) +from xarray.backends.store import StoreBackendEntrypoint +from xarray.core import indexing +from xarray.core.utils import ( + FrozenDict, + emit_user_level_warning, + is_remote_uri, + read_magic_number_from_file, + try_read_magic_number_from_file_or_path, +) +from xarray.core.variable import Variable + +if TYPE_CHECKING: + from io import BufferedIOBase + + from xarray.backends.common import AbstractDataStore + from xarray.core.dataset import Dataset + from xarray.core.datatree import DataTree + + +class H5NetCDFArrayWrapper(BaseNetCDF4Array): + def get_array(self, needs_lock=True): + ds = self.datastore._acquire(needs_lock) + return ds.variables[self.variable_name] + + def __getitem__(self, key): + return indexing.explicit_indexing_adapter( + key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem + ) + + def _getitem(self, key): + with self.datastore.lock: + array = self.get_array(needs_lock=False) + return array[key] + + +def _read_attributes(h5netcdf_var): + # GH451 + # to ensure conventions decoding works properly on Python 3, decode all + # bytes attributes to strings + attrs = {} + for k, v in h5netcdf_var.attrs.items(): + if k not in ["_FillValue", "missing_value"]: + if isinstance(v, bytes): + try: + v = v.decode("utf-8") + except UnicodeDecodeError: + emit_user_level_warning( + f"'utf-8' codec can't decode bytes for attribute " + f"{k!r} of h5netcdf object {h5netcdf_var.name!r}, " + f"returning bytes undecoded.", + UnicodeWarning, + ) + attrs[k] = v + return attrs + + +_extract_h5nc_encoding = functools.partial( + _extract_nc4_variable_encoding, + lsd_okay=False, + h5py_okay=True, + backend="h5netcdf", + unlimited_dims=None, +) + + +def _h5netcdf_create_group(dataset, name): + return dataset.create_group(name) + + +class H5NetCDFStore(WritableCFDataStore): + """Store for reading and writing data via h5netcdf""" + + __slots__ = ( + "autoclose", + "format", + "is_remote", + "lock", + "_filename", + "_group", + "_manager", + "_mode", + ) + + def __init__(self, manager, group=None, mode=None, lock=HDF5_LOCK, autoclose=False): + import h5netcdf + + if isinstance(manager, (h5netcdf.File, h5netcdf.Group)): + if group is None: + root, group = find_root_and_group(manager) + else: + if type(manager) is not h5netcdf.File: + raise ValueError( + "must supply a h5netcdf.File if the group " + "argument is provided" + ) + root = manager + manager = DummyFileManager(root) + + self._manager = manager + self._group = group + self._mode = mode + self.format = None + # todo: utilizing find_root_and_group seems a bit clunky + # making filename available on h5netcdf.Group seems better + self._filename = find_root_and_group(self.ds)[0].filename + self.is_remote = is_remote_uri(self._filename) + self.lock = ensure_lock(lock) + self.autoclose = autoclose + + @classmethod + def open( + cls, + filename, + mode="r", + format=None, + group=None, + lock=None, + autoclose=False, + invalid_netcdf=None, + phony_dims=None, + decode_vlen_strings=True, + driver=None, + driver_kwds=None, + ): + import h5netcdf + + if isinstance(filename, bytes): + raise ValueError( + "can't open netCDF4/HDF5 as bytes " + "try passing a path or file-like object" + ) + elif isinstance(filename, io.IOBase): + magic_number = read_magic_number_from_file(filename) + if not magic_number.startswith(b"\211HDF\r\n\032\n"): + raise ValueError( + f"{magic_number} is not the signature of a valid netCDF4 file" + ) + + if format not in [None, "NETCDF4"]: + raise ValueError("invalid format for h5netcdf backend") + + kwargs = { + "invalid_netcdf": invalid_netcdf, + "decode_vlen_strings": decode_vlen_strings, + "driver": driver, + } + if driver_kwds is not None: + kwargs.update(driver_kwds) + if phony_dims is not None: + kwargs["phony_dims"] = phony_dims + + if lock is None: + if mode == "r": + lock = HDF5_LOCK + else: + lock = combine_locks([HDF5_LOCK, get_write_lock(filename)]) + + manager = CachingFileManager(h5netcdf.File, filename, mode=mode, kwargs=kwargs) + return cls(manager, group=group, mode=mode, lock=lock, autoclose=autoclose) + + def _acquire(self, needs_lock=True): + with self._manager.acquire_context(needs_lock) as root: + ds = _nc4_require_group( + root, self._group, self._mode, create_group=_h5netcdf_create_group + ) + return ds + + @property + def ds(self): + return self._acquire() + + def open_store_variable(self, name, var): + import h5py + + dimensions = var.dimensions + data = indexing.LazilyIndexedArray(H5NetCDFArrayWrapper(name, self)) + attrs = _read_attributes(var) + + # netCDF4 specific encoding + encoding = { + "chunksizes": var.chunks, + "fletcher32": var.fletcher32, + "shuffle": var.shuffle, + } + if var.chunks: + encoding["preferred_chunks"] = dict(zip(var.dimensions, var.chunks)) + # Convert h5py-style compression options to NetCDF4-Python + # style, if possible + if var.compression == "gzip": + encoding["zlib"] = True + encoding["complevel"] = var.compression_opts + elif var.compression is not None: + encoding["compression"] = var.compression + encoding["compression_opts"] = var.compression_opts + + # save source so __repr__ can detect if it's local or not + encoding["source"] = self._filename + encoding["original_shape"] = data.shape + + vlen_dtype = h5py.check_dtype(vlen=var.dtype) + if vlen_dtype is str: + encoding["dtype"] = str + elif vlen_dtype is not None: # pragma: no cover + # xarray doesn't support writing arbitrary vlen dtypes yet. + pass + else: + encoding["dtype"] = var.dtype + + return Variable(dimensions, data, attrs, encoding) + + def get_variables(self): + return FrozenDict( + (k, self.open_store_variable(k, v)) for k, v in self.ds.variables.items() + ) + + def get_attrs(self): + return FrozenDict(_read_attributes(self.ds)) + + def get_dimensions(self): + return FrozenDict((k, len(v)) for k, v in self.ds.dimensions.items()) + + def get_encoding(self): + return { + "unlimited_dims": { + k for k, v in self.ds.dimensions.items() if v.isunlimited() + } + } + + def set_dimension(self, name, length, is_unlimited=False): + _ensure_no_forward_slash_in_name(name) + if is_unlimited: + self.ds.dimensions[name] = None + self.ds.resize_dimension(name, length) + else: + self.ds.dimensions[name] = length + + def set_attribute(self, key, value): + self.ds.attrs[key] = value + + def encode_variable(self, variable): + return _encode_nc4_variable(variable) + + def prepare_variable( + self, name, variable, check_encoding=False, unlimited_dims=None + ): + import h5py + + _ensure_no_forward_slash_in_name(name) + attrs = variable.attrs.copy() + dtype = _get_datatype(variable, raise_on_invalid_encoding=check_encoding) + + fillvalue = attrs.pop("_FillValue", None) + + if dtype is str: + dtype = h5py.special_dtype(vlen=str) + + encoding = _extract_h5nc_encoding(variable, raise_on_invalid=check_encoding) + kwargs = {} + + # Convert from NetCDF4-Python style compression settings to h5py style + # If both styles are used together, h5py takes precedence + # If set_encoding=True, raise ValueError in case of mismatch + if encoding.pop("zlib", False): + if check_encoding and encoding.get("compression") not in (None, "gzip"): + raise ValueError("'zlib' and 'compression' encodings mismatch") + encoding.setdefault("compression", "gzip") + + if ( + check_encoding + and "complevel" in encoding + and "compression_opts" in encoding + and encoding["complevel"] != encoding["compression_opts"] + ): + raise ValueError("'complevel' and 'compression_opts' encodings mismatch") + complevel = encoding.pop("complevel", 0) + if complevel != 0: + encoding.setdefault("compression_opts", complevel) + + encoding["chunks"] = encoding.pop("chunksizes", None) + + # Do not apply compression, filters or chunking to scalars. + if variable.shape: + for key in [ + "compression", + "compression_opts", + "shuffle", + "chunks", + "fletcher32", + ]: + if key in encoding: + kwargs[key] = encoding[key] + if name not in self.ds: + nc4_var = self.ds.create_variable( + name, + dtype=dtype, + dimensions=variable.dims, + fillvalue=fillvalue, + **kwargs, + ) + else: + nc4_var = self.ds[name] + + for k, v in attrs.items(): + nc4_var.attrs[k] = v + + target = H5NetCDFArrayWrapper(name, self) + + return target, variable.data + + def sync(self): + self.ds.sync() + + def close(self, **kwargs): + self._manager.close(**kwargs) + + +class H5netcdfBackendEntrypoint(BackendEntrypoint): + """ + Backend for netCDF files based on the h5netcdf package. + + It can open ".nc", ".nc4", ".cdf" files but will only be + selected as the default if the "netcdf4" engine is not available. + + Additionally it can open valid HDF5 files, see + https://h5netcdf.org/#invalid-netcdf-files for more info. + It will not be detected as valid backend for such files, so make + sure to specify ``engine="h5netcdf"`` in ``open_dataset``. + + For more information about the underlying library, visit: + https://h5netcdf.org + + See Also + -------- + backends.H5NetCDFStore + backends.NetCDF4BackendEntrypoint + backends.ScipyBackendEntrypoint + """ + + description = ( + "Open netCDF (.nc, .nc4 and .cdf) and most HDF5 files using h5netcdf in Xarray" + ) + url = "https://docs.xarray.dev/en/stable/generated/xarray.backends.H5netcdfBackendEntrypoint.html" + + def guess_can_open( + self, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + ) -> bool: + magic_number = try_read_magic_number_from_file_or_path(filename_or_obj) + if magic_number is not None: + return magic_number.startswith(b"\211HDF\r\n\032\n") + + if isinstance(filename_or_obj, (str, os.PathLike)): + _, ext = os.path.splitext(filename_or_obj) + return ext in {".nc", ".nc4", ".cdf"} + + return False + + def open_dataset( # type: ignore[override] # allow LSP violation, not supporting **kwargs + self, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + *, + mask_and_scale=True, + decode_times=True, + concat_characters=True, + decode_coords=True, + drop_variables: str | Iterable[str] | None = None, + use_cftime=None, + decode_timedelta=None, + format=None, + group=None, + lock=None, + invalid_netcdf=None, + phony_dims=None, + decode_vlen_strings=True, + driver=None, + driver_kwds=None, + ) -> Dataset: + filename_or_obj = _normalize_path(filename_or_obj) + store = H5NetCDFStore.open( + filename_or_obj, + format=format, + group=group, + lock=lock, + invalid_netcdf=invalid_netcdf, + phony_dims=phony_dims, + decode_vlen_strings=decode_vlen_strings, + driver=driver, + driver_kwds=driver_kwds, + ) + + store_entrypoint = StoreBackendEntrypoint() + + ds = store_entrypoint.open_dataset( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + return ds + + def open_datatree( + self, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + *, + mask_and_scale=True, + decode_times=True, + concat_characters=True, + decode_coords=True, + drop_variables: str | Iterable[str] | None = None, + use_cftime=None, + decode_timedelta=None, + group: str | Iterable[str] | Callable | None = None, + **kwargs, + ) -> DataTree: + from xarray.backends.api import open_dataset + from xarray.backends.common import _iter_nc_groups + from xarray.core.datatree import DataTree + from xarray.core.treenode import NodePath + from xarray.core.utils import close_on_error + + filename_or_obj = _normalize_path(filename_or_obj) + store = H5NetCDFStore.open( + filename_or_obj, + group=group, + ) + if group: + parent = NodePath("/") / NodePath(group) + else: + parent = NodePath("/") + + manager = store._manager + ds = open_dataset(store, **kwargs) + tree_root = DataTree.from_dict({str(parent): ds}) + for path_group in _iter_nc_groups(store.ds, parent=parent): + group_store = H5NetCDFStore(manager, group=path_group, **kwargs) + store_entrypoint = StoreBackendEntrypoint() + with close_on_error(group_store): + ds = store_entrypoint.open_dataset( + group_store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + new_node: DataTree = DataTree(name=NodePath(path_group).name, data=ds) + tree_root._set_item( + path_group, + new_node, + allow_overwrite=False, + new_nodes_along_path=True, + ) + return tree_root + + +BACKEND_ENTRYPOINTS["h5netcdf"] = ("h5netcdf", H5netcdfBackendEntrypoint) diff --git a/test/fixtures/whole_applications/xarray/xarray/backends/locks.py b/test/fixtures/whole_applications/xarray/xarray/backends/locks.py new file mode 100644 index 0000000..69cef30 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/backends/locks.py @@ -0,0 +1,285 @@ +from __future__ import annotations + +import multiprocessing +import threading +import uuid +import weakref +from collections.abc import Hashable, MutableMapping +from typing import Any, ClassVar +from weakref import WeakValueDictionary + + +# SerializableLock is adapted from Dask: +# https://github.com/dask/dask/blob/74e898f0ec712e8317ba86cc3b9d18b6b9922be0/dask/utils.py#L1160-L1224 +# Used under the terms of Dask's license, see licenses/DASK_LICENSE. +class SerializableLock: + """A Serializable per-process Lock + + This wraps a normal ``threading.Lock`` object and satisfies the same + interface. However, this lock can also be serialized and sent to different + processes. It will not block concurrent operations between processes (for + this you should look at ``dask.multiprocessing.Lock`` or ``locket.lock_file`` + but will consistently deserialize into the same lock. + + So if we make a lock in one process:: + + lock = SerializableLock() + + And then send it over to another process multiple times:: + + bytes = pickle.dumps(lock) + a = pickle.loads(bytes) + b = pickle.loads(bytes) + + Then the deserialized objects will operate as though they were the same + lock, and collide as appropriate. + + This is useful for consistently protecting resources on a per-process + level. + + The creation of locks is itself not threadsafe. + """ + + _locks: ClassVar[WeakValueDictionary[Hashable, threading.Lock]] = ( + WeakValueDictionary() + ) + token: Hashable + lock: threading.Lock + + def __init__(self, token: Hashable | None = None): + self.token = token or str(uuid.uuid4()) + if self.token in SerializableLock._locks: + self.lock = SerializableLock._locks[self.token] + else: + self.lock = threading.Lock() + SerializableLock._locks[self.token] = self.lock + + def acquire(self, *args, **kwargs): + return self.lock.acquire(*args, **kwargs) + + def release(self, *args, **kwargs): + return self.lock.release(*args, **kwargs) + + def __enter__(self): + self.lock.__enter__() + + def __exit__(self, *args): + self.lock.__exit__(*args) + + def locked(self): + return self.lock.locked() + + def __getstate__(self): + return self.token + + def __setstate__(self, token): + self.__init__(token) + + def __str__(self): + return f"<{self.__class__.__name__}: {self.token}>" + + __repr__ = __str__ + + +# Locks used by multiple backends. +# Neither HDF5 nor the netCDF-C library are thread-safe. +HDF5_LOCK = SerializableLock() +NETCDFC_LOCK = SerializableLock() + + +_FILE_LOCKS: MutableMapping[Any, threading.Lock] = weakref.WeakValueDictionary() + + +def _get_threaded_lock(key): + try: + lock = _FILE_LOCKS[key] + except KeyError: + lock = _FILE_LOCKS[key] = threading.Lock() + return lock + + +def _get_multiprocessing_lock(key): + # TODO: make use of the key -- maybe use locket.py? + # https://github.com/mwilliamson/locket.py + del key # unused + return multiprocessing.Lock() + + +def _get_lock_maker(scheduler=None): + """Returns an appropriate function for creating resource locks. + + Parameters + ---------- + scheduler : str or None + Dask scheduler being used. + + See Also + -------- + dask.utils.get_scheduler_lock + """ + + if scheduler is None: + return _get_threaded_lock + elif scheduler == "threaded": + return _get_threaded_lock + elif scheduler == "multiprocessing": + return _get_multiprocessing_lock + elif scheduler == "distributed": + # Lazy import distributed since it is can add a significant + # amount of time to import + try: + from dask.distributed import Lock as DistributedLock + except ImportError: + DistributedLock = None + return DistributedLock + else: + raise KeyError(scheduler) + + +def _get_scheduler(get=None, collection=None) -> str | None: + """Determine the dask scheduler that is being used. + + None is returned if no dask scheduler is active. + + See Also + -------- + dask.base.get_scheduler + """ + try: + # Fix for bug caused by dask installation that doesn't involve the toolz library + # Issue: 4164 + import dask + from dask.base import get_scheduler # noqa: F401 + + actual_get = get_scheduler(get, collection) + except ImportError: + return None + + try: + from dask.distributed import Client + + if isinstance(actual_get.__self__, Client): + return "distributed" + except (ImportError, AttributeError): + pass + + try: + # As of dask=2.6, dask.multiprocessing requires cloudpickle to be installed + # Dependency removed in https://github.com/dask/dask/pull/5511 + if actual_get is dask.multiprocessing.get: + return "multiprocessing" + except AttributeError: + pass + + return "threaded" + + +def get_write_lock(key): + """Get a scheduler appropriate lock for writing to the given resource. + + Parameters + ---------- + key : str + Name of the resource for which to acquire a lock. Typically a filename. + + Returns + ------- + Lock object that can be used like a threading.Lock object. + """ + scheduler = _get_scheduler() + lock_maker = _get_lock_maker(scheduler) + return lock_maker(key) + + +def acquire(lock, blocking=True): + """Acquire a lock, possibly in a non-blocking fashion. + + Includes backwards compatibility hacks for old versions of Python, dask + and dask-distributed. + """ + if blocking: + # no arguments needed + return lock.acquire() + else: + # "blocking" keyword argument not supported for: + # - threading.Lock on Python 2. + # - dask.SerializableLock with dask v1.0.0 or earlier. + # - multiprocessing.Lock calls the argument "block" instead. + # - dask.distributed.Lock uses the blocking argument as the first one + return lock.acquire(blocking) + + +class CombinedLock: + """A combination of multiple locks. + + Like a locked door, a CombinedLock is locked if any of its constituent + locks are locked. + """ + + def __init__(self, locks): + self.locks = tuple(set(locks)) # remove duplicates + + def acquire(self, blocking=True): + return all(acquire(lock, blocking=blocking) for lock in self.locks) + + def release(self): + for lock in self.locks: + lock.release() + + def __enter__(self): + for lock in self.locks: + lock.__enter__() + + def __exit__(self, *args): + for lock in self.locks: + lock.__exit__(*args) + + def locked(self): + return any(lock.locked for lock in self.locks) + + def __repr__(self): + return f"CombinedLock({list(self.locks)!r})" + + +class DummyLock: + """DummyLock provides the lock API without any actual locking.""" + + def acquire(self, blocking=True): + pass + + def release(self): + pass + + def __enter__(self): + pass + + def __exit__(self, *args): + pass + + def locked(self): + return False + + +def combine_locks(locks): + """Combine a sequence of locks into a single lock.""" + all_locks = [] + for lock in locks: + if isinstance(lock, CombinedLock): + all_locks.extend(lock.locks) + elif lock is not None: + all_locks.append(lock) + + num_locks = len(all_locks) + if num_locks > 1: + return CombinedLock(all_locks) + elif num_locks == 1: + return all_locks[0] + else: + return DummyLock() + + +def ensure_lock(lock): + """Ensure that the given object is a lock.""" + if lock is None or lock is False: + return DummyLock() + return lock diff --git a/test/fixtures/whole_applications/xarray/xarray/backends/lru_cache.py b/test/fixtures/whole_applications/xarray/xarray/backends/lru_cache.py new file mode 100644 index 0000000..c09bcb1 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/backends/lru_cache.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +import threading +from collections import OrderedDict +from collections.abc import Iterator, MutableMapping +from typing import Any, Callable, TypeVar + +K = TypeVar("K") +V = TypeVar("V") + + +class LRUCache(MutableMapping[K, V]): + """Thread-safe LRUCache based on an OrderedDict. + + All dict operations (__getitem__, __setitem__, __contains__) update the + priority of the relevant key and take O(1) time. The dict is iterated over + in order from the oldest to newest key, which means that a complete pass + over the dict should not affect the order of any entries. + + When a new item is set and the maximum size of the cache is exceeded, the + oldest item is dropped and called with ``on_evict(key, value)``. + + The ``maxsize`` property can be used to view or adjust the capacity of + the cache, e.g., ``cache.maxsize = new_size``. + """ + + _cache: OrderedDict[K, V] + _maxsize: int + _lock: threading.RLock + _on_evict: Callable[[K, V], Any] | None + + __slots__ = ("_cache", "_lock", "_maxsize", "_on_evict") + + def __init__(self, maxsize: int, on_evict: Callable[[K, V], Any] | None = None): + """ + Parameters + ---------- + maxsize : int + Integer maximum number of items to hold in the cache. + on_evict : callable, optional + Function to call like ``on_evict(key, value)`` when items are + evicted. + """ + if not isinstance(maxsize, int): + raise TypeError("maxsize must be an integer") + if maxsize < 0: + raise ValueError("maxsize must be non-negative") + self._maxsize = maxsize + self._cache = OrderedDict() + self._lock = threading.RLock() + self._on_evict = on_evict + + def __getitem__(self, key: K) -> V: + # record recent use of the key by moving it to the front of the list + with self._lock: + value = self._cache[key] + self._cache.move_to_end(key) + return value + + def _enforce_size_limit(self, capacity: int) -> None: + """Shrink the cache if necessary, evicting the oldest items.""" + while len(self._cache) > capacity: + key, value = self._cache.popitem(last=False) + if self._on_evict is not None: + self._on_evict(key, value) + + def __setitem__(self, key: K, value: V) -> None: + with self._lock: + if key in self._cache: + # insert the new value at the end + del self._cache[key] + self._cache[key] = value + elif self._maxsize: + # make room if necessary + self._enforce_size_limit(self._maxsize - 1) + self._cache[key] = value + elif self._on_evict is not None: + # not saving, immediately evict + self._on_evict(key, value) + + def __delitem__(self, key: K) -> None: + del self._cache[key] + + def __iter__(self) -> Iterator[K]: + # create a list, so accessing the cache during iteration cannot change + # the iteration order + return iter(list(self._cache)) + + def __len__(self) -> int: + return len(self._cache) + + @property + def maxsize(self) -> int: + """Maximum number of items can be held in the cache.""" + return self._maxsize + + @maxsize.setter + def maxsize(self, size: int) -> None: + """Resize the cache, evicting the oldest items if necessary.""" + if size < 0: + raise ValueError("maxsize must be non-negative") + with self._lock: + self._enforce_size_limit(size) + self._maxsize = size diff --git a/test/fixtures/whole_applications/xarray/xarray/backends/memory.py b/test/fixtures/whole_applications/xarray/xarray/backends/memory.py new file mode 100644 index 0000000..9df6701 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/backends/memory.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import copy + +import numpy as np + +from xarray.backends.common import AbstractWritableDataStore +from xarray.core.variable import Variable + + +class InMemoryDataStore(AbstractWritableDataStore): + """ + Stores dimensions, variables and attributes in ordered dictionaries, making + this store fast compared to stores which save to disk. + + This store exists purely for internal testing purposes. + """ + + def __init__(self, variables=None, attributes=None): + self._variables = {} if variables is None else variables + self._attributes = {} if attributes is None else attributes + + def get_attrs(self): + return self._attributes + + def get_variables(self): + return self._variables + + def get_dimensions(self): + dims = {} + for v in self._variables.values(): + for d, s in v.dims.items(): + dims[d] = s + return dims + + def prepare_variable(self, k, v, *args, **kwargs): + new_var = Variable(v.dims, np.empty_like(v), v.attrs) + self._variables[k] = new_var + return new_var, v.data + + def set_attribute(self, k, v): + # copy to imitate writing to disk. + self._attributes[k] = copy.deepcopy(v) + + def set_dimension(self, dim, length, unlimited_dims=None): + # in this model, dimensions are accounted for in the variables + pass diff --git a/test/fixtures/whole_applications/xarray/xarray/backends/netCDF4_.py b/test/fixtures/whole_applications/xarray/xarray/backends/netCDF4_.py new file mode 100644 index 0000000..f8dd1c9 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/backends/netCDF4_.py @@ -0,0 +1,727 @@ +from __future__ import annotations + +import functools +import operator +import os +from collections.abc import Callable, Iterable +from contextlib import suppress +from typing import TYPE_CHECKING, Any + +import numpy as np + +from xarray import coding +from xarray.backends.common import ( + BACKEND_ENTRYPOINTS, + BackendArray, + BackendEntrypoint, + WritableCFDataStore, + _normalize_path, + find_root_and_group, + robust_getitem, +) +from xarray.backends.file_manager import CachingFileManager, DummyFileManager +from xarray.backends.locks import ( + HDF5_LOCK, + NETCDFC_LOCK, + combine_locks, + ensure_lock, + get_write_lock, +) +from xarray.backends.netcdf3 import encode_nc3_attr_value, encode_nc3_variable +from xarray.backends.store import StoreBackendEntrypoint +from xarray.coding.variables import pop_to +from xarray.core import indexing +from xarray.core.utils import ( + FrozenDict, + close_on_error, + is_remote_uri, + try_read_magic_number_from_path, +) +from xarray.core.variable import Variable + +if TYPE_CHECKING: + from io import BufferedIOBase + + from xarray.backends.common import AbstractDataStore + from xarray.core.dataset import Dataset + from xarray.core.datatree import DataTree + +# This lookup table maps from dtype.byteorder to a readable endian +# string used by netCDF4. +_endian_lookup = {"=": "native", ">": "big", "<": "little", "|": "native"} + +NETCDF4_PYTHON_LOCK = combine_locks([NETCDFC_LOCK, HDF5_LOCK]) + + +class BaseNetCDF4Array(BackendArray): + __slots__ = ("datastore", "dtype", "shape", "variable_name") + + def __init__(self, variable_name, datastore): + self.datastore = datastore + self.variable_name = variable_name + + array = self.get_array() + self.shape = array.shape + + dtype = array.dtype + if dtype is str: + # use object dtype (with additional vlen string metadata) because that's + # the only way in numpy to represent variable length strings and to + # check vlen string dtype in further steps + # it also prevents automatic string concatenation via + # conventions.decode_cf_variable + dtype = coding.strings.create_vlen_dtype(str) + self.dtype = dtype + + def __setitem__(self, key, value): + with self.datastore.lock: + data = self.get_array(needs_lock=False) + data[key] = value + if self.datastore.autoclose: + self.datastore.close(needs_lock=False) + + def get_array(self, needs_lock=True): + raise NotImplementedError("Virtual Method") + + +class NetCDF4ArrayWrapper(BaseNetCDF4Array): + __slots__ = () + + def get_array(self, needs_lock=True): + ds = self.datastore._acquire(needs_lock) + variable = ds.variables[self.variable_name] + variable.set_auto_maskandscale(False) + # only added in netCDF4-python v1.2.8 + with suppress(AttributeError): + variable.set_auto_chartostring(False) + return variable + + def __getitem__(self, key): + return indexing.explicit_indexing_adapter( + key, self.shape, indexing.IndexingSupport.OUTER, self._getitem + ) + + def _getitem(self, key): + if self.datastore.is_remote: # pragma: no cover + getitem = functools.partial(robust_getitem, catch=RuntimeError) + else: + getitem = operator.getitem + + try: + with self.datastore.lock: + original_array = self.get_array(needs_lock=False) + array = getitem(original_array, key) + except IndexError: + # Catch IndexError in netCDF4 and return a more informative + # error message. This is most often called when an unsorted + # indexer is used before the data is loaded from disk. + msg = ( + "The indexing operation you are attempting to perform " + "is not valid on netCDF4.Variable object. Try loading " + "your data into memory first by calling .load()." + ) + raise IndexError(msg) + return array + + +def _encode_nc4_variable(var): + for coder in [ + coding.strings.EncodedStringCoder(allows_unicode=True), + coding.strings.CharacterArrayCoder(), + ]: + var = coder.encode(var) + return var + + +def _check_encoding_dtype_is_vlen_string(dtype): + if dtype is not str: + raise AssertionError( # pragma: no cover + f"unexpected dtype encoding {dtype!r}. This shouldn't happen: please " + "file a bug report at github.com/pydata/xarray" + ) + + +def _get_datatype( + var, nc_format="NETCDF4", raise_on_invalid_encoding=False +) -> np.dtype: + if nc_format == "NETCDF4": + return _nc4_dtype(var) + if "dtype" in var.encoding: + encoded_dtype = var.encoding["dtype"] + _check_encoding_dtype_is_vlen_string(encoded_dtype) + if raise_on_invalid_encoding: + raise ValueError( + "encoding dtype=str for vlen strings is only supported " + "with format='NETCDF4'." + ) + return var.dtype + + +def _nc4_dtype(var): + if "dtype" in var.encoding: + dtype = var.encoding.pop("dtype") + _check_encoding_dtype_is_vlen_string(dtype) + elif coding.strings.is_unicode_dtype(var.dtype): + dtype = str + elif var.dtype.kind in ["i", "u", "f", "c", "S"]: + dtype = var.dtype + else: + raise ValueError(f"unsupported dtype for netCDF4 variable: {var.dtype}") + return dtype + + +def _netcdf4_create_group(dataset, name): + return dataset.createGroup(name) + + +def _nc4_require_group(ds, group, mode, create_group=_netcdf4_create_group): + if group in {None, "", "/"}: + # use the root group + return ds + else: + # make sure it's a string + if not isinstance(group, str): + raise ValueError("group must be a string or None") + # support path-like syntax + path = group.strip("/").split("/") + for key in path: + try: + ds = ds.groups[key] + except KeyError as e: + if mode != "r": + ds = create_group(ds, key) + else: + # wrap error to provide slightly more helpful message + raise OSError(f"group not found: {key}", e) + return ds + + +def _ensure_no_forward_slash_in_name(name): + if "/" in name: + raise ValueError( + f"Forward slashes '/' are not allowed in variable and dimension names (got {name!r}). " + "Forward slashes are used as hierarchy-separators for " + "HDF5-based files ('netcdf4'/'h5netcdf')." + ) + + +def _ensure_fill_value_valid(data, attributes): + # work around for netCDF4/scipy issue where _FillValue has the wrong type: + # https://github.com/Unidata/netcdf4-python/issues/271 + if data.dtype.kind == "S" and "_FillValue" in attributes: + attributes["_FillValue"] = np.bytes_(attributes["_FillValue"]) + + +def _force_native_endianness(var): + # possible values for byteorder are: + # = native + # < little-endian + # > big-endian + # | not applicable + # Below we check if the data type is not native or NA + if var.dtype.byteorder not in ["=", "|"]: + # if endianness is specified explicitly, convert to the native type + data = var.data.astype(var.dtype.newbyteorder("=")) + var = Variable(var.dims, data, var.attrs, var.encoding) + # if endian exists, remove it from the encoding. + var.encoding.pop("endian", None) + # check to see if encoding has a value for endian its 'native' + if var.encoding.get("endian", "native") != "native": + raise NotImplementedError( + "Attempt to write non-native endian type, " + "this is not supported by the netCDF4 " + "python library." + ) + return var + + +def _extract_nc4_variable_encoding( + variable: Variable, + raise_on_invalid=False, + lsd_okay=True, + h5py_okay=False, + backend="netCDF4", + unlimited_dims=None, +) -> dict[str, Any]: + if unlimited_dims is None: + unlimited_dims = () + + encoding = variable.encoding.copy() + + safe_to_drop = {"source", "original_shape"} + valid_encodings = { + "zlib", + "complevel", + "fletcher32", + "contiguous", + "chunksizes", + "shuffle", + "_FillValue", + "dtype", + "compression", + "significant_digits", + "quantize_mode", + "blosc_shuffle", + "szip_coding", + "szip_pixels_per_block", + "endian", + } + if lsd_okay: + valid_encodings.add("least_significant_digit") + if h5py_okay: + valid_encodings.add("compression_opts") + + if not raise_on_invalid and encoding.get("chunksizes") is not None: + # It's possible to get encoded chunksizes larger than a dimension size + # if the original file had an unlimited dimension. This is problematic + # if the new file no longer has an unlimited dimension. + chunksizes = encoding["chunksizes"] + chunks_too_big = any( + c > d and dim not in unlimited_dims + for c, d, dim in zip(chunksizes, variable.shape, variable.dims) + ) + has_original_shape = "original_shape" in encoding + changed_shape = ( + has_original_shape and encoding.get("original_shape") != variable.shape + ) + if chunks_too_big or changed_shape: + del encoding["chunksizes"] + + var_has_unlim_dim = any(dim in unlimited_dims for dim in variable.dims) + if not raise_on_invalid and var_has_unlim_dim and "contiguous" in encoding.keys(): + del encoding["contiguous"] + + for k in safe_to_drop: + if k in encoding: + del encoding[k] + + if raise_on_invalid: + invalid = [k for k in encoding if k not in valid_encodings] + if invalid: + raise ValueError( + f"unexpected encoding parameters for {backend!r} backend: {invalid!r}. Valid " + f"encodings are: {valid_encodings!r}" + ) + else: + for k in list(encoding): + if k not in valid_encodings: + del encoding[k] + + return encoding + + +def _is_list_of_strings(value) -> bool: + arr = np.asarray(value) + return arr.dtype.kind in ["U", "S"] and arr.size > 1 + + +class NetCDF4DataStore(WritableCFDataStore): + """Store for reading and writing data via the Python-NetCDF4 library. + + This store supports NetCDF3, NetCDF4 and OpenDAP datasets. + """ + + __slots__ = ( + "autoclose", + "format", + "is_remote", + "lock", + "_filename", + "_group", + "_manager", + "_mode", + ) + + def __init__( + self, manager, group=None, mode=None, lock=NETCDF4_PYTHON_LOCK, autoclose=False + ): + import netCDF4 + + if isinstance(manager, netCDF4.Dataset): + if group is None: + root, group = find_root_and_group(manager) + else: + if type(manager) is not netCDF4.Dataset: + raise ValueError( + "must supply a root netCDF4.Dataset if the group " + "argument is provided" + ) + root = manager + manager = DummyFileManager(root) + + self._manager = manager + self._group = group + self._mode = mode + self.format = self.ds.data_model + self._filename = self.ds.filepath() + self.is_remote = is_remote_uri(self._filename) + self.lock = ensure_lock(lock) + self.autoclose = autoclose + + @classmethod + def open( + cls, + filename, + mode="r", + format="NETCDF4", + group=None, + clobber=True, + diskless=False, + persist=False, + lock=None, + lock_maker=None, + autoclose=False, + ): + import netCDF4 + + if isinstance(filename, os.PathLike): + filename = os.fspath(filename) + + if not isinstance(filename, str): + raise ValueError( + "can only read bytes or file-like objects " + "with engine='scipy' or 'h5netcdf'" + ) + + if format is None: + format = "NETCDF4" + + if lock is None: + if mode == "r": + if is_remote_uri(filename): + lock = NETCDFC_LOCK + else: + lock = NETCDF4_PYTHON_LOCK + else: + if format is None or format.startswith("NETCDF4"): + base_lock = NETCDF4_PYTHON_LOCK + else: + base_lock = NETCDFC_LOCK + lock = combine_locks([base_lock, get_write_lock(filename)]) + + kwargs = dict( + clobber=clobber, diskless=diskless, persist=persist, format=format + ) + manager = CachingFileManager( + netCDF4.Dataset, filename, mode=mode, kwargs=kwargs + ) + return cls(manager, group=group, mode=mode, lock=lock, autoclose=autoclose) + + def _acquire(self, needs_lock=True): + with self._manager.acquire_context(needs_lock) as root: + ds = _nc4_require_group(root, self._group, self._mode) + return ds + + @property + def ds(self): + return self._acquire() + + def open_store_variable(self, name: str, var): + import netCDF4 + + dimensions = var.dimensions + attributes = {k: var.getncattr(k) for k in var.ncattrs()} + data = indexing.LazilyIndexedArray(NetCDF4ArrayWrapper(name, self)) + encoding: dict[str, Any] = {} + if isinstance(var.datatype, netCDF4.EnumType): + encoding["dtype"] = np.dtype( + data.dtype, + metadata={ + "enum": var.datatype.enum_dict, + "enum_name": var.datatype.name, + }, + ) + else: + encoding["dtype"] = var.dtype + _ensure_fill_value_valid(data, attributes) + # netCDF4 specific encoding; save _FillValue for later + filters = var.filters() + if filters is not None: + encoding.update(filters) + chunking = var.chunking() + if chunking is not None: + if chunking == "contiguous": + encoding["contiguous"] = True + encoding["chunksizes"] = None + else: + encoding["contiguous"] = False + encoding["chunksizes"] = tuple(chunking) + encoding["preferred_chunks"] = dict(zip(var.dimensions, chunking)) + # TODO: figure out how to round-trip "endian-ness" without raising + # warnings from netCDF4 + # encoding['endian'] = var.endian() + pop_to(attributes, encoding, "least_significant_digit") + # save source so __repr__ can detect if it's local or not + encoding["source"] = self._filename + encoding["original_shape"] = data.shape + + return Variable(dimensions, data, attributes, encoding) + + def get_variables(self): + return FrozenDict( + (k, self.open_store_variable(k, v)) for k, v in self.ds.variables.items() + ) + + def get_attrs(self): + return FrozenDict((k, self.ds.getncattr(k)) for k in self.ds.ncattrs()) + + def get_dimensions(self): + return FrozenDict((k, len(v)) for k, v in self.ds.dimensions.items()) + + def get_encoding(self): + return { + "unlimited_dims": { + k for k, v in self.ds.dimensions.items() if v.isunlimited() + } + } + + def set_dimension(self, name, length, is_unlimited=False): + _ensure_no_forward_slash_in_name(name) + dim_length = length if not is_unlimited else None + self.ds.createDimension(name, size=dim_length) + + def set_attribute(self, key, value): + if self.format != "NETCDF4": + value = encode_nc3_attr_value(value) + if _is_list_of_strings(value): + # encode as NC_STRING if attr is list of strings + self.ds.setncattr_string(key, value) + else: + self.ds.setncattr(key, value) + + def encode_variable(self, variable): + variable = _force_native_endianness(variable) + if self.format == "NETCDF4": + variable = _encode_nc4_variable(variable) + else: + variable = encode_nc3_variable(variable) + return variable + + def prepare_variable( + self, name, variable: Variable, check_encoding=False, unlimited_dims=None + ): + _ensure_no_forward_slash_in_name(name) + attrs = variable.attrs.copy() + fill_value = attrs.pop("_FillValue", None) + datatype = _get_datatype( + variable, self.format, raise_on_invalid_encoding=check_encoding + ) + # check enum metadata and use netCDF4.EnumType + if ( + (meta := np.dtype(datatype).metadata) + and (e_name := meta.get("enum_name")) + and (e_dict := meta.get("enum")) + ): + datatype = self._build_and_get_enum(name, datatype, e_name, e_dict) + encoding = _extract_nc4_variable_encoding( + variable, raise_on_invalid=check_encoding, unlimited_dims=unlimited_dims + ) + if name in self.ds.variables: + nc4_var = self.ds.variables[name] + else: + default_args = dict( + varname=name, + datatype=datatype, + dimensions=variable.dims, + zlib=False, + complevel=4, + shuffle=True, + fletcher32=False, + contiguous=False, + chunksizes=None, + endian="native", + least_significant_digit=None, + fill_value=fill_value, + ) + default_args.update(encoding) + default_args.pop("_FillValue", None) + nc4_var = self.ds.createVariable(**default_args) + + nc4_var.setncatts(attrs) + + target = NetCDF4ArrayWrapper(name, self) + + return target, variable.data + + def _build_and_get_enum( + self, var_name: str, dtype: np.dtype, enum_name: str, enum_dict: dict[str, int] + ) -> Any: + """ + Add or get the netCDF4 Enum based on the dtype in encoding. + The return type should be ``netCDF4.EnumType``, + but we avoid importing netCDF4 globally for performances. + """ + if enum_name not in self.ds.enumtypes: + return self.ds.createEnumType( + dtype, + enum_name, + enum_dict, + ) + datatype = self.ds.enumtypes[enum_name] + if datatype.enum_dict != enum_dict: + error_msg = ( + f"Cannot save variable `{var_name}` because an enum" + f" `{enum_name}` already exists in the Dataset but have" + " a different definition. To fix this error, make sure" + " each variable have a uniquely named enum in their" + " `encoding['dtype'].metadata` or, if they should share" + " the same enum type, make sure the enums are identical." + ) + raise ValueError(error_msg) + return datatype + + def sync(self): + self.ds.sync() + + def close(self, **kwargs): + self._manager.close(**kwargs) + + +class NetCDF4BackendEntrypoint(BackendEntrypoint): + """ + Backend for netCDF files based on the netCDF4 package. + + It can open ".nc", ".nc4", ".cdf" files and will be chosen + as default for these files. + + Additionally it can open valid HDF5 files, see + https://h5netcdf.org/#invalid-netcdf-files for more info. + It will not be detected as valid backend for such files, so make + sure to specify ``engine="netcdf4"`` in ``open_dataset``. + + For more information about the underlying library, visit: + https://unidata.github.io/netcdf4-python + + See Also + -------- + backends.NetCDF4DataStore + backends.H5netcdfBackendEntrypoint + backends.ScipyBackendEntrypoint + """ + + description = ( + "Open netCDF (.nc, .nc4 and .cdf) and most HDF5 files using netCDF4 in Xarray" + ) + url = "https://docs.xarray.dev/en/stable/generated/xarray.backends.NetCDF4BackendEntrypoint.html" + + def guess_can_open( + self, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + ) -> bool: + if isinstance(filename_or_obj, str) and is_remote_uri(filename_or_obj): + return True + magic_number = try_read_magic_number_from_path(filename_or_obj) + if magic_number is not None: + # netcdf 3 or HDF5 + return magic_number.startswith((b"CDF", b"\211HDF\r\n\032\n")) + + if isinstance(filename_or_obj, (str, os.PathLike)): + _, ext = os.path.splitext(filename_or_obj) + return ext in {".nc", ".nc4", ".cdf"} + + return False + + def open_dataset( # type: ignore[override] # allow LSP violation, not supporting **kwargs + self, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + *, + mask_and_scale=True, + decode_times=True, + concat_characters=True, + decode_coords=True, + drop_variables: str | Iterable[str] | None = None, + use_cftime=None, + decode_timedelta=None, + group=None, + mode="r", + format="NETCDF4", + clobber=True, + diskless=False, + persist=False, + lock=None, + autoclose=False, + ) -> Dataset: + filename_or_obj = _normalize_path(filename_or_obj) + store = NetCDF4DataStore.open( + filename_or_obj, + mode=mode, + format=format, + group=group, + clobber=clobber, + diskless=diskless, + persist=persist, + lock=lock, + autoclose=autoclose, + ) + + store_entrypoint = StoreBackendEntrypoint() + with close_on_error(store): + ds = store_entrypoint.open_dataset( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + return ds + + def open_datatree( + self, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + *, + mask_and_scale=True, + decode_times=True, + concat_characters=True, + decode_coords=True, + drop_variables: str | Iterable[str] | None = None, + use_cftime=None, + decode_timedelta=None, + group: str | Iterable[str] | Callable | None = None, + **kwargs, + ) -> DataTree: + from xarray.backends.api import open_dataset + from xarray.backends.common import _iter_nc_groups + from xarray.core.datatree import DataTree + from xarray.core.treenode import NodePath + + filename_or_obj = _normalize_path(filename_or_obj) + store = NetCDF4DataStore.open( + filename_or_obj, + group=group, + ) + if group: + parent = NodePath("/") / NodePath(group) + else: + parent = NodePath("/") + + manager = store._manager + ds = open_dataset(store, **kwargs) + tree_root = DataTree.from_dict({str(parent): ds}) + for path_group in _iter_nc_groups(store.ds, parent=parent): + group_store = NetCDF4DataStore(manager, group=path_group, **kwargs) + store_entrypoint = StoreBackendEntrypoint() + with close_on_error(group_store): + ds = store_entrypoint.open_dataset( + group_store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + new_node: DataTree = DataTree(name=NodePath(path_group).name, data=ds) + tree_root._set_item( + path_group, + new_node, + allow_overwrite=False, + new_nodes_along_path=True, + ) + return tree_root + + +BACKEND_ENTRYPOINTS["netcdf4"] = ("netCDF4", NetCDF4BackendEntrypoint) diff --git a/test/fixtures/whole_applications/xarray/xarray/backends/netcdf3.py b/test/fixtures/whole_applications/xarray/xarray/backends/netcdf3.py new file mode 100644 index 0000000..70ddbdd --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/backends/netcdf3.py @@ -0,0 +1,171 @@ +from __future__ import annotations + +import unicodedata + +import numpy as np + +from xarray import coding +from xarray.core.variable import Variable + +# Special characters that are permitted in netCDF names except in the +# 0th position of the string +_specialchars = '_.@+- !"#$%&\\()*,:;<=>?[]^`{|}~' + +# The following are reserved names in CDL and may not be used as names of +# variables, dimension, attributes +_reserved_names = { + "byte", + "char", + "short", + "ushort", + "int", + "uint", + "int64", + "uint64", + "float", + "real", + "double", + "bool", + "string", +} + +# These data-types aren't supported by netCDF3, so they are automatically +# coerced instead as indicated by the "coerce_nc3_dtype" function +_nc3_dtype_coercions = { + "int64": "int32", + "uint64": "int32", + "uint32": "int32", + "uint16": "int16", + "uint8": "int8", + "bool": "int8", +} + +# encode all strings as UTF-8 +STRING_ENCODING = "utf-8" +COERCION_VALUE_ERROR = ( + "could not safely cast array from {dtype} to {new_dtype}. While it is not " + "always the case, a common reason for this is that xarray has deemed it " + "safest to encode np.datetime64[ns] or np.timedelta64[ns] values with " + "int64 values representing units of 'nanoseconds'. This is either due to " + "the fact that the times are known to require nanosecond precision for an " + "accurate round trip, or that the times are unknown prior to writing due " + "to being contained in a chunked array. Ways to work around this are " + "either to use a backend that supports writing int64 values, or to " + "manually specify the encoding['units'] and encoding['dtype'] (e.g. " + "'seconds since 1970-01-01' and np.dtype('int32')) on the time " + "variable(s) such that the times can be serialized in a netCDF3 file " + "(note that depending on the situation, however, this latter option may " + "result in an inaccurate round trip)." +) + + +def coerce_nc3_dtype(arr): + """Coerce an array to a data type that can be stored in a netCDF-3 file + + This function performs the dtype conversions as specified by the + ``_nc3_dtype_coercions`` mapping: + int64 -> int32 + uint64 -> int32 + uint32 -> int32 + uint16 -> int16 + uint8 -> int8 + bool -> int8 + + Data is checked for equality, or equivalence (non-NaN values) using the + ``(cast_array == original_array).all()``. + """ + dtype = str(arr.dtype) + if dtype in _nc3_dtype_coercions: + new_dtype = _nc3_dtype_coercions[dtype] + # TODO: raise a warning whenever casting the data-type instead? + cast_arr = arr.astype(new_dtype) + if not (cast_arr == arr).all(): + raise ValueError( + COERCION_VALUE_ERROR.format(dtype=dtype, new_dtype=new_dtype) + ) + arr = cast_arr + return arr + + +def encode_nc3_attr_value(value): + if isinstance(value, bytes): + pass + elif isinstance(value, str): + value = value.encode(STRING_ENCODING) + else: + value = coerce_nc3_dtype(np.atleast_1d(value)) + if value.ndim > 1: + raise ValueError("netCDF attributes must be 1-dimensional") + return value + + +def encode_nc3_attrs(attrs): + return {k: encode_nc3_attr_value(v) for k, v in attrs.items()} + + +def _maybe_prepare_times(var): + # checks for integer-based time-like and + # replaces np.iinfo(np.int64).min with _FillValue or np.nan + # this keeps backwards compatibility + + data = var.data + if data.dtype.kind in "iu": + units = var.attrs.get("units", None) + if units is not None: + if coding.variables._is_time_like(units): + mask = data == np.iinfo(np.int64).min + if mask.any(): + data = np.where(mask, var.attrs.get("_FillValue", np.nan), data) + return data + + +def encode_nc3_variable(var): + for coder in [ + coding.strings.EncodedStringCoder(allows_unicode=False), + coding.strings.CharacterArrayCoder(), + ]: + var = coder.encode(var) + data = _maybe_prepare_times(var) + data = coerce_nc3_dtype(data) + attrs = encode_nc3_attrs(var.attrs) + return Variable(var.dims, data, attrs, var.encoding) + + +def _isalnumMUTF8(c): + """Return True if the given UTF-8 encoded character is alphanumeric + or multibyte. + + Input is not checked! + """ + return c.isalnum() or (len(c.encode("utf-8")) > 1) + + +def is_valid_nc3_name(s): + """Test whether an object can be validly converted to a netCDF-3 + dimension, variable or attribute name + + Earlier versions of the netCDF C-library reference implementation + enforced a more restricted set of characters in creating new names, + but permitted reading names containing arbitrary bytes. This + specification extends the permitted characters in names to include + multi-byte UTF-8 encoded Unicode and additional printing characters + from the US-ASCII alphabet. The first character of a name must be + alphanumeric, a multi-byte UTF-8 character, or '_' (reserved for + special names with meaning to implementations, such as the + "_FillValue" attribute). Subsequent characters may also include + printing special characters, except for '/' which is not allowed in + names. Names that have trailing space characters are also not + permitted. + """ + if not isinstance(s, str): + return False + num_bytes = len(s.encode("utf-8")) + return ( + (unicodedata.normalize("NFC", s) == s) + and (s not in _reserved_names) + and (num_bytes >= 0) + and ("/" not in s) + and (s[-1] != " ") + and (_isalnumMUTF8(s[0]) or (s[0] == "_")) + and all(_isalnumMUTF8(c) or c in _specialchars for c in s) + ) diff --git a/test/fixtures/whole_applications/xarray/xarray/backends/plugins.py b/test/fixtures/whole_applications/xarray/xarray/backends/plugins.py new file mode 100644 index 0000000..a62ca6c --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/backends/plugins.py @@ -0,0 +1,217 @@ +from __future__ import annotations + +import functools +import inspect +import itertools +import sys +import warnings +from importlib.metadata import entry_points +from typing import TYPE_CHECKING, Any, Callable + +from xarray.backends.common import BACKEND_ENTRYPOINTS, BackendEntrypoint +from xarray.core.utils import module_available + +if TYPE_CHECKING: + import os + from importlib.metadata import EntryPoint + + if sys.version_info >= (3, 10): + from importlib.metadata import EntryPoints + else: + EntryPoints = list[EntryPoint] + from io import BufferedIOBase + + from xarray.backends.common import AbstractDataStore + +STANDARD_BACKENDS_ORDER = ["netcdf4", "h5netcdf", "scipy"] + + +def remove_duplicates(entrypoints: EntryPoints) -> list[EntryPoint]: + # sort and group entrypoints by name + entrypoints_sorted = sorted(entrypoints, key=lambda ep: ep.name) + entrypoints_grouped = itertools.groupby(entrypoints_sorted, key=lambda ep: ep.name) + # check if there are multiple entrypoints for the same name + unique_entrypoints = [] + for name, _matches in entrypoints_grouped: + # remove equal entrypoints + matches = list(set(_matches)) + unique_entrypoints.append(matches[0]) + matches_len = len(matches) + if matches_len > 1: + all_module_names = [e.value.split(":")[0] for e in matches] + selected_module_name = all_module_names[0] + warnings.warn( + f"Found {matches_len} entrypoints for the engine name {name}:" + f"\n {all_module_names}.\n " + f"The entrypoint {selected_module_name} will be used.", + RuntimeWarning, + ) + return unique_entrypoints + + +def detect_parameters(open_dataset: Callable) -> tuple[str, ...]: + signature = inspect.signature(open_dataset) + parameters = signature.parameters + parameters_list = [] + for name, param in parameters.items(): + if param.kind in ( + inspect.Parameter.VAR_KEYWORD, + inspect.Parameter.VAR_POSITIONAL, + ): + raise TypeError( + f"All the parameters in {open_dataset!r} signature should be explicit. " + "*args and **kwargs is not supported" + ) + if name != "self": + parameters_list.append(name) + return tuple(parameters_list) + + +def backends_dict_from_pkg( + entrypoints: list[EntryPoint], +) -> dict[str, type[BackendEntrypoint]]: + backend_entrypoints = {} + for entrypoint in entrypoints: + name = entrypoint.name + try: + backend = entrypoint.load() + backend_entrypoints[name] = backend + except Exception as ex: + warnings.warn(f"Engine {name!r} loading failed:\n{ex}", RuntimeWarning) + return backend_entrypoints + + +def set_missing_parameters( + backend_entrypoints: dict[str, type[BackendEntrypoint]] +) -> None: + for _, backend in backend_entrypoints.items(): + if backend.open_dataset_parameters is None: + open_dataset = backend.open_dataset + backend.open_dataset_parameters = detect_parameters(open_dataset) + + +def sort_backends( + backend_entrypoints: dict[str, type[BackendEntrypoint]] +) -> dict[str, type[BackendEntrypoint]]: + ordered_backends_entrypoints = {} + for be_name in STANDARD_BACKENDS_ORDER: + if backend_entrypoints.get(be_name, None) is not None: + ordered_backends_entrypoints[be_name] = backend_entrypoints.pop(be_name) + ordered_backends_entrypoints.update( + {name: backend_entrypoints[name] for name in sorted(backend_entrypoints)} + ) + return ordered_backends_entrypoints + + +def build_engines(entrypoints: EntryPoints) -> dict[str, BackendEntrypoint]: + backend_entrypoints: dict[str, type[BackendEntrypoint]] = {} + for backend_name, (module_name, backend) in BACKEND_ENTRYPOINTS.items(): + if module_name is None or module_available(module_name): + backend_entrypoints[backend_name] = backend + entrypoints_unique = remove_duplicates(entrypoints) + external_backend_entrypoints = backends_dict_from_pkg(entrypoints_unique) + backend_entrypoints.update(external_backend_entrypoints) + backend_entrypoints = sort_backends(backend_entrypoints) + set_missing_parameters(backend_entrypoints) + return {name: backend() for name, backend in backend_entrypoints.items()} + + +@functools.lru_cache(maxsize=1) +def list_engines() -> dict[str, BackendEntrypoint]: + """ + Return a dictionary of available engines and their BackendEntrypoint objects. + + Returns + ------- + dictionary + + Notes + ----- + This function lives in the backends namespace (``engs=xr.backends.list_engines()``). + If available, more information is available about each backend via ``engs["eng_name"]``. + + # New selection mechanism introduced with Python 3.10. See GH6514. + """ + if sys.version_info >= (3, 10): + entrypoints = entry_points(group="xarray.backends") + else: + entrypoints = entry_points().get("xarray.backends", []) + return build_engines(entrypoints) + + +def refresh_engines() -> None: + """Refreshes the backend engines based on installed packages.""" + list_engines.cache_clear() + + +def guess_engine( + store_spec: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, +) -> str | type[BackendEntrypoint]: + engines = list_engines() + + for engine, backend in engines.items(): + try: + if backend.guess_can_open(store_spec): + return engine + except PermissionError: + raise + except Exception: + warnings.warn(f"{engine!r} fails while guessing", RuntimeWarning) + + compatible_engines = [] + for engine, (_, backend_cls) in BACKEND_ENTRYPOINTS.items(): + try: + backend = backend_cls() + if backend.guess_can_open(store_spec): + compatible_engines.append(engine) + except Exception: + warnings.warn(f"{engine!r} fails while guessing", RuntimeWarning) + + installed_engines = [k for k in engines if k != "store"] + if not compatible_engines: + if installed_engines: + error_msg = ( + "did not find a match in any of xarray's currently installed IO " + f"backends {installed_engines}. Consider explicitly selecting one of the " + "installed engines via the ``engine`` parameter, or installing " + "additional IO dependencies, see:\n" + "https://docs.xarray.dev/en/stable/getting-started-guide/installing.html\n" + "https://docs.xarray.dev/en/stable/user-guide/io.html" + ) + else: + error_msg = ( + "xarray is unable to open this file because it has no currently " + "installed IO backends. Xarray's read/write support requires " + "installing optional IO dependencies, see:\n" + "https://docs.xarray.dev/en/stable/getting-started-guide/installing.html\n" + "https://docs.xarray.dev/en/stable/user-guide/io" + ) + else: + error_msg = ( + "found the following matches with the input file in xarray's IO " + f"backends: {compatible_engines}. But their dependencies may not be installed, see:\n" + "https://docs.xarray.dev/en/stable/user-guide/io.html \n" + "https://docs.xarray.dev/en/stable/getting-started-guide/installing.html" + ) + + raise ValueError(error_msg) + + +def get_backend(engine: str | type[BackendEntrypoint]) -> BackendEntrypoint: + """Select open_dataset method based on current engine.""" + if isinstance(engine, str): + engines = list_engines() + if engine not in engines: + raise ValueError( + f"unrecognized engine {engine} must be one of: {list(engines)}" + ) + backend = engines[engine] + elif isinstance(engine, type) and issubclass(engine, BackendEntrypoint): + backend = engine() + else: + raise TypeError( + "engine must be a string or a subclass of " + f"xarray.backends.BackendEntrypoint: {engine}" + ) + + return backend diff --git a/test/fixtures/whole_applications/xarray/xarray/backends/pydap_.py b/test/fixtures/whole_applications/xarray/xarray/backends/pydap_.py new file mode 100644 index 0000000..5a475a7 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/backends/pydap_.py @@ -0,0 +1,216 @@ +from __future__ import annotations + +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any + +import numpy as np + +from xarray.backends.common import ( + BACKEND_ENTRYPOINTS, + AbstractDataStore, + BackendArray, + BackendEntrypoint, + robust_getitem, +) +from xarray.backends.store import StoreBackendEntrypoint +from xarray.core import indexing +from xarray.core.utils import ( + Frozen, + FrozenDict, + close_on_error, + is_dict_like, + is_remote_uri, +) +from xarray.core.variable import Variable +from xarray.namedarray.pycompat import integer_types + +if TYPE_CHECKING: + import os + from io import BufferedIOBase + + from xarray.core.dataset import Dataset + + +class PydapArrayWrapper(BackendArray): + def __init__(self, array): + self.array = array + + @property + def shape(self) -> tuple[int, ...]: + return self.array.shape + + @property + def dtype(self): + return self.array.dtype + + def __getitem__(self, key): + return indexing.explicit_indexing_adapter( + key, self.shape, indexing.IndexingSupport.BASIC, self._getitem + ) + + def _getitem(self, key): + # pull the data from the array attribute if possible, to avoid + # downloading coordinate data twice + array = getattr(self.array, "array", self.array) + result = robust_getitem(array, key, catch=ValueError) + result = np.asarray(result) + # in some cases, pydap doesn't squeeze axes automatically like numpy + axis = tuple(n for n, k in enumerate(key) if isinstance(k, integer_types)) + if result.ndim + len(axis) != array.ndim and axis: + result = np.squeeze(result, axis) + + return result + + +def _fix_attributes(attributes): + attributes = dict(attributes) + for k in list(attributes): + if k.lower() == "global" or k.lower().endswith("_global"): + # move global attributes to the top level, like the netcdf-C + # DAP client + attributes.update(attributes.pop(k)) + elif is_dict_like(attributes[k]): + # Make Hierarchical attributes to a single level with a + # dot-separated key + attributes.update( + { + f"{k}.{k_child}": v_child + for k_child, v_child in attributes.pop(k).items() + } + ) + return attributes + + +class PydapDataStore(AbstractDataStore): + """Store for accessing OpenDAP datasets with pydap. + + This store provides an alternative way to access OpenDAP datasets that may + be useful if the netCDF4 library is not available. + """ + + def __init__(self, ds): + """ + Parameters + ---------- + ds : pydap DatasetType + """ + self.ds = ds + + @classmethod + def open( + cls, + url, + application=None, + session=None, + output_grid=None, + timeout=None, + verify=None, + user_charset=None, + ): + import pydap.client + import pydap.lib + + if timeout is None: + from pydap.lib import DEFAULT_TIMEOUT + + timeout = DEFAULT_TIMEOUT + + kwargs = { + "url": url, + "application": application, + "session": session, + "output_grid": output_grid or True, + "timeout": timeout, + } + if verify is not None: + kwargs.update({"verify": verify}) + if user_charset is not None: + kwargs.update({"user_charset": user_charset}) + ds = pydap.client.open_url(**kwargs) + return cls(ds) + + def open_store_variable(self, var): + data = indexing.LazilyIndexedArray(PydapArrayWrapper(var)) + return Variable(var.dimensions, data, _fix_attributes(var.attributes)) + + def get_variables(self): + return FrozenDict( + (k, self.open_store_variable(self.ds[k])) for k in self.ds.keys() + ) + + def get_attrs(self): + return Frozen(_fix_attributes(self.ds.attributes)) + + def get_dimensions(self): + return Frozen(self.ds.dimensions) + + +class PydapBackendEntrypoint(BackendEntrypoint): + """ + Backend for steaming datasets over the internet using + the Data Access Protocol, also known as DODS or OPeNDAP + based on the pydap package. + + This backend is selected by default for urls. + + For more information about the underlying library, visit: + https://www.pydap.org + + See Also + -------- + backends.PydapDataStore + """ + + description = "Open remote datasets via OPeNDAP using pydap in Xarray" + url = "https://docs.xarray.dev/en/stable/generated/xarray.backends.PydapBackendEntrypoint.html" + + def guess_can_open( + self, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + ) -> bool: + return isinstance(filename_or_obj, str) and is_remote_uri(filename_or_obj) + + def open_dataset( # type: ignore[override] # allow LSP violation, not supporting **kwargs + self, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + *, + mask_and_scale=True, + decode_times=True, + concat_characters=True, + decode_coords=True, + drop_variables: str | Iterable[str] | None = None, + use_cftime=None, + decode_timedelta=None, + application=None, + session=None, + output_grid=None, + timeout=None, + verify=None, + user_charset=None, + ) -> Dataset: + store = PydapDataStore.open( + url=filename_or_obj, + application=application, + session=session, + output_grid=output_grid, + timeout=timeout, + verify=verify, + user_charset=user_charset, + ) + + store_entrypoint = StoreBackendEntrypoint() + with close_on_error(store): + ds = store_entrypoint.open_dataset( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + return ds + + +BACKEND_ENTRYPOINTS["pydap"] = ("pydap", PydapBackendEntrypoint) diff --git a/test/fixtures/whole_applications/xarray/xarray/backends/scipy_.py b/test/fixtures/whole_applications/xarray/xarray/backends/scipy_.py new file mode 100644 index 0000000..f8c486e --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/backends/scipy_.py @@ -0,0 +1,345 @@ +from __future__ import annotations + +import gzip +import io +import os +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any + +import numpy as np + +from xarray.backends.common import ( + BACKEND_ENTRYPOINTS, + BackendArray, + BackendEntrypoint, + WritableCFDataStore, + _normalize_path, +) +from xarray.backends.file_manager import CachingFileManager, DummyFileManager +from xarray.backends.locks import ensure_lock, get_write_lock +from xarray.backends.netcdf3 import ( + encode_nc3_attr_value, + encode_nc3_variable, + is_valid_nc3_name, +) +from xarray.backends.store import StoreBackendEntrypoint +from xarray.core import indexing +from xarray.core.utils import ( + Frozen, + FrozenDict, + close_on_error, + module_available, + try_read_magic_number_from_file_or_path, +) +from xarray.core.variable import Variable + +if TYPE_CHECKING: + from io import BufferedIOBase + + from xarray.backends.common import AbstractDataStore + from xarray.core.dataset import Dataset + + +HAS_NUMPY_2_0 = module_available("numpy", minversion="2.0.0.dev0") + + +def _decode_string(s): + if isinstance(s, bytes): + return s.decode("utf-8", "replace") + return s + + +def _decode_attrs(d): + # don't decode _FillValue from bytes -> unicode, because we want to ensure + # that its type matches the data exactly + return {k: v if k == "_FillValue" else _decode_string(v) for (k, v) in d.items()} + + +class ScipyArrayWrapper(BackendArray): + def __init__(self, variable_name, datastore): + self.datastore = datastore + self.variable_name = variable_name + array = self.get_variable().data + self.shape = array.shape + self.dtype = np.dtype(array.dtype.kind + str(array.dtype.itemsize)) + + def get_variable(self, needs_lock=True): + ds = self.datastore._manager.acquire(needs_lock) + return ds.variables[self.variable_name] + + def _getitem(self, key): + with self.datastore.lock: + data = self.get_variable(needs_lock=False).data + return data[key] + + def __getitem__(self, key): + data = indexing.explicit_indexing_adapter( + key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem + ) + # Copy data if the source file is mmapped. This makes things consistent + # with the netCDF4 library by ensuring we can safely read arrays even + # after closing associated files. + copy = self.datastore.ds.use_mmap + + # adapt handling of copy-kwarg to numpy 2.0 + # see https://github.com/numpy/numpy/issues/25916 + # and https://github.com/numpy/numpy/pull/25922 + copy = None if HAS_NUMPY_2_0 and copy is False else copy + + return np.array(data, dtype=self.dtype, copy=copy) + + def __setitem__(self, key, value): + with self.datastore.lock: + data = self.get_variable(needs_lock=False) + try: + data[key] = value + except TypeError: + if key is Ellipsis: + # workaround for GH: scipy/scipy#6880 + data[:] = value + else: + raise + + +def _open_scipy_netcdf(filename, mode, mmap, version): + import scipy.io + + # if the string ends with .gz, then gunzip and open as netcdf file + if isinstance(filename, str) and filename.endswith(".gz"): + try: + return scipy.io.netcdf_file( + gzip.open(filename), mode=mode, mmap=mmap, version=version + ) + except TypeError as e: + # TODO: gzipped loading only works with NetCDF3 files. + errmsg = e.args[0] + if "is not a valid NetCDF 3 file" in errmsg: + raise ValueError("gzipped file loading only supports NetCDF 3 files.") + else: + raise + + if isinstance(filename, bytes) and filename.startswith(b"CDF"): + # it's a NetCDF3 bytestring + filename = io.BytesIO(filename) + + try: + return scipy.io.netcdf_file(filename, mode=mode, mmap=mmap, version=version) + except TypeError as e: # netcdf3 message is obscure in this case + errmsg = e.args[0] + if "is not a valid NetCDF 3 file" in errmsg: + msg = """ + If this is a NetCDF4 file, you may need to install the + netcdf4 library, e.g., + + $ pip install netcdf4 + """ + errmsg += msg + raise TypeError(errmsg) + else: + raise + + +class ScipyDataStore(WritableCFDataStore): + """Store for reading and writing data via scipy.io.netcdf. + + This store has the advantage of being able to be initialized with a + StringIO object, allow for serialization without writing to disk. + + It only supports the NetCDF3 file-format. + """ + + def __init__( + self, filename_or_obj, mode="r", format=None, group=None, mmap=None, lock=None + ): + if group is not None: + raise ValueError("cannot save to a group with the scipy.io.netcdf backend") + + if format is None or format == "NETCDF3_64BIT": + version = 2 + elif format == "NETCDF3_CLASSIC": + version = 1 + else: + raise ValueError(f"invalid format for scipy.io.netcdf backend: {format!r}") + + if lock is None and mode != "r" and isinstance(filename_or_obj, str): + lock = get_write_lock(filename_or_obj) + + self.lock = ensure_lock(lock) + + if isinstance(filename_or_obj, str): + manager = CachingFileManager( + _open_scipy_netcdf, + filename_or_obj, + mode=mode, + lock=lock, + kwargs=dict(mmap=mmap, version=version), + ) + else: + scipy_dataset = _open_scipy_netcdf( + filename_or_obj, mode=mode, mmap=mmap, version=version + ) + manager = DummyFileManager(scipy_dataset) + + self._manager = manager + + @property + def ds(self): + return self._manager.acquire() + + def open_store_variable(self, name, var): + return Variable( + var.dimensions, + ScipyArrayWrapper(name, self), + _decode_attrs(var._attributes), + ) + + def get_variables(self): + return FrozenDict( + (k, self.open_store_variable(k, v)) for k, v in self.ds.variables.items() + ) + + def get_attrs(self): + return Frozen(_decode_attrs(self.ds._attributes)) + + def get_dimensions(self): + return Frozen(self.ds.dimensions) + + def get_encoding(self): + return { + "unlimited_dims": {k for k, v in self.ds.dimensions.items() if v is None} + } + + def set_dimension(self, name, length, is_unlimited=False): + if name in self.ds.dimensions: + raise ValueError( + f"{type(self).__name__} does not support modifying dimensions" + ) + dim_length = length if not is_unlimited else None + self.ds.createDimension(name, dim_length) + + def _validate_attr_key(self, key): + if not is_valid_nc3_name(key): + raise ValueError("Not a valid attribute name") + + def set_attribute(self, key, value): + self._validate_attr_key(key) + value = encode_nc3_attr_value(value) + setattr(self.ds, key, value) + + def encode_variable(self, variable): + variable = encode_nc3_variable(variable) + return variable + + def prepare_variable( + self, name, variable, check_encoding=False, unlimited_dims=None + ): + if ( + check_encoding + and variable.encoding + and variable.encoding != {"_FillValue": None} + ): + raise ValueError( + f"unexpected encoding for scipy backend: {list(variable.encoding)}" + ) + + data = variable.data + # nb. this still creates a numpy array in all memory, even though we + # don't write the data yet; scipy.io.netcdf does not not support + # incremental writes. + if name not in self.ds.variables: + self.ds.createVariable(name, data.dtype, variable.dims) + scipy_var = self.ds.variables[name] + for k, v in variable.attrs.items(): + self._validate_attr_key(k) + setattr(scipy_var, k, v) + + target = ScipyArrayWrapper(name, self) + + return target, data + + def sync(self): + self.ds.sync() + + def close(self): + self._manager.close() + + +class ScipyBackendEntrypoint(BackendEntrypoint): + """ + Backend for netCDF files based on the scipy package. + + It can open ".nc", ".nc4", ".cdf" and ".gz" files but will only be + selected as the default if the "netcdf4" and "h5netcdf" engines are + not available. It has the advantage that is is a lightweight engine + that has no system requirements (unlike netcdf4 and h5netcdf). + + Additionally it can open gizp compressed (".gz") files. + + For more information about the underlying library, visit: + https://docs.scipy.org/doc/scipy/reference/generated/scipy.io.netcdf_file.html + + See Also + -------- + backends.ScipyDataStore + backends.NetCDF4BackendEntrypoint + backends.H5netcdfBackendEntrypoint + """ + + description = "Open netCDF files (.nc, .nc4, .cdf and .gz) using scipy in Xarray" + url = "https://docs.xarray.dev/en/stable/generated/xarray.backends.ScipyBackendEntrypoint.html" + + def guess_can_open( + self, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + ) -> bool: + magic_number = try_read_magic_number_from_file_or_path(filename_or_obj) + if magic_number is not None and magic_number.startswith(b"\x1f\x8b"): + with gzip.open(filename_or_obj) as f: # type: ignore[arg-type] + magic_number = try_read_magic_number_from_file_or_path(f) + if magic_number is not None: + return magic_number.startswith(b"CDF") + + if isinstance(filename_or_obj, (str, os.PathLike)): + _, ext = os.path.splitext(filename_or_obj) + return ext in {".nc", ".nc4", ".cdf", ".gz"} + + return False + + def open_dataset( # type: ignore[override] # allow LSP violation, not supporting **kwargs + self, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + *, + mask_and_scale=True, + decode_times=True, + concat_characters=True, + decode_coords=True, + drop_variables: str | Iterable[str] | None = None, + use_cftime=None, + decode_timedelta=None, + mode="r", + format=None, + group=None, + mmap=None, + lock=None, + ) -> Dataset: + filename_or_obj = _normalize_path(filename_or_obj) + store = ScipyDataStore( + filename_or_obj, mode=mode, format=format, group=group, mmap=mmap, lock=lock + ) + + store_entrypoint = StoreBackendEntrypoint() + with close_on_error(store): + ds = store_entrypoint.open_dataset( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + return ds + + +BACKEND_ENTRYPOINTS["scipy"] = ("scipy", ScipyBackendEntrypoint) diff --git a/test/fixtures/whole_applications/xarray/xarray/backends/store.py b/test/fixtures/whole_applications/xarray/xarray/backends/store.py new file mode 100644 index 0000000..a507ee3 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/backends/store.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any + +from xarray import conventions +from xarray.backends.common import ( + BACKEND_ENTRYPOINTS, + AbstractDataStore, + BackendEntrypoint, +) +from xarray.core.dataset import Dataset + +if TYPE_CHECKING: + import os + from io import BufferedIOBase + + +class StoreBackendEntrypoint(BackendEntrypoint): + description = "Open AbstractDataStore instances in Xarray" + url = "https://docs.xarray.dev/en/stable/generated/xarray.backends.StoreBackendEntrypoint.html" + + def guess_can_open( + self, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + ) -> bool: + return isinstance(filename_or_obj, AbstractDataStore) + + def open_dataset( # type: ignore[override] # allow LSP violation, not supporting **kwargs + self, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + *, + mask_and_scale=True, + decode_times=True, + concat_characters=True, + decode_coords=True, + drop_variables: str | Iterable[str] | None = None, + use_cftime=None, + decode_timedelta=None, + ) -> Dataset: + assert isinstance(filename_or_obj, AbstractDataStore) + + vars, attrs = filename_or_obj.load() + encoding = filename_or_obj.get_encoding() + + vars, attrs, coord_names = conventions.decode_cf_variables( + vars, + attrs, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + + ds = Dataset(vars, attrs=attrs) + ds = ds.set_coords(coord_names.intersection(vars)) + ds.set_close(filename_or_obj.close) + ds.encoding = encoding + + return ds + + +BACKEND_ENTRYPOINTS["store"] = (None, StoreBackendEntrypoint) diff --git a/test/fixtures/whole_applications/xarray/xarray/backends/zarr.py b/test/fixtures/whole_applications/xarray/xarray/backends/zarr.py new file mode 100644 index 0000000..5f6aa0f --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/backends/zarr.py @@ -0,0 +1,1337 @@ +from __future__ import annotations + +import json +import os +import warnings +from collections.abc import Callable, Iterable +from typing import TYPE_CHECKING, Any + +import numpy as np +import pandas as pd + +from xarray import coding, conventions +from xarray.backends.common import ( + BACKEND_ENTRYPOINTS, + AbstractWritableDataStore, + BackendArray, + BackendEntrypoint, + _encode_variable_name, + _normalize_path, +) +from xarray.backends.store import StoreBackendEntrypoint +from xarray.core import indexing +from xarray.core.types import ZarrWriteModes +from xarray.core.utils import ( + FrozenDict, + HiddenKeyDict, + close_on_error, +) +from xarray.core.variable import Variable +from xarray.namedarray.parallelcompat import guess_chunkmanager +from xarray.namedarray.pycompat import integer_types + +if TYPE_CHECKING: + from io import BufferedIOBase + + from xarray.backends.common import AbstractDataStore + from xarray.core.dataset import Dataset + from xarray.core.datatree import DataTree + +# need some special secret attributes to tell us the dimensions +DIMENSION_KEY = "_ARRAY_DIMENSIONS" + + +def encode_zarr_attr_value(value): + """ + Encode a attribute value as something that can be serialized as json + + Many xarray datasets / variables have numpy arrays and values. This + function handles encoding / decoding of such items. + + ndarray -> list + scalar array -> scalar + other -> other (no change) + """ + if isinstance(value, np.ndarray): + encoded = value.tolist() + # this checks if it's a scalar number + elif isinstance(value, np.generic): + encoded = value.item() + else: + encoded = value + return encoded + + +class ZarrArrayWrapper(BackendArray): + __slots__ = ("dtype", "shape", "_array") + + def __init__(self, zarr_array): + # some callers attempt to evaluate an array if an `array` property exists on the object. + # we prefix with _ to avoid this inference. + self._array = zarr_array + self.shape = self._array.shape + + # preserve vlen string object dtype (GH 7328) + if self._array.filters is not None and any( + [filt.codec_id == "vlen-utf8" for filt in self._array.filters] + ): + dtype = coding.strings.create_vlen_dtype(str) + else: + dtype = self._array.dtype + + self.dtype = dtype + + def get_array(self): + return self._array + + def _oindex(self, key): + return self._array.oindex[key] + + def _vindex(self, key): + return self._array.vindex[key] + + def _getitem(self, key): + return self._array[key] + + def __getitem__(self, key): + array = self._array + if isinstance(key, indexing.BasicIndexer): + method = self._getitem + elif isinstance(key, indexing.VectorizedIndexer): + method = self._vindex + elif isinstance(key, indexing.OuterIndexer): + method = self._oindex + return indexing.explicit_indexing_adapter( + key, array.shape, indexing.IndexingSupport.VECTORIZED, method + ) + + # if self.ndim == 0: + # could possibly have a work-around for 0d data here + + +def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks): + """ + Given encoding chunks (possibly None or []) and variable chunks + (possibly None or []). + """ + + # zarr chunk spec: + # chunks : int or tuple of ints, optional + # Chunk shape. If not provided, will be guessed from shape and dtype. + + # if there are no chunks in encoding and the variable data is a numpy + # array, then we let zarr use its own heuristics to pick the chunks + if not var_chunks and not enc_chunks: + return None + + # if there are no chunks in encoding but there are dask chunks, we try to + # use the same chunks in zarr + # However, zarr chunks needs to be uniform for each array + # http://zarr.readthedocs.io/en/latest/spec/v1.html#chunks + # while dask chunks can be variable sized + # http://dask.pydata.org/en/latest/array-design.html#chunks + if var_chunks and not enc_chunks: + if any(len(set(chunks[:-1])) > 1 for chunks in var_chunks): + raise ValueError( + "Zarr requires uniform chunk sizes except for final chunk. " + f"Variable named {name!r} has incompatible dask chunks: {var_chunks!r}. " + "Consider rechunking using `chunk()`." + ) + if any((chunks[0] < chunks[-1]) for chunks in var_chunks): + raise ValueError( + "Final chunk of Zarr array must be the same size or smaller " + f"than the first. Variable named {name!r} has incompatible Dask chunks {var_chunks!r}." + "Consider either rechunking using `chunk()` or instead deleting " + "or modifying `encoding['chunks']`." + ) + # return the first chunk for each dimension + return tuple(chunk[0] for chunk in var_chunks) + + # from here on, we are dealing with user-specified chunks in encoding + # zarr allows chunks to be an integer, in which case it uses the same chunk + # size on each dimension. + # Here we re-implement this expansion ourselves. That makes the logic of + # checking chunk compatibility easier + + if isinstance(enc_chunks, integer_types): + enc_chunks_tuple = ndim * (enc_chunks,) + else: + enc_chunks_tuple = tuple(enc_chunks) + + if len(enc_chunks_tuple) != ndim: + # throw away encoding chunks, start over + return _determine_zarr_chunks(None, var_chunks, ndim, name, safe_chunks) + + for x in enc_chunks_tuple: + if not isinstance(x, int): + raise TypeError( + "zarr chunk sizes specified in `encoding['chunks']` " + "must be an int or a tuple of ints. " + f"Instead found encoding['chunks']={enc_chunks_tuple!r} " + f"for variable named {name!r}." + ) + + # if there are chunks in encoding and the variable data is a numpy array, + # we use the specified chunks + if not var_chunks: + return enc_chunks_tuple + + # the hard case + # DESIGN CHOICE: do not allow multiple dask chunks on a single zarr chunk + # this avoids the need to get involved in zarr synchronization / locking + # From zarr docs: + # "If each worker in a parallel computation is writing to a + # separate region of the array, and if region boundaries are perfectly aligned + # with chunk boundaries, then no synchronization is required." + # TODO: incorporate synchronizer to allow writes from multiple dask + # threads + if var_chunks and enc_chunks_tuple: + for zchunk, dchunks in zip(enc_chunks_tuple, var_chunks): + for dchunk in dchunks[:-1]: + if dchunk % zchunk: + base_error = ( + f"Specified zarr chunks encoding['chunks']={enc_chunks_tuple!r} for " + f"variable named {name!r} would overlap multiple dask chunks {var_chunks!r}. " + f"Writing this array in parallel with dask could lead to corrupted data." + ) + if safe_chunks: + raise ValueError( + base_error + + " Consider either rechunking using `chunk()`, deleting " + "or modifying `encoding['chunks']`, or specify `safe_chunks=False`." + ) + return enc_chunks_tuple + + raise AssertionError("We should never get here. Function logic must be wrong.") + + +def _get_zarr_dims_and_attrs(zarr_obj, dimension_key, try_nczarr): + # Zarr arrays do not have dimensions. To get around this problem, we add + # an attribute that specifies the dimension. We have to hide this attribute + # when we send the attributes to the user. + # zarr_obj can be either a zarr group or zarr array + try: + # Xarray-Zarr + dimensions = zarr_obj.attrs[dimension_key] + except KeyError as e: + if not try_nczarr: + raise KeyError( + f"Zarr object is missing the attribute `{dimension_key}`, which is " + "required for xarray to determine variable dimensions." + ) from e + + # NCZarr defines dimensions through metadata in .zarray + zarray_path = os.path.join(zarr_obj.path, ".zarray") + zarray = json.loads(zarr_obj.store[zarray_path]) + try: + # NCZarr uses Fully Qualified Names + dimensions = [ + os.path.basename(dim) for dim in zarray["_NCZARR_ARRAY"]["dimrefs"] + ] + except KeyError as e: + raise KeyError( + f"Zarr object is missing the attribute `{dimension_key}` and the NCZarr metadata, " + "which are required for xarray to determine variable dimensions." + ) from e + + nc_attrs = [attr for attr in zarr_obj.attrs if attr.lower().startswith("_nc")] + attributes = HiddenKeyDict(zarr_obj.attrs, [dimension_key] + nc_attrs) + return dimensions, attributes + + +def extract_zarr_variable_encoding( + variable, raise_on_invalid=False, name=None, safe_chunks=True +): + """ + Extract zarr encoding dictionary from xarray Variable + + Parameters + ---------- + variable : Variable + raise_on_invalid : bool, optional + + Returns + ------- + encoding : dict + Zarr encoding for `variable` + """ + encoding = variable.encoding.copy() + + safe_to_drop = {"source", "original_shape"} + valid_encodings = { + "chunks", + "compressor", + "filters", + "cache_metadata", + "write_empty_chunks", + } + + for k in safe_to_drop: + if k in encoding: + del encoding[k] + + if raise_on_invalid: + invalid = [k for k in encoding if k not in valid_encodings] + if invalid: + raise ValueError( + f"unexpected encoding parameters for zarr backend: {invalid!r}" + ) + else: + for k in list(encoding): + if k not in valid_encodings: + del encoding[k] + + chunks = _determine_zarr_chunks( + encoding.get("chunks"), variable.chunks, variable.ndim, name, safe_chunks + ) + encoding["chunks"] = chunks + return encoding + + +# Function below is copied from conventions.encode_cf_variable. +# The only change is to raise an error for object dtypes. +def encode_zarr_variable(var, needs_copy=True, name=None): + """ + Converts an Variable into an Variable which follows some + of the CF conventions: + + - Nans are masked using _FillValue (or the deprecated missing_value) + - Rescaling via: scale_factor and add_offset + - datetimes are converted to the CF 'units since time' format + - dtype encodings are enforced. + + Parameters + ---------- + var : Variable + A variable holding un-encoded data. + + Returns + ------- + out : Variable + A variable which has been encoded as described above. + """ + + var = conventions.encode_cf_variable(var, name=name) + + # zarr allows unicode, but not variable-length strings, so it's both + # simpler and more compact to always encode as UTF-8 explicitly. + # TODO: allow toggling this explicitly via dtype in encoding. + coder = coding.strings.EncodedStringCoder(allows_unicode=True) + var = coder.encode(var, name=name) + var = coding.strings.ensure_fixed_length_bytes(var) + + return var + + +def _validate_datatypes_for_zarr_append(vname, existing_var, new_var): + """If variable exists in the store, confirm dtype of the data to append is compatible with + existing dtype. + """ + if ( + np.issubdtype(new_var.dtype, np.number) + or np.issubdtype(new_var.dtype, np.datetime64) + or np.issubdtype(new_var.dtype, np.bool_) + or new_var.dtype == object + ): + # We can skip dtype equality checks under two conditions: (1) if the var to append is + # new to the dataset, because in this case there is no existing var to compare it to; + # or (2) if var to append's dtype is known to be easy-to-append, because in this case + # we can be confident appending won't cause problems. Examples of dtypes which are not + # easy-to-append include length-specified strings of type `|S*` or `
{escape(repr(self))}
" + return formatting_html.array_repr(self) + + def __format__(self: Any, format_spec: str = "") -> str: + if format_spec != "": + if self.shape == (): + # Scalar values might be ok use format_spec with instead of repr: + return self.data.__format__(format_spec) + else: + # TODO: If it's an array the formatting.array_repr(self) should + # take format_spec as an input. If we'd only use self.data we + # lose all the information about coords for example which is + # important information: + raise NotImplementedError( + "Using format_spec is only supported" + f" when shape is (). Got shape = {self.shape}." + ) + else: + return self.__repr__() + + def _iter(self: Any) -> Iterator[Any]: + for n in range(len(self)): + yield self[n] + + def __iter__(self: Any) -> Iterator[Any]: + if self.ndim == 0: + raise TypeError("iteration over a 0-d array") + return self._iter() + + @overload + def get_axis_num(self, dim: Iterable[Hashable]) -> tuple[int, ...]: ... + + @overload + def get_axis_num(self, dim: Hashable) -> int: ... + + def get_axis_num(self, dim: Hashable | Iterable[Hashable]) -> int | tuple[int, ...]: + """Return axis number(s) corresponding to dimension(s) in this array. + + Parameters + ---------- + dim : str or iterable of str + Dimension name(s) for which to lookup axes. + + Returns + ------- + int or tuple of int + Axis number or numbers corresponding to the given dimensions. + """ + if not isinstance(dim, str) and isinstance(dim, Iterable): + return tuple(self._get_axis_num(d) for d in dim) + else: + return self._get_axis_num(dim) + + def _get_axis_num(self: Any, dim: Hashable) -> int: + _raise_if_any_duplicate_dimensions(self.dims) + try: + return self.dims.index(dim) + except ValueError: + raise ValueError(f"{dim!r} not found in array dimensions {self.dims!r}") + + @property + def sizes(self: Any) -> Mapping[Hashable, int]: + """Ordered mapping from dimension names to lengths. + + Immutable. + + See Also + -------- + Dataset.sizes + """ + return Frozen(dict(zip(self.dims, self.shape))) + + +class AttrAccessMixin: + """Mixin class that allows getting keys with attribute access""" + + __slots__ = () + + def __init_subclass__(cls, **kwargs): + """Verify that all subclasses explicitly define ``__slots__``. If they don't, + raise error in the core xarray module and a FutureWarning in third-party + extensions. + """ + if not hasattr(object.__new__(cls), "__dict__"): + pass + elif cls.__module__.startswith("xarray."): + raise AttributeError(f"{cls.__name__} must explicitly define __slots__") + else: + cls.__setattr__ = cls._setattr_dict + warnings.warn( + f"xarray subclass {cls.__name__} should explicitly define __slots__", + FutureWarning, + stacklevel=2, + ) + super().__init_subclass__(**kwargs) + + @property + def _attr_sources(self) -> Iterable[Mapping[Hashable, Any]]: + """Places to look-up items for attribute-style access""" + yield from () + + @property + def _item_sources(self) -> Iterable[Mapping[Hashable, Any]]: + """Places to look-up items for key-autocompletion""" + yield from () + + def __getattr__(self, name: str) -> Any: + if name not in {"__dict__", "__setstate__"}: + # this avoids an infinite loop when pickle looks for the + # __setstate__ attribute before the xarray object is initialized + for source in self._attr_sources: + with suppress(KeyError): + return source[name] + raise AttributeError( + f"{type(self).__name__!r} object has no attribute {name!r}" + ) + + # This complicated two-method design boosts overall performance of simple operations + # - particularly DataArray methods that perform a _to_temp_dataset() round-trip - by + # a whopping 8% compared to a single method that checks hasattr(self, "__dict__") at + # runtime before every single assignment. All of this is just temporary until the + # FutureWarning can be changed into a hard crash. + def _setattr_dict(self, name: str, value: Any) -> None: + """Deprecated third party subclass (see ``__init_subclass__`` above)""" + object.__setattr__(self, name, value) + if name in self.__dict__: + # Custom, non-slotted attr, or improperly assigned variable? + warnings.warn( + f"Setting attribute {name!r} on a {type(self).__name__!r} object. Explicitly define __slots__ " + "to suppress this warning for legitimate custom attributes and " + "raise an error when attempting variables assignments.", + FutureWarning, + stacklevel=2, + ) + + def __setattr__(self, name: str, value: Any) -> None: + """Objects with ``__slots__`` raise AttributeError if you try setting an + undeclared attribute. This is desirable, but the error message could use some + improvement. + """ + try: + object.__setattr__(self, name, value) + except AttributeError as e: + # Don't accidentally shadow custom AttributeErrors, e.g. + # DataArray.dims.setter + if str(e) != f"{type(self).__name__!r} object has no attribute {name!r}": + raise + raise AttributeError( + f"cannot set attribute {name!r} on a {type(self).__name__!r} object. Use __setitem__ style" + "assignment (e.g., `ds['name'] = ...`) instead of assigning variables." + ) from e + + def __dir__(self) -> list[str]: + """Provide method name lookup and completion. Only provide 'public' + methods. + """ + extra_attrs = { + item + for source in self._attr_sources + for item in source + if isinstance(item, str) + } + return sorted(set(dir(type(self))) | extra_attrs) + + def _ipython_key_completions_(self) -> list[str]: + """Provide method for the key-autocompletions in IPython. + See http://ipython.readthedocs.io/en/stable/config/integrating.html#tab-completion + For the details. + """ + items = { + item + for source in self._item_sources + for item in source + if isinstance(item, str) + } + return list(items) + + +class TreeAttrAccessMixin(AttrAccessMixin): + """Mixin class that allows getting keys with attribute access""" + + # TODO: Ensure ipython tab completion can include both child datatrees and + # variables from Dataset objects on relevant nodes. + + __slots__ = () + + def __init_subclass__(cls, **kwargs): + """This method overrides the check from ``AttrAccessMixin`` that ensures + ``__dict__`` is absent in a class, with ``__slots__`` used instead. + ``DataTree`` has some dynamically defined attributes in addition to those + defined in ``__slots__``. (GH9068) + """ + if not hasattr(object.__new__(cls), "__dict__"): + pass + + +def get_squeeze_dims( + xarray_obj, + dim: Hashable | Iterable[Hashable] | None = None, + axis: int | Iterable[int] | None = None, +) -> list[Hashable]: + """Get a list of dimensions to squeeze out.""" + if dim is not None and axis is not None: + raise ValueError("cannot use both parameters `axis` and `dim`") + if dim is None and axis is None: + return [d for d, s in xarray_obj.sizes.items() if s == 1] + + if isinstance(dim, Iterable) and not isinstance(dim, str): + dim = list(dim) + elif dim is not None: + dim = [dim] + else: + assert axis is not None + if isinstance(axis, int): + axis = [axis] + axis = list(axis) + if any(not isinstance(a, int) for a in axis): + raise TypeError("parameter `axis` must be int or iterable of int.") + alldims = list(xarray_obj.sizes.keys()) + dim = [alldims[a] for a in axis] + + if any(xarray_obj.sizes[k] > 1 for k in dim): + raise ValueError( + "cannot select a dimension to squeeze out " + "which has length greater than one" + ) + return dim + + +class DataWithCoords(AttrAccessMixin): + """Shared base class for Dataset and DataArray.""" + + _close: Callable[[], None] | None + _indexes: dict[Hashable, Index] + + __slots__ = ("_close",) + + def squeeze( + self, + dim: Hashable | Iterable[Hashable] | None = None, + drop: bool = False, + axis: int | Iterable[int] | None = None, + ) -> Self: + """Return a new object with squeezed data. + + Parameters + ---------- + dim : None or Hashable or iterable of Hashable, optional + Selects a subset of the length one dimensions. If a dimension is + selected with length greater than one, an error is raised. If + None, all length one dimensions are squeezed. + drop : bool, default: False + If ``drop=True``, drop squeezed coordinates instead of making them + scalar. + axis : None or int or iterable of int, optional + Like dim, but positional. + + Returns + ------- + squeezed : same type as caller + This object, but with with all or a subset of the dimensions of + length 1 removed. + + See Also + -------- + numpy.squeeze + """ + dims = get_squeeze_dims(self, dim, axis) + return self.isel(drop=drop, **{d: 0 for d in dims}) + + def clip( + self, + min: ScalarOrArray | None = None, + max: ScalarOrArray | None = None, + *, + keep_attrs: bool | None = None, + ) -> Self: + """ + Return an array whose values are limited to ``[min, max]``. + At least one of max or min must be given. + + Parameters + ---------- + min : None or Hashable, optional + Minimum value. If None, no lower clipping is performed. + max : None or Hashable, optional + Maximum value. If None, no upper clipping is performed. + keep_attrs : bool or None, optional + If True, the attributes (`attrs`) will be copied from + the original object to the new one. If False, the new + object will be returned without attributes. + + Returns + ------- + clipped : same type as caller + This object, but with with values < min are replaced with min, + and those > max with max. + + See Also + -------- + numpy.clip : equivalent function + """ + from xarray.core.computation import apply_ufunc + + if keep_attrs is None: + # When this was a unary func, the default was True, so retaining the + # default. + keep_attrs = _get_keep_attrs(default=True) + + return apply_ufunc( + np.clip, self, min, max, keep_attrs=keep_attrs, dask="allowed" + ) + + def get_index(self, key: Hashable) -> pd.Index: + """Get an index for a dimension, with fall-back to a default RangeIndex""" + if key not in self.dims: + raise KeyError(key) + + try: + return self._indexes[key].to_pandas_index() + except KeyError: + return pd.Index(range(self.sizes[key]), name=key) + + def _calc_assign_results( + self: C, kwargs: Mapping[Any, T | Callable[[C], T]] + ) -> dict[Hashable, T]: + return {k: v(self) if callable(v) else v for k, v in kwargs.items()} + + def assign_coords( + self, + coords: Mapping | None = None, + **coords_kwargs: Any, + ) -> Self: + """Assign new coordinates to this object. + + Returns a new object with all the original data in addition to the new + coordinates. + + Parameters + ---------- + coords : mapping of dim to coord, optional + A mapping whose keys are the names of the coordinates and values are the + coordinates to assign. The mapping will generally be a dict or + :class:`Coordinates`. + + * If a value is a standard data value — for example, a ``DataArray``, + scalar, or array — the data is simply assigned as a coordinate. + + * If a value is callable, it is called with this object as the only + parameter, and the return value is used as new coordinate variables. + + * A coordinate can also be defined and attached to an existing dimension + using a tuple with the first element the dimension name and the second + element the values for this new coordinate. + + **coords_kwargs : optional + The keyword arguments form of ``coords``. + One of ``coords`` or ``coords_kwargs`` must be provided. + + Returns + ------- + assigned : same type as caller + A new object with the new coordinates in addition to the existing + data. + + Examples + -------- + Convert `DataArray` longitude coordinates from 0-359 to -180-179: + + >>> da = xr.DataArray( + ... np.random.rand(4), + ... coords=[np.array([358, 359, 0, 1])], + ... dims="lon", + ... ) + >>> da + Size: 32B + array([0.5488135 , 0.71518937, 0.60276338, 0.54488318]) + Coordinates: + * lon (lon) int64 32B 358 359 0 1 + >>> da.assign_coords(lon=(((da.lon + 180) % 360) - 180)) + Size: 32B + array([0.5488135 , 0.71518937, 0.60276338, 0.54488318]) + Coordinates: + * lon (lon) int64 32B -2 -1 0 1 + + The function also accepts dictionary arguments: + + >>> da.assign_coords({"lon": (((da.lon + 180) % 360) - 180)}) + Size: 32B + array([0.5488135 , 0.71518937, 0.60276338, 0.54488318]) + Coordinates: + * lon (lon) int64 32B -2 -1 0 1 + + New coordinate can also be attached to an existing dimension: + + >>> lon_2 = np.array([300, 289, 0, 1]) + >>> da.assign_coords(lon_2=("lon", lon_2)) + Size: 32B + array([0.5488135 , 0.71518937, 0.60276338, 0.54488318]) + Coordinates: + * lon (lon) int64 32B 358 359 0 1 + lon_2 (lon) int64 32B 300 289 0 1 + + Note that the same result can also be obtained with a dict e.g. + + >>> _ = da.assign_coords({"lon_2": ("lon", lon_2)}) + + Note the same method applies to `Dataset` objects. + + Convert `Dataset` longitude coordinates from 0-359 to -180-179: + + >>> temperature = np.linspace(20, 32, num=16).reshape(2, 2, 4) + >>> precipitation = 2 * np.identity(4).reshape(2, 2, 4) + >>> ds = xr.Dataset( + ... data_vars=dict( + ... temperature=(["x", "y", "time"], temperature), + ... precipitation=(["x", "y", "time"], precipitation), + ... ), + ... coords=dict( + ... lon=(["x", "y"], [[260.17, 260.68], [260.21, 260.77]]), + ... lat=(["x", "y"], [[42.25, 42.21], [42.63, 42.59]]), + ... time=pd.date_range("2014-09-06", periods=4), + ... reference_time=pd.Timestamp("2014-09-05"), + ... ), + ... attrs=dict(description="Weather-related data"), + ... ) + >>> ds + Size: 360B + Dimensions: (x: 2, y: 2, time: 4) + Coordinates: + lon (x, y) float64 32B 260.2 260.7 260.2 260.8 + lat (x, y) float64 32B 42.25 42.21 42.63 42.59 + * time (time) datetime64[ns] 32B 2014-09-06 ... 2014-09-09 + reference_time datetime64[ns] 8B 2014-09-05 + Dimensions without coordinates: x, y + Data variables: + temperature (x, y, time) float64 128B 20.0 20.8 21.6 ... 30.4 31.2 32.0 + precipitation (x, y, time) float64 128B 2.0 0.0 0.0 0.0 ... 0.0 0.0 2.0 + Attributes: + description: Weather-related data + >>> ds.assign_coords(lon=(((ds.lon + 180) % 360) - 180)) + Size: 360B + Dimensions: (x: 2, y: 2, time: 4) + Coordinates: + lon (x, y) float64 32B -99.83 -99.32 -99.79 -99.23 + lat (x, y) float64 32B 42.25 42.21 42.63 42.59 + * time (time) datetime64[ns] 32B 2014-09-06 ... 2014-09-09 + reference_time datetime64[ns] 8B 2014-09-05 + Dimensions without coordinates: x, y + Data variables: + temperature (x, y, time) float64 128B 20.0 20.8 21.6 ... 30.4 31.2 32.0 + precipitation (x, y, time) float64 128B 2.0 0.0 0.0 0.0 ... 0.0 0.0 2.0 + Attributes: + description: Weather-related data + + See Also + -------- + Dataset.assign + Dataset.swap_dims + Dataset.set_coords + """ + from xarray.core.coordinates import Coordinates + + coords_combined = either_dict_or_kwargs(coords, coords_kwargs, "assign_coords") + data = self.copy(deep=False) + + results: Coordinates | dict[Hashable, Any] + if isinstance(coords, Coordinates): + results = coords + else: + results = self._calc_assign_results(coords_combined) + + data.coords.update(results) + return data + + def assign_attrs(self, *args: Any, **kwargs: Any) -> Self: + """Assign new attrs to this object. + + Returns a new object equivalent to ``self.attrs.update(*args, **kwargs)``. + + Parameters + ---------- + *args + positional arguments passed into ``attrs.update``. + **kwargs + keyword arguments passed into ``attrs.update``. + + Examples + -------- + >>> dataset = xr.Dataset({"temperature": [25, 30, 27]}) + >>> dataset + Size: 24B + Dimensions: (temperature: 3) + Coordinates: + * temperature (temperature) int64 24B 25 30 27 + Data variables: + *empty* + + >>> new_dataset = dataset.assign_attrs( + ... units="Celsius", description="Temperature data" + ... ) + >>> new_dataset + Size: 24B + Dimensions: (temperature: 3) + Coordinates: + * temperature (temperature) int64 24B 25 30 27 + Data variables: + *empty* + Attributes: + units: Celsius + description: Temperature data + + # Attributes of the new dataset + + >>> new_dataset.attrs + {'units': 'Celsius', 'description': 'Temperature data'} + + Returns + ------- + assigned : same type as caller + A new object with the new attrs in addition to the existing data. + + See Also + -------- + Dataset.assign + """ + out = self.copy(deep=False) + out.attrs.update(*args, **kwargs) + return out + + def pipe( + self, + func: Callable[..., T] | tuple[Callable[..., T], str], + *args: Any, + **kwargs: Any, + ) -> T: + """ + Apply ``func(self, *args, **kwargs)`` + + This method replicates the pandas method of the same name. + + Parameters + ---------- + func : callable + function to apply to this xarray object (Dataset/DataArray). + ``args``, and ``kwargs`` are passed into ``func``. + Alternatively a ``(callable, data_keyword)`` tuple where + ``data_keyword`` is a string indicating the keyword of + ``callable`` that expects the xarray object. + *args + positional arguments passed into ``func``. + **kwargs + a dictionary of keyword arguments passed into ``func``. + + Returns + ------- + object : Any + the return type of ``func``. + + Notes + ----- + Use ``.pipe`` when chaining together functions that expect + xarray or pandas objects, e.g., instead of writing + + .. code:: python + + f(g(h(ds), arg1=a), arg2=b, arg3=c) + + You can write + + .. code:: python + + (ds.pipe(h).pipe(g, arg1=a).pipe(f, arg2=b, arg3=c)) + + If you have a function that takes the data as (say) the second + argument, pass a tuple indicating which keyword expects the + data. For example, suppose ``f`` takes its data as ``arg2``: + + .. code:: python + + (ds.pipe(h).pipe(g, arg1=a).pipe((f, "arg2"), arg1=a, arg3=c)) + + Examples + -------- + >>> x = xr.Dataset( + ... { + ... "temperature_c": ( + ... ("lat", "lon"), + ... 20 * np.random.rand(4).reshape(2, 2), + ... ), + ... "precipitation": (("lat", "lon"), np.random.rand(4).reshape(2, 2)), + ... }, + ... coords={"lat": [10, 20], "lon": [150, 160]}, + ... ) + >>> x + Size: 96B + Dimensions: (lat: 2, lon: 2) + Coordinates: + * lat (lat) int64 16B 10 20 + * lon (lon) int64 16B 150 160 + Data variables: + temperature_c (lat, lon) float64 32B 10.98 14.3 12.06 10.9 + precipitation (lat, lon) float64 32B 0.4237 0.6459 0.4376 0.8918 + + >>> def adder(data, arg): + ... return data + arg + ... + >>> def div(data, arg): + ... return data / arg + ... + >>> def sub_mult(data, sub_arg, mult_arg): + ... return (data * mult_arg) - sub_arg + ... + >>> x.pipe(adder, 2) + Size: 96B + Dimensions: (lat: 2, lon: 2) + Coordinates: + * lat (lat) int64 16B 10 20 + * lon (lon) int64 16B 150 160 + Data variables: + temperature_c (lat, lon) float64 32B 12.98 16.3 14.06 12.9 + precipitation (lat, lon) float64 32B 2.424 2.646 2.438 2.892 + + >>> x.pipe(adder, arg=2) + Size: 96B + Dimensions: (lat: 2, lon: 2) + Coordinates: + * lat (lat) int64 16B 10 20 + * lon (lon) int64 16B 150 160 + Data variables: + temperature_c (lat, lon) float64 32B 12.98 16.3 14.06 12.9 + precipitation (lat, lon) float64 32B 2.424 2.646 2.438 2.892 + + >>> ( + ... x.pipe(adder, arg=2) + ... .pipe(div, arg=2) + ... .pipe(sub_mult, sub_arg=2, mult_arg=2) + ... ) + Size: 96B + Dimensions: (lat: 2, lon: 2) + Coordinates: + * lat (lat) int64 16B 10 20 + * lon (lon) int64 16B 150 160 + Data variables: + temperature_c (lat, lon) float64 32B 10.98 14.3 12.06 10.9 + precipitation (lat, lon) float64 32B 0.4237 0.6459 0.4376 0.8918 + + See Also + -------- + pandas.DataFrame.pipe + """ + if isinstance(func, tuple): + func, target = func + if target in kwargs: + raise ValueError( + f"{target} is both the pipe target and a keyword argument" + ) + kwargs[target] = self + return func(*args, **kwargs) + else: + return func(self, *args, **kwargs) + + def rolling_exp( + self: T_DataWithCoords, + window: Mapping[Any, int] | None = None, + window_type: str = "span", + **window_kwargs, + ) -> RollingExp[T_DataWithCoords]: + """ + Exponentially-weighted moving window. + Similar to EWM in pandas + + Requires the optional Numbagg dependency. + + Parameters + ---------- + window : mapping of hashable to int, optional + A mapping from the name of the dimension to create the rolling + exponential window along (e.g. `time`) to the size of the moving window. + window_type : {"span", "com", "halflife", "alpha"}, default: "span" + The format of the previously supplied window. Each is a simple + numerical transformation of the others. Described in detail: + https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.ewm.html + **window_kwargs : optional + The keyword arguments form of ``window``. + One of window or window_kwargs must be provided. + + See Also + -------- + core.rolling_exp.RollingExp + """ + from xarray.core import rolling_exp + + if "keep_attrs" in window_kwargs: + warnings.warn( + "Passing ``keep_attrs`` to ``rolling_exp`` has no effect. Pass" + " ``keep_attrs`` directly to the applied function, e.g." + " ``rolling_exp(...).mean(keep_attrs=False)``." + ) + + window = either_dict_or_kwargs(window, window_kwargs, "rolling_exp") + + return rolling_exp.RollingExp(self, window, window_type) + + def _resample( + self, + resample_cls: type[T_Resample], + indexer: Mapping[Any, str] | None, + skipna: bool | None, + closed: SideOptions | None, + label: SideOptions | None, + base: int | None, + offset: pd.Timedelta | datetime.timedelta | str | None, + origin: str | DatetimeLike, + loffset: datetime.timedelta | str | None, + restore_coord_dims: bool | None, + **indexer_kwargs: str, + ) -> T_Resample: + """Returns a Resample object for performing resampling operations. + + Handles both downsampling and upsampling. The resampled + dimension must be a datetime-like coordinate. If any intervals + contain no values from the original object, they will be given + the value ``NaN``. + + Parameters + ---------- + indexer : {dim: freq}, optional + Mapping from the dimension name to resample frequency [1]_. The + dimension must be datetime-like. + skipna : bool, optional + Whether to skip missing values when aggregating in downsampling. + closed : {"left", "right"}, optional + Side of each interval to treat as closed. + label : {"left", "right"}, optional + Side of each interval to use for labeling. + base : int, optional + For frequencies that evenly subdivide 1 day, the "origin" of the + aggregated intervals. For example, for "24H" frequency, base could + range from 0 through 23. + + .. deprecated:: 2023.03.0 + Following pandas, the ``base`` parameter is deprecated in favor + of the ``origin`` and ``offset`` parameters, and will be removed + in a future version of xarray. + + origin : {'epoch', 'start', 'start_day', 'end', 'end_day'}, pd.Timestamp, datetime.datetime, np.datetime64, or cftime.datetime, default 'start_day' + The datetime on which to adjust the grouping. The timezone of origin + must match the timezone of the index. + + If a datetime is not used, these values are also supported: + - 'epoch': `origin` is 1970-01-01 + - 'start': `origin` is the first value of the timeseries + - 'start_day': `origin` is the first day at midnight of the timeseries + - 'end': `origin` is the last value of the timeseries + - 'end_day': `origin` is the ceiling midnight of the last day + offset : pd.Timedelta, datetime.timedelta, or str, default is None + An offset timedelta added to the origin. + loffset : timedelta or str, optional + Offset used to adjust the resampled time labels. Some pandas date + offset strings are supported. + + .. deprecated:: 2023.03.0 + Following pandas, the ``loffset`` parameter is deprecated in favor + of using time offset arithmetic, and will be removed in a future + version of xarray. + + restore_coord_dims : bool, optional + If True, also restore the dimension order of multi-dimensional + coordinates. + **indexer_kwargs : {dim: freq} + The keyword arguments form of ``indexer``. + One of indexer or indexer_kwargs must be provided. + + Returns + ------- + resampled : same type as caller + This object resampled. + + Examples + -------- + Downsample monthly time-series data to seasonal data: + + >>> da = xr.DataArray( + ... np.linspace(0, 11, num=12), + ... coords=[ + ... pd.date_range( + ... "1999-12-15", + ... periods=12, + ... freq=pd.DateOffset(months=1), + ... ) + ... ], + ... dims="time", + ... ) + >>> da + Size: 96B + array([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.]) + Coordinates: + * time (time) datetime64[ns] 96B 1999-12-15 2000-01-15 ... 2000-11-15 + >>> da.resample(time="QS-DEC").mean() + Size: 32B + array([ 1., 4., 7., 10.]) + Coordinates: + * time (time) datetime64[ns] 32B 1999-12-01 2000-03-01 ... 2000-09-01 + + Upsample monthly time-series data to daily data: + + >>> da.resample(time="1D").interpolate("linear") # +doctest: ELLIPSIS + Size: 3kB + array([ 0. , 0.03225806, 0.06451613, 0.09677419, 0.12903226, + 0.16129032, 0.19354839, 0.22580645, 0.25806452, 0.29032258, + 0.32258065, 0.35483871, 0.38709677, 0.41935484, 0.4516129 , + 0.48387097, 0.51612903, 0.5483871 , 0.58064516, 0.61290323, + 0.64516129, 0.67741935, 0.70967742, 0.74193548, 0.77419355, + 0.80645161, 0.83870968, 0.87096774, 0.90322581, 0.93548387, + 0.96774194, 1. , 1.03225806, 1.06451613, 1.09677419, + 1.12903226, 1.16129032, 1.19354839, 1.22580645, 1.25806452, + 1.29032258, 1.32258065, 1.35483871, 1.38709677, 1.41935484, + 1.4516129 , 1.48387097, 1.51612903, 1.5483871 , 1.58064516, + 1.61290323, 1.64516129, 1.67741935, 1.70967742, 1.74193548, + 1.77419355, 1.80645161, 1.83870968, 1.87096774, 1.90322581, + 1.93548387, 1.96774194, 2. , 2.03448276, 2.06896552, + 2.10344828, 2.13793103, 2.17241379, 2.20689655, 2.24137931, + 2.27586207, 2.31034483, 2.34482759, 2.37931034, 2.4137931 , + 2.44827586, 2.48275862, 2.51724138, 2.55172414, 2.5862069 , + 2.62068966, 2.65517241, 2.68965517, 2.72413793, 2.75862069, + 2.79310345, 2.82758621, 2.86206897, 2.89655172, 2.93103448, + 2.96551724, 3. , 3.03225806, 3.06451613, 3.09677419, + 3.12903226, 3.16129032, 3.19354839, 3.22580645, 3.25806452, + ... + 7.87096774, 7.90322581, 7.93548387, 7.96774194, 8. , + 8.03225806, 8.06451613, 8.09677419, 8.12903226, 8.16129032, + 8.19354839, 8.22580645, 8.25806452, 8.29032258, 8.32258065, + 8.35483871, 8.38709677, 8.41935484, 8.4516129 , 8.48387097, + 8.51612903, 8.5483871 , 8.58064516, 8.61290323, 8.64516129, + 8.67741935, 8.70967742, 8.74193548, 8.77419355, 8.80645161, + 8.83870968, 8.87096774, 8.90322581, 8.93548387, 8.96774194, + 9. , 9.03333333, 9.06666667, 9.1 , 9.13333333, + 9.16666667, 9.2 , 9.23333333, 9.26666667, 9.3 , + 9.33333333, 9.36666667, 9.4 , 9.43333333, 9.46666667, + 9.5 , 9.53333333, 9.56666667, 9.6 , 9.63333333, + 9.66666667, 9.7 , 9.73333333, 9.76666667, 9.8 , + 9.83333333, 9.86666667, 9.9 , 9.93333333, 9.96666667, + 10. , 10.03225806, 10.06451613, 10.09677419, 10.12903226, + 10.16129032, 10.19354839, 10.22580645, 10.25806452, 10.29032258, + 10.32258065, 10.35483871, 10.38709677, 10.41935484, 10.4516129 , + 10.48387097, 10.51612903, 10.5483871 , 10.58064516, 10.61290323, + 10.64516129, 10.67741935, 10.70967742, 10.74193548, 10.77419355, + 10.80645161, 10.83870968, 10.87096774, 10.90322581, 10.93548387, + 10.96774194, 11. ]) + Coordinates: + * time (time) datetime64[ns] 3kB 1999-12-15 1999-12-16 ... 2000-11-15 + + Limit scope of upsampling method + + >>> da.resample(time="1D").nearest(tolerance="1D") + Size: 3kB + array([ 0., 0., nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, 1., 1., 1., nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, 2., 2., 2., nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, 3., + 3., 3., nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, 4., 4., 4., nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, 5., 5., 5., nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + 6., 6., 6., nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, 7., 7., 7., nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, 8., 8., 8., nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, 9., 9., 9., nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, 10., 10., 10., nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, 11., 11.]) + Coordinates: + * time (time) datetime64[ns] 3kB 1999-12-15 1999-12-16 ... 2000-11-15 + + See Also + -------- + pandas.Series.resample + pandas.DataFrame.resample + + References + ---------- + .. [1] https://pandas.pydata.org/docs/user_guide/timeseries.html#dateoffset-objects + """ + # TODO support non-string indexer after removing the old API. + + from xarray.core.dataarray import DataArray + from xarray.core.groupby import ResolvedGrouper, TimeResampler + from xarray.core.resample import RESAMPLE_DIM + + # note: the second argument (now 'skipna') use to be 'dim' + if ( + (skipna is not None and not isinstance(skipna, bool)) + or ("how" in indexer_kwargs and "how" not in self.dims) + or ("dim" in indexer_kwargs and "dim" not in self.dims) + ): + raise TypeError( + "resample() no longer supports the `how` or " + "`dim` arguments. Instead call methods on resample " + "objects, e.g., data.resample(time='1D').mean()" + ) + + indexer = either_dict_or_kwargs(indexer, indexer_kwargs, "resample") + if len(indexer) != 1: + raise ValueError("Resampling only supported along single dimensions.") + dim, freq = next(iter(indexer.items())) + + dim_name: Hashable = dim + dim_coord = self[dim] + + group = DataArray( + dim_coord, + coords=dim_coord.coords, + dims=dim_coord.dims, + name=RESAMPLE_DIM, + ) + + grouper = TimeResampler( + freq=freq, + closed=closed, + label=label, + origin=origin, + offset=offset, + loffset=loffset, + base=base, + ) + + rgrouper = ResolvedGrouper(grouper, group, self) + + return resample_cls( + self, + (rgrouper,), + dim=dim_name, + resample_dim=RESAMPLE_DIM, + restore_coord_dims=restore_coord_dims, + ) + + def where(self, cond: Any, other: Any = dtypes.NA, drop: bool = False) -> Self: + """Filter elements from this object according to a condition. + + Returns elements from 'DataArray', where 'cond' is True, + otherwise fill in 'other'. + + This operation follows the normal broadcasting and alignment rules that + xarray uses for binary arithmetic. + + Parameters + ---------- + cond : DataArray, Dataset, or callable + Locations at which to preserve this object's values. dtype must be `bool`. + If a callable, the callable is passed this object, and the result is used as + the value for cond. + other : scalar, DataArray, Dataset, or callable, optional + Value to use for locations in this object where ``cond`` is False. + By default, these locations are filled with NA. If a callable, it must + expect this object as its only parameter. + drop : bool, default: False + If True, coordinate labels that only correspond to False values of + the condition are dropped from the result. + + Returns + ------- + DataArray or Dataset + Same xarray type as caller, with dtype float64. + + Examples + -------- + >>> a = xr.DataArray(np.arange(25).reshape(5, 5), dims=("x", "y")) + >>> a + Size: 200B + array([[ 0, 1, 2, 3, 4], + [ 5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19], + [20, 21, 22, 23, 24]]) + Dimensions without coordinates: x, y + + >>> a.where(a.x + a.y < 4) + Size: 200B + array([[ 0., 1., 2., 3., nan], + [ 5., 6., 7., nan, nan], + [10., 11., nan, nan, nan], + [15., nan, nan, nan, nan], + [nan, nan, nan, nan, nan]]) + Dimensions without coordinates: x, y + + >>> a.where(a.x + a.y < 5, -1) + Size: 200B + array([[ 0, 1, 2, 3, 4], + [ 5, 6, 7, 8, -1], + [10, 11, 12, -1, -1], + [15, 16, -1, -1, -1], + [20, -1, -1, -1, -1]]) + Dimensions without coordinates: x, y + + >>> a.where(a.x + a.y < 4, drop=True) + Size: 128B + array([[ 0., 1., 2., 3.], + [ 5., 6., 7., nan], + [10., 11., nan, nan], + [15., nan, nan, nan]]) + Dimensions without coordinates: x, y + + >>> a.where(lambda x: x.x + x.y < 4, lambda x: -x) + Size: 200B + array([[ 0, 1, 2, 3, -4], + [ 5, 6, 7, -8, -9], + [ 10, 11, -12, -13, -14], + [ 15, -16, -17, -18, -19], + [-20, -21, -22, -23, -24]]) + Dimensions without coordinates: x, y + + >>> a.where(a.x + a.y < 4, drop=True) + Size: 128B + array([[ 0., 1., 2., 3.], + [ 5., 6., 7., nan], + [10., 11., nan, nan], + [15., nan, nan, nan]]) + Dimensions without coordinates: x, y + + See Also + -------- + numpy.where : corresponding numpy function + where : equivalent function + """ + from xarray.core.alignment import align + from xarray.core.dataarray import DataArray + from xarray.core.dataset import Dataset + + if callable(cond): + cond = cond(self) + if callable(other): + other = other(self) + + if drop: + if not isinstance(cond, (Dataset, DataArray)): + raise TypeError( + f"cond argument is {cond!r} but must be a {Dataset!r} or {DataArray!r} (or a callable than returns one)." + ) + + self, cond = align(self, cond) + + def _dataarray_indexer(dim: Hashable) -> DataArray: + return cond.any(dim=(d for d in cond.dims if d != dim)) + + def _dataset_indexer(dim: Hashable) -> DataArray: + cond_wdim = cond.drop_vars( + var for var in cond if dim not in cond[var].dims + ) + keepany = cond_wdim.any(dim=(d for d in cond.dims if d != dim)) + return keepany.to_dataarray().any("variable") + + _get_indexer = ( + _dataarray_indexer if isinstance(cond, DataArray) else _dataset_indexer + ) + + indexers = {} + for dim in cond.sizes.keys(): + indexers[dim] = _get_indexer(dim) + + self = self.isel(**indexers) + cond = cond.isel(**indexers) + + return ops.where_method(self, cond, other) + + def set_close(self, close: Callable[[], None] | None) -> None: + """Register the function that releases any resources linked to this object. + + This method controls how xarray cleans up resources associated + with this object when the ``.close()`` method is called. It is mostly + intended for backend developers and it is rarely needed by regular + end-users. + + Parameters + ---------- + close : callable + The function that when called like ``close()`` releases + any resources linked to this object. + """ + self._close = close + + def close(self) -> None: + """Release any resources linked to this object.""" + if self._close is not None: + self._close() + self._close = None + + def isnull(self, keep_attrs: bool | None = None) -> Self: + """Test each value in the array for whether it is a missing value. + + Parameters + ---------- + keep_attrs : bool or None, optional + If True, the attributes (`attrs`) will be copied from + the original object to the new one. If False, the new + object will be returned without attributes. + + Returns + ------- + isnull : DataArray or Dataset + Same type and shape as object, but the dtype of the data is bool. + + See Also + -------- + pandas.isnull + + Examples + -------- + >>> array = xr.DataArray([1, np.nan, 3], dims="x") + >>> array + Size: 24B + array([ 1., nan, 3.]) + Dimensions without coordinates: x + >>> array.isnull() + Size: 3B + array([False, True, False]) + Dimensions without coordinates: x + """ + from xarray.core.computation import apply_ufunc + + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=False) + + return apply_ufunc( + duck_array_ops.isnull, + self, + dask="allowed", + keep_attrs=keep_attrs, + ) + + def notnull(self, keep_attrs: bool | None = None) -> Self: + """Test each value in the array for whether it is not a missing value. + + Parameters + ---------- + keep_attrs : bool or None, optional + If True, the attributes (`attrs`) will be copied from + the original object to the new one. If False, the new + object will be returned without attributes. + + Returns + ------- + notnull : DataArray or Dataset + Same type and shape as object, but the dtype of the data is bool. + + See Also + -------- + pandas.notnull + + Examples + -------- + >>> array = xr.DataArray([1, np.nan, 3], dims="x") + >>> array + Size: 24B + array([ 1., nan, 3.]) + Dimensions without coordinates: x + >>> array.notnull() + Size: 3B + array([ True, False, True]) + Dimensions without coordinates: x + """ + from xarray.core.computation import apply_ufunc + + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=False) + + return apply_ufunc( + duck_array_ops.notnull, + self, + dask="allowed", + keep_attrs=keep_attrs, + ) + + def isin(self, test_elements: Any) -> Self: + """Tests each value in the array for whether it is in test elements. + + Parameters + ---------- + test_elements : array_like + The values against which to test each value of `element`. + This argument is flattened if an array or array_like. + See numpy notes for behavior with non-array-like parameters. + + Returns + ------- + isin : DataArray or Dataset + Has the same type and shape as this object, but with a bool dtype. + + Examples + -------- + >>> array = xr.DataArray([1, 2, 3], dims="x") + >>> array.isin([1, 3]) + Size: 3B + array([ True, False, True]) + Dimensions without coordinates: x + + See Also + -------- + numpy.isin + """ + from xarray.core.computation import apply_ufunc + from xarray.core.dataarray import DataArray + from xarray.core.dataset import Dataset + from xarray.core.variable import Variable + + if isinstance(test_elements, Dataset): + raise TypeError( + f"isin() argument must be convertible to an array: {test_elements}" + ) + elif isinstance(test_elements, (Variable, DataArray)): + # need to explicitly pull out data to support dask arrays as the + # second argument + test_elements = test_elements.data + + return apply_ufunc( + duck_array_ops.isin, + self, + kwargs=dict(test_elements=test_elements), + dask="allowed", + ) + + def astype( + self, + dtype, + *, + order=None, + casting=None, + subok=None, + copy=None, + keep_attrs=True, + ) -> Self: + """ + Copy of the xarray object, with data cast to a specified type. + Leaves coordinate dtype unchanged. + + Parameters + ---------- + dtype : str or dtype + Typecode or data-type to which the array is cast. + order : {'C', 'F', 'A', 'K'}, optional + Controls the memory layout order of the result. ‘C’ means C order, + ‘F’ means Fortran order, ‘A’ means ‘F’ order if all the arrays are + Fortran contiguous, ‘C’ order otherwise, and ‘K’ means as close to + the order the array elements appear in memory as possible. + casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional + Controls what kind of data casting may occur. + + * 'no' means the data types should not be cast at all. + * 'equiv' means only byte-order changes are allowed. + * 'safe' means only casts which can preserve values are allowed. + * 'same_kind' means only safe casts or casts within a kind, + like float64 to float32, are allowed. + * 'unsafe' means any data conversions may be done. + subok : bool, optional + If True, then sub-classes will be passed-through, otherwise the + returned array will be forced to be a base-class array. + copy : bool, optional + By default, astype always returns a newly allocated array. If this + is set to False and the `dtype` requirement is satisfied, the input + array is returned instead of a copy. + keep_attrs : bool, optional + By default, astype keeps attributes. Set to False to remove + attributes in the returned object. + + Returns + ------- + out : same as object + New object with data cast to the specified type. + + Notes + ----- + The ``order``, ``casting``, ``subok`` and ``copy`` arguments are only passed + through to the ``astype`` method of the underlying array when a value + different than ``None`` is supplied. + Make sure to only supply these arguments if the underlying array class + supports them. + + See Also + -------- + numpy.ndarray.astype + dask.array.Array.astype + sparse.COO.astype + """ + from xarray.core.computation import apply_ufunc + + kwargs = dict(order=order, casting=casting, subok=subok, copy=copy) + kwargs = {k: v for k, v in kwargs.items() if v is not None} + + return apply_ufunc( + duck_array_ops.astype, + self, + dtype, + kwargs=kwargs, + keep_attrs=keep_attrs, + dask="allowed", + ) + + def __enter__(self) -> Self: + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + self.close() + + def __getitem__(self, value): + # implementations of this class should implement this method + raise NotImplementedError() + + +@overload +def full_like( + other: DataArray, + fill_value: Any, + dtype: DTypeLikeSave | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, +) -> DataArray: ... + + +@overload +def full_like( + other: Dataset, + fill_value: Any, + dtype: DTypeMaybeMapping | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, +) -> Dataset: ... + + +@overload +def full_like( + other: Variable, + fill_value: Any, + dtype: DTypeLikeSave | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, +) -> Variable: ... + + +@overload +def full_like( + other: Dataset | DataArray, + fill_value: Any, + dtype: DTypeMaybeMapping | None = None, + *, + chunks: T_Chunks = {}, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, +) -> Dataset | DataArray: ... + + +@overload +def full_like( + other: Dataset | DataArray | Variable, + fill_value: Any, + dtype: DTypeMaybeMapping | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, +) -> Dataset | DataArray | Variable: ... + + +def full_like( + other: Dataset | DataArray | Variable, + fill_value: Any, + dtype: DTypeMaybeMapping | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, +) -> Dataset | DataArray | Variable: + """ + Return a new object with the same shape and type as a given object. + + Returned object will be chunked if if the given object is chunked, or if chunks or chunked_array_type are specified. + + Parameters + ---------- + other : DataArray, Dataset or Variable + The reference object in input + fill_value : scalar or dict-like + Value to fill the new object with before returning it. If + other is a Dataset, may also be a dict-like mapping data + variables to fill values. + dtype : dtype or dict-like of dtype, optional + dtype of the new array. If a dict-like, maps dtypes to + variables. If omitted, it defaults to other.dtype. + chunks : int, "auto", tuple of int or mapping of Hashable to int, optional + Chunk sizes along each dimension, e.g., ``5``, ``"auto"``, ``(5, 5)`` or + ``{"x": 5, "y": 5}``. + chunked_array_type: str, optional + Which chunked array type to coerce the underlying data array to. + Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEnetryPoint` system. + Experimental API that should not be relied upon. + from_array_kwargs: dict, optional + Additional keyword arguments passed on to the `ChunkManagerEntrypoint.from_array` method used to create + chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg. + For example, with dask as the default chunked array type, this method would pass additional kwargs + to :py:func:`dask.array.from_array`. Experimental API that should not be relied upon. + + Returns + ------- + out : same as object + New object with the same shape and type as other, with the data + filled with fill_value. Coords will be copied from other. + If other is based on dask, the new one will be as well, and will be + split in the same chunks. + + Examples + -------- + >>> x = xr.DataArray( + ... np.arange(6).reshape(2, 3), + ... dims=["lat", "lon"], + ... coords={"lat": [1, 2], "lon": [0, 1, 2]}, + ... ) + >>> x + Size: 48B + array([[0, 1, 2], + [3, 4, 5]]) + Coordinates: + * lat (lat) int64 16B 1 2 + * lon (lon) int64 24B 0 1 2 + + >>> xr.full_like(x, 1) + Size: 48B + array([[1, 1, 1], + [1, 1, 1]]) + Coordinates: + * lat (lat) int64 16B 1 2 + * lon (lon) int64 24B 0 1 2 + + >>> xr.full_like(x, 0.5) + Size: 48B + array([[0, 0, 0], + [0, 0, 0]]) + Coordinates: + * lat (lat) int64 16B 1 2 + * lon (lon) int64 24B 0 1 2 + + >>> xr.full_like(x, 0.5, dtype=np.double) + Size: 48B + array([[0.5, 0.5, 0.5], + [0.5, 0.5, 0.5]]) + Coordinates: + * lat (lat) int64 16B 1 2 + * lon (lon) int64 24B 0 1 2 + + >>> xr.full_like(x, np.nan, dtype=np.double) + Size: 48B + array([[nan, nan, nan], + [nan, nan, nan]]) + Coordinates: + * lat (lat) int64 16B 1 2 + * lon (lon) int64 24B 0 1 2 + + >>> ds = xr.Dataset( + ... {"a": ("x", [3, 5, 2]), "b": ("x", [9, 1, 0])}, coords={"x": [2, 4, 6]} + ... ) + >>> ds + Size: 72B + Dimensions: (x: 3) + Coordinates: + * x (x) int64 24B 2 4 6 + Data variables: + a (x) int64 24B 3 5 2 + b (x) int64 24B 9 1 0 + >>> xr.full_like(ds, fill_value={"a": 1, "b": 2}) + Size: 72B + Dimensions: (x: 3) + Coordinates: + * x (x) int64 24B 2 4 6 + Data variables: + a (x) int64 24B 1 1 1 + b (x) int64 24B 2 2 2 + >>> xr.full_like(ds, fill_value={"a": 1, "b": 2}, dtype={"a": bool, "b": float}) + Size: 51B + Dimensions: (x: 3) + Coordinates: + * x (x) int64 24B 2 4 6 + Data variables: + a (x) bool 3B True True True + b (x) float64 24B 2.0 2.0 2.0 + + See Also + -------- + zeros_like + ones_like + + """ + from xarray.core.dataarray import DataArray + from xarray.core.dataset import Dataset + from xarray.core.variable import Variable + + if not is_scalar(fill_value) and not ( + isinstance(other, Dataset) and isinstance(fill_value, dict) + ): + raise ValueError( + f"fill_value must be scalar or, for datasets, a dict-like. Received {fill_value} instead." + ) + + if isinstance(other, Dataset): + if not isinstance(fill_value, dict): + fill_value = {k: fill_value for k in other.data_vars.keys()} + + dtype_: Mapping[Any, DTypeLikeSave] + if not isinstance(dtype, Mapping): + dtype_ = {k: dtype for k in other.data_vars.keys()} + else: + dtype_ = dtype + + data_vars = { + k: _full_like_variable( + v.variable, + fill_value.get(k, dtypes.NA), + dtype_.get(k, None), + chunks, + chunked_array_type, + from_array_kwargs, + ) + for k, v in other.data_vars.items() + } + return Dataset(data_vars, coords=other.coords, attrs=other.attrs) + elif isinstance(other, DataArray): + if isinstance(dtype, Mapping): + raise ValueError("'dtype' cannot be dict-like when passing a DataArray") + return DataArray( + _full_like_variable( + other.variable, + fill_value, + dtype, + chunks, + chunked_array_type, + from_array_kwargs, + ), + dims=other.dims, + coords=other.coords, + attrs=other.attrs, + name=other.name, + ) + elif isinstance(other, Variable): + if isinstance(dtype, Mapping): + raise ValueError("'dtype' cannot be dict-like when passing a Variable") + return _full_like_variable( + other, fill_value, dtype, chunks, chunked_array_type, from_array_kwargs + ) + else: + raise TypeError("Expected DataArray, Dataset, or Variable") + + +def _full_like_variable( + other: Variable, + fill_value: Any, + dtype: DTypeLike | None = None, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, +) -> Variable: + """Inner function of full_like, where other must be a variable""" + from xarray.core.variable import Variable + + if fill_value is dtypes.NA: + fill_value = dtypes.get_fill_value(dtype if dtype is not None else other.dtype) + + if ( + is_chunked_array(other.data) + or chunked_array_type is not None + or chunks is not None + ): + if chunked_array_type is None: + chunkmanager = get_chunked_array_type(other.data) + else: + chunkmanager = guess_chunkmanager(chunked_array_type) + + if dtype is None: + dtype = other.dtype + + if from_array_kwargs is None: + from_array_kwargs = {} + + data = chunkmanager.array_api.full( + other.shape, + fill_value, + dtype=dtype, + chunks=chunks if chunks else other.data.chunks, + **from_array_kwargs, + ) + else: + data = np.full_like(other.data, fill_value, dtype=dtype) + + return Variable(dims=other.dims, data=data, attrs=other.attrs) + + +@overload +def zeros_like( + other: DataArray, + dtype: DTypeLikeSave | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, +) -> DataArray: ... + + +@overload +def zeros_like( + other: Dataset, + dtype: DTypeMaybeMapping | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, +) -> Dataset: ... + + +@overload +def zeros_like( + other: Variable, + dtype: DTypeLikeSave | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, +) -> Variable: ... + + +@overload +def zeros_like( + other: Dataset | DataArray, + dtype: DTypeMaybeMapping | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, +) -> Dataset | DataArray: ... + + +@overload +def zeros_like( + other: Dataset | DataArray | Variable, + dtype: DTypeMaybeMapping | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, +) -> Dataset | DataArray | Variable: ... + + +def zeros_like( + other: Dataset | DataArray | Variable, + dtype: DTypeMaybeMapping | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, +) -> Dataset | DataArray | Variable: + """Return a new object of zeros with the same shape and + type as a given dataarray or dataset. + + Parameters + ---------- + other : DataArray, Dataset or Variable + The reference object. The output will have the same dimensions and coordinates as this object. + dtype : dtype, optional + dtype of the new array. If omitted, it defaults to other.dtype. + chunks : int, "auto", tuple of int or mapping of Hashable to int, optional + Chunk sizes along each dimension, e.g., ``5``, ``"auto"``, ``(5, 5)`` or + ``{"x": 5, "y": 5}``. + chunked_array_type: str, optional + Which chunked array type to coerce the underlying data array to. + Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEnetryPoint` system. + Experimental API that should not be relied upon. + from_array_kwargs: dict, optional + Additional keyword arguments passed on to the `ChunkManagerEntrypoint.from_array` method used to create + chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg. + For example, with dask as the default chunked array type, this method would pass additional kwargs + to :py:func:`dask.array.from_array`. Experimental API that should not be relied upon. + + Returns + ------- + out : DataArray, Dataset or Variable + New object of zeros with the same shape and type as other. + + Examples + -------- + >>> x = xr.DataArray( + ... np.arange(6).reshape(2, 3), + ... dims=["lat", "lon"], + ... coords={"lat": [1, 2], "lon": [0, 1, 2]}, + ... ) + >>> x + Size: 48B + array([[0, 1, 2], + [3, 4, 5]]) + Coordinates: + * lat (lat) int64 16B 1 2 + * lon (lon) int64 24B 0 1 2 + + >>> xr.zeros_like(x) + Size: 48B + array([[0, 0, 0], + [0, 0, 0]]) + Coordinates: + * lat (lat) int64 16B 1 2 + * lon (lon) int64 24B 0 1 2 + + >>> xr.zeros_like(x, dtype=float) + Size: 48B + array([[0., 0., 0.], + [0., 0., 0.]]) + Coordinates: + * lat (lat) int64 16B 1 2 + * lon (lon) int64 24B 0 1 2 + + See Also + -------- + ones_like + full_like + + """ + return full_like( + other, + 0, + dtype, + chunks=chunks, + chunked_array_type=chunked_array_type, + from_array_kwargs=from_array_kwargs, + ) + + +@overload +def ones_like( + other: DataArray, + dtype: DTypeLikeSave | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, +) -> DataArray: ... + + +@overload +def ones_like( + other: Dataset, + dtype: DTypeMaybeMapping | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, +) -> Dataset: ... + + +@overload +def ones_like( + other: Variable, + dtype: DTypeLikeSave | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, +) -> Variable: ... + + +@overload +def ones_like( + other: Dataset | DataArray, + dtype: DTypeMaybeMapping | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, +) -> Dataset | DataArray: ... + + +@overload +def ones_like( + other: Dataset | DataArray | Variable, + dtype: DTypeMaybeMapping | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, +) -> Dataset | DataArray | Variable: ... + + +def ones_like( + other: Dataset | DataArray | Variable, + dtype: DTypeMaybeMapping | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, +) -> Dataset | DataArray | Variable: + """Return a new object of ones with the same shape and + type as a given dataarray or dataset. + + Parameters + ---------- + other : DataArray, Dataset, or Variable + The reference object. The output will have the same dimensions and coordinates as this object. + dtype : dtype, optional + dtype of the new array. If omitted, it defaults to other.dtype. + chunks : int, "auto", tuple of int or mapping of Hashable to int, optional + Chunk sizes along each dimension, e.g., ``5``, ``"auto"``, ``(5, 5)`` or + ``{"x": 5, "y": 5}``. + chunked_array_type: str, optional + Which chunked array type to coerce the underlying data array to. + Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEnetryPoint` system. + Experimental API that should not be relied upon. + from_array_kwargs: dict, optional + Additional keyword arguments passed on to the `ChunkManagerEntrypoint.from_array` method used to create + chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg. + For example, with dask as the default chunked array type, this method would pass additional kwargs + to :py:func:`dask.array.from_array`. Experimental API that should not be relied upon. + + Returns + ------- + out : same as object + New object of ones with the same shape and type as other. + + Examples + -------- + >>> x = xr.DataArray( + ... np.arange(6).reshape(2, 3), + ... dims=["lat", "lon"], + ... coords={"lat": [1, 2], "lon": [0, 1, 2]}, + ... ) + >>> x + Size: 48B + array([[0, 1, 2], + [3, 4, 5]]) + Coordinates: + * lat (lat) int64 16B 1 2 + * lon (lon) int64 24B 0 1 2 + + >>> xr.ones_like(x) + Size: 48B + array([[1, 1, 1], + [1, 1, 1]]) + Coordinates: + * lat (lat) int64 16B 1 2 + * lon (lon) int64 24B 0 1 2 + + See Also + -------- + zeros_like + full_like + + """ + return full_like( + other, + 1, + dtype, + chunks=chunks, + chunked_array_type=chunked_array_type, + from_array_kwargs=from_array_kwargs, + ) + + +def get_chunksizes( + variables: Iterable[Variable], +) -> Mapping[Any, tuple[int, ...]]: + chunks: dict[Any, tuple[int, ...]] = {} + for v in variables: + if hasattr(v._data, "chunks"): + for dim, c in v.chunksizes.items(): + if dim in chunks and c != chunks[dim]: + raise ValueError( + f"Object has inconsistent chunks along dimension {dim}. " + "This can be fixed by calling unify_chunks()." + ) + chunks[dim] = c + return Frozen(chunks) + + +def is_np_datetime_like(dtype: DTypeLike) -> bool: + """Check if a dtype is a subclass of the numpy datetime types""" + return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64) + + +def is_np_timedelta_like(dtype: DTypeLike) -> bool: + """Check whether dtype is of the timedelta64 dtype.""" + return np.issubdtype(dtype, np.timedelta64) + + +def _contains_cftime_datetimes(array: Any) -> bool: + """Check if a array inside a Variable contains cftime.datetime objects""" + if cftime is None: + return False + + if array.dtype == np.dtype("O") and array.size > 0: + first_idx = (0,) * array.ndim + if isinstance(array, ExplicitlyIndexed): + first_idx = BasicIndexer(first_idx) + sample = array[first_idx] + return isinstance(np.asarray(sample).item(), cftime.datetime) + + return False + + +def contains_cftime_datetimes(var: T_Variable) -> bool: + """Check if an xarray.Variable contains cftime.datetime objects""" + return _contains_cftime_datetimes(var._data) + + +def _contains_datetime_like_objects(var: T_Variable) -> bool: + """Check if a variable contains datetime like objects (either + np.datetime64, np.timedelta64, or cftime.datetime) + """ + return is_np_datetime_like(var.dtype) or contains_cftime_datetimes(var) diff --git a/test/fixtures/whole_applications/xarray/xarray/core/computation.py b/test/fixtures/whole_applications/xarray/xarray/core/computation.py new file mode 100644 index 0000000..f09b04b --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/core/computation.py @@ -0,0 +1,2312 @@ +""" +Functions for applying functions that act on arrays to xarray's labeled data. +""" + +from __future__ import annotations + +import functools +import itertools +import operator +import warnings +from collections import Counter +from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence, Set +from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, Union, cast, overload + +import numpy as np + +from xarray.core import dtypes, duck_array_ops, utils +from xarray.core.alignment import align, deep_align +from xarray.core.common import zeros_like +from xarray.core.duck_array_ops import datetime_to_numeric +from xarray.core.formatting import limit_lines +from xarray.core.indexes import Index, filter_indexes_from_coords +from xarray.core.merge import merge_attrs, merge_coordinates_without_align +from xarray.core.options import OPTIONS, _get_keep_attrs +from xarray.core.types import Dims, T_DataArray +from xarray.core.utils import is_dict_like, is_duck_dask_array, is_scalar, parse_dims +from xarray.core.variable import Variable +from xarray.namedarray.parallelcompat import get_chunked_array_type +from xarray.namedarray.pycompat import is_chunked_array +from xarray.util.deprecation_helpers import deprecate_dims + +if TYPE_CHECKING: + from xarray.core.coordinates import Coordinates + from xarray.core.dataarray import DataArray + from xarray.core.dataset import Dataset + from xarray.core.types import CombineAttrsOptions, JoinOptions + + MissingCoreDimOptions = Literal["raise", "copy", "drop"] + +_NO_FILL_VALUE = utils.ReprObject("") +_DEFAULT_NAME = utils.ReprObject("") +_JOINS_WITHOUT_FILL_VALUES = frozenset({"inner", "exact"}) + + +def _first_of_type(args, kind): + """Return either first object of type 'kind' or raise if not found.""" + for arg in args: + if isinstance(arg, kind): + return arg + + raise ValueError("This should be unreachable.") + + +def _all_of_type(args, kind): + """Return all objects of type 'kind'""" + return [arg for arg in args if isinstance(arg, kind)] + + +class _UFuncSignature: + """Core dimensions signature for a given function. + + Based on the signature provided by generalized ufuncs in NumPy. + + Attributes + ---------- + input_core_dims : tuple[tuple] + Core dimension names on each input variable. + output_core_dims : tuple[tuple] + Core dimension names on each output variable. + """ + + __slots__ = ( + "input_core_dims", + "output_core_dims", + "_all_input_core_dims", + "_all_output_core_dims", + "_all_core_dims", + ) + + def __init__(self, input_core_dims, output_core_dims=((),)): + self.input_core_dims = tuple(tuple(a) for a in input_core_dims) + self.output_core_dims = tuple(tuple(a) for a in output_core_dims) + self._all_input_core_dims = None + self._all_output_core_dims = None + self._all_core_dims = None + + @property + def all_input_core_dims(self): + if self._all_input_core_dims is None: + self._all_input_core_dims = frozenset( + dim for dims in self.input_core_dims for dim in dims + ) + return self._all_input_core_dims + + @property + def all_output_core_dims(self): + if self._all_output_core_dims is None: + self._all_output_core_dims = frozenset( + dim for dims in self.output_core_dims for dim in dims + ) + return self._all_output_core_dims + + @property + def all_core_dims(self): + if self._all_core_dims is None: + self._all_core_dims = self.all_input_core_dims | self.all_output_core_dims + return self._all_core_dims + + @property + def dims_map(self): + return { + core_dim: f"dim{n}" for n, core_dim in enumerate(sorted(self.all_core_dims)) + } + + @property + def num_inputs(self): + return len(self.input_core_dims) + + @property + def num_outputs(self): + return len(self.output_core_dims) + + def __eq__(self, other): + try: + return ( + self.input_core_dims == other.input_core_dims + and self.output_core_dims == other.output_core_dims + ) + except AttributeError: + return False + + def __ne__(self, other): + return not self == other + + def __repr__(self): + return f"{type(self).__name__}({list(self.input_core_dims)!r}, {list(self.output_core_dims)!r})" + + def __str__(self): + lhs = ",".join("({})".format(",".join(dims)) for dims in self.input_core_dims) + rhs = ",".join("({})".format(",".join(dims)) for dims in self.output_core_dims) + return f"{lhs}->{rhs}" + + def to_gufunc_string(self, exclude_dims=frozenset()): + """Create an equivalent signature string for a NumPy gufunc. + + Unlike __str__, handles dimensions that don't map to Python + identifiers. + + Also creates unique names for input_core_dims contained in exclude_dims. + """ + input_core_dims = [ + [self.dims_map[dim] for dim in core_dims] + for core_dims in self.input_core_dims + ] + output_core_dims = [ + [self.dims_map[dim] for dim in core_dims] + for core_dims in self.output_core_dims + ] + + # enumerate input_core_dims contained in exclude_dims to make them unique + if exclude_dims: + exclude_dims = [self.dims_map[dim] for dim in exclude_dims] + + counter: Counter = Counter() + + def _enumerate(dim): + if dim in exclude_dims: + n = counter[dim] + counter.update([dim]) + dim = f"{dim}_{n}" + return dim + + input_core_dims = [ + [_enumerate(dim) for dim in arg] for arg in input_core_dims + ] + + alt_signature = type(self)(input_core_dims, output_core_dims) + return str(alt_signature) + + +def result_name(objects: Iterable[Any]) -> Any: + # use the same naming heuristics as pandas: + # https://github.com/blaze/blaze/issues/458#issuecomment-51936356 + names = {getattr(obj, "name", _DEFAULT_NAME) for obj in objects} + names.discard(_DEFAULT_NAME) + if len(names) == 1: + (name,) = names + else: + name = None + return name + + +def _get_coords_list(args: Iterable[Any]) -> list[Coordinates]: + coords_list = [] + for arg in args: + try: + coords = arg.coords + except AttributeError: + pass # skip this argument + else: + coords_list.append(coords) + return coords_list + + +def build_output_coords_and_indexes( + args: Iterable[Any], + signature: _UFuncSignature, + exclude_dims: Set = frozenset(), + combine_attrs: CombineAttrsOptions = "override", +) -> tuple[list[dict[Any, Variable]], list[dict[Any, Index]]]: + """Build output coordinates and indexes for an operation. + + Parameters + ---------- + args : Iterable + List of raw operation arguments. Any valid types for xarray operations + are OK, e.g., scalars, Variable, DataArray, Dataset. + signature : _UfuncSignature + Core dimensions signature for the operation. + exclude_dims : set, optional + Dimensions excluded from the operation. Coordinates along these + dimensions are dropped. + combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \ + "override"} or callable, default: "drop" + A callable or a string indicating how to combine attrs of the objects being + merged: + + - "drop": empty attrs on returned Dataset. + - "identical": all attrs must be the same on every object. + - "no_conflicts": attrs from all objects are combined, any that have + the same name must also have the same value. + - "drop_conflicts": attrs from all objects are combined, any that have + the same name but different values are dropped. + - "override": skip comparing and copy attrs from the first dataset to + the result. + + If a callable, it must expect a sequence of ``attrs`` dicts and a context object + as its only parameters. + + Returns + ------- + Dictionaries of Variable and Index objects with merged coordinates. + """ + coords_list = _get_coords_list(args) + + if len(coords_list) == 1 and not exclude_dims: + # we can skip the expensive merge + (unpacked_coords,) = coords_list + merged_vars = dict(unpacked_coords.variables) + merged_indexes = dict(unpacked_coords.xindexes) + else: + merged_vars, merged_indexes = merge_coordinates_without_align( + coords_list, exclude_dims=exclude_dims, combine_attrs=combine_attrs + ) + + output_coords = [] + output_indexes = [] + for output_dims in signature.output_core_dims: + dropped_dims = signature.all_input_core_dims - set(output_dims) + if dropped_dims: + filtered_coords = { + k: v for k, v in merged_vars.items() if dropped_dims.isdisjoint(v.dims) + } + filtered_indexes = filter_indexes_from_coords( + merged_indexes, set(filtered_coords) + ) + else: + filtered_coords = merged_vars + filtered_indexes = merged_indexes + output_coords.append(filtered_coords) + output_indexes.append(filtered_indexes) + + return output_coords, output_indexes + + +def apply_dataarray_vfunc( + func, + *args, + signature: _UFuncSignature, + join: JoinOptions = "inner", + exclude_dims=frozenset(), + keep_attrs="override", +) -> tuple[DataArray, ...] | DataArray: + """Apply a variable level function over DataArray, Variable and/or ndarray + objects. + """ + from xarray.core.dataarray import DataArray + + if len(args) > 1: + args = tuple( + deep_align( + args, + join=join, + copy=False, + exclude=exclude_dims, + raise_on_invalid=False, + ) + ) + + objs = _all_of_type(args, DataArray) + + if keep_attrs == "drop": + name = result_name(args) + else: + first_obj = _first_of_type(args, DataArray) + name = first_obj.name + result_coords, result_indexes = build_output_coords_and_indexes( + args, signature, exclude_dims, combine_attrs=keep_attrs + ) + + data_vars = [getattr(a, "variable", a) for a in args] + result_var = func(*data_vars) + + out: tuple[DataArray, ...] | DataArray + if signature.num_outputs > 1: + out = tuple( + DataArray( + variable, coords=coords, indexes=indexes, name=name, fastpath=True + ) + for variable, coords, indexes in zip( + result_var, result_coords, result_indexes + ) + ) + else: + (coords,) = result_coords + (indexes,) = result_indexes + out = DataArray( + result_var, coords=coords, indexes=indexes, name=name, fastpath=True + ) + + attrs = merge_attrs([x.attrs for x in objs], combine_attrs=keep_attrs) + if isinstance(out, tuple): + for da in out: + da.attrs = attrs + else: + out.attrs = attrs + + return out + + +def ordered_set_union(all_keys: list[Iterable]) -> Iterable: + return {key: None for keys in all_keys for key in keys}.keys() + + +def ordered_set_intersection(all_keys: list[Iterable]) -> Iterable: + intersection = set(all_keys[0]) + for keys in all_keys[1:]: + intersection.intersection_update(keys) + return [key for key in all_keys[0] if key in intersection] + + +def assert_and_return_exact_match(all_keys): + first_keys = all_keys[0] + for keys in all_keys[1:]: + if keys != first_keys: + raise ValueError( + "exact match required for all data variable names, " + f"but {list(keys)} != {list(first_keys)}: {set(keys) ^ set(first_keys)} are not in both." + ) + return first_keys + + +_JOINERS: dict[str, Callable] = { + "inner": ordered_set_intersection, + "outer": ordered_set_union, + "left": operator.itemgetter(0), + "right": operator.itemgetter(-1), + "exact": assert_and_return_exact_match, +} + + +def join_dict_keys(objects: Iterable[Mapping | Any], how: str = "inner") -> Iterable: + joiner = _JOINERS[how] + all_keys = [obj.keys() for obj in objects if hasattr(obj, "keys")] + return joiner(all_keys) + + +def collect_dict_values( + objects: Iterable[Mapping | Any], keys: Iterable, fill_value: object = None +) -> list[list]: + return [ + [obj.get(key, fill_value) if is_dict_like(obj) else obj for obj in objects] + for key in keys + ] + + +def _as_variables_or_variable(arg) -> Variable | tuple[Variable]: + try: + return arg.variables + except AttributeError: + try: + return arg.variable + except AttributeError: + return arg + + +def _unpack_dict_tuples( + result_vars: Mapping[Any, tuple[Variable, ...]], num_outputs: int +) -> tuple[dict[Hashable, Variable], ...]: + out: tuple[dict[Hashable, Variable], ...] = tuple({} for _ in range(num_outputs)) + for name, values in result_vars.items(): + for value, results_dict in zip(values, out): + results_dict[name] = value + return out + + +def _check_core_dims(signature, variable_args, name): + """ + Check if an arg has all the core dims required by the signature. + + Slightly awkward design, of returning the error message. But we want to + give a detailed error message, which requires inspecting the variable in + the inner loop. + """ + missing = [] + for i, (core_dims, variable_arg) in enumerate( + zip(signature.input_core_dims, variable_args) + ): + # Check whether all the dims are on the variable. Note that we need the + # `hasattr` to check for a dims property, to protect against the case where + # a numpy array is passed in. + if hasattr(variable_arg, "dims") and set(core_dims) - set(variable_arg.dims): + missing += [[i, variable_arg, core_dims]] + if missing: + message = "" + for i, variable_arg, core_dims in missing: + message += f"Missing core dims {set(core_dims) - set(variable_arg.dims)} from arg number {i + 1} on a variable named `{name}`:\n{variable_arg}\n\n" + message += "Either add the core dimension, or if passing a dataset alternatively pass `on_missing_core_dim` as `copy` or `drop`. " + return message + return True + + +def apply_dict_of_variables_vfunc( + func, + *args, + signature: _UFuncSignature, + join="inner", + fill_value=None, + on_missing_core_dim: MissingCoreDimOptions = "raise", +): + """Apply a variable level function over dicts of DataArray, DataArray, + Variable and ndarray objects. + """ + args = tuple(_as_variables_or_variable(arg) for arg in args) + names = join_dict_keys(args, how=join) + grouped_by_name = collect_dict_values(args, names, fill_value) + + result_vars = {} + for name, variable_args in zip(names, grouped_by_name): + core_dim_present = _check_core_dims(signature, variable_args, name) + if core_dim_present is True: + result_vars[name] = func(*variable_args) + else: + if on_missing_core_dim == "raise": + raise ValueError(core_dim_present) + elif on_missing_core_dim == "copy": + result_vars[name] = variable_args[0] + elif on_missing_core_dim == "drop": + pass + else: + raise ValueError( + f"Invalid value for `on_missing_core_dim`: {on_missing_core_dim!r}" + ) + + if signature.num_outputs > 1: + return _unpack_dict_tuples(result_vars, signature.num_outputs) + else: + return result_vars + + +def _fast_dataset( + variables: dict[Hashable, Variable], + coord_variables: Mapping[Hashable, Variable], + indexes: dict[Hashable, Index], +) -> Dataset: + """Create a dataset as quickly as possible. + + Beware: the `variables` dict is modified INPLACE. + """ + from xarray.core.dataset import Dataset + + variables.update(coord_variables) + coord_names = set(coord_variables) + return Dataset._construct_direct(variables, coord_names, indexes=indexes) + + +def apply_dataset_vfunc( + func, + *args, + signature: _UFuncSignature, + join="inner", + dataset_join="exact", + fill_value=_NO_FILL_VALUE, + exclude_dims=frozenset(), + keep_attrs="override", + on_missing_core_dim: MissingCoreDimOptions = "raise", +) -> Dataset | tuple[Dataset, ...]: + """Apply a variable level function over Dataset, dict of DataArray, + DataArray, Variable and/or ndarray objects. + """ + from xarray.core.dataset import Dataset + + if dataset_join not in _JOINS_WITHOUT_FILL_VALUES and fill_value is _NO_FILL_VALUE: + raise TypeError( + "to apply an operation to datasets with different " + "data variables with apply_ufunc, you must supply the " + "dataset_fill_value argument." + ) + + objs = _all_of_type(args, Dataset) + + if len(args) > 1: + args = tuple( + deep_align( + args, + join=join, + copy=False, + exclude=exclude_dims, + raise_on_invalid=False, + ) + ) + + list_of_coords, list_of_indexes = build_output_coords_and_indexes( + args, signature, exclude_dims, combine_attrs=keep_attrs + ) + args = tuple(getattr(arg, "data_vars", arg) for arg in args) + + result_vars = apply_dict_of_variables_vfunc( + func, + *args, + signature=signature, + join=dataset_join, + fill_value=fill_value, + on_missing_core_dim=on_missing_core_dim, + ) + + out: Dataset | tuple[Dataset, ...] + if signature.num_outputs > 1: + out = tuple( + _fast_dataset(*args) + for args in zip(result_vars, list_of_coords, list_of_indexes) + ) + else: + (coord_vars,) = list_of_coords + (indexes,) = list_of_indexes + out = _fast_dataset(result_vars, coord_vars, indexes=indexes) + + attrs = merge_attrs([x.attrs for x in objs], combine_attrs=keep_attrs) + if isinstance(out, tuple): + for ds in out: + ds.attrs = attrs + else: + out.attrs = attrs + + return out + + +def _iter_over_selections(obj, dim, values): + """Iterate over selections of an xarray object in the provided order.""" + from xarray.core.groupby import _dummy_copy + + dummy = None + for value in values: + try: + obj_sel = obj.sel(**{dim: value}) + except (KeyError, IndexError): + if dummy is None: + dummy = _dummy_copy(obj) + obj_sel = dummy + yield obj_sel + + +def apply_groupby_func(func, *args): + """Apply a dataset or datarray level function over GroupBy, Dataset, + DataArray, Variable and/or ndarray objects. + """ + from xarray.core.groupby import GroupBy, peek_at + from xarray.core.variable import Variable + + groupbys = [arg for arg in args if isinstance(arg, GroupBy)] + assert groupbys, "must have at least one groupby to iterate over" + first_groupby = groupbys[0] + (grouper,) = first_groupby.groupers + if any(not grouper.group.equals(gb.groupers[0].group) for gb in groupbys[1:]): # type: ignore[union-attr] + raise ValueError( + "apply_ufunc can only perform operations over " + "multiple GroupBy objects at once if they are all " + "grouped the same way" + ) + + grouped_dim = grouper.name + unique_values = grouper.unique_coord.values + + iterators = [] + for arg in args: + iterator: Iterator[Any] + if isinstance(arg, GroupBy): + iterator = (value for _, value in arg) + elif hasattr(arg, "dims") and grouped_dim in arg.dims: + if isinstance(arg, Variable): + raise ValueError( + "groupby operations cannot be performed with " + "xarray.Variable objects that share a dimension with " + "the grouped dimension" + ) + iterator = _iter_over_selections(arg, grouped_dim, unique_values) + else: + iterator = itertools.repeat(arg) + iterators.append(iterator) + + applied: Iterator = (func(*zipped_args) for zipped_args in zip(*iterators)) + applied_example, applied = peek_at(applied) + combine = first_groupby._combine # type: ignore[attr-defined] + if isinstance(applied_example, tuple): + combined = tuple(combine(output) for output in zip(*applied)) + else: + combined = combine(applied) + return combined + + +def unified_dim_sizes( + variables: Iterable[Variable], exclude_dims: Set = frozenset() +) -> dict[Hashable, int]: + dim_sizes: dict[Hashable, int] = {} + + for var in variables: + if len(set(var.dims)) < len(var.dims): + raise ValueError( + "broadcasting cannot handle duplicate " + f"dimensions on a variable: {list(var.dims)}" + ) + for dim, size in zip(var.dims, var.shape): + if dim not in exclude_dims: + if dim not in dim_sizes: + dim_sizes[dim] = size + elif dim_sizes[dim] != size: + raise ValueError( + "operands cannot be broadcast together " + "with mismatched lengths for dimension " + f"{dim}: {dim_sizes[dim]} vs {size}" + ) + return dim_sizes + + +SLICE_NONE = slice(None) + + +def broadcast_compat_data( + variable: Variable, + broadcast_dims: tuple[Hashable, ...], + core_dims: tuple[Hashable, ...], +) -> Any: + data = variable.data + + old_dims = variable.dims + new_dims = broadcast_dims + core_dims + + if new_dims == old_dims: + # optimize for the typical case + return data + + set_old_dims = set(old_dims) + set_new_dims = set(new_dims) + unexpected_dims = [d for d in old_dims if d not in set_new_dims] + + if unexpected_dims: + raise ValueError( + "operand to apply_ufunc encountered unexpected " + f"dimensions {unexpected_dims!r} on an input variable: these are core " + "dimensions on other input or output variables" + ) + + # for consistency with numpy, keep broadcast dimensions to the left + old_broadcast_dims = tuple(d for d in broadcast_dims if d in set_old_dims) + reordered_dims = old_broadcast_dims + core_dims + if reordered_dims != old_dims: + order = tuple(old_dims.index(d) for d in reordered_dims) + data = duck_array_ops.transpose(data, order) + + if new_dims != reordered_dims: + key_parts: list[slice | None] = [] + for dim in new_dims: + if dim in set_old_dims: + key_parts.append(SLICE_NONE) + elif key_parts: + # no need to insert new axes at the beginning that are already + # handled by broadcasting + key_parts.append(np.newaxis) + data = data[tuple(key_parts)] + + return data + + +def _vectorize(func, signature, output_dtypes, exclude_dims): + if signature.all_core_dims: + func = np.vectorize( + func, + otypes=output_dtypes, + signature=signature.to_gufunc_string(exclude_dims), + ) + else: + func = np.vectorize(func, otypes=output_dtypes) + + return func + + +def apply_variable_ufunc( + func, + *args, + signature: _UFuncSignature, + exclude_dims=frozenset(), + dask="forbidden", + output_dtypes=None, + vectorize=False, + keep_attrs="override", + dask_gufunc_kwargs=None, +) -> Variable | tuple[Variable, ...]: + """Apply a ndarray level function over Variable and/or ndarray objects.""" + from xarray.core.formatting import short_array_repr + from xarray.core.variable import Variable, as_compatible_data + + dim_sizes = unified_dim_sizes( + (a for a in args if hasattr(a, "dims")), exclude_dims=exclude_dims + ) + broadcast_dims = tuple( + dim for dim in dim_sizes if dim not in signature.all_core_dims + ) + output_dims = [broadcast_dims + out for out in signature.output_core_dims] + + input_data = [ + ( + broadcast_compat_data(arg, broadcast_dims, core_dims) + if isinstance(arg, Variable) + else arg + ) + for arg, core_dims in zip(args, signature.input_core_dims) + ] + + if any(is_chunked_array(array) for array in input_data): + if dask == "forbidden": + raise ValueError( + "apply_ufunc encountered a chunked array on an " + "argument, but handling for chunked arrays has not " + "been enabled. Either set the ``dask`` argument " + "or load your data into memory first with " + "``.load()`` or ``.compute()``" + ) + elif dask == "parallelized": + chunkmanager = get_chunked_array_type(*input_data) + + numpy_func = func + + if dask_gufunc_kwargs is None: + dask_gufunc_kwargs = {} + else: + dask_gufunc_kwargs = dask_gufunc_kwargs.copy() + + allow_rechunk = dask_gufunc_kwargs.get("allow_rechunk", None) + if allow_rechunk is None: + for n, (data, core_dims) in enumerate( + zip(input_data, signature.input_core_dims) + ): + if is_chunked_array(data): + # core dimensions cannot span multiple chunks + for axis, dim in enumerate(core_dims, start=-len(core_dims)): + if len(data.chunks[axis]) != 1: + raise ValueError( + f"dimension {dim} on {n}th function argument to " + "apply_ufunc with dask='parallelized' consists of " + "multiple chunks, but is also a core dimension. To " + "fix, either rechunk into a single array chunk along " + f"this dimension, i.e., ``.chunk(dict({dim}=-1))``, or " + "pass ``allow_rechunk=True`` in ``dask_gufunc_kwargs`` " + "but beware that this may significantly increase memory usage." + ) + dask_gufunc_kwargs["allow_rechunk"] = True + + output_sizes = dask_gufunc_kwargs.pop("output_sizes", {}) + if output_sizes: + output_sizes_renamed = {} + for key, value in output_sizes.items(): + if key not in signature.all_output_core_dims: + raise ValueError( + f"dimension '{key}' in 'output_sizes' must correspond to output_core_dims" + ) + output_sizes_renamed[signature.dims_map[key]] = value + dask_gufunc_kwargs["output_sizes"] = output_sizes_renamed + + for key in signature.all_output_core_dims: + if ( + key not in signature.all_input_core_dims or key in exclude_dims + ) and key not in output_sizes: + raise ValueError( + f"dimension '{key}' in 'output_core_dims' needs corresponding (dim, size) in 'output_sizes'" + ) + + def func(*arrays): + res = chunkmanager.apply_gufunc( + numpy_func, + signature.to_gufunc_string(exclude_dims), + *arrays, + vectorize=vectorize, + output_dtypes=output_dtypes, + **dask_gufunc_kwargs, + ) + + return res + + elif dask == "allowed": + pass + else: + raise ValueError( + f"unknown setting for chunked array handling in apply_ufunc: {dask}" + ) + else: + if vectorize: + func = _vectorize( + func, signature, output_dtypes=output_dtypes, exclude_dims=exclude_dims + ) + + result_data = func(*input_data) + + if signature.num_outputs == 1: + result_data = (result_data,) + elif ( + not isinstance(result_data, tuple) or len(result_data) != signature.num_outputs + ): + raise ValueError( + f"applied function does not have the number of " + f"outputs specified in the ufunc signature. " + f"Received a {type(result_data)} with {len(result_data)} elements. " + f"Expected a tuple of {signature.num_outputs} elements:\n\n" + f"{limit_lines(repr(result_data), limit=10)}" + ) + + objs = _all_of_type(args, Variable) + attrs = merge_attrs( + [obj.attrs for obj in objs], + combine_attrs=keep_attrs, + ) + + output: list[Variable] = [] + for dims, data in zip(output_dims, result_data): + data = as_compatible_data(data) + if data.ndim != len(dims): + raise ValueError( + "applied function returned data with an unexpected " + f"number of dimensions. Received {data.ndim} dimension(s) but " + f"expected {len(dims)} dimensions with names {dims!r}, from:\n\n" + f"{short_array_repr(data)}" + ) + + var = Variable(dims, data, fastpath=True) + for dim, new_size in var.sizes.items(): + if dim in dim_sizes and new_size != dim_sizes[dim]: + raise ValueError( + f"size of dimension '{dim}' on inputs was unexpectedly " + f"changed by applied function from {dim_sizes[dim]} to {new_size}. Only " + "dimensions specified in ``exclude_dims`` with " + "xarray.apply_ufunc are allowed to change size. " + "The data returned was:\n\n" + f"{short_array_repr(data)}" + ) + + var.attrs = attrs + output.append(var) + + if signature.num_outputs == 1: + return output[0] + else: + return tuple(output) + + +def apply_array_ufunc(func, *args, dask="forbidden"): + """Apply a ndarray level function over ndarray objects.""" + if any(is_chunked_array(arg) for arg in args): + if dask == "forbidden": + raise ValueError( + "apply_ufunc encountered a dask array on an " + "argument, but handling for dask arrays has not " + "been enabled. Either set the ``dask`` argument " + "or load your data into memory first with " + "``.load()`` or ``.compute()``" + ) + elif dask == "parallelized": + raise ValueError( + "cannot use dask='parallelized' for apply_ufunc " + "unless at least one input is an xarray object" + ) + elif dask == "allowed": + pass + else: + raise ValueError(f"unknown setting for dask array handling: {dask}") + return func(*args) + + +def apply_ufunc( + func: Callable, + *args: Any, + input_core_dims: Sequence[Sequence] | None = None, + output_core_dims: Sequence[Sequence] | None = ((),), + exclude_dims: Set = frozenset(), + vectorize: bool = False, + join: JoinOptions = "exact", + dataset_join: str = "exact", + dataset_fill_value: object = _NO_FILL_VALUE, + keep_attrs: bool | str | None = None, + kwargs: Mapping | None = None, + dask: Literal["forbidden", "allowed", "parallelized"] = "forbidden", + output_dtypes: Sequence | None = None, + output_sizes: Mapping[Any, int] | None = None, + meta: Any = None, + dask_gufunc_kwargs: dict[str, Any] | None = None, + on_missing_core_dim: MissingCoreDimOptions = "raise", +) -> Any: + """Apply a vectorized function for unlabeled arrays on xarray objects. + + The function will be mapped over the data variable(s) of the input + arguments using xarray's standard rules for labeled computation, including + alignment, broadcasting, looping over GroupBy/Dataset variables, and + merging of coordinates. + + Parameters + ---------- + func : callable + Function to call like ``func(*args, **kwargs)`` on unlabeled arrays + (``.data``) that returns an array or tuple of arrays. If multiple + arguments with non-matching dimensions are supplied, this function is + expected to vectorize (broadcast) over axes of positional arguments in + the style of NumPy universal functions [1]_ (if this is not the case, + set ``vectorize=True``). If this function returns multiple outputs, you + must set ``output_core_dims`` as well. + *args : Dataset, DataArray, DataArrayGroupBy, DatasetGroupBy, Variable, \ + numpy.ndarray, dask.array.Array or scalar + Mix of labeled and/or unlabeled arrays to which to apply the function. + input_core_dims : sequence of sequence, optional + List of the same length as ``args`` giving the list of core dimensions + on each input argument that should not be broadcast. By default, we + assume there are no core dimensions on any input arguments. + + For example, ``input_core_dims=[[], ['time']]`` indicates that all + dimensions on the first argument and all dimensions other than 'time' + on the second argument should be broadcast. + + Core dimensions are automatically moved to the last axes of input + variables before applying ``func``, which facilitates using NumPy style + generalized ufuncs [2]_. + output_core_dims : list of tuple, optional + List of the same length as the number of output arguments from + ``func``, giving the list of core dimensions on each output that were + not broadcast on the inputs. By default, we assume that ``func`` + outputs exactly one array, with axes corresponding to each broadcast + dimension. + + Core dimensions are assumed to appear as the last dimensions of each + output in the provided order. + exclude_dims : set, optional + Core dimensions on the inputs to exclude from alignment and + broadcasting entirely. Any input coordinates along these dimensions + will be dropped. Each excluded dimension must also appear in + ``input_core_dims`` for at least one argument. Only dimensions listed + here are allowed to change size between input and output objects. + vectorize : bool, optional + If True, then assume ``func`` only takes arrays defined over core + dimensions as input and vectorize it automatically with + :py:func:`numpy.vectorize`. This option exists for convenience, but is + almost always slower than supplying a pre-vectorized function. + join : {"outer", "inner", "left", "right", "exact"}, default: "exact" + Method for joining the indexes of the passed objects along each + dimension, and the variables of Dataset objects with mismatched + data variables: + + - 'outer': use the union of object indexes + - 'inner': use the intersection of object indexes + - 'left': use indexes from the first object with each dimension + - 'right': use indexes from the last object with each dimension + - 'exact': raise `ValueError` instead of aligning when indexes to be + aligned are not equal + dataset_join : {"outer", "inner", "left", "right", "exact"}, default: "exact" + Method for joining variables of Dataset objects with mismatched + data variables. + + - 'outer': take variables from both Dataset objects + - 'inner': take only overlapped variables + - 'left': take only variables from the first object + - 'right': take only variables from the last object + - 'exact': data variables on all Dataset objects must match exactly + dataset_fill_value : optional + Value used in place of missing variables on Dataset inputs when the + datasets do not share the exact same ``data_vars``. Required if + ``dataset_join not in {'inner', 'exact'}``, otherwise ignored. + keep_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", "override"} or bool, optional + - 'drop' or False: empty attrs on returned xarray object. + - 'identical': all attrs must be the same on every object. + - 'no_conflicts': attrs from all objects are combined, any that have the same name must also have the same value. + - 'drop_conflicts': attrs from all objects are combined, any that have the same name but different values are dropped. + - 'override' or True: skip comparing and copy attrs from the first object to the result. + kwargs : dict, optional + Optional keyword arguments passed directly on to call ``func``. + dask : {"forbidden", "allowed", "parallelized"}, default: "forbidden" + How to handle applying to objects containing lazy data in the form of + dask arrays: + + - 'forbidden' (default): raise an error if a dask array is encountered. + - 'allowed': pass dask arrays directly on to ``func``. Prefer this option if + ``func`` natively supports dask arrays. + - 'parallelized': automatically parallelize ``func`` if any of the + inputs are a dask array by using :py:func:`dask.array.apply_gufunc`. Multiple output + arguments are supported. Only use this option if ``func`` does not natively + support dask arrays (e.g. converts them to numpy arrays). + dask_gufunc_kwargs : dict, optional + Optional keyword arguments passed to :py:func:`dask.array.apply_gufunc` if + dask='parallelized'. Possible keywords are ``output_sizes``, ``allow_rechunk`` + and ``meta``. + output_dtypes : list of dtype, optional + Optional list of output dtypes. Only used if ``dask='parallelized'`` or + ``vectorize=True``. + output_sizes : dict, optional + Optional mapping from dimension names to sizes for outputs. Only used + if dask='parallelized' and new dimensions (not found on inputs) appear + on outputs. ``output_sizes`` should be given in the ``dask_gufunc_kwargs`` + parameter. It will be removed as direct parameter in a future version. + meta : optional + Size-0 object representing the type of array wrapped by dask array. Passed on to + :py:func:`dask.array.apply_gufunc`. ``meta`` should be given in the + ``dask_gufunc_kwargs`` parameter . It will be removed as direct parameter + a future version. + on_missing_core_dim : {"raise", "copy", "drop"}, default: "raise" + How to handle missing core dimensions on input variables. + + Returns + ------- + Single value or tuple of Dataset, DataArray, Variable, dask.array.Array or + numpy.ndarray, the first type on that list to appear on an input. + + Notes + ----- + This function is designed for the more common case where ``func`` can work on numpy + arrays. If ``func`` needs to manipulate a whole xarray object subset to each block + it is possible to use :py:func:`xarray.map_blocks`. + + Note that due to the overhead :py:func:`xarray.map_blocks` is considerably slower than ``apply_ufunc``. + + Examples + -------- + Calculate the vector magnitude of two arguments: + + >>> def magnitude(a, b): + ... func = lambda x, y: np.sqrt(x**2 + y**2) + ... return xr.apply_ufunc(func, a, b) + ... + + You can now apply ``magnitude()`` to :py:class:`DataArray` and :py:class:`Dataset` + objects, with automatically preserved dimensions and coordinates, e.g., + + >>> array = xr.DataArray([1, 2, 3], coords=[("x", [0.1, 0.2, 0.3])]) + >>> magnitude(array, -array) + Size: 24B + array([1.41421356, 2.82842712, 4.24264069]) + Coordinates: + * x (x) float64 24B 0.1 0.2 0.3 + + Plain scalars, numpy arrays and a mix of these with xarray objects is also + supported: + + >>> magnitude(3, 4) + 5.0 + >>> magnitude(3, np.array([0, 4])) + array([3., 5.]) + >>> magnitude(array, 0) + Size: 24B + array([1., 2., 3.]) + Coordinates: + * x (x) float64 24B 0.1 0.2 0.3 + + Other examples of how you could use ``apply_ufunc`` to write functions to + (very nearly) replicate existing xarray functionality: + + Compute the mean (``.mean``) over one dimension: + + >>> def mean(obj, dim): + ... # note: apply always moves core dimensions to the end + ... return apply_ufunc( + ... np.mean, obj, input_core_dims=[[dim]], kwargs={"axis": -1} + ... ) + ... + + Inner product over a specific dimension (like :py:func:`dot`): + + >>> def _inner(x, y): + ... result = np.matmul(x[..., np.newaxis, :], y[..., :, np.newaxis]) + ... return result[..., 0, 0] + ... + >>> def inner_product(a, b, dim): + ... return apply_ufunc(_inner, a, b, input_core_dims=[[dim], [dim]]) + ... + + Stack objects along a new dimension (like :py:func:`concat`): + + >>> def stack(objects, dim, new_coord): + ... # note: this version does not stack coordinates + ... func = lambda *x: np.stack(x, axis=-1) + ... result = apply_ufunc( + ... func, + ... *objects, + ... output_core_dims=[[dim]], + ... join="outer", + ... dataset_fill_value=np.nan + ... ) + ... result[dim] = new_coord + ... return result + ... + + If your function is not vectorized but can be applied only to core + dimensions, you can use ``vectorize=True`` to turn into a vectorized + function. This wraps :py:func:`numpy.vectorize`, so the operation isn't + terribly fast. Here we'll use it to calculate the distance between + empirical samples from two probability distributions, using a scipy + function that needs to be applied to vectors: + + >>> import scipy.stats + >>> def earth_mover_distance(first_samples, second_samples, dim="ensemble"): + ... return apply_ufunc( + ... scipy.stats.wasserstein_distance, + ... first_samples, + ... second_samples, + ... input_core_dims=[[dim], [dim]], + ... vectorize=True, + ... ) + ... + + Most of NumPy's builtin functions already broadcast their inputs + appropriately for use in ``apply_ufunc``. You may find helper functions such as + :py:func:`numpy.broadcast_arrays` helpful in writing your function. ``apply_ufunc`` also + works well with :py:func:`numba.vectorize` and :py:func:`numba.guvectorize`. + + See Also + -------- + numpy.broadcast_arrays + numba.vectorize + numba.guvectorize + dask.array.apply_gufunc + xarray.map_blocks + + :ref:`dask.automatic-parallelization` + User guide describing :py:func:`apply_ufunc` and :py:func:`map_blocks`. + + :doc:`xarray-tutorial:advanced/apply_ufunc/apply_ufunc` + Advanced Tutorial on applying numpy function using :py:func:`apply_ufunc` + + References + ---------- + .. [1] https://numpy.org/doc/stable/reference/ufuncs.html + .. [2] https://numpy.org/doc/stable/reference/c-api/generalized-ufuncs.html + """ + from xarray.core.dataarray import DataArray + from xarray.core.groupby import GroupBy + from xarray.core.variable import Variable + + if input_core_dims is None: + input_core_dims = ((),) * (len(args)) + elif len(input_core_dims) != len(args): + raise ValueError( + f"input_core_dims must be None or a tuple with the length same to " + f"the number of arguments. " + f"Given {len(input_core_dims)} input_core_dims: {input_core_dims}, " + f" but number of args is {len(args)}." + ) + + if kwargs is None: + kwargs = {} + + signature = _UFuncSignature(input_core_dims, output_core_dims) + + if exclude_dims: + if not isinstance(exclude_dims, set): + raise TypeError( + f"Expected exclude_dims to be a 'set'. Received '{type(exclude_dims).__name__}' instead." + ) + if not exclude_dims <= signature.all_core_dims: + raise ValueError( + f"each dimension in `exclude_dims` must also be a " + f"core dimension in the function signature. " + f"Please make {(exclude_dims - signature.all_core_dims)} a core dimension" + ) + + # handle dask_gufunc_kwargs + if dask == "parallelized": + if dask_gufunc_kwargs is None: + dask_gufunc_kwargs = {} + else: + dask_gufunc_kwargs = dask_gufunc_kwargs.copy() + # todo: remove warnings after deprecation cycle + if meta is not None: + warnings.warn( + "``meta`` should be given in the ``dask_gufunc_kwargs`` parameter." + " It will be removed as direct parameter in a future version.", + FutureWarning, + stacklevel=2, + ) + dask_gufunc_kwargs.setdefault("meta", meta) + if output_sizes is not None: + warnings.warn( + "``output_sizes`` should be given in the ``dask_gufunc_kwargs`` " + "parameter. It will be removed as direct parameter in a future " + "version.", + FutureWarning, + stacklevel=2, + ) + dask_gufunc_kwargs.setdefault("output_sizes", output_sizes) + + if kwargs: + func = functools.partial(func, **kwargs) + + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=False) + + if isinstance(keep_attrs, bool): + keep_attrs = "override" if keep_attrs else "drop" + + variables_vfunc = functools.partial( + apply_variable_ufunc, + func, + signature=signature, + exclude_dims=exclude_dims, + keep_attrs=keep_attrs, + dask=dask, + vectorize=vectorize, + output_dtypes=output_dtypes, + dask_gufunc_kwargs=dask_gufunc_kwargs, + ) + + # feed groupby-apply_ufunc through apply_groupby_func + if any(isinstance(a, GroupBy) for a in args): + this_apply = functools.partial( + apply_ufunc, + func, + input_core_dims=input_core_dims, + output_core_dims=output_core_dims, + exclude_dims=exclude_dims, + join=join, + dataset_join=dataset_join, + dataset_fill_value=dataset_fill_value, + keep_attrs=keep_attrs, + dask=dask, + vectorize=vectorize, + output_dtypes=output_dtypes, + dask_gufunc_kwargs=dask_gufunc_kwargs, + ) + return apply_groupby_func(this_apply, *args) + # feed datasets apply_variable_ufunc through apply_dataset_vfunc + elif any(is_dict_like(a) for a in args): + return apply_dataset_vfunc( + variables_vfunc, + *args, + signature=signature, + join=join, + exclude_dims=exclude_dims, + dataset_join=dataset_join, + fill_value=dataset_fill_value, + keep_attrs=keep_attrs, + on_missing_core_dim=on_missing_core_dim, + ) + # feed DataArray apply_variable_ufunc through apply_dataarray_vfunc + elif any(isinstance(a, DataArray) for a in args): + return apply_dataarray_vfunc( + variables_vfunc, + *args, + signature=signature, + join=join, + exclude_dims=exclude_dims, + keep_attrs=keep_attrs, + ) + # feed Variables directly through apply_variable_ufunc + elif any(isinstance(a, Variable) for a in args): + return variables_vfunc(*args) + else: + # feed anything else through apply_array_ufunc + return apply_array_ufunc(func, *args, dask=dask) + + +def cov( + da_a: T_DataArray, + da_b: T_DataArray, + dim: Dims = None, + ddof: int = 1, + weights: T_DataArray | None = None, +) -> T_DataArray: + """ + Compute covariance between two DataArray objects along a shared dimension. + + Parameters + ---------- + da_a : DataArray + Array to compute. + da_b : DataArray + Array to compute. + dim : str, iterable of hashable, "..." or None, optional + The dimension along which the covariance will be computed + ddof : int, default: 1 + If ddof=1, covariance is normalized by N-1, giving an unbiased estimate, + else normalization is by N. + weights : DataArray, optional + Array of weights. + + Returns + ------- + covariance : DataArray + + See Also + -------- + pandas.Series.cov : corresponding pandas function + xarray.corr : respective function to calculate correlation + + Examples + -------- + >>> from xarray import DataArray + >>> da_a = DataArray( + ... np.array([[1, 2, 3], [0.1, 0.2, 0.3], [3.2, 0.6, 1.8]]), + ... dims=("space", "time"), + ... coords=[ + ... ("space", ["IA", "IL", "IN"]), + ... ("time", pd.date_range("2000-01-01", freq="1D", periods=3)), + ... ], + ... ) + >>> da_a + Size: 72B + array([[1. , 2. , 3. ], + [0.1, 0.2, 0.3], + [3.2, 0.6, 1.8]]) + Coordinates: + * space (space) >> da_b = DataArray( + ... np.array([[0.2, 0.4, 0.6], [15, 10, 5], [3.2, 0.6, 1.8]]), + ... dims=("space", "time"), + ... coords=[ + ... ("space", ["IA", "IL", "IN"]), + ... ("time", pd.date_range("2000-01-01", freq="1D", periods=3)), + ... ], + ... ) + >>> da_b + Size: 72B + array([[ 0.2, 0.4, 0.6], + [15. , 10. , 5. ], + [ 3.2, 0.6, 1.8]]) + Coordinates: + * space (space) >> xr.cov(da_a, da_b) + Size: 8B + array(-3.53055556) + >>> xr.cov(da_a, da_b, dim="time") + Size: 24B + array([ 0.2 , -0.5 , 1.69333333]) + Coordinates: + * space (space) >> weights = DataArray( + ... [4, 2, 1], + ... dims=("space"), + ... coords=[ + ... ("space", ["IA", "IL", "IN"]), + ... ], + ... ) + >>> weights + Size: 24B + array([4, 2, 1]) + Coordinates: + * space (space) >> xr.cov(da_a, da_b, dim="space", weights=weights) + Size: 24B + array([-4.69346939, -4.49632653, -3.37959184]) + Coordinates: + * time (time) datetime64[ns] 24B 2000-01-01 2000-01-02 2000-01-03 + """ + from xarray.core.dataarray import DataArray + + if any(not isinstance(arr, DataArray) for arr in [da_a, da_b]): + raise TypeError( + "Only xr.DataArray is supported." + f"Given {[type(arr) for arr in [da_a, da_b]]}." + ) + if weights is not None: + if not isinstance(weights, DataArray): + raise TypeError(f"Only xr.DataArray is supported. Given {type(weights)}.") + return _cov_corr(da_a, da_b, weights=weights, dim=dim, ddof=ddof, method="cov") + + +def corr( + da_a: T_DataArray, + da_b: T_DataArray, + dim: Dims = None, + weights: T_DataArray | None = None, +) -> T_DataArray: + """ + Compute the Pearson correlation coefficient between + two DataArray objects along a shared dimension. + + Parameters + ---------- + da_a : DataArray + Array to compute. + da_b : DataArray + Array to compute. + dim : str, iterable of hashable, "..." or None, optional + The dimension along which the correlation will be computed + weights : DataArray, optional + Array of weights. + + Returns + ------- + correlation: DataArray + + See Also + -------- + pandas.Series.corr : corresponding pandas function + xarray.cov : underlying covariance function + + Examples + -------- + >>> from xarray import DataArray + >>> da_a = DataArray( + ... np.array([[1, 2, 3], [0.1, 0.2, 0.3], [3.2, 0.6, 1.8]]), + ... dims=("space", "time"), + ... coords=[ + ... ("space", ["IA", "IL", "IN"]), + ... ("time", pd.date_range("2000-01-01", freq="1D", periods=3)), + ... ], + ... ) + >>> da_a + Size: 72B + array([[1. , 2. , 3. ], + [0.1, 0.2, 0.3], + [3.2, 0.6, 1.8]]) + Coordinates: + * space (space) >> da_b = DataArray( + ... np.array([[0.2, 0.4, 0.6], [15, 10, 5], [3.2, 0.6, 1.8]]), + ... dims=("space", "time"), + ... coords=[ + ... ("space", ["IA", "IL", "IN"]), + ... ("time", pd.date_range("2000-01-01", freq="1D", periods=3)), + ... ], + ... ) + >>> da_b + Size: 72B + array([[ 0.2, 0.4, 0.6], + [15. , 10. , 5. ], + [ 3.2, 0.6, 1.8]]) + Coordinates: + * space (space) >> xr.corr(da_a, da_b) + Size: 8B + array(-0.57087777) + >>> xr.corr(da_a, da_b, dim="time") + Size: 24B + array([ 1., -1., 1.]) + Coordinates: + * space (space) >> weights = DataArray( + ... [4, 2, 1], + ... dims=("space"), + ... coords=[ + ... ("space", ["IA", "IL", "IN"]), + ... ], + ... ) + >>> weights + Size: 24B + array([4, 2, 1]) + Coordinates: + * space (space) >> xr.corr(da_a, da_b, dim="space", weights=weights) + Size: 24B + array([-0.50240504, -0.83215028, -0.99057446]) + Coordinates: + * time (time) datetime64[ns] 24B 2000-01-01 2000-01-02 2000-01-03 + """ + from xarray.core.dataarray import DataArray + + if any(not isinstance(arr, DataArray) for arr in [da_a, da_b]): + raise TypeError( + "Only xr.DataArray is supported." + f"Given {[type(arr) for arr in [da_a, da_b]]}." + ) + if weights is not None: + if not isinstance(weights, DataArray): + raise TypeError(f"Only xr.DataArray is supported. Given {type(weights)}.") + return _cov_corr(da_a, da_b, weights=weights, dim=dim, method="corr") + + +def _cov_corr( + da_a: T_DataArray, + da_b: T_DataArray, + weights: T_DataArray | None = None, + dim: Dims = None, + ddof: int = 0, + method: Literal["cov", "corr", None] = None, +) -> T_DataArray: + """ + Internal method for xr.cov() and xr.corr() so only have to + sanitize the input arrays once and we don't repeat code. + """ + # 1. Broadcast the two arrays + da_a, da_b = align(da_a, da_b, join="inner", copy=False) + + # 2. Ignore the nans + valid_values = da_a.notnull() & da_b.notnull() + da_a = da_a.where(valid_values) + da_b = da_b.where(valid_values) + + # 3. Detrend along the given dim + if weights is not None: + demeaned_da_a = da_a - da_a.weighted(weights).mean(dim=dim) + demeaned_da_b = da_b - da_b.weighted(weights).mean(dim=dim) + else: + demeaned_da_a = da_a - da_a.mean(dim=dim) + demeaned_da_b = da_b - da_b.mean(dim=dim) + + # 4. Compute covariance along the given dim + # N.B. `skipna=True` is required or auto-covariance is computed incorrectly. E.g. + # Try xr.cov(da,da) for da = xr.DataArray([[1, 2], [1, np.nan]], dims=["x", "time"]) + if weights is not None: + cov = ( + (demeaned_da_a.conj() * demeaned_da_b) + .weighted(weights) + .mean(dim=dim, skipna=True) + ) + else: + cov = (demeaned_da_a.conj() * demeaned_da_b).mean(dim=dim, skipna=True) + + if method == "cov": + # Adjust covariance for degrees of freedom + valid_count = valid_values.sum(dim) + adjust = valid_count / (valid_count - ddof) + # I think the cast is required because of `T_DataArray` + `T_Xarray` (would be + # the same with `T_DatasetOrArray`) + # https://github.com/pydata/xarray/pull/8384#issuecomment-1784228026 + return cast(T_DataArray, cov * adjust) + + else: + # Compute std and corr + if weights is not None: + da_a_std = da_a.weighted(weights).std(dim=dim) + da_b_std = da_b.weighted(weights).std(dim=dim) + else: + da_a_std = da_a.std(dim=dim) + da_b_std = da_b.std(dim=dim) + corr = cov / (da_a_std * da_b_std) + return cast(T_DataArray, corr) + + +def cross( + a: DataArray | Variable, b: DataArray | Variable, *, dim: Hashable +) -> DataArray | Variable: + """ + Compute the cross product of two (arrays of) vectors. + + The cross product of `a` and `b` in :math:`R^3` is a vector + perpendicular to both `a` and `b`. The vectors in `a` and `b` are + defined by the values along the dimension `dim` and can have sizes + 1, 2 or 3. Where the size of either `a` or `b` is + 1 or 2, the remaining components of the input vector is assumed to + be zero and the cross product calculated accordingly. In cases where + both input vectors have dimension 2, the z-component of the cross + product is returned. + + Parameters + ---------- + a, b : DataArray or Variable + Components of the first and second vector(s). + dim : hashable + The dimension along which the cross product will be computed. + Must be available in both vectors. + + Examples + -------- + Vector cross-product with 3 dimensions: + + >>> a = xr.DataArray([1, 2, 3]) + >>> b = xr.DataArray([4, 5, 6]) + >>> xr.cross(a, b, dim="dim_0") + Size: 24B + array([-3, 6, -3]) + Dimensions without coordinates: dim_0 + + Vector cross-product with 2 dimensions, returns in the perpendicular + direction: + + >>> a = xr.DataArray([1, 2]) + >>> b = xr.DataArray([4, 5]) + >>> xr.cross(a, b, dim="dim_0") + Size: 8B + array(-3) + + Vector cross-product with 3 dimensions but zeros at the last axis + yields the same results as with 2 dimensions: + + >>> a = xr.DataArray([1, 2, 0]) + >>> b = xr.DataArray([4, 5, 0]) + >>> xr.cross(a, b, dim="dim_0") + Size: 24B + array([ 0, 0, -3]) + Dimensions without coordinates: dim_0 + + One vector with dimension 2: + + >>> a = xr.DataArray( + ... [1, 2], + ... dims=["cartesian"], + ... coords=dict(cartesian=(["cartesian"], ["x", "y"])), + ... ) + >>> b = xr.DataArray( + ... [4, 5, 6], + ... dims=["cartesian"], + ... coords=dict(cartesian=(["cartesian"], ["x", "y", "z"])), + ... ) + >>> xr.cross(a, b, dim="cartesian") + Size: 24B + array([12, -6, -3]) + Coordinates: + * cartesian (cartesian) >> a = xr.DataArray( + ... [1, 2], + ... dims=["cartesian"], + ... coords=dict(cartesian=(["cartesian"], ["x", "z"])), + ... ) + >>> b = xr.DataArray( + ... [4, 5, 6], + ... dims=["cartesian"], + ... coords=dict(cartesian=(["cartesian"], ["x", "y", "z"])), + ... ) + >>> xr.cross(a, b, dim="cartesian") + Size: 24B + array([-10, 2, 5]) + Coordinates: + * cartesian (cartesian) >> a = xr.DataArray( + ... [[1, 2, 3], [4, 5, 6]], + ... dims=("time", "cartesian"), + ... coords=dict( + ... time=(["time"], [0, 1]), + ... cartesian=(["cartesian"], ["x", "y", "z"]), + ... ), + ... ) + >>> b = xr.DataArray( + ... [[4, 5, 6], [1, 2, 3]], + ... dims=("time", "cartesian"), + ... coords=dict( + ... time=(["time"], [0, 1]), + ... cartesian=(["cartesian"], ["x", "y", "z"]), + ... ), + ... ) + >>> xr.cross(a, b, dim="cartesian") + Size: 48B + array([[-3, 6, -3], + [ 3, -6, 3]]) + Coordinates: + * time (time) int64 16B 0 1 + * cartesian (cartesian) >> ds_a = xr.Dataset(dict(x=("dim_0", [1]), y=("dim_0", [2]), z=("dim_0", [3]))) + >>> ds_b = xr.Dataset(dict(x=("dim_0", [4]), y=("dim_0", [5]), z=("dim_0", [6]))) + >>> c = xr.cross( + ... ds_a.to_dataarray("cartesian"), + ... ds_b.to_dataarray("cartesian"), + ... dim="cartesian", + ... ) + >>> c.to_dataset(dim="cartesian") + Size: 24B + Dimensions: (dim_0: 1) + Dimensions without coordinates: dim_0 + Data variables: + x (dim_0) int64 8B -3 + y (dim_0) int64 8B 6 + z (dim_0) int64 8B -3 + + See Also + -------- + numpy.cross : Corresponding numpy function + """ + + if dim not in a.dims: + raise ValueError(f"Dimension {dim!r} not on a") + elif dim not in b.dims: + raise ValueError(f"Dimension {dim!r} not on b") + + if not 1 <= a.sizes[dim] <= 3: + raise ValueError( + f"The size of {dim!r} on a must be 1, 2, or 3 to be " + f"compatible with a cross product but is {a.sizes[dim]}" + ) + elif not 1 <= b.sizes[dim] <= 3: + raise ValueError( + f"The size of {dim!r} on b must be 1, 2, or 3 to be " + f"compatible with a cross product but is {b.sizes[dim]}" + ) + + all_dims = list(dict.fromkeys(a.dims + b.dims)) + + if a.sizes[dim] != b.sizes[dim]: + # Arrays have different sizes. Append zeros where the smaller + # array is missing a value, zeros will not affect np.cross: + + if ( + not isinstance(a, Variable) # Only used to make mypy happy. + and dim in getattr(a, "coords", {}) + and not isinstance(b, Variable) # Only used to make mypy happy. + and dim in getattr(b, "coords", {}) + ): + # If the arrays have coords we know which indexes to fill + # with zeros: + a, b = align( + a, + b, + fill_value=0, + join="outer", + exclude=set(all_dims) - {dim}, + ) + elif min(a.sizes[dim], b.sizes[dim]) == 2: + # If the array doesn't have coords we can only infer + # that it has composite values if the size is at least 2. + # Once padded, rechunk the padded array because apply_ufunc + # requires core dimensions not to be chunked: + if a.sizes[dim] < b.sizes[dim]: + a = a.pad({dim: (0, 1)}, constant_values=0) + # TODO: Should pad or apply_ufunc handle correct chunking? + a = a.chunk({dim: -1}) if is_duck_dask_array(a.data) else a + else: + b = b.pad({dim: (0, 1)}, constant_values=0) + # TODO: Should pad or apply_ufunc handle correct chunking? + b = b.chunk({dim: -1}) if is_duck_dask_array(b.data) else b + else: + raise ValueError( + f"{dim!r} on {'a' if a.sizes[dim] == 1 else 'b'} is incompatible:" + " dimensions without coordinates must have have a length of 2 or 3" + ) + + c = apply_ufunc( + np.cross, + a, + b, + input_core_dims=[[dim], [dim]], + output_core_dims=[[dim] if a.sizes[dim] == 3 else []], + dask="parallelized", + output_dtypes=[np.result_type(a, b)], + ) + c = c.transpose(*all_dims, missing_dims="ignore") + + return c + + +@deprecate_dims +def dot( + *arrays, + dim: Dims = None, + **kwargs: Any, +): + """Generalized dot product for xarray objects. Like ``np.einsum``, but + provides a simpler interface based on array dimension names. + + Parameters + ---------- + *arrays : DataArray or Variable + Arrays to compute. + dim : str, iterable of hashable, "..." or None, optional + Which dimensions to sum over. Ellipsis ('...') sums over all dimensions. + If not specified, then all the common dimensions are summed over. + **kwargs : dict + Additional keyword arguments passed to ``numpy.einsum`` or + ``dask.array.einsum`` + + Returns + ------- + DataArray + + See Also + -------- + numpy.einsum + dask.array.einsum + opt_einsum.contract + + Notes + ----- + We recommend installing the optional ``opt_einsum`` package, or alternatively passing ``optimize=True``, + which is passed through to ``np.einsum``, and works for most array backends. + + Examples + -------- + >>> da_a = xr.DataArray(np.arange(3 * 2).reshape(3, 2), dims=["a", "b"]) + >>> da_b = xr.DataArray(np.arange(3 * 2 * 2).reshape(3, 2, 2), dims=["a", "b", "c"]) + >>> da_c = xr.DataArray(np.arange(2 * 3).reshape(2, 3), dims=["c", "d"]) + + >>> da_a + Size: 48B + array([[0, 1], + [2, 3], + [4, 5]]) + Dimensions without coordinates: a, b + + >>> da_b + Size: 96B + array([[[ 0, 1], + [ 2, 3]], + + [[ 4, 5], + [ 6, 7]], + + [[ 8, 9], + [10, 11]]]) + Dimensions without coordinates: a, b, c + + >>> da_c + Size: 48B + array([[0, 1, 2], + [3, 4, 5]]) + Dimensions without coordinates: c, d + + >>> xr.dot(da_a, da_b, dim=["a", "b"]) + Size: 16B + array([110, 125]) + Dimensions without coordinates: c + + >>> xr.dot(da_a, da_b, dim=["a"]) + Size: 32B + array([[40, 46], + [70, 79]]) + Dimensions without coordinates: b, c + + >>> xr.dot(da_a, da_b, da_c, dim=["b", "c"]) + Size: 72B + array([[ 9, 14, 19], + [ 93, 150, 207], + [273, 446, 619]]) + Dimensions without coordinates: a, d + + >>> xr.dot(da_a, da_b) + Size: 16B + array([110, 125]) + Dimensions without coordinates: c + + >>> xr.dot(da_a, da_b, dim=...) + Size: 8B + array(235) + """ + from xarray.core.dataarray import DataArray + from xarray.core.variable import Variable + + if any(not isinstance(arr, (Variable, DataArray)) for arr in arrays): + raise TypeError( + "Only xr.DataArray and xr.Variable are supported." + f"Given {[type(arr) for arr in arrays]}." + ) + + if len(arrays) == 0: + raise TypeError("At least one array should be given.") + + common_dims: set[Hashable] = set.intersection(*(set(arr.dims) for arr in arrays)) + all_dims = [] + for arr in arrays: + all_dims += [d for d in arr.dims if d not in all_dims] + + einsum_axes = "abcdefghijklmnopqrstuvwxyz" + dim_map = {d: einsum_axes[i] for i, d in enumerate(all_dims)} + + if dim is None: + # find dimensions that occur more than once + dim_counts: Counter = Counter() + for arr in arrays: + dim_counts.update(arr.dims) + dim = tuple(d for d, c in dim_counts.items() if c > 1) + else: + dim = parse_dims(dim, all_dims=tuple(all_dims)) + + dot_dims: set[Hashable] = set(dim) + + # dimensions to be parallelized + broadcast_dims = common_dims - dot_dims + input_core_dims = [ + [d for d in arr.dims if d not in broadcast_dims] for arr in arrays + ] + output_core_dims = [ + [d for d in all_dims if d not in dot_dims and d not in broadcast_dims] + ] + + # construct einsum subscripts, such as '...abc,...ab->...c' + # Note: input_core_dims are always moved to the last position + subscripts_list = [ + "..." + "".join(dim_map[d] for d in ds) for ds in input_core_dims + ] + subscripts = ",".join(subscripts_list) + subscripts += "->..." + "".join(dim_map[d] for d in output_core_dims[0]) + + join = OPTIONS["arithmetic_join"] + # using "inner" emulates `(a * b).sum()` for all joins (except "exact") + if join != "exact": + join = "inner" + + # subscripts should be passed to np.einsum as arg, not as kwargs. We need + # to construct a partial function for apply_ufunc to work. + func = functools.partial(duck_array_ops.einsum, subscripts, **kwargs) + result = apply_ufunc( + func, + *arrays, + input_core_dims=input_core_dims, + output_core_dims=output_core_dims, + join=join, + dask="allowed", + ) + return result.transpose(*all_dims, missing_dims="ignore") + + +def where(cond, x, y, keep_attrs=None): + """Return elements from `x` or `y` depending on `cond`. + + Performs xarray-like broadcasting across input arguments. + + All dimension coordinates on `x` and `y` must be aligned with each + other and with `cond`. + + Parameters + ---------- + cond : scalar, array, Variable, DataArray or Dataset + When True, return values from `x`, otherwise returns values from `y`. + x : scalar, array, Variable, DataArray or Dataset + values to choose from where `cond` is True + y : scalar, array, Variable, DataArray or Dataset + values to choose from where `cond` is False + keep_attrs : bool or str or callable, optional + How to treat attrs. If True, keep the attrs of `x`. + + Returns + ------- + Dataset, DataArray, Variable or array + In priority order: Dataset, DataArray, Variable or array, whichever + type appears as an input argument. + + Examples + -------- + >>> x = xr.DataArray( + ... 0.1 * np.arange(10), + ... dims=["lat"], + ... coords={"lat": np.arange(10)}, + ... name="sst", + ... ) + >>> x + Size: 80B + array([0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]) + Coordinates: + * lat (lat) int64 80B 0 1 2 3 4 5 6 7 8 9 + + >>> xr.where(x < 0.5, x, x * 100) + Size: 80B + array([ 0. , 0.1, 0.2, 0.3, 0.4, 50. , 60. , 70. , 80. , 90. ]) + Coordinates: + * lat (lat) int64 80B 0 1 2 3 4 5 6 7 8 9 + + >>> y = xr.DataArray( + ... 0.1 * np.arange(9).reshape(3, 3), + ... dims=["lat", "lon"], + ... coords={"lat": np.arange(3), "lon": 10 + np.arange(3)}, + ... name="sst", + ... ) + >>> y + Size: 72B + array([[0. , 0.1, 0.2], + [0.3, 0.4, 0.5], + [0.6, 0.7, 0.8]]) + Coordinates: + * lat (lat) int64 24B 0 1 2 + * lon (lon) int64 24B 10 11 12 + + >>> xr.where(y.lat < 1, y, -1) + Size: 72B + array([[ 0. , 0.1, 0.2], + [-1. , -1. , -1. ], + [-1. , -1. , -1. ]]) + Coordinates: + * lat (lat) int64 24B 0 1 2 + * lon (lon) int64 24B 10 11 12 + + >>> cond = xr.DataArray([True, False], dims=["x"]) + >>> x = xr.DataArray([1, 2], dims=["y"]) + >>> xr.where(cond, x, 0) + Size: 32B + array([[1, 2], + [0, 0]]) + Dimensions without coordinates: x, y + + See Also + -------- + numpy.where : corresponding numpy function + Dataset.where, DataArray.where : + equivalent methods + """ + from xarray.core.dataset import Dataset + + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=False) + + # alignment for three arguments is complicated, so don't support it yet + result = apply_ufunc( + duck_array_ops.where, + cond, + x, + y, + join="exact", + dataset_join="exact", + dask="allowed", + keep_attrs=keep_attrs, + ) + + # keep the attributes of x, the second parameter, by default to + # be consistent with the `where` method of `DataArray` and `Dataset` + # rebuild the attrs from x at each level of the output, which could be + # Dataset, DataArray, or Variable, and also handle coords + if keep_attrs is True and hasattr(result, "attrs"): + if isinstance(y, Dataset) and not isinstance(x, Dataset): + # handle special case where x gets promoted to Dataset + result.attrs = {} + if getattr(x, "name", None) in result.data_vars: + result[x.name].attrs = getattr(x, "attrs", {}) + else: + # otherwise, fill in global attrs and variable attrs (if they exist) + result.attrs = getattr(x, "attrs", {}) + for v in getattr(result, "data_vars", []): + result[v].attrs = getattr(getattr(x, v, None), "attrs", {}) + for c in getattr(result, "coords", []): + # always fill coord attrs of x + result[c].attrs = getattr(getattr(x, c, None), "attrs", {}) + + return result + + +@overload +def polyval( + coord: DataArray, coeffs: DataArray, degree_dim: Hashable = "degree" +) -> DataArray: ... + + +@overload +def polyval( + coord: DataArray, coeffs: Dataset, degree_dim: Hashable = "degree" +) -> Dataset: ... + + +@overload +def polyval( + coord: Dataset, coeffs: DataArray, degree_dim: Hashable = "degree" +) -> Dataset: ... + + +@overload +def polyval( + coord: Dataset, coeffs: Dataset, degree_dim: Hashable = "degree" +) -> Dataset: ... + + +@overload +def polyval( + coord: Dataset | DataArray, + coeffs: Dataset | DataArray, + degree_dim: Hashable = "degree", +) -> Dataset | DataArray: ... + + +def polyval( + coord: Dataset | DataArray, + coeffs: Dataset | DataArray, + degree_dim: Hashable = "degree", +) -> Dataset | DataArray: + """Evaluate a polynomial at specific values + + Parameters + ---------- + coord : DataArray or Dataset + Values at which to evaluate the polynomial. + coeffs : DataArray or Dataset + Coefficients of the polynomial. + degree_dim : Hashable, default: "degree" + Name of the polynomial degree dimension in `coeffs`. + + Returns + ------- + DataArray or Dataset + Evaluated polynomial. + + See Also + -------- + xarray.DataArray.polyfit + numpy.polynomial.polynomial.polyval + """ + + if degree_dim not in coeffs._indexes: + raise ValueError( + f"Dimension `{degree_dim}` should be a coordinate variable with labels." + ) + if not np.issubdtype(coeffs[degree_dim].dtype, np.integer): + raise ValueError( + f"Dimension `{degree_dim}` should be of integer dtype. Received {coeffs[degree_dim].dtype} instead." + ) + max_deg = coeffs[degree_dim].max().item() + coeffs = coeffs.reindex( + {degree_dim: np.arange(max_deg + 1)}, fill_value=0, copy=False + ) + coord = _ensure_numeric(coord) + + # using Horner's method + # https://en.wikipedia.org/wiki/Horner%27s_method + res = zeros_like(coord) + coeffs.isel({degree_dim: max_deg}, drop=True) + for deg in range(max_deg - 1, -1, -1): + res *= coord + res += coeffs.isel({degree_dim: deg}, drop=True) + + return res + + +def _ensure_numeric(data: Dataset | DataArray) -> Dataset | DataArray: + """Converts all datetime64 variables to float64 + + Parameters + ---------- + data : DataArray or Dataset + Variables with possible datetime dtypes. + + Returns + ------- + DataArray or Dataset + Variables with datetime64 dtypes converted to float64. + """ + from xarray.core.dataset import Dataset + + def _cfoffset(x: DataArray) -> Any: + scalar = x.compute().data[0] + if not is_scalar(scalar): + # we do not get a scalar back on dask == 2021.04.1 + scalar = scalar.item() + return type(scalar)(1970, 1, 1) + + def to_floatable(x: DataArray) -> DataArray: + if x.dtype.kind in "MO": + # datetimes (CFIndexes are object type) + offset = ( + np.datetime64("1970-01-01") if x.dtype.kind == "M" else _cfoffset(x) + ) + return x.copy( + data=datetime_to_numeric(x.data, offset=offset, datetime_unit="ns"), + ) + elif x.dtype.kind == "m": + # timedeltas + return duck_array_ops.astype(x, dtype=float) + return x + + if isinstance(data, Dataset): + return data.map(to_floatable) + else: + return to_floatable(data) + + +def _calc_idxminmax( + *, + array, + func: Callable, + dim: Hashable | None = None, + skipna: bool | None = None, + fill_value: Any = dtypes.NA, + keep_attrs: bool | None = None, +): + """Apply common operations for idxmin and idxmax.""" + # This function doesn't make sense for scalars so don't try + if not array.ndim: + raise ValueError("This function does not apply for scalars") + + if dim is not None: + pass # Use the dim if available + elif array.ndim == 1: + # it is okay to guess the dim if there is only 1 + dim = array.dims[0] + else: + # The dim is not specified and ambiguous. Don't guess. + raise ValueError("Must supply 'dim' argument for multidimensional arrays") + + if dim not in array.dims: + raise KeyError( + f"Dimension {dim!r} not found in array dimensions {array.dims!r}" + ) + if dim not in array.coords: + raise KeyError( + f"Dimension {dim!r} is not one of the coordinates {tuple(array.coords.keys())}" + ) + + # These are dtypes with NaN values argmin and argmax can handle + na_dtypes = "cfO" + + if skipna or (skipna is None and array.dtype.kind in na_dtypes): + # Need to skip NaN values since argmin and argmax can't handle them + allna = array.isnull().all(dim) + array = array.where(~allna, 0) + + # This will run argmin or argmax. + indx = func(array, dim=dim, axis=None, keep_attrs=keep_attrs, skipna=skipna) + + # Handle chunked arrays (e.g. dask). + if is_chunked_array(array.data): + chunkmanager = get_chunked_array_type(array.data) + chunks = dict(zip(array.dims, array.chunks)) + dask_coord = chunkmanager.from_array(array[dim].data, chunks=chunks[dim]) + data = dask_coord[duck_array_ops.ravel(indx.data)] + res = indx.copy(data=duck_array_ops.reshape(data, indx.shape)) + # we need to attach back the dim name + res.name = dim + else: + res = array[dim][(indx,)] + # The dim is gone but we need to remove the corresponding coordinate. + del res.coords[dim] + + if skipna or (skipna is None and array.dtype.kind in na_dtypes): + # Put the NaN values back in after removing them + res = res.where(~allna, fill_value) + + # Copy attributes from argmin/argmax, if any + res.attrs = indx.attrs + + return res + + +_T = TypeVar("_T", bound=Union["Dataset", "DataArray"]) +_U = TypeVar("_U", bound=Union["Dataset", "DataArray"]) +_V = TypeVar("_V", bound=Union["Dataset", "DataArray"]) + + +@overload +def unify_chunks(__obj: _T) -> tuple[_T]: ... + + +@overload +def unify_chunks(__obj1: _T, __obj2: _U) -> tuple[_T, _U]: ... + + +@overload +def unify_chunks(__obj1: _T, __obj2: _U, __obj3: _V) -> tuple[_T, _U, _V]: ... + + +@overload +def unify_chunks(*objects: Dataset | DataArray) -> tuple[Dataset | DataArray, ...]: ... + + +def unify_chunks(*objects: Dataset | DataArray) -> tuple[Dataset | DataArray, ...]: + """ + Given any number of Dataset and/or DataArray objects, returns + new objects with unified chunk size along all chunked dimensions. + + Returns + ------- + unified (DataArray or Dataset) – Tuple of objects with the same type as + *objects with consistent chunk sizes for all dask-array variables + + See Also + -------- + dask.array.core.unify_chunks + """ + from xarray.core.dataarray import DataArray + + # Convert all objects to datasets + datasets = [ + obj._to_temp_dataset() if isinstance(obj, DataArray) else obj.copy() + for obj in objects + ] + + # Get arguments to pass into dask.array.core.unify_chunks + unify_chunks_args = [] + sizes: dict[Hashable, int] = {} + for ds in datasets: + for v in ds._variables.values(): + if v.chunks is not None: + # Check that sizes match across different datasets + for dim, size in v.sizes.items(): + try: + if sizes[dim] != size: + raise ValueError( + f"Dimension {dim!r} size mismatch: {sizes[dim]} != {size}" + ) + except KeyError: + sizes[dim] = size + unify_chunks_args += [v._data, v._dims] + + # No dask arrays: Return inputs + if not unify_chunks_args: + return objects + + chunkmanager = get_chunked_array_type(*[arg for arg in unify_chunks_args]) + _, chunked_data = chunkmanager.unify_chunks(*unify_chunks_args) + chunked_data_iter = iter(chunked_data) + out: list[Dataset | DataArray] = [] + for obj, ds in zip(objects, datasets): + for k, v in ds._variables.items(): + if v.chunks is not None: + ds._variables[k] = v.copy(data=next(chunked_data_iter)) + out.append(obj._from_temp_dataset(ds) if isinstance(obj, DataArray) else ds) + + return tuple(out) diff --git a/test/fixtures/whole_applications/xarray/xarray/core/concat.py b/test/fixtures/whole_applications/xarray/xarray/core/concat.py new file mode 100644 index 0000000..b1cca58 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/core/concat.py @@ -0,0 +1,766 @@ +from __future__ import annotations + +from collections.abc import Hashable, Iterable +from typing import TYPE_CHECKING, Any, Union, overload + +import numpy as np +import pandas as pd + +from xarray.core import dtypes, utils +from xarray.core.alignment import align, reindex_variables +from xarray.core.coordinates import Coordinates +from xarray.core.duck_array_ops import lazy_array_equiv +from xarray.core.indexes import Index, PandasIndex +from xarray.core.merge import ( + _VALID_COMPAT, + collect_variables_and_indexes, + merge_attrs, + merge_collected, +) +from xarray.core.types import T_DataArray, T_Dataset, T_Variable +from xarray.core.variable import Variable +from xarray.core.variable import concat as concat_vars + +if TYPE_CHECKING: + from xarray.core.types import ( + CombineAttrsOptions, + CompatOptions, + ConcatOptions, + JoinOptions, + ) + + T_DataVars = Union[ConcatOptions, Iterable[Hashable]] + + +@overload +def concat( + objs: Iterable[T_Dataset], + dim: Hashable | T_Variable | T_DataArray | pd.Index, + data_vars: T_DataVars = "all", + coords: ConcatOptions | list[Hashable] = "different", + compat: CompatOptions = "equals", + positions: Iterable[Iterable[int]] | None = None, + fill_value: object = dtypes.NA, + join: JoinOptions = "outer", + combine_attrs: CombineAttrsOptions = "override", + create_index_for_new_dim: bool = True, +) -> T_Dataset: ... + + +@overload +def concat( + objs: Iterable[T_DataArray], + dim: Hashable | T_Variable | T_DataArray | pd.Index, + data_vars: T_DataVars = "all", + coords: ConcatOptions | list[Hashable] = "different", + compat: CompatOptions = "equals", + positions: Iterable[Iterable[int]] | None = None, + fill_value: object = dtypes.NA, + join: JoinOptions = "outer", + combine_attrs: CombineAttrsOptions = "override", + create_index_for_new_dim: bool = True, +) -> T_DataArray: ... + + +def concat( + objs, + dim, + data_vars: T_DataVars = "all", + coords="different", + compat: CompatOptions = "equals", + positions=None, + fill_value=dtypes.NA, + join: JoinOptions = "outer", + combine_attrs: CombineAttrsOptions = "override", + create_index_for_new_dim: bool = True, +): + """Concatenate xarray objects along a new or existing dimension. + + Parameters + ---------- + objs : sequence of Dataset and DataArray + xarray objects to concatenate together. Each object is expected to + consist of variables and coordinates with matching shapes except for + along the concatenated dimension. + dim : Hashable or Variable or DataArray or pandas.Index + Name of the dimension to concatenate along. This can either be a new + dimension name, in which case it is added along axis=0, or an existing + dimension name, in which case the location of the dimension is + unchanged. If dimension is provided as a Variable, DataArray or Index, its name + is used as the dimension to concatenate along and the values are added + as a coordinate. + data_vars : {"minimal", "different", "all"} or list of Hashable, optional + These data variables will be concatenated together: + * "minimal": Only data variables in which the dimension already + appears are included. + * "different": Data variables which are not equal (ignoring + attributes) across all datasets are also concatenated (as well as + all for which dimension already appears). Beware: this option may + load the data payload of data variables into memory if they are not + already loaded. + * "all": All data variables will be concatenated. + * list of dims: The listed data variables will be concatenated, in + addition to the "minimal" data variables. + + If objects are DataArrays, data_vars must be "all". + coords : {"minimal", "different", "all"} or list of Hashable, optional + These coordinate variables will be concatenated together: + * "minimal": Only coordinates in which the dimension already appears + are included. + * "different": Coordinates which are not equal (ignoring attributes) + across all datasets are also concatenated (as well as all for which + dimension already appears). Beware: this option may load the data + payload of coordinate variables into memory if they are not already + loaded. + * "all": All coordinate variables will be concatenated, except + those corresponding to other dimensions. + * list of Hashable: The listed coordinate variables will be concatenated, + in addition to the "minimal" coordinates. + compat : {"identical", "equals", "broadcast_equals", "no_conflicts", "override"}, optional + String indicating how to compare non-concatenated variables of the same name for + potential conflicts. This is passed down to merge. + + - "broadcast_equals": all values must be equal when variables are + broadcast against each other to ensure common dimensions. + - "equals": all values and dimensions must be the same. + - "identical": all values, dimensions and attributes must be the + same. + - "no_conflicts": only values which are not null in both datasets + must be equal. The returned dataset then contains the combination + of all non-null values. + - "override": skip comparing and pick variable from first dataset + positions : None or list of integer arrays, optional + List of integer arrays which specifies the integer positions to which + to assign each dataset along the concatenated dimension. If not + supplied, objects are concatenated in the provided order. + fill_value : scalar or dict-like, optional + Value to use for newly missing values. If a dict-like, maps + variable names to fill values. Use a data array's name to + refer to its values. + join : {"outer", "inner", "left", "right", "exact"}, optional + String indicating how to combine differing indexes + (excluding dim) in objects + + - "outer": use the union of object indexes + - "inner": use the intersection of object indexes + - "left": use indexes from the first object with each dimension + - "right": use indexes from the last object with each dimension + - "exact": instead of aligning, raise `ValueError` when indexes to be + aligned are not equal + - "override": if indexes are of same size, rewrite indexes to be + those of the first object with that dimension. Indexes for the same + dimension must have the same size in all objects. + combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \ + "override"} or callable, default: "override" + A callable or a string indicating how to combine attrs of the objects being + merged: + + - "drop": empty attrs on returned Dataset. + - "identical": all attrs must be the same on every object. + - "no_conflicts": attrs from all objects are combined, any that have + the same name must also have the same value. + - "drop_conflicts": attrs from all objects are combined, any that have + the same name but different values are dropped. + - "override": skip comparing and copy attrs from the first dataset to + the result. + + If a callable, it must expect a sequence of ``attrs`` dicts and a context object + as its only parameters. + create_index_for_new_dim : bool, default: True + Whether to create a new ``PandasIndex`` object when the objects being concatenated contain scalar variables named ``dim``. + + Returns + ------- + concatenated : type of objs + + See also + -------- + merge + + Examples + -------- + >>> da = xr.DataArray( + ... np.arange(6).reshape(2, 3), [("x", ["a", "b"]), ("y", [10, 20, 30])] + ... ) + >>> da + Size: 48B + array([[0, 1, 2], + [3, 4, 5]]) + Coordinates: + * x (x) >> xr.concat([da.isel(y=slice(0, 1)), da.isel(y=slice(1, None))], dim="y") + Size: 48B + array([[0, 1, 2], + [3, 4, 5]]) + Coordinates: + * x (x) >> xr.concat([da.isel(x=0), da.isel(x=1)], "x") + Size: 48B + array([[0, 1, 2], + [3, 4, 5]]) + Coordinates: + * x (x) >> xr.concat([da.isel(x=0), da.isel(x=1)], "new_dim") + Size: 48B + array([[0, 1, 2], + [3, 4, 5]]) + Coordinates: + x (new_dim) >> xr.concat([da.isel(x=0), da.isel(x=1)], pd.Index([-90, -100], name="new_dim")) + Size: 48B + array([[0, 1, 2], + [3, 4, 5]]) + Coordinates: + x (new_dim) >> ds = xr.Dataset(coords={"x": 0}) + >>> xr.concat([ds, ds], dim="x") + Size: 16B + Dimensions: (x: 2) + Coordinates: + * x (x) int64 16B 0 0 + Data variables: + *empty* + + >>> xr.concat([ds, ds], dim="x").indexes + Indexes: + x Index([0, 0], dtype='int64', name='x') + + >>> xr.concat([ds, ds], dim="x", create_index_for_new_dim=False).indexes + Indexes: + *empty* + """ + # TODO: add ignore_index arguments copied from pandas.concat + # TODO: support concatenating scalar coordinates even if the concatenated + # dimension already exists + from xarray.core.dataarray import DataArray + from xarray.core.dataset import Dataset + + try: + first_obj, objs = utils.peek_at(objs) + except StopIteration: + raise ValueError("must supply at least one object to concatenate") + + if compat not in _VALID_COMPAT: + raise ValueError( + f"compat={compat!r} invalid: must be 'broadcast_equals', 'equals', 'identical', 'no_conflicts' or 'override'" + ) + + if isinstance(first_obj, DataArray): + return _dataarray_concat( + objs, + dim=dim, + data_vars=data_vars, + coords=coords, + compat=compat, + positions=positions, + fill_value=fill_value, + join=join, + combine_attrs=combine_attrs, + create_index_for_new_dim=create_index_for_new_dim, + ) + elif isinstance(first_obj, Dataset): + return _dataset_concat( + objs, + dim=dim, + data_vars=data_vars, + coords=coords, + compat=compat, + positions=positions, + fill_value=fill_value, + join=join, + combine_attrs=combine_attrs, + create_index_for_new_dim=create_index_for_new_dim, + ) + else: + raise TypeError( + "can only concatenate xarray Dataset and DataArray " + f"objects, got {type(first_obj)}" + ) + + +def _calc_concat_dim_index( + dim_or_data: Hashable | Any, +) -> tuple[Hashable, PandasIndex | None]: + """Infer the dimension name and 1d index / coordinate variable (if appropriate) + for concatenating along the new dimension. + + """ + from xarray.core.dataarray import DataArray + + dim: Hashable | None + + if isinstance(dim_or_data, str): + dim = dim_or_data + index = None + else: + if not isinstance(dim_or_data, (DataArray, Variable)): + dim = getattr(dim_or_data, "name", None) + if dim is None: + dim = "concat_dim" + else: + (dim,) = dim_or_data.dims + coord_dtype = getattr(dim_or_data, "dtype", None) + index = PandasIndex(dim_or_data, dim, coord_dtype=coord_dtype) + + return dim, index + + +def _calc_concat_over(datasets, dim, dim_names, data_vars: T_DataVars, coords, compat): + """ + Determine which dataset variables need to be concatenated in the result, + """ + # Return values + concat_over = set() + equals = {} + + if dim in dim_names: + concat_over_existing_dim = True + concat_over.add(dim) + else: + concat_over_existing_dim = False + + concat_dim_lengths = [] + for ds in datasets: + if concat_over_existing_dim: + if dim not in ds.dims: + if dim in ds: + ds = ds.set_coords(dim) + concat_over.update(k for k, v in ds.variables.items() if dim in v.dims) + concat_dim_lengths.append(ds.sizes.get(dim, 1)) + + def process_subset_opt(opt, subset): + if isinstance(opt, str): + if opt == "different": + if compat == "override": + raise ValueError( + f"Cannot specify both {subset}='different' and compat='override'." + ) + # all nonindexes that are not the same in each dataset + for k in getattr(datasets[0], subset): + if k not in concat_over: + equals[k] = None + + variables = [ + ds.variables[k] for ds in datasets if k in ds.variables + ] + + if len(variables) == 1: + # coords="different" doesn't make sense when only one object + # contains a particular variable. + break + elif len(variables) != len(datasets) and opt == "different": + raise ValueError( + f"{k!r} not present in all datasets and coords='different'. " + f"Either add {k!r} to datasets where it is missing or " + "specify coords='minimal'." + ) + + # first check without comparing values i.e. no computes + for var in variables[1:]: + equals[k] = getattr(variables[0], compat)( + var, equiv=lazy_array_equiv + ) + if equals[k] is not True: + # exit early if we know these are not equal or that + # equality cannot be determined i.e. one or all of + # the variables wraps a numpy array + break + + if equals[k] is False: + concat_over.add(k) + + elif equals[k] is None: + # Compare the variable of all datasets vs. the one + # of the first dataset. Perform the minimum amount of + # loads in order to avoid multiple loads from disk + # while keeping the RAM footprint low. + v_lhs = datasets[0].variables[k].load() + # We'll need to know later on if variables are equal. + computed = [] + for ds_rhs in datasets[1:]: + v_rhs = ds_rhs.variables[k].compute() + computed.append(v_rhs) + if not getattr(v_lhs, compat)(v_rhs): + concat_over.add(k) + equals[k] = False + # computed variables are not to be re-computed + # again in the future + for ds, v in zip(datasets[1:], computed): + ds.variables[k].data = v.data + break + else: + equals[k] = True + + elif opt == "all": + concat_over.update( + set().union( + *list(set(getattr(d, subset)) - set(d.dims) for d in datasets) + ) + ) + elif opt == "minimal": + pass + else: + raise ValueError(f"unexpected value for {subset}: {opt}") + else: + valid_vars = tuple(getattr(datasets[0], subset)) + invalid_vars = [k for k in opt if k not in valid_vars] + if invalid_vars: + if subset == "coords": + raise ValueError( + f"the variables {invalid_vars} in coords are not " + f"found in the coordinates of the first dataset {valid_vars}" + ) + else: + # note: data_vars are not listed in the error message here, + # because there may be lots of them + raise ValueError( + f"the variables {invalid_vars} in data_vars are not " + f"found in the data variables of the first dataset" + ) + concat_over.update(opt) + + process_subset_opt(data_vars, "data_vars") + process_subset_opt(coords, "coords") + return concat_over, equals, concat_dim_lengths + + +# determine dimensional coordinate names and a dict mapping name to DataArray +def _parse_datasets( + datasets: list[T_Dataset], +) -> tuple[ + dict[Hashable, Variable], + dict[Hashable, int], + set[Hashable], + set[Hashable], + list[Hashable], +]: + dims: set[Hashable] = set() + all_coord_names: set[Hashable] = set() + data_vars: set[Hashable] = set() # list of data_vars + dim_coords: dict[Hashable, Variable] = {} # maps dim name to variable + dims_sizes: dict[Hashable, int] = {} # shared dimension sizes to expand variables + variables_order: dict[Hashable, Variable] = {} # variables in order of appearance + + for ds in datasets: + dims_sizes.update(ds.sizes) + all_coord_names.update(ds.coords) + data_vars.update(ds.data_vars) + variables_order.update(ds.variables) + + # preserves ordering of dimensions + for dim in ds.dims: + if dim in dims: + continue + + if dim in ds.coords and dim not in dim_coords: + dim_coords[dim] = ds.coords[dim].variable + dims = dims | set(ds.dims) + + return dim_coords, dims_sizes, all_coord_names, data_vars, list(variables_order) + + +def _dataset_concat( + datasets: list[T_Dataset], + dim: str | T_Variable | T_DataArray | pd.Index, + data_vars: T_DataVars, + coords: str | list[str], + compat: CompatOptions, + positions: Iterable[Iterable[int]] | None, + fill_value: Any = dtypes.NA, + join: JoinOptions = "outer", + combine_attrs: CombineAttrsOptions = "override", + create_index_for_new_dim: bool = True, +) -> T_Dataset: + """ + Concatenate a sequence of datasets along a new or existing dimension + """ + from xarray.core.dataarray import DataArray + from xarray.core.dataset import Dataset + + datasets = list(datasets) + + if not all(isinstance(dataset, Dataset) for dataset in datasets): + raise TypeError( + "The elements in the input list need to be either all 'Dataset's or all 'DataArray's" + ) + + if isinstance(dim, DataArray): + dim_var = dim.variable + elif isinstance(dim, Variable): + dim_var = dim + else: + dim_var = None + + dim, index = _calc_concat_dim_index(dim) + + # Make sure we're working on a copy (we'll be loading variables) + datasets = [ds.copy() for ds in datasets] + datasets = list( + align(*datasets, join=join, copy=False, exclude=[dim], fill_value=fill_value) + ) + + dim_coords, dims_sizes, coord_names, data_names, vars_order = _parse_datasets( + datasets + ) + dim_names = set(dim_coords) + + both_data_and_coords = coord_names & data_names + if both_data_and_coords: + raise ValueError( + f"{both_data_and_coords!r} is a coordinate in some datasets but not others." + ) + # we don't want the concat dimension in the result dataset yet + dim_coords.pop(dim, None) + dims_sizes.pop(dim, None) + + # case where concat dimension is a coordinate or data_var but not a dimension + if (dim in coord_names or dim in data_names) and dim not in dim_names: + datasets = [ + ds.expand_dims(dim, create_index_for_new_dim=create_index_for_new_dim) + for ds in datasets + ] + + # determine which variables to concatenate + concat_over, equals, concat_dim_lengths = _calc_concat_over( + datasets, dim, dim_names, data_vars, coords, compat + ) + + # determine which variables to merge, and then merge them according to compat + variables_to_merge = (coord_names | data_names) - concat_over + + result_vars = {} + result_indexes = {} + + if variables_to_merge: + grouped = { + k: v + for k, v in collect_variables_and_indexes(datasets).items() + if k in variables_to_merge + } + merged_vars, merged_indexes = merge_collected( + grouped, compat=compat, equals=equals + ) + result_vars.update(merged_vars) + result_indexes.update(merged_indexes) + + result_vars.update(dim_coords) + + # assign attrs and encoding from first dataset + result_attrs = merge_attrs([ds.attrs for ds in datasets], combine_attrs) + result_encoding = datasets[0].encoding + + # check that global attributes are fixed across all datasets if necessary + if compat == "identical": + for ds in datasets[1:]: + if not utils.dict_equiv(ds.attrs, result_attrs): + raise ValueError("Dataset global attributes not equal.") + + # we've already verified everything is consistent; now, calculate + # shared dimension sizes so we can expand the necessary variables + def ensure_common_dims(vars, concat_dim_lengths): + # ensure each variable with the given name shares the same + # dimensions and the same shape for all of them except along the + # concat dimension + common_dims = tuple(utils.OrderedSet(d for v in vars for d in v.dims)) + if dim not in common_dims: + common_dims = (dim,) + common_dims + for var, dim_len in zip(vars, concat_dim_lengths): + if var.dims != common_dims: + common_shape = tuple(dims_sizes.get(d, dim_len) for d in common_dims) + var = var.set_dims(common_dims, common_shape) + yield var + + # get the indexes to concatenate together, create a PandasIndex + # for any scalar coordinate variable found with ``name`` matching ``dim``. + # TODO: depreciate concat a mix of scalar and dimensional indexed coordinates? + # TODO: (benbovy - explicit indexes): check index types and/or coordinates + # of all datasets? + def get_indexes(name): + for ds in datasets: + if name in ds._indexes: + yield ds._indexes[name] + elif name == dim: + var = ds._variables[name] + if not var.dims: + data = var.set_dims(dim).values + if create_index_for_new_dim: + yield PandasIndex(data, dim, coord_dtype=var.dtype) + + # create concatenation index, needed for later reindexing + file_start_indexes = np.append(0, np.cumsum(concat_dim_lengths)) + concat_index = np.arange(file_start_indexes[-1]) + concat_index_size = concat_index.size + variable_index_mask = np.ones(concat_index_size, dtype=bool) + + # stack up each variable and/or index to fill-out the dataset (in order) + # n.b. this loop preserves variable order, needed for groupby. + ndatasets = len(datasets) + for name in vars_order: + if name in concat_over and name not in result_indexes: + variables = [] + # Initialize the mask to all True then set False if any name is missing in + # the datasets: + variable_index_mask.fill(True) + var_concat_dim_length = [] + for i, ds in enumerate(datasets): + if name in ds.variables: + variables.append(ds[name].variable) + var_concat_dim_length.append(concat_dim_lengths[i]) + else: + # raise if coordinate not in all datasets + if name in coord_names: + raise ValueError( + f"coordinate {name!r} not present in all datasets." + ) + + # Mask out the indexes without the name: + start = file_start_indexes[i] + end = file_start_indexes[i + 1] + variable_index_mask[slice(start, end)] = False + + variable_index = concat_index[variable_index_mask] + vars = ensure_common_dims(variables, var_concat_dim_length) + + # Try to concatenate the indexes, concatenate the variables when no index + # is found on all datasets. + indexes: list[Index] = list(get_indexes(name)) + if indexes: + if len(indexes) < ndatasets: + raise ValueError( + f"{name!r} must have either an index or no index in all datasets, " + f"found {len(indexes)}/{len(datasets)} datasets with an index." + ) + combined_idx = indexes[0].concat(indexes, dim, positions) + if name in datasets[0]._indexes: + idx_vars = datasets[0].xindexes.get_all_coords(name) + else: + # index created from a scalar coordinate + idx_vars = {name: datasets[0][name].variable} + result_indexes.update({k: combined_idx for k in idx_vars}) + combined_idx_vars = combined_idx.create_variables(idx_vars) + for k, v in combined_idx_vars.items(): + v.attrs = merge_attrs( + [ds.variables[k].attrs for ds in datasets], + combine_attrs=combine_attrs, + ) + result_vars[k] = v + else: + combined_var = concat_vars( + vars, dim, positions, combine_attrs=combine_attrs + ) + # reindex if variable is not present in all datasets + if len(variable_index) < concat_index_size: + combined_var = reindex_variables( + variables={name: combined_var}, + dim_pos_indexers={ + dim: pd.Index(variable_index).get_indexer(concat_index) + }, + fill_value=fill_value, + )[name] + result_vars[name] = combined_var + + elif name in result_vars: + # preserves original variable order + result_vars[name] = result_vars.pop(name) + + absent_coord_names = coord_names - set(result_vars) + if absent_coord_names: + raise ValueError( + f"Variables {absent_coord_names!r} are coordinates in some datasets but not others." + ) + + result_data_vars = {} + coord_vars = {} + for name, result_var in result_vars.items(): + if name in coord_names: + coord_vars[name] = result_var + else: + result_data_vars[name] = result_var + + if index is not None: + if dim_var is not None: + index_vars = index.create_variables({dim: dim_var}) + else: + index_vars = index.create_variables() + + coord_vars[dim] = index_vars[dim] + result_indexes[dim] = index + + coords_obj = Coordinates(coord_vars, indexes=result_indexes) + + result = type(datasets[0])(result_data_vars, coords=coords_obj, attrs=result_attrs) + result.encoding = result_encoding + + return result + + +def _dataarray_concat( + arrays: Iterable[T_DataArray], + dim: str | T_Variable | T_DataArray | pd.Index, + data_vars: T_DataVars, + coords: str | list[str], + compat: CompatOptions, + positions: Iterable[Iterable[int]] | None, + fill_value: object = dtypes.NA, + join: JoinOptions = "outer", + combine_attrs: CombineAttrsOptions = "override", + create_index_for_new_dim: bool = True, +) -> T_DataArray: + from xarray.core.dataarray import DataArray + + arrays = list(arrays) + + if not all(isinstance(array, DataArray) for array in arrays): + raise TypeError( + "The elements in the input list need to be either all 'Dataset's or all 'DataArray's" + ) + + if data_vars != "all": + raise ValueError( + "data_vars is not a valid argument when concatenating DataArray objects" + ) + + datasets = [] + for n, arr in enumerate(arrays): + if n == 0: + name = arr.name + elif name != arr.name: + if compat == "identical": + raise ValueError("array names not identical") + else: + arr = arr.rename(name) + datasets.append(arr._to_temp_dataset()) + + ds = _dataset_concat( + datasets, + dim, + data_vars, + coords, + compat, + positions, + fill_value=fill_value, + join=join, + combine_attrs=combine_attrs, + create_index_for_new_dim=create_index_for_new_dim, + ) + + merged_attrs = merge_attrs([da.attrs for da in arrays], combine_attrs) + + result = arrays[0]._from_temp_dataset(ds, name) + result.attrs = merged_attrs + + return result diff --git a/test/fixtures/whole_applications/xarray/xarray/core/coordinates.py b/test/fixtures/whole_applications/xarray/xarray/core/coordinates.py new file mode 100644 index 0000000..251edd1 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/core/coordinates.py @@ -0,0 +1,1025 @@ +from __future__ import annotations + +from collections.abc import Hashable, Iterator, Mapping, Sequence +from contextlib import contextmanager +from typing import ( + TYPE_CHECKING, + Any, + Generic, + cast, +) + +import numpy as np +import pandas as pd + +from xarray.core import formatting +from xarray.core.alignment import Aligner +from xarray.core.indexes import ( + Index, + Indexes, + PandasIndex, + PandasMultiIndex, + assert_no_index_corrupted, + create_default_index_implicit, +) +from xarray.core.merge import merge_coordinates_without_align, merge_coords +from xarray.core.types import DataVars, Self, T_DataArray, T_Xarray +from xarray.core.utils import ( + Frozen, + ReprObject, + either_dict_or_kwargs, + emit_user_level_warning, +) +from xarray.core.variable import Variable, as_variable, calculate_dimensions + +if TYPE_CHECKING: + from xarray.core.common import DataWithCoords + from xarray.core.dataarray import DataArray + from xarray.core.dataset import Dataset + +# Used as the key corresponding to a DataArray's variable when converting +# arbitrary DataArray objects to datasets +_THIS_ARRAY = ReprObject("") + + +class AbstractCoordinates(Mapping[Hashable, "T_DataArray"]): + _data: DataWithCoords + __slots__ = ("_data",) + + def __getitem__(self, key: Hashable) -> T_DataArray: + raise NotImplementedError() + + @property + def _names(self) -> set[Hashable]: + raise NotImplementedError() + + @property + def dims(self) -> Frozen[Hashable, int] | tuple[Hashable, ...]: + raise NotImplementedError() + + @property + def dtypes(self) -> Frozen[Hashable, np.dtype]: + raise NotImplementedError() + + @property + def indexes(self) -> Indexes[pd.Index]: + """Mapping of pandas.Index objects used for label based indexing. + + Raises an error if this Coordinates object has indexes that cannot + be coerced to pandas.Index objects. + + See Also + -------- + Coordinates.xindexes + """ + return self._data.indexes + + @property + def xindexes(self) -> Indexes[Index]: + """Mapping of :py:class:`~xarray.indexes.Index` objects + used for label based indexing. + """ + return self._data.xindexes + + @property + def variables(self): + raise NotImplementedError() + + def _update_coords(self, coords, indexes): + raise NotImplementedError() + + def _drop_coords(self, coord_names): + raise NotImplementedError() + + def __iter__(self) -> Iterator[Hashable]: + # needs to be in the same order as the dataset variables + for k in self.variables: + if k in self._names: + yield k + + def __len__(self) -> int: + return len(self._names) + + def __contains__(self, key: Hashable) -> bool: + return key in self._names + + def __repr__(self) -> str: + return formatting.coords_repr(self) + + def to_dataset(self) -> Dataset: + raise NotImplementedError() + + def to_index(self, ordered_dims: Sequence[Hashable] | None = None) -> pd.Index: + """Convert all index coordinates into a :py:class:`pandas.Index`. + + Parameters + ---------- + ordered_dims : sequence of hashable, optional + Possibly reordered version of this object's dimensions indicating + the order in which dimensions should appear on the result. + + Returns + ------- + pandas.Index + Index subclass corresponding to the outer-product of all dimension + coordinates. This will be a MultiIndex if this object is has more + than more dimension. + """ + if ordered_dims is None: + ordered_dims = list(self.dims) + elif set(ordered_dims) != set(self.dims): + raise ValueError( + "ordered_dims must match dims, but does not: " + f"{ordered_dims} vs {self.dims}" + ) + + if len(ordered_dims) == 0: + raise ValueError("no valid index for a 0-dimensional object") + elif len(ordered_dims) == 1: + (dim,) = ordered_dims + return self._data.get_index(dim) + else: + indexes = [self._data.get_index(k) for k in ordered_dims] + + # compute the sizes of the repeat and tile for the cartesian product + # (taken from pandas.core.reshape.util) + index_lengths = np.fromiter( + (len(index) for index in indexes), dtype=np.intp + ) + cumprod_lengths = np.cumprod(index_lengths) + + if cumprod_lengths[-1] == 0: + # if any factor is empty, the cartesian product is empty + repeat_counts = np.zeros_like(cumprod_lengths) + + else: + # sizes of the repeats + repeat_counts = cumprod_lengths[-1] / cumprod_lengths + # sizes of the tiles + tile_counts = np.roll(cumprod_lengths, 1) + tile_counts[0] = 1 + + # loop over the indexes + # for each MultiIndex or Index compute the cartesian product of the codes + + code_list = [] + level_list = [] + names = [] + + for i, index in enumerate(indexes): + if isinstance(index, pd.MultiIndex): + codes, levels = index.codes, index.levels + else: + code, level = pd.factorize(index) + codes = [code] + levels = [level] + + # compute the cartesian product + code_list += [ + np.tile(np.repeat(code, repeat_counts[i]), tile_counts[i]) + for code in codes + ] + level_list += levels + names += index.names + + return pd.MultiIndex(level_list, code_list, names=names) + + +class Coordinates(AbstractCoordinates): + """Dictionary like container for Xarray coordinates (variables + indexes). + + This collection is a mapping of coordinate names to + :py:class:`~xarray.DataArray` objects. + + It can be passed directly to the :py:class:`~xarray.Dataset` and + :py:class:`~xarray.DataArray` constructors via their `coords` argument. This + will add both the coordinates variables and their index. + + Coordinates are either: + + - returned via the :py:attr:`Dataset.coords` and :py:attr:`DataArray.coords` + properties + - built from Pandas or other index objects + (e.g., :py:meth:`Coordinates.from_pandas_multiindex`) + - built directly from coordinate data and Xarray ``Index`` objects (beware that + no consistency check is done on those inputs) + + Parameters + ---------- + coords: dict-like, optional + Mapping where keys are coordinate names and values are objects that + can be converted into a :py:class:`~xarray.Variable` object + (see :py:func:`~xarray.as_variable`). If another + :py:class:`~xarray.Coordinates` object is passed, its indexes + will be added to the new created object. + indexes: dict-like, optional + Mapping where keys are coordinate names and values are + :py:class:`~xarray.indexes.Index` objects. If None (default), + pandas indexes will be created for each dimension coordinate. + Passing an empty dictionary will skip this default behavior. + + Examples + -------- + Create a dimension coordinate with a default (pandas) index: + + >>> xr.Coordinates({"x": [1, 2]}) + Coordinates: + * x (x) int64 16B 1 2 + + Create a dimension coordinate with no index: + + >>> xr.Coordinates(coords={"x": [1, 2]}, indexes={}) + Coordinates: + x (x) int64 16B 1 2 + + Create a new Coordinates object from existing dataset coordinates + (indexes are passed): + + >>> ds = xr.Dataset(coords={"x": [1, 2]}) + >>> xr.Coordinates(ds.coords) + Coordinates: + * x (x) int64 16B 1 2 + + Create indexed coordinates from a ``pandas.MultiIndex`` object: + + >>> midx = pd.MultiIndex.from_product([["a", "b"], [0, 1]]) + >>> xr.Coordinates.from_pandas_multiindex(midx, "x") + Coordinates: + * x (x) object 32B MultiIndex + * x_level_0 (x) object 32B 'a' 'a' 'b' 'b' + * x_level_1 (x) int64 32B 0 1 0 1 + + Create a new Dataset object by passing a Coordinates object: + + >>> midx_coords = xr.Coordinates.from_pandas_multiindex(midx, "x") + >>> xr.Dataset(coords=midx_coords) + Size: 96B + Dimensions: (x: 4) + Coordinates: + * x (x) object 32B MultiIndex + * x_level_0 (x) object 32B 'a' 'a' 'b' 'b' + * x_level_1 (x) int64 32B 0 1 0 1 + Data variables: + *empty* + + """ + + _data: DataWithCoords + + __slots__ = ("_data",) + + def __init__( + self, + coords: Mapping[Any, Any] | None = None, + indexes: Mapping[Any, Index] | None = None, + ) -> None: + # When coordinates are constructed directly, an internal Dataset is + # created so that it is compatible with the DatasetCoordinates and + # DataArrayCoordinates classes serving as a proxy for the data. + # TODO: refactor DataArray / Dataset so that Coordinates store the data. + from xarray.core.dataset import Dataset + + if coords is None: + coords = {} + + variables: dict[Hashable, Variable] + default_indexes: dict[Hashable, PandasIndex] = {} + coords_obj_indexes: dict[Hashable, Index] = {} + + if isinstance(coords, Coordinates): + if indexes is not None: + raise ValueError( + "passing both a ``Coordinates`` object and a mapping of indexes " + "to ``Coordinates.__init__`` is not allowed " + "(this constructor does not support merging them)" + ) + variables = {k: v.copy() for k, v in coords.variables.items()} + coords_obj_indexes = dict(coords.xindexes) + else: + variables = {} + for name, data in coords.items(): + var = as_variable(data, name=name, auto_convert=False) + if var.dims == (name,) and indexes is None: + index, index_vars = create_default_index_implicit(var, list(coords)) + default_indexes.update({k: index for k in index_vars}) + variables.update(index_vars) + else: + variables[name] = var + + if indexes is None: + indexes = {} + else: + indexes = dict(indexes) + + indexes.update(default_indexes) + indexes.update(coords_obj_indexes) + + no_coord_index = set(indexes) - set(variables) + if no_coord_index: + raise ValueError( + f"no coordinate variables found for these indexes: {no_coord_index}" + ) + + for k, idx in indexes.items(): + if not isinstance(idx, Index): + raise TypeError(f"'{k}' is not an `xarray.indexes.Index` object") + + # maybe convert to base variable + for k, v in variables.items(): + if k not in indexes: + variables[k] = v.to_base_variable() + + self._data = Dataset._construct_direct( + coord_names=set(variables), variables=variables, indexes=indexes + ) + + @classmethod + def _construct_direct( + cls, + coords: dict[Any, Variable], + indexes: dict[Any, Index], + dims: dict[Any, int] | None = None, + ) -> Self: + from xarray.core.dataset import Dataset + + obj = object.__new__(cls) + obj._data = Dataset._construct_direct( + coord_names=set(coords), + variables=coords, + indexes=indexes, + dims=dims, + ) + return obj + + @classmethod + def from_pandas_multiindex(cls, midx: pd.MultiIndex, dim: str) -> Self: + """Wrap a pandas multi-index as Xarray coordinates (dimension + levels). + + The returned coordinates can be directly assigned to a + :py:class:`~xarray.Dataset` or :py:class:`~xarray.DataArray` via the + ``coords`` argument of their constructor. + + Parameters + ---------- + midx : :py:class:`pandas.MultiIndex` + Pandas multi-index object. + dim : str + Dimension name. + + Returns + ------- + coords : Coordinates + A collection of Xarray indexed coordinates created from the multi-index. + + """ + xr_idx = PandasMultiIndex(midx, dim) + + variables = xr_idx.create_variables() + indexes = {k: xr_idx for k in variables} + + return cls(coords=variables, indexes=indexes) + + @property + def _names(self) -> set[Hashable]: + return self._data._coord_names + + @property + def dims(self) -> Frozen[Hashable, int] | tuple[Hashable, ...]: + """Mapping from dimension names to lengths or tuple of dimension names.""" + return self._data.dims + + @property + def sizes(self) -> Frozen[Hashable, int]: + """Mapping from dimension names to lengths.""" + return self._data.sizes + + @property + def dtypes(self) -> Frozen[Hashable, np.dtype]: + """Mapping from coordinate names to dtypes. + + Cannot be modified directly. + + See Also + -------- + Dataset.dtypes + """ + return Frozen({n: v.dtype for n, v in self._data.variables.items()}) + + @property + def variables(self) -> Mapping[Hashable, Variable]: + """Low level interface to Coordinates contents as dict of Variable objects. + + This dictionary is frozen to prevent mutation. + """ + return self._data.variables + + def to_dataset(self) -> Dataset: + """Convert these coordinates into a new Dataset.""" + names = [name for name in self._data._variables if name in self._names] + return self._data._copy_listed(names) + + def __getitem__(self, key: Hashable) -> DataArray: + return self._data[key] + + def __delitem__(self, key: Hashable) -> None: + # redirect to DatasetCoordinates.__delitem__ + del self._data.coords[key] + + def equals(self, other: Self) -> bool: + """Two Coordinates objects are equal if they have matching variables, + all of which are equal. + + See Also + -------- + Coordinates.identical + """ + if not isinstance(other, Coordinates): + return False + return self.to_dataset().equals(other.to_dataset()) + + def identical(self, other: Self) -> bool: + """Like equals, but also checks all variable attributes. + + See Also + -------- + Coordinates.equals + """ + if not isinstance(other, Coordinates): + return False + return self.to_dataset().identical(other.to_dataset()) + + def _update_coords( + self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index] + ) -> None: + # redirect to DatasetCoordinates._update_coords + self._data.coords._update_coords(coords, indexes) + + def _drop_coords(self, coord_names): + # redirect to DatasetCoordinates._drop_coords + self._data.coords._drop_coords(coord_names) + + def _merge_raw(self, other, reflexive): + """For use with binary arithmetic.""" + if other is None: + variables = dict(self.variables) + indexes = dict(self.xindexes) + else: + coord_list = [self, other] if not reflexive else [other, self] + variables, indexes = merge_coordinates_without_align(coord_list) + return variables, indexes + + @contextmanager + def _merge_inplace(self, other): + """For use with in-place binary arithmetic.""" + if other is None: + yield + else: + # don't include indexes in prioritized, because we didn't align + # first and we want indexes to be checked + prioritized = { + k: (v, None) + for k, v in self.variables.items() + if k not in self.xindexes + } + variables, indexes = merge_coordinates_without_align( + [self, other], prioritized + ) + yield + self._update_coords(variables, indexes) + + def merge(self, other: Mapping[Any, Any] | None) -> Dataset: + """Merge two sets of coordinates to create a new Dataset + + The method implements the logic used for joining coordinates in the + result of a binary operation performed on xarray objects: + + - If two index coordinates conflict (are not equal), an exception is + raised. You must align your data before passing it to this method. + - If an index coordinate and a non-index coordinate conflict, the non- + index coordinate is dropped. + - If two non-index coordinates conflict, both are dropped. + + Parameters + ---------- + other : dict-like, optional + A :py:class:`Coordinates` object or any mapping that can be turned + into coordinates. + + Returns + ------- + merged : Dataset + A new Dataset with merged coordinates. + """ + from xarray.core.dataset import Dataset + + if other is None: + return self.to_dataset() + + if not isinstance(other, Coordinates): + other = Dataset(coords=other).coords + + coords, indexes = merge_coordinates_without_align([self, other]) + coord_names = set(coords) + return Dataset._construct_direct( + variables=coords, coord_names=coord_names, indexes=indexes + ) + + def __setitem__(self, key: Hashable, value: Any) -> None: + self.update({key: value}) + + def update(self, other: Mapping[Any, Any]) -> None: + """Update this Coordinates variables with other coordinate variables.""" + + if not len(other): + return + + other_coords: Coordinates + + if isinstance(other, Coordinates): + # Coordinates object: just pass it (default indexes won't be created) + other_coords = other + else: + other_coords = create_coords_with_default_indexes( + getattr(other, "variables", other) + ) + + # Discard original indexed coordinates prior to merge allows to: + # - fail early if the new coordinates don't preserve the integrity of existing + # multi-coordinate indexes + # - drop & replace coordinates without alignment (note: we must keep indexed + # coordinates extracted from the DataArray objects passed as values to + # `other` - if any - as those are still used for aligning the old/new coordinates) + coords_to_align = drop_indexed_coords(set(other_coords) & set(other), self) + + coords, indexes = merge_coords( + [coords_to_align, other_coords], + priority_arg=1, + indexes=coords_to_align.xindexes, + ) + + # special case for PandasMultiIndex: updating only its dimension coordinate + # is still allowed but depreciated. + # It is the only case where we need to actually drop coordinates here (multi-index levels) + # TODO: remove when removing PandasMultiIndex's dimension coordinate. + self._drop_coords(self._names - coords_to_align._names) + + self._update_coords(coords, indexes) + + def assign(self, coords: Mapping | None = None, **coords_kwargs: Any) -> Self: + """Assign new coordinates (and indexes) to a Coordinates object, returning + a new object with all the original coordinates in addition to the new ones. + + Parameters + ---------- + coords : mapping of dim to coord, optional + A mapping whose keys are the names of the coordinates and values are the + coordinates to assign. The mapping will generally be a dict or + :class:`Coordinates`. + + * If a value is a standard data value — for example, a ``DataArray``, + scalar, or array — the data is simply assigned as a coordinate. + + * A coordinate can also be defined and attached to an existing dimension + using a tuple with the first element the dimension name and the second + element the values for this new coordinate. + + **coords_kwargs + The keyword arguments form of ``coords``. + One of ``coords`` or ``coords_kwargs`` must be provided. + + Returns + ------- + new_coords : Coordinates + A new Coordinates object with the new coordinates (and indexes) + in addition to all the existing coordinates. + + Examples + -------- + >>> coords = xr.Coordinates() + >>> coords + Coordinates: + *empty* + + >>> coords.assign(x=[1, 2]) + Coordinates: + * x (x) int64 16B 1 2 + + >>> midx = pd.MultiIndex.from_product([["a", "b"], [0, 1]]) + >>> coords.assign(xr.Coordinates.from_pandas_multiindex(midx, "y")) + Coordinates: + * y (y) object 32B MultiIndex + * y_level_0 (y) object 32B 'a' 'a' 'b' 'b' + * y_level_1 (y) int64 32B 0 1 0 1 + + """ + # TODO: this doesn't support a callable, which is inconsistent with `DataArray.assign_coords` + coords = either_dict_or_kwargs(coords, coords_kwargs, "assign") + new_coords = self.copy() + new_coords.update(coords) + return new_coords + + def _overwrite_indexes( + self, + indexes: Mapping[Any, Index], + variables: Mapping[Any, Variable] | None = None, + ) -> Self: + results = self.to_dataset()._overwrite_indexes(indexes, variables) + + # TODO: remove cast once we get rid of DatasetCoordinates + # and DataArrayCoordinates (i.e., Dataset and DataArray encapsulate Coordinates) + return cast(Self, results.coords) + + def _reindex_callback( + self, + aligner: Aligner, + dim_pos_indexers: dict[Hashable, Any], + variables: dict[Hashable, Variable], + indexes: dict[Hashable, Index], + fill_value: Any, + exclude_dims: frozenset[Hashable], + exclude_vars: frozenset[Hashable], + ) -> Self: + """Callback called from ``Aligner`` to create a new reindexed Coordinate.""" + aligned = self.to_dataset()._reindex_callback( + aligner, + dim_pos_indexers, + variables, + indexes, + fill_value, + exclude_dims, + exclude_vars, + ) + + # TODO: remove cast once we get rid of DatasetCoordinates + # and DataArrayCoordinates (i.e., Dataset and DataArray encapsulate Coordinates) + return cast(Self, aligned.coords) + + def _ipython_key_completions_(self): + """Provide method for the key-autocompletions in IPython.""" + return self._data._ipython_key_completions_() + + def copy( + self, + deep: bool = False, + memo: dict[int, Any] | None = None, + ) -> Self: + """Return a copy of this Coordinates object.""" + # do not copy indexes (may corrupt multi-coordinate indexes) + # TODO: disable variables deepcopy? it may also be problematic when they + # encapsulate index objects like pd.Index + variables = { + k: v._copy(deep=deep, memo=memo) for k, v in self.variables.items() + } + + # TODO: getting an error with `self._construct_direct`, possibly because of how + # a subclass implements `_construct_direct`. (This was originally the same + # runtime code, but we switched the type definitions in #8216, which + # necessitates the cast.) + return cast( + Self, + Coordinates._construct_direct( + coords=variables, indexes=dict(self.xindexes), dims=dict(self.sizes) + ), + ) + + +class DatasetCoordinates(Coordinates): + """Dictionary like container for Dataset coordinates (variables + indexes). + + This collection can be passed directly to the :py:class:`~xarray.Dataset` + and :py:class:`~xarray.DataArray` constructors via their `coords` argument. + This will add both the coordinates variables and their index. + """ + + _data: Dataset + + __slots__ = ("_data",) + + def __init__(self, dataset: Dataset): + self._data = dataset + + @property + def _names(self) -> set[Hashable]: + return self._data._coord_names + + @property + def dims(self) -> Frozen[Hashable, int]: + return self._data.dims + + @property + def dtypes(self) -> Frozen[Hashable, np.dtype]: + """Mapping from coordinate names to dtypes. + + Cannot be modified directly, but is updated when adding new variables. + + See Also + -------- + Dataset.dtypes + """ + return Frozen( + { + n: v.dtype + for n, v in self._data._variables.items() + if n in self._data._coord_names + } + ) + + @property + def variables(self) -> Mapping[Hashable, Variable]: + return Frozen( + {k: v for k, v in self._data.variables.items() if k in self._names} + ) + + def __getitem__(self, key: Hashable) -> DataArray: + if key in self._data.data_vars: + raise KeyError(key) + return self._data[key] + + def to_dataset(self) -> Dataset: + """Convert these coordinates into a new Dataset""" + + names = [name for name in self._data._variables if name in self._names] + return self._data._copy_listed(names) + + def _update_coords( + self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index] + ) -> None: + variables = self._data._variables.copy() + variables.update(coords) + + # check for inconsistent state *before* modifying anything in-place + dims = calculate_dimensions(variables) + new_coord_names = set(coords) + for dim, size in dims.items(): + if dim in variables: + new_coord_names.add(dim) + + self._data._variables = variables + self._data._coord_names.update(new_coord_names) + self._data._dims = dims + + # TODO(shoyer): once ._indexes is always populated by a dict, modify + # it to update inplace instead. + original_indexes = dict(self._data.xindexes) + original_indexes.update(indexes) + self._data._indexes = original_indexes + + def _drop_coords(self, coord_names): + # should drop indexed coordinates only + for name in coord_names: + del self._data._variables[name] + del self._data._indexes[name] + self._data._coord_names.difference_update(coord_names) + + def _drop_indexed_coords(self, coords_to_drop: set[Hashable]) -> None: + assert self._data.xindexes is not None + new_coords = drop_indexed_coords(coords_to_drop, self) + for name in self._data._coord_names - new_coords._names: + del self._data._variables[name] + self._data._indexes = dict(new_coords.xindexes) + self._data._coord_names.intersection_update(new_coords._names) + + def __delitem__(self, key: Hashable) -> None: + if key in self: + del self._data[key] + else: + raise KeyError( + f"{key!r} is not in coordinate variables {tuple(self.keys())}" + ) + + def _ipython_key_completions_(self): + """Provide method for the key-autocompletions in IPython.""" + return [ + key + for key in self._data._ipython_key_completions_() + if key not in self._data.data_vars + ] + + +class DataArrayCoordinates(Coordinates, Generic[T_DataArray]): + """Dictionary like container for DataArray coordinates (variables + indexes). + + This collection can be passed directly to the :py:class:`~xarray.Dataset` + and :py:class:`~xarray.DataArray` constructors via their `coords` argument. + This will add both the coordinates variables and their index. + """ + + _data: T_DataArray + + __slots__ = ("_data",) + + def __init__(self, dataarray: T_DataArray) -> None: + self._data = dataarray + + @property + def dims(self) -> tuple[Hashable, ...]: + return self._data.dims + + @property + def dtypes(self) -> Frozen[Hashable, np.dtype]: + """Mapping from coordinate names to dtypes. + + Cannot be modified directly, but is updated when adding new variables. + + See Also + -------- + DataArray.dtype + """ + return Frozen({n: v.dtype for n, v in self._data._coords.items()}) + + @property + def _names(self) -> set[Hashable]: + return set(self._data._coords) + + def __getitem__(self, key: Hashable) -> T_DataArray: + return self._data._getitem_coord(key) + + def _update_coords( + self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index] + ) -> None: + coords_plus_data = coords.copy() + coords_plus_data[_THIS_ARRAY] = self._data.variable + dims = calculate_dimensions(coords_plus_data) + if not set(dims) <= set(self.dims): + raise ValueError( + "cannot add coordinates with new dimensions to a DataArray" + ) + self._data._coords = coords + + # TODO(shoyer): once ._indexes is always populated by a dict, modify + # it to update inplace instead. + original_indexes = dict(self._data.xindexes) + original_indexes.update(indexes) + self._data._indexes = original_indexes + + def _drop_coords(self, coord_names): + # should drop indexed coordinates only + for name in coord_names: + del self._data._coords[name] + del self._data._indexes[name] + + @property + def variables(self): + return Frozen(self._data._coords) + + def to_dataset(self) -> Dataset: + from xarray.core.dataset import Dataset + + coords = {k: v.copy(deep=False) for k, v in self._data._coords.items()} + indexes = dict(self._data.xindexes) + return Dataset._construct_direct(coords, set(coords), indexes=indexes) + + def __delitem__(self, key: Hashable) -> None: + if key not in self: + raise KeyError( + f"{key!r} is not in coordinate variables {tuple(self.keys())}" + ) + assert_no_index_corrupted(self._data.xindexes, {key}) + + del self._data._coords[key] + if self._data._indexes is not None and key in self._data._indexes: + del self._data._indexes[key] + + def _ipython_key_completions_(self): + """Provide method for the key-autocompletions in IPython.""" + return self._data._ipython_key_completions_() + + +def drop_indexed_coords( + coords_to_drop: set[Hashable], coords: Coordinates +) -> Coordinates: + """Drop indexed coordinates associated with coordinates in coords_to_drop. + + This will raise an error in case it corrupts any passed index and its + coordinate variables. + + """ + new_variables = dict(coords.variables) + new_indexes = dict(coords.xindexes) + + for idx, idx_coords in coords.xindexes.group_by_index(): + idx_drop_coords = set(idx_coords) & coords_to_drop + + # special case for pandas multi-index: still allow but deprecate + # dropping only its dimension coordinate. + # TODO: remove when removing PandasMultiIndex's dimension coordinate. + if isinstance(idx, PandasMultiIndex) and idx_drop_coords == {idx.dim}: + idx_drop_coords.update(idx.index.names) + emit_user_level_warning( + f"updating coordinate {idx.dim!r} with a PandasMultiIndex would leave " + f"the multi-index level coordinates {list(idx.index.names)!r} in an inconsistent state. " + f"This will raise an error in the future. Use `.drop_vars({list(idx_coords)!r})` before " + "assigning new coordinate values.", + FutureWarning, + ) + + elif idx_drop_coords and len(idx_drop_coords) != len(idx_coords): + idx_drop_coords_str = ", ".join(f"{k!r}" for k in idx_drop_coords) + idx_coords_str = ", ".join(f"{k!r}" for k in idx_coords) + raise ValueError( + f"cannot drop or update coordinate(s) {idx_drop_coords_str}, which would corrupt " + f"the following index built from coordinates {idx_coords_str}:\n" + f"{idx}" + ) + + for k in idx_drop_coords: + del new_variables[k] + del new_indexes[k] + + return Coordinates._construct_direct(coords=new_variables, indexes=new_indexes) + + +def assert_coordinate_consistent(obj: T_Xarray, coords: Mapping[Any, Variable]) -> None: + """Make sure the dimension coordinate of obj is consistent with coords. + + obj: DataArray or Dataset + coords: Dict-like of variables + """ + for k in obj.dims: + # make sure there are no conflict in dimension coordinates + if k in coords and k in obj.coords and not coords[k].equals(obj[k].variable): + raise IndexError( + f"dimension coordinate {k!r} conflicts between " + f"indexed and indexing objects:\n{obj[k]}\nvs.\n{coords[k]}" + ) + + +def create_coords_with_default_indexes( + coords: Mapping[Any, Any], data_vars: DataVars | None = None +) -> Coordinates: + """Returns a Coordinates object from a mapping of coordinates (arbitrary objects). + + Create default (pandas) indexes for each of the input dimension coordinates. + Extract coordinates from each input DataArray. + + """ + # Note: data_vars is needed here only because a pd.MultiIndex object + # can be promoted as coordinates. + # TODO: It won't be relevant anymore when this behavior will be dropped + # in favor of the more explicit ``Coordinates.from_pandas_multiindex()``. + + from xarray.core.dataarray import DataArray + + all_variables = dict(coords) + if data_vars is not None: + all_variables.update(data_vars) + + indexes: dict[Hashable, Index] = {} + variables: dict[Hashable, Variable] = {} + + # promote any pandas multi-index in data_vars as coordinates + coords_promoted: dict[Hashable, Any] = {} + pd_mindex_keys: list[Hashable] = [] + + for k, v in all_variables.items(): + if isinstance(v, pd.MultiIndex): + coords_promoted[k] = v + pd_mindex_keys.append(k) + elif k in coords: + coords_promoted[k] = v + + if pd_mindex_keys: + pd_mindex_keys_fmt = ",".join([f"'{k}'" for k in pd_mindex_keys]) + emit_user_level_warning( + f"the `pandas.MultiIndex` object(s) passed as {pd_mindex_keys_fmt} coordinate(s) or " + "data variable(s) will no longer be implicitly promoted and wrapped into " + "multiple indexed coordinates in the future " + "(i.e., one coordinate for each multi-index level + one dimension coordinate). " + "If you want to keep this behavior, you need to first wrap it explicitly using " + "`mindex_coords = xarray.Coordinates.from_pandas_multiindex(mindex_obj, 'dim')` " + "and pass it as coordinates, e.g., `xarray.Dataset(coords=mindex_coords)`, " + "`dataset.assign_coords(mindex_coords)` or `dataarray.assign_coords(mindex_coords)`.", + FutureWarning, + ) + + dataarray_coords: list[DataArrayCoordinates] = [] + + for name, obj in coords_promoted.items(): + if isinstance(obj, DataArray): + dataarray_coords.append(obj.coords) + + variable = as_variable(obj, name=name, auto_convert=False) + + if variable.dims == (name,): + # still needed to convert to IndexVariable first due to some + # pandas multi-index edge cases. + variable = variable.to_index_variable() + idx, idx_vars = create_default_index_implicit(variable, all_variables) + indexes.update({k: idx for k in idx_vars}) + variables.update(idx_vars) + all_variables.update(idx_vars) + else: + variables[name] = variable + + new_coords = Coordinates._construct_direct(coords=variables, indexes=indexes) + + # extract and merge coordinates and indexes from input DataArrays + if dataarray_coords: + prioritized = {k: (v, indexes.get(k, None)) for k, v in variables.items()} + variables, indexes = merge_coordinates_without_align( + dataarray_coords + [new_coords], + prioritized=prioritized, + ) + new_coords = Coordinates._construct_direct(coords=variables, indexes=indexes) + + return new_coords diff --git a/test/fixtures/whole_applications/xarray/xarray/core/dask_array_ops.py b/test/fixtures/whole_applications/xarray/xarray/core/dask_array_ops.py new file mode 100644 index 0000000..98ff900 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/core/dask_array_ops.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +from xarray.core import dtypes, nputils + + +def dask_rolling_wrapper(moving_func, a, window, min_count=None, axis=-1): + """Wrapper to apply bottleneck moving window funcs on dask arrays""" + import dask.array as da + + dtype, fill_value = dtypes.maybe_promote(a.dtype) + a = a.astype(dtype) + # inputs for overlap + if axis < 0: + axis = a.ndim + axis + depth = {d: 0 for d in range(a.ndim)} + depth[axis] = (window + 1) // 2 + boundary = {d: fill_value for d in range(a.ndim)} + # Create overlap array. + ag = da.overlap.overlap(a, depth=depth, boundary=boundary) + # apply rolling func + out = da.map_blocks( + moving_func, ag, window, min_count=min_count, axis=axis, dtype=a.dtype + ) + # trim array + result = da.overlap.trim_internal(out, depth) + return result + + +def least_squares(lhs, rhs, rcond=None, skipna=False): + import dask.array as da + + lhs_da = da.from_array(lhs, chunks=(rhs.chunks[0], lhs.shape[1])) + if skipna: + added_dim = rhs.ndim == 1 + if added_dim: + rhs = rhs.reshape(rhs.shape[0], 1) + results = da.apply_along_axis( + nputils._nanpolyfit_1d, + 0, + rhs, + lhs_da, + dtype=float, + shape=(lhs.shape[1] + 1,), + rcond=rcond, + ) + coeffs = results[:-1, ...] + residuals = results[-1, ...] + if added_dim: + coeffs = coeffs.reshape(coeffs.shape[0]) + residuals = residuals.reshape(residuals.shape[0]) + else: + # Residuals here are (1, 1) but should be (K,) as rhs is (N, K) + # See issue dask/dask#6516 + coeffs, residuals, _, _ = da.linalg.lstsq(lhs_da, rhs) + return coeffs, residuals + + +def push(array, n, axis): + """ + Dask-aware bottleneck.push + """ + import dask.array as da + import numpy as np + + from xarray.core.duck_array_ops import _push + + def _fill_with_last_one(a, b): + # cumreduction apply the push func over all the blocks first so, the only missing part is filling + # the missing values using the last data of the previous chunk + return np.where(~np.isnan(b), b, a) + + if n is not None and 0 < n < array.shape[axis] - 1: + arange = da.broadcast_to( + da.arange( + array.shape[axis], chunks=array.chunks[axis], dtype=array.dtype + ).reshape( + tuple(size if i == axis else 1 for i, size in enumerate(array.shape)) + ), + array.shape, + array.chunks, + ) + valid_arange = da.where(da.notnull(array), arange, np.nan) + valid_limits = (arange - push(valid_arange, None, axis)) <= n + # omit the forward fill that violate the limit + return da.where(valid_limits, push(array, None, axis), np.nan) + + # The method parameter makes that the tests for python 3.7 fails. + return da.reductions.cumreduction( + func=_push, + binop=_fill_with_last_one, + ident=np.nan, + x=array, + axis=axis, + dtype=array.dtype, + ) diff --git a/test/fixtures/whole_applications/xarray/xarray/core/dataarray.py b/test/fixtures/whole_applications/xarray/xarray/core/dataarray.py new file mode 100644 index 0000000..16b9330 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/core/dataarray.py @@ -0,0 +1,7458 @@ +from __future__ import annotations + +import datetime +import warnings +from collections.abc import Hashable, Iterable, Mapping, MutableMapping, Sequence +from functools import partial +from os import PathLike +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generic, + Literal, + NoReturn, + TypeVar, + Union, + overload, +) + +import numpy as np +import pandas as pd + +from xarray.coding.calendar_ops import convert_calendar, interp_calendar +from xarray.coding.cftimeindex import CFTimeIndex +from xarray.core import alignment, computation, dtypes, indexing, ops, utils +from xarray.core._aggregations import DataArrayAggregations +from xarray.core.accessor_dt import CombinedDatetimelikeAccessor +from xarray.core.accessor_str import StringAccessor +from xarray.core.alignment import ( + _broadcast_helper, + _get_broadcast_dims_map_common_coords, + align, +) +from xarray.core.arithmetic import DataArrayArithmetic +from xarray.core.common import AbstractArray, DataWithCoords, get_chunksizes +from xarray.core.computation import unify_chunks +from xarray.core.coordinates import ( + Coordinates, + DataArrayCoordinates, + assert_coordinate_consistent, + create_coords_with_default_indexes, +) +from xarray.core.dataset import Dataset +from xarray.core.formatting import format_item +from xarray.core.indexes import ( + Index, + Indexes, + PandasMultiIndex, + filter_indexes_from_coords, + isel_indexes, +) +from xarray.core.indexing import is_fancy_indexer, map_index_queries +from xarray.core.merge import PANDAS_TYPES, MergeError +from xarray.core.options import OPTIONS, _get_keep_attrs +from xarray.core.types import ( + DaCompatible, + NetcdfWriteModes, + T_DataArray, + T_DataArrayOrSet, + ZarrWriteModes, +) +from xarray.core.utils import ( + Default, + HybridMappingProxy, + ReprObject, + _default, + either_dict_or_kwargs, + hashable, + infix_dims, +) +from xarray.core.variable import ( + IndexVariable, + Variable, + as_compatible_data, + as_variable, +) +from xarray.plot.accessor import DataArrayPlotAccessor +from xarray.plot.utils import _get_units_from_attrs +from xarray.util.deprecation_helpers import _deprecate_positional_args, deprecate_dims + +if TYPE_CHECKING: + from dask.dataframe import DataFrame as DaskDataFrame + from dask.delayed import Delayed + from iris.cube import Cube as iris_Cube + from numpy.typing import ArrayLike + + from xarray.backends import ZarrStore + from xarray.backends.api import T_NetcdfEngine, T_NetcdfTypes + from xarray.core.groupby import DataArrayGroupBy + from xarray.core.resample import DataArrayResample + from xarray.core.rolling import DataArrayCoarsen, DataArrayRolling + from xarray.core.types import ( + CoarsenBoundaryOptions, + DatetimeLike, + DatetimeUnitOptions, + Dims, + ErrorOptions, + ErrorOptionsWithWarn, + InterpOptions, + PadModeOptions, + PadReflectOptions, + QuantileMethods, + QueryEngineOptions, + QueryParserOptions, + ReindexMethodOptions, + Self, + SideOptions, + T_Chunks, + T_Xarray, + ) + from xarray.core.weighted import DataArrayWeighted + from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint + + T_XarrayOther = TypeVar("T_XarrayOther", bound=Union["DataArray", Dataset]) + + +def _check_coords_dims(shape, coords, dim): + sizes = dict(zip(dim, shape)) + for k, v in coords.items(): + if any(d not in dim for d in v.dims): + raise ValueError( + f"coordinate {k} has dimensions {v.dims}, but these " + "are not a subset of the DataArray " + f"dimensions {dim}" + ) + + for d, s in v.sizes.items(): + if s != sizes[d]: + raise ValueError( + f"conflicting sizes for dimension {d!r}: " + f"length {sizes[d]} on the data but length {s} on " + f"coordinate {k!r}" + ) + + +def _infer_coords_and_dims( + shape: tuple[int, ...], + coords: Sequence[Sequence | pd.Index | DataArray] | Mapping | None, + dims: str | Iterable[Hashable] | None, +) -> tuple[Mapping[Hashable, Any], tuple[Hashable, ...]]: + """All the logic for creating a new DataArray""" + + if ( + coords is not None + and not utils.is_dict_like(coords) + and len(coords) != len(shape) + ): + raise ValueError( + f"coords is not dict-like, but it has {len(coords)} items, " + f"which does not match the {len(shape)} dimensions of the " + "data" + ) + + if isinstance(dims, str): + dims = (dims,) + elif dims is None: + dims = [f"dim_{n}" for n in range(len(shape))] + if coords is not None and len(coords) == len(shape): + # try to infer dimensions from coords + if utils.is_dict_like(coords): + dims = list(coords.keys()) + else: + for n, (dim, coord) in enumerate(zip(dims, coords)): + coord = as_variable( + coord, name=dims[n], auto_convert=False + ).to_index_variable() + dims[n] = coord.name + dims_tuple = tuple(dims) + if len(dims_tuple) != len(shape): + raise ValueError( + "different number of dimensions on data " + f"and dims: {len(shape)} vs {len(dims_tuple)}" + ) + for d in dims_tuple: + if not hashable(d): + raise TypeError(f"Dimension {d} is not hashable") + + new_coords: Mapping[Hashable, Any] + + if isinstance(coords, Coordinates): + new_coords = coords + else: + new_coords = {} + if utils.is_dict_like(coords): + for k, v in coords.items(): + new_coords[k] = as_variable(v, name=k, auto_convert=False) + if new_coords[k].dims == (k,): + new_coords[k] = new_coords[k].to_index_variable() + elif coords is not None: + for dim, coord in zip(dims_tuple, coords): + var = as_variable(coord, name=dim, auto_convert=False) + var.dims = (dim,) + new_coords[dim] = var.to_index_variable() + + _check_coords_dims(shape, new_coords, dims_tuple) + + return new_coords, dims_tuple + + +def _check_data_shape( + data: Any, + coords: Sequence[Sequence | pd.Index | DataArray] | Mapping | None, + dims: str | Iterable[Hashable] | None, +) -> Any: + if data is dtypes.NA: + data = np.nan + if coords is not None and utils.is_scalar(data, include_0d=False): + if utils.is_dict_like(coords): + if dims is None: + return data + else: + data_shape = tuple( + ( + as_variable(coords[k], k, auto_convert=False).size + if k in coords.keys() + else 1 + ) + for k in dims + ) + else: + data_shape = tuple( + as_variable(coord, "foo", auto_convert=False).size for coord in coords + ) + data = np.full(data_shape, data) + return data + + +class _LocIndexer(Generic[T_DataArray]): + __slots__ = ("data_array",) + + def __init__(self, data_array: T_DataArray): + self.data_array = data_array + + def __getitem__(self, key) -> T_DataArray: + if not utils.is_dict_like(key): + # expand the indexer so we can handle Ellipsis + labels = indexing.expanded_indexer(key, self.data_array.ndim) + key = dict(zip(self.data_array.dims, labels)) + return self.data_array.sel(key) + + def __setitem__(self, key, value) -> None: + if not utils.is_dict_like(key): + # expand the indexer so we can handle Ellipsis + labels = indexing.expanded_indexer(key, self.data_array.ndim) + key = dict(zip(self.data_array.dims, labels)) + + dim_indexers = map_index_queries(self.data_array, key).dim_indexers + self.data_array[dim_indexers] = value + + +# Used as the key corresponding to a DataArray's variable when converting +# arbitrary DataArray objects to datasets +_THIS_ARRAY = ReprObject("") + + +class DataArray( + AbstractArray, + DataWithCoords, + DataArrayArithmetic, + DataArrayAggregations, +): + """N-dimensional array with labeled coordinates and dimensions. + + DataArray provides a wrapper around numpy ndarrays that uses + labeled dimensions and coordinates to support metadata aware + operations. The API is similar to that for the pandas Series or + DataFrame, but DataArray objects can have any number of dimensions, + and their contents have fixed data types. + + Additional features over raw numpy arrays: + + - Apply operations over dimensions by name: ``x.sum('time')``. + - Select or assign values by integer location (like numpy): + ``x[:10]`` or by label (like pandas): ``x.loc['2014-01-01']`` or + ``x.sel(time='2014-01-01')``. + - Mathematical operations (e.g., ``x - y``) vectorize across + multiple dimensions (known in numpy as "broadcasting") based on + dimension names, regardless of their original order. + - Keep track of arbitrary metadata in the form of a Python + dictionary: ``x.attrs`` + - Convert to a pandas Series: ``x.to_series()``. + + Getting items from or doing mathematical operations with a + DataArray always returns another DataArray. + + Parameters + ---------- + data : array_like + Values for this array. Must be an ``numpy.ndarray``, ndarray + like, or castable to an ``ndarray``. If a self-described xarray + or pandas object, attempts are made to use this array's + metadata to fill in other unspecified arguments. A view of the + array's data is used instead of a copy if possible. + coords : sequence or dict of array_like or :py:class:`~xarray.Coordinates`, optional + Coordinates (tick labels) to use for indexing along each + dimension. The following notations are accepted: + + - mapping {dimension name: array-like} + - sequence of tuples that are valid arguments for + ``xarray.Variable()`` + - (dims, data) + - (dims, data, attrs) + - (dims, data, attrs, encoding) + + Additionally, it is possible to define a coord whose name + does not match the dimension name, or a coord based on multiple + dimensions, with one of the following notations: + + - mapping {coord name: DataArray} + - mapping {coord name: Variable} + - mapping {coord name: (dimension name, array-like)} + - mapping {coord name: (tuple of dimension names, array-like)} + + Alternatively, a :py:class:`~xarray.Coordinates` object may be used in + order to explicitly pass indexes (e.g., a multi-index or any custom + Xarray index) or to bypass the creation of a default index for any + :term:`Dimension coordinate` included in that object. + dims : Hashable or sequence of Hashable, optional + Name(s) of the data dimension(s). Must be either a Hashable + (only for 1D data) or a sequence of Hashables with length equal + to the number of dimensions. If this argument is omitted, + dimension names are taken from ``coords`` (if possible) and + otherwise default to ``['dim_0', ... 'dim_n']``. + name : str or None, optional + Name of this array. + attrs : dict_like or None, optional + Attributes to assign to the new instance. By default, an empty + attribute dictionary is initialized. + indexes : py:class:`~xarray.Indexes` or dict-like, optional + For internal use only. For passing indexes objects to the + new DataArray, use the ``coords`` argument instead with a + :py:class:`~xarray.Coordinate` object (both coordinate variables + and indexes will be extracted from the latter). + + Examples + -------- + Create data: + + >>> np.random.seed(0) + >>> temperature = 15 + 8 * np.random.randn(2, 2, 3) + >>> lon = [[-99.83, -99.32], [-99.79, -99.23]] + >>> lat = [[42.25, 42.21], [42.63, 42.59]] + >>> time = pd.date_range("2014-09-06", periods=3) + >>> reference_time = pd.Timestamp("2014-09-05") + + Initialize a dataarray with multiple dimensions: + + >>> da = xr.DataArray( + ... data=temperature, + ... dims=["x", "y", "time"], + ... coords=dict( + ... lon=(["x", "y"], lon), + ... lat=(["x", "y"], lat), + ... time=time, + ... reference_time=reference_time, + ... ), + ... attrs=dict( + ... description="Ambient temperature.", + ... units="degC", + ... ), + ... ) + >>> da + Size: 96B + array([[[29.11241877, 18.20125767, 22.82990387], + [32.92714559, 29.94046392, 7.18177696]], + + [[22.60070734, 13.78914233, 14.17424919], + [18.28478802, 16.15234857, 26.63418806]]]) + Coordinates: + lon (x, y) float64 32B -99.83 -99.32 -99.79 -99.23 + lat (x, y) float64 32B 42.25 42.21 42.63 42.59 + * time (time) datetime64[ns] 24B 2014-09-06 2014-09-07 2014-09-08 + reference_time datetime64[ns] 8B 2014-09-05 + Dimensions without coordinates: x, y + Attributes: + description: Ambient temperature. + units: degC + + Find out where the coldest temperature was: + + >>> da.isel(da.argmin(...)) + Size: 8B + array(7.18177696) + Coordinates: + lon float64 8B -99.32 + lat float64 8B 42.21 + time datetime64[ns] 8B 2014-09-08 + reference_time datetime64[ns] 8B 2014-09-05 + Attributes: + description: Ambient temperature. + units: degC + """ + + _cache: dict[str, Any] + _coords: dict[Any, Variable] + _close: Callable[[], None] | None + _indexes: dict[Hashable, Index] + _name: Hashable | None + _variable: Variable + + __slots__ = ( + "_cache", + "_coords", + "_close", + "_indexes", + "_name", + "_variable", + "__weakref__", + ) + + dt = utils.UncachedAccessor(CombinedDatetimelikeAccessor["DataArray"]) + + def __init__( + self, + data: Any = dtypes.NA, + coords: Sequence[Sequence | pd.Index | DataArray] | Mapping | None = None, + dims: str | Iterable[Hashable] | None = None, + name: Hashable | None = None, + attrs: Mapping | None = None, + # internal parameters + indexes: Mapping[Any, Index] | None = None, + fastpath: bool = False, + ) -> None: + if fastpath: + variable = data + assert dims is None + assert attrs is None + assert indexes is not None + else: + if indexes is not None: + raise ValueError( + "Explicitly passing indexes via the `indexes` argument is not supported " + "when `fastpath=False`. Use the `coords` argument instead." + ) + + # try to fill in arguments from data if they weren't supplied + if coords is None: + if isinstance(data, DataArray): + coords = data.coords + elif isinstance(data, pd.Series): + coords = [data.index] + elif isinstance(data, pd.DataFrame): + coords = [data.index, data.columns] + elif isinstance(data, (pd.Index, IndexVariable)): + coords = [data] + + if dims is None: + dims = getattr(data, "dims", getattr(coords, "dims", None)) + if name is None: + name = getattr(data, "name", None) + if attrs is None and not isinstance(data, PANDAS_TYPES): + attrs = getattr(data, "attrs", None) + + data = _check_data_shape(data, coords, dims) + data = as_compatible_data(data) + coords, dims = _infer_coords_and_dims(data.shape, coords, dims) + variable = Variable(dims, data, attrs, fastpath=True) + + if not isinstance(coords, Coordinates): + coords = create_coords_with_default_indexes(coords) + indexes = dict(coords.xindexes) + coords = {k: v.copy() for k, v in coords.variables.items()} + + # These fully describe a DataArray + self._variable = variable + assert isinstance(coords, dict) + self._coords = coords + self._name = name + self._indexes = indexes # type: ignore[assignment] + + self._close = None + + @classmethod + def _construct_direct( + cls, + variable: Variable, + coords: dict[Any, Variable], + name: Hashable, + indexes: dict[Hashable, Index], + ) -> Self: + """Shortcut around __init__ for internal use when we want to skip + costly validation + """ + obj = object.__new__(cls) + obj._variable = variable + obj._coords = coords + obj._name = name + obj._indexes = indexes + obj._close = None + return obj + + def _replace( + self, + variable: Variable | None = None, + coords=None, + name: Hashable | None | Default = _default, + indexes=None, + ) -> Self: + if variable is None: + variable = self.variable + if coords is None: + coords = self._coords + if indexes is None: + indexes = self._indexes + if name is _default: + name = self.name + return type(self)(variable, coords, name=name, indexes=indexes, fastpath=True) + + def _replace_maybe_drop_dims( + self, + variable: Variable, + name: Hashable | None | Default = _default, + ) -> Self: + if variable.dims == self.dims and variable.shape == self.shape: + coords = self._coords.copy() + indexes = self._indexes + elif variable.dims == self.dims: + # Shape has changed (e.g. from reduce(..., keepdims=True) + new_sizes = dict(zip(self.dims, variable.shape)) + coords = { + k: v + for k, v in self._coords.items() + if v.shape == tuple(new_sizes[d] for d in v.dims) + } + indexes = filter_indexes_from_coords(self._indexes, set(coords)) + else: + allowed_dims = set(variable.dims) + coords = { + k: v for k, v in self._coords.items() if set(v.dims) <= allowed_dims + } + indexes = filter_indexes_from_coords(self._indexes, set(coords)) + return self._replace(variable, coords, name, indexes=indexes) + + def _overwrite_indexes( + self, + indexes: Mapping[Any, Index], + variables: Mapping[Any, Variable] | None = None, + drop_coords: list[Hashable] | None = None, + rename_dims: Mapping[Any, Any] | None = None, + ) -> Self: + """Maybe replace indexes and their corresponding coordinates.""" + if not indexes: + return self + + if variables is None: + variables = {} + if drop_coords is None: + drop_coords = [] + + new_variable = self.variable.copy() + new_coords = self._coords.copy() + new_indexes = dict(self._indexes) + + for name in indexes: + new_coords[name] = variables[name] + new_indexes[name] = indexes[name] + + for name in drop_coords: + new_coords.pop(name) + new_indexes.pop(name) + + if rename_dims: + new_variable.dims = tuple(rename_dims.get(d, d) for d in new_variable.dims) + + return self._replace( + variable=new_variable, coords=new_coords, indexes=new_indexes + ) + + def _to_temp_dataset(self) -> Dataset: + return self._to_dataset_whole(name=_THIS_ARRAY, shallow_copy=False) + + def _from_temp_dataset( + self, dataset: Dataset, name: Hashable | None | Default = _default + ) -> Self: + variable = dataset._variables.pop(_THIS_ARRAY) + coords = dataset._variables + indexes = dataset._indexes + return self._replace(variable, coords, name, indexes=indexes) + + def _to_dataset_split(self, dim: Hashable) -> Dataset: + """splits dataarray along dimension 'dim'""" + + def subset(dim, label): + array = self.loc[{dim: label}] + array.attrs = {} + return as_variable(array) + + variables_from_split = { + label: subset(dim, label) for label in self.get_index(dim) + } + coord_names = set(self._coords) - {dim} + + ambiguous_vars = set(variables_from_split) & coord_names + if ambiguous_vars: + rename_msg_fmt = ", ".join([f"{v}=..." for v in sorted(ambiguous_vars)]) + raise ValueError( + f"Splitting along the dimension {dim!r} would produce the variables " + f"{tuple(sorted(ambiguous_vars))} which are also existing coordinate " + f"variables. Use DataArray.rename({rename_msg_fmt}) or " + f"DataArray.assign_coords({dim}=...) to resolve this ambiguity." + ) + + variables = variables_from_split | { + k: v for k, v in self._coords.items() if k != dim + } + indexes = filter_indexes_from_coords(self._indexes, coord_names) + dataset = Dataset._construct_direct( + variables, coord_names, indexes=indexes, attrs=self.attrs + ) + return dataset + + def _to_dataset_whole( + self, name: Hashable = None, shallow_copy: bool = True + ) -> Dataset: + if name is None: + name = self.name + if name is None: + raise ValueError( + "unable to convert unnamed DataArray to a " + "Dataset without providing an explicit name" + ) + if name in self.coords: + raise ValueError( + "cannot create a Dataset from a DataArray with " + "the same name as one of its coordinates" + ) + # use private APIs for speed: this is called by _to_temp_dataset(), + # which is used in the guts of a lot of operations (e.g., reindex) + variables = self._coords.copy() + variables[name] = self.variable + if shallow_copy: + for k in variables: + variables[k] = variables[k].copy(deep=False) + indexes = self._indexes + + coord_names = set(self._coords) + return Dataset._construct_direct(variables, coord_names, indexes=indexes) + + def to_dataset( + self, + dim: Hashable = None, + *, + name: Hashable = None, + promote_attrs: bool = False, + ) -> Dataset: + """Convert a DataArray to a Dataset. + + Parameters + ---------- + dim : Hashable, optional + Name of the dimension on this array along which to split this array + into separate variables. If not provided, this array is converted + into a Dataset of one variable. + name : Hashable, optional + Name to substitute for this array's name. Only valid if ``dim`` is + not provided. + promote_attrs : bool, default: False + Set to True to shallow copy attrs of DataArray to returned Dataset. + + Returns + ------- + dataset : Dataset + """ + if dim is not None and dim not in self.dims: + raise TypeError( + f"{dim} is not a dim. If supplying a ``name``, pass as a kwarg." + ) + + if dim is not None: + if name is not None: + raise TypeError("cannot supply both dim and name arguments") + result = self._to_dataset_split(dim) + else: + result = self._to_dataset_whole(name) + + if promote_attrs: + result.attrs = dict(self.attrs) + + return result + + @property + def name(self) -> Hashable | None: + """The name of this array.""" + return self._name + + @name.setter + def name(self, value: Hashable | None) -> None: + self._name = value + + @property + def variable(self) -> Variable: + """Low level interface to the Variable object for this DataArray.""" + return self._variable + + @property + def dtype(self) -> np.dtype: + """ + Data-type of the array’s elements. + + See Also + -------- + ndarray.dtype + numpy.dtype + """ + return self.variable.dtype + + @property + def shape(self) -> tuple[int, ...]: + """ + Tuple of array dimensions. + + See Also + -------- + numpy.ndarray.shape + """ + return self.variable.shape + + @property + def size(self) -> int: + """ + Number of elements in the array. + + Equal to ``np.prod(a.shape)``, i.e., the product of the array’s dimensions. + + See Also + -------- + numpy.ndarray.size + """ + return self.variable.size + + @property + def nbytes(self) -> int: + """ + Total bytes consumed by the elements of this DataArray's data. + + If the underlying data array does not include ``nbytes``, estimates + the bytes consumed based on the ``size`` and ``dtype``. + """ + return self.variable.nbytes + + @property + def ndim(self) -> int: + """ + Number of array dimensions. + + See Also + -------- + numpy.ndarray.ndim + """ + return self.variable.ndim + + def __len__(self) -> int: + return len(self.variable) + + @property + def data(self) -> Any: + """ + The DataArray's data as an array. The underlying array type + (e.g. dask, sparse, pint) is preserved. + + See Also + -------- + DataArray.to_numpy + DataArray.as_numpy + DataArray.values + """ + return self.variable.data + + @data.setter + def data(self, value: Any) -> None: + self.variable.data = value + + @property + def values(self) -> np.ndarray: + """ + The array's data converted to numpy.ndarray. + + This will attempt to convert the array naively using np.array(), + which will raise an error if the array type does not support + coercion like this (e.g. cupy). + + Note that this array is not copied; operations on it follow + numpy's rules of what generates a view vs. a copy, and changes + to this array may be reflected in the DataArray as well. + """ + return self.variable.values + + @values.setter + def values(self, value: Any) -> None: + self.variable.values = value + + def to_numpy(self) -> np.ndarray: + """ + Coerces wrapped data to numpy and returns a numpy.ndarray. + + See Also + -------- + DataArray.as_numpy : Same but returns the surrounding DataArray instead. + Dataset.as_numpy + DataArray.values + DataArray.data + """ + return self.variable.to_numpy() + + def as_numpy(self) -> Self: + """ + Coerces wrapped data and coordinates into numpy arrays, returning a DataArray. + + See Also + -------- + DataArray.to_numpy : Same but returns only the data as a numpy.ndarray object. + Dataset.as_numpy : Converts all variables in a Dataset. + DataArray.values + DataArray.data + """ + coords = {k: v.as_numpy() for k, v in self._coords.items()} + return self._replace(self.variable.as_numpy(), coords, indexes=self._indexes) + + @property + def _in_memory(self) -> bool: + return self.variable._in_memory + + def _to_index(self) -> pd.Index: + return self.variable._to_index() + + def to_index(self) -> pd.Index: + """Convert this variable to a pandas.Index. Only possible for 1D + arrays. + """ + return self.variable.to_index() + + @property + def dims(self) -> tuple[Hashable, ...]: + """Tuple of dimension names associated with this array. + + Note that the type of this property is inconsistent with + `Dataset.dims`. See `Dataset.sizes` and `DataArray.sizes` for + consistently named properties. + + See Also + -------- + DataArray.sizes + Dataset.dims + """ + return self.variable.dims + + @dims.setter + def dims(self, value: Any) -> NoReturn: + raise AttributeError( + "you cannot assign dims on a DataArray. Use " + ".rename() or .swap_dims() instead." + ) + + def _item_key_to_dict(self, key: Any) -> Mapping[Hashable, Any]: + if utils.is_dict_like(key): + return key + key = indexing.expanded_indexer(key, self.ndim) + return dict(zip(self.dims, key)) + + def _getitem_coord(self, key: Any) -> Self: + from xarray.core.dataset import _get_virtual_variable + + try: + var = self._coords[key] + except KeyError: + dim_sizes = dict(zip(self.dims, self.shape)) + _, key, var = _get_virtual_variable(self._coords, key, dim_sizes) + + return self._replace_maybe_drop_dims(var, name=key) + + def __getitem__(self, key: Any) -> Self: + if isinstance(key, str): + return self._getitem_coord(key) + else: + # xarray-style array indexing + return self.isel(indexers=self._item_key_to_dict(key)) + + def __setitem__(self, key: Any, value: Any) -> None: + if isinstance(key, str): + self.coords[key] = value + else: + # Coordinates in key, value and self[key] should be consistent. + # TODO Coordinate consistency in key is checked here, but it + # causes unnecessary indexing. It should be optimized. + obj = self[key] + if isinstance(value, DataArray): + assert_coordinate_consistent(value, obj.coords.variables) + value = value.variable + # DataArray key -> Variable key + key = { + k: v.variable if isinstance(v, DataArray) else v + for k, v in self._item_key_to_dict(key).items() + } + self.variable[key] = value + + def __delitem__(self, key: Any) -> None: + del self.coords[key] + + @property + def _attr_sources(self) -> Iterable[Mapping[Hashable, Any]]: + """Places to look-up items for attribute-style access""" + yield from self._item_sources + yield self.attrs + + @property + def _item_sources(self) -> Iterable[Mapping[Hashable, Any]]: + """Places to look-up items for key-completion""" + yield HybridMappingProxy(keys=self._coords, mapping=self.coords) + + # virtual coordinates + # uses empty dict -- everything here can already be found in self.coords. + yield HybridMappingProxy(keys=self.dims, mapping={}) + + def __contains__(self, key: Any) -> bool: + return key in self.data + + @property + def loc(self) -> _LocIndexer: + """Attribute for location based indexing like pandas.""" + return _LocIndexer(self) + + @property + def attrs(self) -> dict[Any, Any]: + """Dictionary storing arbitrary metadata with this array.""" + return self.variable.attrs + + @attrs.setter + def attrs(self, value: Mapping[Any, Any]) -> None: + self.variable.attrs = dict(value) + + @property + def encoding(self) -> dict[Any, Any]: + """Dictionary of format-specific settings for how this array should be + serialized.""" + return self.variable.encoding + + @encoding.setter + def encoding(self, value: Mapping[Any, Any]) -> None: + self.variable.encoding = dict(value) + + def reset_encoding(self) -> Self: + warnings.warn( + "reset_encoding is deprecated since 2023.11, use `drop_encoding` instead" + ) + return self.drop_encoding() + + def drop_encoding(self) -> Self: + """Return a new DataArray without encoding on the array or any attached + coords.""" + ds = self._to_temp_dataset().drop_encoding() + return self._from_temp_dataset(ds) + + @property + def indexes(self) -> Indexes: + """Mapping of pandas.Index objects used for label based indexing. + + Raises an error if this Dataset has indexes that cannot be coerced + to pandas.Index objects. + + See Also + -------- + DataArray.xindexes + + """ + return self.xindexes.to_pandas_indexes() + + @property + def xindexes(self) -> Indexes: + """Mapping of :py:class:`~xarray.indexes.Index` objects + used for label based indexing. + """ + return Indexes(self._indexes, {k: self._coords[k] for k in self._indexes}) + + @property + def coords(self) -> DataArrayCoordinates: + """Mapping of :py:class:`~xarray.DataArray` objects corresponding to + coordinate variables. + + See Also + -------- + Coordinates + """ + return DataArrayCoordinates(self) + + @overload + def reset_coords( + self, + names: Dims = None, + *, + drop: Literal[False] = False, + ) -> Dataset: ... + + @overload + def reset_coords( + self, + names: Dims = None, + *, + drop: Literal[True], + ) -> Self: ... + + @_deprecate_positional_args("v2023.10.0") + def reset_coords( + self, + names: Dims = None, + *, + drop: bool = False, + ) -> Self | Dataset: + """Given names of coordinates, reset them to become variables. + + Parameters + ---------- + names : str, Iterable of Hashable or None, optional + Name(s) of non-index coordinates in this dataset to reset into + variables. By default, all non-index coordinates are reset. + drop : bool, default: False + If True, remove coordinates instead of converting them into + variables. + + Returns + ------- + Dataset, or DataArray if ``drop == True`` + + Examples + -------- + >>> temperature = np.arange(25).reshape(5, 5) + >>> pressure = np.arange(50, 75).reshape(5, 5) + >>> da = xr.DataArray( + ... data=temperature, + ... dims=["x", "y"], + ... coords=dict( + ... lon=("x", np.arange(10, 15)), + ... lat=("y", np.arange(20, 25)), + ... Pressure=(["x", "y"], pressure), + ... ), + ... name="Temperature", + ... ) + >>> da + Size: 200B + array([[ 0, 1, 2, 3, 4], + [ 5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19], + [20, 21, 22, 23, 24]]) + Coordinates: + lon (x) int64 40B 10 11 12 13 14 + lat (y) int64 40B 20 21 22 23 24 + Pressure (x, y) int64 200B 50 51 52 53 54 55 56 57 ... 68 69 70 71 72 73 74 + Dimensions without coordinates: x, y + + Return Dataset with target coordinate as a data variable rather than a coordinate variable: + + >>> da.reset_coords(names="Pressure") + Size: 480B + Dimensions: (x: 5, y: 5) + Coordinates: + lon (x) int64 40B 10 11 12 13 14 + lat (y) int64 40B 20 21 22 23 24 + Dimensions without coordinates: x, y + Data variables: + Pressure (x, y) int64 200B 50 51 52 53 54 55 56 ... 68 69 70 71 72 73 74 + Temperature (x, y) int64 200B 0 1 2 3 4 5 6 7 8 ... 17 18 19 20 21 22 23 24 + + Return DataArray without targeted coordinate: + + >>> da.reset_coords(names="Pressure", drop=True) + Size: 200B + array([[ 0, 1, 2, 3, 4], + [ 5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19], + [20, 21, 22, 23, 24]]) + Coordinates: + lon (x) int64 40B 10 11 12 13 14 + lat (y) int64 40B 20 21 22 23 24 + Dimensions without coordinates: x, y + """ + if names is None: + names = set(self.coords) - set(self._indexes) + dataset = self.coords.to_dataset().reset_coords(names, drop) + if drop: + return self._replace(coords=dataset._variables) + if self.name is None: + raise ValueError( + "cannot reset_coords with drop=False on an unnamed DataArrray" + ) + dataset[self.name] = self.variable + return dataset + + def __dask_tokenize__(self) -> object: + from dask.base import normalize_token + + return normalize_token((type(self), self._variable, self._coords, self._name)) + + def __dask_graph__(self): + return self._to_temp_dataset().__dask_graph__() + + def __dask_keys__(self): + return self._to_temp_dataset().__dask_keys__() + + def __dask_layers__(self): + return self._to_temp_dataset().__dask_layers__() + + @property + def __dask_optimize__(self): + return self._to_temp_dataset().__dask_optimize__ + + @property + def __dask_scheduler__(self): + return self._to_temp_dataset().__dask_scheduler__ + + def __dask_postcompute__(self): + func, args = self._to_temp_dataset().__dask_postcompute__() + return self._dask_finalize, (self.name, func) + args + + def __dask_postpersist__(self): + func, args = self._to_temp_dataset().__dask_postpersist__() + return self._dask_finalize, (self.name, func) + args + + @classmethod + def _dask_finalize(cls, results, name, func, *args, **kwargs) -> Self: + ds = func(results, *args, **kwargs) + variable = ds._variables.pop(_THIS_ARRAY) + coords = ds._variables + indexes = ds._indexes + return cls(variable, coords, name=name, indexes=indexes, fastpath=True) + + def load(self, **kwargs) -> Self: + """Manually trigger loading of this array's data from disk or a + remote source into memory and return this array. + + Unlike compute, the original dataset is modified and returned. + + Normally, it should not be necessary to call this method in user code, + because all xarray functions should either work on deferred data or + load data automatically. However, this method can be necessary when + working with many file objects on disk. + + Parameters + ---------- + **kwargs : dict + Additional keyword arguments passed on to ``dask.compute``. + + See Also + -------- + dask.compute + """ + ds = self._to_temp_dataset().load(**kwargs) + new = self._from_temp_dataset(ds) + self._variable = new._variable + self._coords = new._coords + return self + + def compute(self, **kwargs) -> Self: + """Manually trigger loading of this array's data from disk or a + remote source into memory and return a new array. + + Unlike load, the original is left unaltered. + + Normally, it should not be necessary to call this method in user code, + because all xarray functions should either work on deferred data or + load data automatically. However, this method can be necessary when + working with many file objects on disk. + + Parameters + ---------- + **kwargs : dict + Additional keyword arguments passed on to ``dask.compute``. + + Returns + ------- + object : DataArray + New object with the data and all coordinates as in-memory arrays. + + See Also + -------- + dask.compute + """ + new = self.copy(deep=False) + return new.load(**kwargs) + + def persist(self, **kwargs) -> Self: + """Trigger computation in constituent dask arrays + + This keeps them as dask arrays but encourages them to keep data in + memory. This is particularly useful when on a distributed machine. + When on a single machine consider using ``.compute()`` instead. + Like compute (but unlike load), the original dataset is left unaltered. + + Parameters + ---------- + **kwargs : dict + Additional keyword arguments passed on to ``dask.persist``. + + Returns + ------- + object : DataArray + New object with all dask-backed data and coordinates as persisted dask arrays. + + See Also + -------- + dask.persist + """ + ds = self._to_temp_dataset().persist(**kwargs) + return self._from_temp_dataset(ds) + + def copy(self, deep: bool = True, data: Any = None) -> Self: + """Returns a copy of this array. + + If `deep=True`, a deep copy is made of the data array. + Otherwise, a shallow copy is made, and the returned data array's + values are a new view of this data array's values. + + Use `data` to create a new object with the same structure as + original but entirely new data. + + Parameters + ---------- + deep : bool, optional + Whether the data array and its coordinates are loaded into memory + and copied onto the new object. Default is True. + data : array_like, optional + Data to use in the new object. Must have same shape as original. + When `data` is used, `deep` is ignored for all data variables, + and only used for coords. + + Returns + ------- + copy : DataArray + New object with dimensions, attributes, coordinates, name, + encoding, and optionally data copied from original. + + Examples + -------- + Shallow versus deep copy + + >>> array = xr.DataArray([1, 2, 3], dims="x", coords={"x": ["a", "b", "c"]}) + >>> array.copy() + Size: 24B + array([1, 2, 3]) + Coordinates: + * x (x) >> array_0 = array.copy(deep=False) + >>> array_0[0] = 7 + >>> array_0 + Size: 24B + array([7, 2, 3]) + Coordinates: + * x (x) >> array + Size: 24B + array([7, 2, 3]) + Coordinates: + * x (x) >> array.copy(data=[0.1, 0.2, 0.3]) + Size: 24B + array([0.1, 0.2, 0.3]) + Coordinates: + * x (x) >> array + Size: 24B + array([7, 2, 3]) + Coordinates: + * x (x) Self: + variable = self.variable._copy(deep=deep, data=data, memo=memo) + indexes, index_vars = self.xindexes.copy_indexes(deep=deep) + + coords = {} + for k, v in self._coords.items(): + if k in index_vars: + coords[k] = index_vars[k] + else: + coords[k] = v._copy(deep=deep, memo=memo) + + return self._replace(variable, coords, indexes=indexes) + + def __copy__(self) -> Self: + return self._copy(deep=False) + + def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Self: + return self._copy(deep=True, memo=memo) + + # mutable objects should not be Hashable + # https://github.com/python/mypy/issues/4266 + __hash__ = None # type: ignore[assignment] + + @property + def chunks(self) -> tuple[tuple[int, ...], ...] | None: + """ + Tuple of block lengths for this dataarray's data, in order of dimensions, or None if + the underlying data is not a dask array. + + See Also + -------- + DataArray.chunk + DataArray.chunksizes + xarray.unify_chunks + """ + return self.variable.chunks + + @property + def chunksizes(self) -> Mapping[Any, tuple[int, ...]]: + """ + Mapping from dimension names to block lengths for this dataarray's data, or None if + the underlying data is not a dask array. + Cannot be modified directly, but can be modified by calling .chunk(). + + Differs from DataArray.chunks because it returns a mapping of dimensions to chunk shapes + instead of a tuple of chunk shapes. + + See Also + -------- + DataArray.chunk + DataArray.chunks + xarray.unify_chunks + """ + all_variables = [self.variable] + [c.variable for c in self.coords.values()] + return get_chunksizes(all_variables) + + @_deprecate_positional_args("v2023.10.0") + def chunk( + self, + chunks: T_Chunks = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667) + *, + name_prefix: str = "xarray-", + token: str | None = None, + lock: bool = False, + inline_array: bool = False, + chunked_array_type: str | ChunkManagerEntrypoint | None = None, + from_array_kwargs=None, + **chunks_kwargs: Any, + ) -> Self: + """Coerce this array's data into a dask arrays with the given chunks. + + If this variable is a non-dask array, it will be converted to dask + array. If it's a dask array, it will be rechunked to the given chunk + sizes. + + If neither chunks is not provided for one or more dimensions, chunk + sizes along that dimension will not be updated; non-dask arrays will be + converted into dask arrays with a single block. + + Parameters + ---------- + chunks : int, "auto", tuple of int or mapping of Hashable to int, optional + Chunk sizes along each dimension, e.g., ``5``, ``"auto"``, ``(5, 5)`` or + ``{"x": 5, "y": 5}``. + name_prefix : str, optional + Prefix for the name of the new dask array. + token : str, optional + Token uniquely identifying this array. + lock : bool, default: False + Passed on to :py:func:`dask.array.from_array`, if the array is not + already as dask array. + inline_array: bool, default: False + Passed on to :py:func:`dask.array.from_array`, if the array is not + already as dask array. + chunked_array_type: str, optional + Which chunked array type to coerce the underlying data array to. + Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEntryPoint` system. + Experimental API that should not be relied upon. + from_array_kwargs: dict, optional + Additional keyword arguments passed on to the `ChunkManagerEntrypoint.from_array` method used to create + chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg. + For example, with dask as the default chunked array type, this method would pass additional kwargs + to :py:func:`dask.array.from_array`. Experimental API that should not be relied upon. + **chunks_kwargs : {dim: chunks, ...}, optional + The keyword arguments form of ``chunks``. + One of chunks or chunks_kwargs must be provided. + + Returns + ------- + chunked : xarray.DataArray + + See Also + -------- + DataArray.chunks + DataArray.chunksizes + xarray.unify_chunks + dask.array.from_array + """ + if chunks is None: + warnings.warn( + "None value for 'chunks' is deprecated. " + "It will raise an error in the future. Use instead '{}'", + category=FutureWarning, + ) + chunks = {} + + if isinstance(chunks, (float, str, int)): + # ignoring type; unclear why it won't accept a Literal into the value. + chunks = dict.fromkeys(self.dims, chunks) + elif isinstance(chunks, (tuple, list)): + utils.emit_user_level_warning( + "Supplying chunks as dimension-order tuples is deprecated. " + "It will raise an error in the future. Instead use a dict with dimension names as keys.", + category=DeprecationWarning, + ) + chunks = dict(zip(self.dims, chunks)) + else: + chunks = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk") + + ds = self._to_temp_dataset().chunk( + chunks, + name_prefix=name_prefix, + token=token, + lock=lock, + inline_array=inline_array, + chunked_array_type=chunked_array_type, + from_array_kwargs=from_array_kwargs, + ) + return self._from_temp_dataset(ds) + + def isel( + self, + indexers: Mapping[Any, Any] | None = None, + drop: bool = False, + missing_dims: ErrorOptionsWithWarn = "raise", + **indexers_kwargs: Any, + ) -> Self: + """Return a new DataArray whose data is given by selecting indexes + along the specified dimension(s). + + Parameters + ---------- + indexers : dict, optional + A dict with keys matching dimensions and values given + by integers, slice objects or arrays. + indexer can be a integer, slice, array-like or DataArray. + If DataArrays are passed as indexers, xarray-style indexing will be + carried out. See :ref:`indexing` for the details. + One of indexers or indexers_kwargs must be provided. + drop : bool, default: False + If ``drop=True``, drop coordinates variables indexed by integers + instead of making them scalar. + missing_dims : {"raise", "warn", "ignore"}, default: "raise" + What to do if dimensions that should be selected from are not present in the + DataArray: + - "raise": raise an exception + - "warn": raise a warning, and ignore the missing dimensions + - "ignore": ignore the missing dimensions + **indexers_kwargs : {dim: indexer, ...}, optional + The keyword arguments form of ``indexers``. + + Returns + ------- + indexed : xarray.DataArray + + See Also + -------- + Dataset.isel + DataArray.sel + + :doc:`xarray-tutorial:intermediate/indexing/indexing` + Tutorial material on indexing with Xarray objects + + :doc:`xarray-tutorial:fundamentals/02.1_indexing_Basic` + Tutorial material on basics of indexing + + Examples + -------- + >>> da = xr.DataArray(np.arange(25).reshape(5, 5), dims=("x", "y")) + >>> da + Size: 200B + array([[ 0, 1, 2, 3, 4], + [ 5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19], + [20, 21, 22, 23, 24]]) + Dimensions without coordinates: x, y + + >>> tgt_x = xr.DataArray(np.arange(0, 5), dims="points") + >>> tgt_y = xr.DataArray(np.arange(0, 5), dims="points") + >>> da = da.isel(x=tgt_x, y=tgt_y) + >>> da + Size: 40B + array([ 0, 6, 12, 18, 24]) + Dimensions without coordinates: points + """ + + indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel") + + if any(is_fancy_indexer(idx) for idx in indexers.values()): + ds = self._to_temp_dataset()._isel_fancy( + indexers, drop=drop, missing_dims=missing_dims + ) + return self._from_temp_dataset(ds) + + # Much faster algorithm for when all indexers are ints, slices, one-dimensional + # lists, or zero or one-dimensional np.ndarray's + + variable = self._variable.isel(indexers, missing_dims=missing_dims) + indexes, index_variables = isel_indexes(self.xindexes, indexers) + + coords = {} + for coord_name, coord_value in self._coords.items(): + if coord_name in index_variables: + coord_value = index_variables[coord_name] + else: + coord_indexers = { + k: v for k, v in indexers.items() if k in coord_value.dims + } + if coord_indexers: + coord_value = coord_value.isel(coord_indexers) + if drop and coord_value.ndim == 0: + continue + coords[coord_name] = coord_value + + return self._replace(variable=variable, coords=coords, indexes=indexes) + + def sel( + self, + indexers: Mapping[Any, Any] | None = None, + method: str | None = None, + tolerance=None, + drop: bool = False, + **indexers_kwargs: Any, + ) -> Self: + """Return a new DataArray whose data is given by selecting index + labels along the specified dimension(s). + + In contrast to `DataArray.isel`, indexers for this method should use + labels instead of integers. + + Under the hood, this method is powered by using pandas's powerful Index + objects. This makes label based indexing essentially just as fast as + using integer indexing. + + It also means this method uses pandas's (well documented) logic for + indexing. This means you can use string shortcuts for datetime indexes + (e.g., '2000-01' to select all values in January 2000). It also means + that slices are treated as inclusive of both the start and stop values, + unlike normal Python indexing. + + .. warning:: + + Do not try to assign values when using any of the indexing methods + ``isel`` or ``sel``:: + + da = xr.DataArray([0, 1, 2, 3], dims=['x']) + # DO NOT do this + da.isel(x=[0, 1, 2])[1] = -1 + + Assigning values with the chained indexing using ``.sel`` or + ``.isel`` fails silently. + + Parameters + ---------- + indexers : dict, optional + A dict with keys matching dimensions and values given + by scalars, slices or arrays of tick labels. For dimensions with + multi-index, the indexer may also be a dict-like object with keys + matching index level names. + If DataArrays are passed as indexers, xarray-style indexing will be + carried out. See :ref:`indexing` for the details. + One of indexers or indexers_kwargs must be provided. + method : {None, "nearest", "pad", "ffill", "backfill", "bfill"}, optional + Method to use for inexact matches: + + - None (default): only exact matches + - pad / ffill: propagate last valid index value forward + - backfill / bfill: propagate next valid index value backward + - nearest: use nearest valid index value + + tolerance : optional + Maximum distance between original and new labels for inexact + matches. The values of the index at the matching locations must + satisfy the equation ``abs(index[indexer] - target) <= tolerance``. + drop : bool, optional + If ``drop=True``, drop coordinates variables in `indexers` instead + of making them scalar. + **indexers_kwargs : {dim: indexer, ...}, optional + The keyword arguments form of ``indexers``. + One of indexers or indexers_kwargs must be provided. + + Returns + ------- + obj : DataArray + A new DataArray with the same contents as this DataArray, except the + data and each dimension is indexed by the appropriate indexers. + If indexer DataArrays have coordinates that do not conflict with + this object, then these coordinates will be attached. + In general, each array's data will be a view of the array's data + in this DataArray, unless vectorized indexing was triggered by using + an array indexer, in which case the data will be a copy. + + See Also + -------- + Dataset.sel + DataArray.isel + + :doc:`xarray-tutorial:intermediate/indexing/indexing` + Tutorial material on indexing with Xarray objects + + :doc:`xarray-tutorial:fundamentals/02.1_indexing_Basic` + Tutorial material on basics of indexing + + Examples + -------- + >>> da = xr.DataArray( + ... np.arange(25).reshape(5, 5), + ... coords={"x": np.arange(5), "y": np.arange(5)}, + ... dims=("x", "y"), + ... ) + >>> da + Size: 200B + array([[ 0, 1, 2, 3, 4], + [ 5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19], + [20, 21, 22, 23, 24]]) + Coordinates: + * x (x) int64 40B 0 1 2 3 4 + * y (y) int64 40B 0 1 2 3 4 + + >>> tgt_x = xr.DataArray(np.linspace(0, 4, num=5), dims="points") + >>> tgt_y = xr.DataArray(np.linspace(0, 4, num=5), dims="points") + >>> da = da.sel(x=tgt_x, y=tgt_y, method="nearest") + >>> da + Size: 40B + array([ 0, 6, 12, 18, 24]) + Coordinates: + x (points) int64 40B 0 1 2 3 4 + y (points) int64 40B 0 1 2 3 4 + Dimensions without coordinates: points + """ + ds = self._to_temp_dataset().sel( + indexers=indexers, + drop=drop, + method=method, + tolerance=tolerance, + **indexers_kwargs, + ) + return self._from_temp_dataset(ds) + + def head( + self, + indexers: Mapping[Any, int] | int | None = None, + **indexers_kwargs: Any, + ) -> Self: + """Return a new DataArray whose data is given by the the first `n` + values along the specified dimension(s). Default `n` = 5 + + See Also + -------- + Dataset.head + DataArray.tail + DataArray.thin + + Examples + -------- + >>> da = xr.DataArray( + ... np.arange(25).reshape(5, 5), + ... dims=("x", "y"), + ... ) + >>> da + Size: 200B + array([[ 0, 1, 2, 3, 4], + [ 5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19], + [20, 21, 22, 23, 24]]) + Dimensions without coordinates: x, y + + >>> da.head(x=1) + Size: 40B + array([[0, 1, 2, 3, 4]]) + Dimensions without coordinates: x, y + + >>> da.head({"x": 2, "y": 2}) + Size: 32B + array([[0, 1], + [5, 6]]) + Dimensions without coordinates: x, y + """ + ds = self._to_temp_dataset().head(indexers, **indexers_kwargs) + return self._from_temp_dataset(ds) + + def tail( + self, + indexers: Mapping[Any, int] | int | None = None, + **indexers_kwargs: Any, + ) -> Self: + """Return a new DataArray whose data is given by the the last `n` + values along the specified dimension(s). Default `n` = 5 + + See Also + -------- + Dataset.tail + DataArray.head + DataArray.thin + + Examples + -------- + >>> da = xr.DataArray( + ... np.arange(25).reshape(5, 5), + ... dims=("x", "y"), + ... ) + >>> da + Size: 200B + array([[ 0, 1, 2, 3, 4], + [ 5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19], + [20, 21, 22, 23, 24]]) + Dimensions without coordinates: x, y + + >>> da.tail(y=1) + Size: 40B + array([[ 4], + [ 9], + [14], + [19], + [24]]) + Dimensions without coordinates: x, y + + >>> da.tail({"x": 2, "y": 2}) + Size: 32B + array([[18, 19], + [23, 24]]) + Dimensions without coordinates: x, y + """ + ds = self._to_temp_dataset().tail(indexers, **indexers_kwargs) + return self._from_temp_dataset(ds) + + def thin( + self, + indexers: Mapping[Any, int] | int | None = None, + **indexers_kwargs: Any, + ) -> Self: + """Return a new DataArray whose data is given by each `n` value + along the specified dimension(s). + + Examples + -------- + >>> x_arr = np.arange(0, 26) + >>> x_arr + array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25]) + >>> x = xr.DataArray( + ... np.reshape(x_arr, (2, 13)), + ... dims=("x", "y"), + ... coords={"x": [0, 1], "y": np.arange(0, 13)}, + ... ) + >>> x + Size: 208B + array([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + [13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25]]) + Coordinates: + * x (x) int64 16B 0 1 + * y (y) int64 104B 0 1 2 3 4 5 6 7 8 9 10 11 12 + + >>> + >>> x.thin(3) + Size: 40B + array([[ 0, 3, 6, 9, 12]]) + Coordinates: + * x (x) int64 8B 0 + * y (y) int64 40B 0 3 6 9 12 + >>> x.thin({"x": 2, "y": 5}) + Size: 24B + array([[ 0, 5, 10]]) + Coordinates: + * x (x) int64 8B 0 + * y (y) int64 24B 0 5 10 + + See Also + -------- + Dataset.thin + DataArray.head + DataArray.tail + """ + ds = self._to_temp_dataset().thin(indexers, **indexers_kwargs) + return self._from_temp_dataset(ds) + + @_deprecate_positional_args("v2023.10.0") + def broadcast_like( + self, + other: T_DataArrayOrSet, + *, + exclude: Iterable[Hashable] | None = None, + ) -> Self: + """Broadcast this DataArray against another Dataset or DataArray. + + This is equivalent to xr.broadcast(other, self)[1] + + xarray objects are broadcast against each other in arithmetic + operations, so this method is not be necessary for most uses. + + If no change is needed, the input data is returned to the output + without being copied. + + If new coords are added by the broadcast, their values are + NaN filled. + + Parameters + ---------- + other : Dataset or DataArray + Object against which to broadcast this array. + exclude : iterable of Hashable, optional + Dimensions that must not be broadcasted + + Returns + ------- + new_da : DataArray + The caller broadcasted against ``other``. + + Examples + -------- + >>> arr1 = xr.DataArray( + ... np.random.randn(2, 3), + ... dims=("x", "y"), + ... coords={"x": ["a", "b"], "y": ["a", "b", "c"]}, + ... ) + >>> arr2 = xr.DataArray( + ... np.random.randn(3, 2), + ... dims=("x", "y"), + ... coords={"x": ["a", "b", "c"], "y": ["a", "b"]}, + ... ) + >>> arr1 + Size: 48B + array([[ 1.76405235, 0.40015721, 0.97873798], + [ 2.2408932 , 1.86755799, -0.97727788]]) + Coordinates: + * x (x) >> arr2 + Size: 48B + array([[ 0.95008842, -0.15135721], + [-0.10321885, 0.4105985 ], + [ 0.14404357, 1.45427351]]) + Coordinates: + * x (x) >> arr1.broadcast_like(arr2) + Size: 72B + array([[ 1.76405235, 0.40015721, 0.97873798], + [ 2.2408932 , 1.86755799, -0.97727788], + [ nan, nan, nan]]) + Coordinates: + * x (x) Self: + """Callback called from ``Aligner`` to create a new reindexed DataArray.""" + + if isinstance(fill_value, dict): + fill_value = fill_value.copy() + sentinel = object() + value = fill_value.pop(self.name, sentinel) + if value is not sentinel: + fill_value[_THIS_ARRAY] = value + + ds = self._to_temp_dataset() + reindexed = ds._reindex_callback( + aligner, + dim_pos_indexers, + variables, + indexes, + fill_value, + exclude_dims, + exclude_vars, + ) + + da = self._from_temp_dataset(reindexed) + da.encoding = self.encoding + + return da + + @_deprecate_positional_args("v2023.10.0") + def reindex_like( + self, + other: T_DataArrayOrSet, + *, + method: ReindexMethodOptions = None, + tolerance: int | float | Iterable[int | float] | None = None, + copy: bool = True, + fill_value=dtypes.NA, + ) -> Self: + """ + Conform this object onto the indexes of another object, for indexes which the + objects share. Missing values are filled with ``fill_value``. The default fill + value is NaN. + + Parameters + ---------- + other : Dataset or DataArray + Object with an 'indexes' attribute giving a mapping from dimension + names to pandas.Index objects, which provides coordinates upon + which to index the variables in this dataset. The indexes on this + other object need not be the same as the indexes on this + dataset. Any mis-matched index values will be filled in with + NaN, and any mis-matched dimension names will simply be ignored. + method : {None, "nearest", "pad", "ffill", "backfill", "bfill"}, optional + Method to use for filling index values from other not found on this + data array: + + - None (default): don't fill gaps + - pad / ffill: propagate last valid index value forward + - backfill / bfill: propagate next valid index value backward + - nearest: use nearest valid index value + + tolerance : optional + Maximum distance between original and new labels for inexact + matches. The values of the index at the matching locations must + satisfy the equation ``abs(index[indexer] - target) <= tolerance``. + Tolerance may be a scalar value, which applies the same tolerance + to all values, or list-like, which applies variable tolerance per + element. List-like must be the same size as the index and its dtype + must exactly match the index’s type. + copy : bool, default: True + If ``copy=True``, data in the return value is always copied. If + ``copy=False`` and reindexing is unnecessary, or can be performed + with only slice operations, then the output may share memory with + the input. In either case, a new xarray object is always returned. + fill_value : scalar or dict-like, optional + Value to use for newly missing values. If a dict-like, maps + variable names (including coordinates) to fill values. Use this + data array's name to refer to the data array's values. + + Returns + ------- + reindexed : DataArray + Another dataset array, with this array's data but coordinates from + the other object. + + Examples + -------- + >>> data = np.arange(12).reshape(4, 3) + >>> da1 = xr.DataArray( + ... data=data, + ... dims=["x", "y"], + ... coords={"x": [10, 20, 30, 40], "y": [70, 80, 90]}, + ... ) + >>> da1 + Size: 96B + array([[ 0, 1, 2], + [ 3, 4, 5], + [ 6, 7, 8], + [ 9, 10, 11]]) + Coordinates: + * x (x) int64 32B 10 20 30 40 + * y (y) int64 24B 70 80 90 + >>> da2 = xr.DataArray( + ... data=data, + ... dims=["x", "y"], + ... coords={"x": [40, 30, 20, 10], "y": [90, 80, 70]}, + ... ) + >>> da2 + Size: 96B + array([[ 0, 1, 2], + [ 3, 4, 5], + [ 6, 7, 8], + [ 9, 10, 11]]) + Coordinates: + * x (x) int64 32B 40 30 20 10 + * y (y) int64 24B 90 80 70 + + Reindexing with both DataArrays having the same coordinates set, but in different order: + + >>> da1.reindex_like(da2) + Size: 96B + array([[11, 10, 9], + [ 8, 7, 6], + [ 5, 4, 3], + [ 2, 1, 0]]) + Coordinates: + * x (x) int64 32B 40 30 20 10 + * y (y) int64 24B 90 80 70 + + Reindexing with the other array having additional coordinates: + + >>> da3 = xr.DataArray( + ... data=data, + ... dims=["x", "y"], + ... coords={"x": [20, 10, 29, 39], "y": [70, 80, 90]}, + ... ) + >>> da1.reindex_like(da3) + Size: 96B + array([[ 3., 4., 5.], + [ 0., 1., 2.], + [nan, nan, nan], + [nan, nan, nan]]) + Coordinates: + * x (x) int64 32B 20 10 29 39 + * y (y) int64 24B 70 80 90 + + Filling missing values with the previous valid index with respect to the coordinates' value: + + >>> da1.reindex_like(da3, method="ffill") + Size: 96B + array([[3, 4, 5], + [0, 1, 2], + [3, 4, 5], + [6, 7, 8]]) + Coordinates: + * x (x) int64 32B 20 10 29 39 + * y (y) int64 24B 70 80 90 + + Filling missing values while tolerating specified error for inexact matches: + + >>> da1.reindex_like(da3, method="ffill", tolerance=5) + Size: 96B + array([[ 3., 4., 5.], + [ 0., 1., 2.], + [nan, nan, nan], + [nan, nan, nan]]) + Coordinates: + * x (x) int64 32B 20 10 29 39 + * y (y) int64 24B 70 80 90 + + Filling missing values with manually specified values: + + >>> da1.reindex_like(da3, fill_value=19) + Size: 96B + array([[ 3, 4, 5], + [ 0, 1, 2], + [19, 19, 19], + [19, 19, 19]]) + Coordinates: + * x (x) int64 32B 20 10 29 39 + * y (y) int64 24B 70 80 90 + + Note that unlike ``broadcast_like``, ``reindex_like`` doesn't create new dimensions: + + >>> da1.sel(x=20) + Size: 24B + array([3, 4, 5]) + Coordinates: + x int64 8B 20 + * y (y) int64 24B 70 80 90 + + ...so ``b`` in not added here: + + >>> da1.sel(x=20).reindex_like(da1) + Size: 24B + array([3, 4, 5]) + Coordinates: + x int64 8B 20 + * y (y) int64 24B 70 80 90 + + See Also + -------- + DataArray.reindex + DataArray.broadcast_like + align + """ + return alignment.reindex_like( + self, + other=other, + method=method, + tolerance=tolerance, + copy=copy, + fill_value=fill_value, + ) + + @_deprecate_positional_args("v2023.10.0") + def reindex( + self, + indexers: Mapping[Any, Any] | None = None, + *, + method: ReindexMethodOptions = None, + tolerance: float | Iterable[float] | None = None, + copy: bool = True, + fill_value=dtypes.NA, + **indexers_kwargs: Any, + ) -> Self: + """Conform this object onto the indexes of another object, filling in + missing values with ``fill_value``. The default fill value is NaN. + + Parameters + ---------- + indexers : dict, optional + Dictionary with keys given by dimension names and values given by + arrays of coordinates tick labels. Any mis-matched coordinate + values will be filled in with NaN, and any mis-matched dimension + names will simply be ignored. + One of indexers or indexers_kwargs must be provided. + copy : bool, optional + If ``copy=True``, data in the return value is always copied. If + ``copy=False`` and reindexing is unnecessary, or can be performed + with only slice operations, then the output may share memory with + the input. In either case, a new xarray object is always returned. + method : {None, 'nearest', 'pad'/'ffill', 'backfill'/'bfill'}, optional + Method to use for filling index values in ``indexers`` not found on + this data array: + + - None (default): don't fill gaps + - pad / ffill: propagate last valid index value forward + - backfill / bfill: propagate next valid index value backward + - nearest: use nearest valid index value + + tolerance : float | Iterable[float] | None, default: None + Maximum distance between original and new labels for inexact + matches. The values of the index at the matching locations must + satisfy the equation ``abs(index[indexer] - target) <= tolerance``. + Tolerance may be a scalar value, which applies the same tolerance + to all values, or list-like, which applies variable tolerance per + element. List-like must be the same size as the index and its dtype + must exactly match the index’s type. + fill_value : scalar or dict-like, optional + Value to use for newly missing values. If a dict-like, maps + variable names (including coordinates) to fill values. Use this + data array's name to refer to the data array's values. + **indexers_kwargs : {dim: indexer, ...}, optional + The keyword arguments form of ``indexers``. + One of indexers or indexers_kwargs must be provided. + + Returns + ------- + reindexed : DataArray + Another dataset array, with this array's data but replaced + coordinates. + + Examples + -------- + Reverse latitude: + + >>> da = xr.DataArray( + ... np.arange(4), + ... coords=[np.array([90, 89, 88, 87])], + ... dims="lat", + ... ) + >>> da + Size: 32B + array([0, 1, 2, 3]) + Coordinates: + * lat (lat) int64 32B 90 89 88 87 + >>> da.reindex(lat=da.lat[::-1]) + Size: 32B + array([3, 2, 1, 0]) + Coordinates: + * lat (lat) int64 32B 87 88 89 90 + + See Also + -------- + DataArray.reindex_like + align + """ + indexers = utils.either_dict_or_kwargs(indexers, indexers_kwargs, "reindex") + return alignment.reindex( + self, + indexers=indexers, + method=method, + tolerance=tolerance, + copy=copy, + fill_value=fill_value, + ) + + def interp( + self, + coords: Mapping[Any, Any] | None = None, + method: InterpOptions = "linear", + assume_sorted: bool = False, + kwargs: Mapping[str, Any] | None = None, + **coords_kwargs: Any, + ) -> Self: + """Interpolate a DataArray onto new coordinates + + Performs univariate or multivariate interpolation of a DataArray onto + new coordinates using scipy's interpolation routines. If interpolating + along an existing dimension, :py:class:`scipy.interpolate.interp1d` is + called. When interpolating along multiple existing dimensions, an + attempt is made to decompose the interpolation into multiple + 1-dimensional interpolations. If this is possible, + :py:class:`scipy.interpolate.interp1d` is called. Otherwise, + :py:func:`scipy.interpolate.interpn` is called. + + Parameters + ---------- + coords : dict, optional + Mapping from dimension names to the new coordinates. + New coordinate can be a scalar, array-like or DataArray. + If DataArrays are passed as new coordinates, their dimensions are + used for the broadcasting. Missing values are skipped. + method : {"linear", "nearest", "zero", "slinear", "quadratic", "cubic", "polynomial"}, default: "linear" + The method used to interpolate. The method should be supported by + the scipy interpolator: + + - ``interp1d``: {"linear", "nearest", "zero", "slinear", + "quadratic", "cubic", "polynomial"} + - ``interpn``: {"linear", "nearest"} + + If ``"polynomial"`` is passed, the ``order`` keyword argument must + also be provided. + assume_sorted : bool, default: False + If False, values of x can be in any order and they are sorted + first. If True, x has to be an array of monotonically increasing + values. + kwargs : dict-like or None, default: None + Additional keyword arguments passed to scipy's interpolator. Valid + options and their behavior depend whether ``interp1d`` or + ``interpn`` is used. + **coords_kwargs : {dim: coordinate, ...}, optional + The keyword arguments form of ``coords``. + One of coords or coords_kwargs must be provided. + + Returns + ------- + interpolated : DataArray + New dataarray on the new coordinates. + + Notes + ----- + scipy is required. + + See Also + -------- + scipy.interpolate.interp1d + scipy.interpolate.interpn + + :doc:`xarray-tutorial:fundamentals/02.2_manipulating_dimensions` + Tutorial material on manipulating data resolution using :py:func:`~xarray.DataArray.interp` + + Examples + -------- + >>> da = xr.DataArray( + ... data=[[1, 4, 2, 9], [2, 7, 6, np.nan], [6, np.nan, 5, 8]], + ... dims=("x", "y"), + ... coords={"x": [0, 1, 2], "y": [10, 12, 14, 16]}, + ... ) + >>> da + Size: 96B + array([[ 1., 4., 2., 9.], + [ 2., 7., 6., nan], + [ 6., nan, 5., 8.]]) + Coordinates: + * x (x) int64 24B 0 1 2 + * y (y) int64 32B 10 12 14 16 + + 1D linear interpolation (the default): + + >>> da.interp(x=[0, 0.75, 1.25, 1.75]) + Size: 128B + array([[1. , 4. , 2. , nan], + [1.75, 6.25, 5. , nan], + [3. , nan, 5.75, nan], + [5. , nan, 5.25, nan]]) + Coordinates: + * y (y) int64 32B 10 12 14 16 + * x (x) float64 32B 0.0 0.75 1.25 1.75 + + 1D nearest interpolation: + + >>> da.interp(x=[0, 0.75, 1.25, 1.75], method="nearest") + Size: 128B + array([[ 1., 4., 2., 9.], + [ 2., 7., 6., nan], + [ 2., 7., 6., nan], + [ 6., nan, 5., 8.]]) + Coordinates: + * y (y) int64 32B 10 12 14 16 + * x (x) float64 32B 0.0 0.75 1.25 1.75 + + 1D linear extrapolation: + + >>> da.interp( + ... x=[1, 1.5, 2.5, 3.5], + ... method="linear", + ... kwargs={"fill_value": "extrapolate"}, + ... ) + Size: 128B + array([[ 2. , 7. , 6. , nan], + [ 4. , nan, 5.5, nan], + [ 8. , nan, 4.5, nan], + [12. , nan, 3.5, nan]]) + Coordinates: + * y (y) int64 32B 10 12 14 16 + * x (x) float64 32B 1.0 1.5 2.5 3.5 + + 2D linear interpolation: + + >>> da.interp(x=[0, 0.75, 1.25, 1.75], y=[11, 13, 15], method="linear") + Size: 96B + array([[2.5 , 3. , nan], + [4. , 5.625, nan], + [ nan, nan, nan], + [ nan, nan, nan]]) + Coordinates: + * x (x) float64 32B 0.0 0.75 1.25 1.75 + * y (y) int64 24B 11 13 15 + """ + if self.dtype.kind not in "uifc": + raise TypeError( + f"interp only works for a numeric type array. Given {self.dtype}." + ) + ds = self._to_temp_dataset().interp( + coords, + method=method, + kwargs=kwargs, + assume_sorted=assume_sorted, + **coords_kwargs, + ) + return self._from_temp_dataset(ds) + + def interp_like( + self, + other: T_Xarray, + method: InterpOptions = "linear", + assume_sorted: bool = False, + kwargs: Mapping[str, Any] | None = None, + ) -> Self: + """Interpolate this object onto the coordinates of another object, + filling out of range values with NaN. + + If interpolating along a single existing dimension, + :py:class:`scipy.interpolate.interp1d` is called. When interpolating + along multiple existing dimensions, an attempt is made to decompose the + interpolation into multiple 1-dimensional interpolations. If this is + possible, :py:class:`scipy.interpolate.interp1d` is called. Otherwise, + :py:func:`scipy.interpolate.interpn` is called. + + Parameters + ---------- + other : Dataset or DataArray + Object with an 'indexes' attribute giving a mapping from dimension + names to an 1d array-like, which provides coordinates upon + which to index the variables in this dataset. Missing values are skipped. + method : {"linear", "nearest", "zero", "slinear", "quadratic", "cubic", "polynomial"}, default: "linear" + The method used to interpolate. The method should be supported by + the scipy interpolator: + + - {"linear", "nearest", "zero", "slinear", "quadratic", "cubic", + "polynomial"} when ``interp1d`` is called. + - {"linear", "nearest"} when ``interpn`` is called. + + If ``"polynomial"`` is passed, the ``order`` keyword argument must + also be provided. + assume_sorted : bool, default: False + If False, values of coordinates that are interpolated over can be + in any order and they are sorted first. If True, interpolated + coordinates are assumed to be an array of monotonically increasing + values. + kwargs : dict, optional + Additional keyword passed to scipy's interpolator. + + Returns + ------- + interpolated : DataArray + Another dataarray by interpolating this dataarray's data along the + coordinates of the other object. + + Examples + -------- + >>> data = np.arange(12).reshape(4, 3) + >>> da1 = xr.DataArray( + ... data=data, + ... dims=["x", "y"], + ... coords={"x": [10, 20, 30, 40], "y": [70, 80, 90]}, + ... ) + >>> da1 + Size: 96B + array([[ 0, 1, 2], + [ 3, 4, 5], + [ 6, 7, 8], + [ 9, 10, 11]]) + Coordinates: + * x (x) int64 32B 10 20 30 40 + * y (y) int64 24B 70 80 90 + >>> da2 = xr.DataArray( + ... data=data, + ... dims=["x", "y"], + ... coords={"x": [10, 20, 29, 39], "y": [70, 80, 90]}, + ... ) + >>> da2 + Size: 96B + array([[ 0, 1, 2], + [ 3, 4, 5], + [ 6, 7, 8], + [ 9, 10, 11]]) + Coordinates: + * x (x) int64 32B 10 20 29 39 + * y (y) int64 24B 70 80 90 + + Interpolate the values in the coordinates of the other DataArray with respect to the source's values: + + >>> da2.interp_like(da1) + Size: 96B + array([[0. , 1. , 2. ], + [3. , 4. , 5. ], + [6.3, 7.3, 8.3], + [nan, nan, nan]]) + Coordinates: + * x (x) int64 32B 10 20 30 40 + * y (y) int64 24B 70 80 90 + + Could also extrapolate missing values: + + >>> da2.interp_like(da1, kwargs={"fill_value": "extrapolate"}) + Size: 96B + array([[ 0. , 1. , 2. ], + [ 3. , 4. , 5. ], + [ 6.3, 7.3, 8.3], + [ 9.3, 10.3, 11.3]]) + Coordinates: + * x (x) int64 32B 10 20 30 40 + * y (y) int64 24B 70 80 90 + + Notes + ----- + scipy is required. + If the dataarray has object-type coordinates, reindex is used for these + coordinates instead of the interpolation. + + See Also + -------- + DataArray.interp + DataArray.reindex_like + """ + if self.dtype.kind not in "uifc": + raise TypeError( + f"interp only works for a numeric type array. Given {self.dtype}." + ) + ds = self._to_temp_dataset().interp_like( + other, method=method, kwargs=kwargs, assume_sorted=assume_sorted + ) + return self._from_temp_dataset(ds) + + def rename( + self, + new_name_or_name_dict: Hashable | Mapping[Any, Hashable] | None = None, + **names: Hashable, + ) -> Self: + """Returns a new DataArray with renamed coordinates, dimensions or a new name. + + Parameters + ---------- + new_name_or_name_dict : str or dict-like, optional + If the argument is dict-like, it used as a mapping from old + names to new names for coordinates or dimensions. Otherwise, + use the argument as the new name for this array. + **names : Hashable, optional + The keyword arguments form of a mapping from old names to + new names for coordinates or dimensions. + One of new_name_or_name_dict or names must be provided. + + Returns + ------- + renamed : DataArray + Renamed array or array with renamed coordinates. + + See Also + -------- + Dataset.rename + DataArray.swap_dims + """ + if new_name_or_name_dict is None and not names: + # change name to None? + return self._replace(name=None) + if utils.is_dict_like(new_name_or_name_dict) or new_name_or_name_dict is None: + # change dims/coords + name_dict = either_dict_or_kwargs(new_name_or_name_dict, names, "rename") + dataset = self._to_temp_dataset()._rename(name_dict) + return self._from_temp_dataset(dataset) + if utils.hashable(new_name_or_name_dict) and names: + # change name + dims/coords + dataset = self._to_temp_dataset()._rename(names) + dataarray = self._from_temp_dataset(dataset) + return dataarray._replace(name=new_name_or_name_dict) + # only change name + return self._replace(name=new_name_or_name_dict) + + def swap_dims( + self, + dims_dict: Mapping[Any, Hashable] | None = None, + **dims_kwargs, + ) -> Self: + """Returns a new DataArray with swapped dimensions. + + Parameters + ---------- + dims_dict : dict-like + Dictionary whose keys are current dimension names and whose values + are new names. + **dims_kwargs : {existing_dim: new_dim, ...}, optional + The keyword arguments form of ``dims_dict``. + One of dims_dict or dims_kwargs must be provided. + + Returns + ------- + swapped : DataArray + DataArray with swapped dimensions. + + Examples + -------- + >>> arr = xr.DataArray( + ... data=[0, 1], + ... dims="x", + ... coords={"x": ["a", "b"], "y": ("x", [0, 1])}, + ... ) + >>> arr + Size: 16B + array([0, 1]) + Coordinates: + * x (x) >> arr.swap_dims({"x": "y"}) + Size: 16B + array([0, 1]) + Coordinates: + x (y) >> arr.swap_dims({"x": "z"}) + Size: 16B + array([0, 1]) + Coordinates: + x (z) Self: + """Return a new object with an additional axis (or axes) inserted at + the corresponding position in the array shape. The new object is a + view into the underlying array, not a copy. + + If dim is already a scalar coordinate, it will be promoted to a 1D + coordinate consisting of a single value. + + The automatic creation of indexes to back new 1D coordinate variables + controlled by the create_index_for_new_dim kwarg. + + Parameters + ---------- + dim : Hashable, sequence of Hashable, dict, or None, optional + Dimensions to include on the new variable. + If provided as str or sequence of str, then dimensions are inserted + with length 1. If provided as a dict, then the keys are the new + dimensions and the values are either integers (giving the length of + the new dimensions) or sequence/ndarray (giving the coordinates of + the new dimensions). + axis : int, sequence of int, or None, default: None + Axis position(s) where new axis is to be inserted (position(s) on + the result array). If a sequence of integers is passed, + multiple axes are inserted. In this case, dim arguments should be + same length list. If axis=None is passed, all the axes will be + inserted to the start of the result array. + create_index_for_new_dim : bool, default: True + Whether to create new ``PandasIndex`` objects when the object being expanded contains scalar variables with names in ``dim``. + **dim_kwargs : int or sequence or ndarray + The keywords are arbitrary dimensions being inserted and the values + are either the lengths of the new dims (if int is given), or their + coordinates. Note, this is an alternative to passing a dict to the + dim kwarg and will only be used if dim is None. + + Returns + ------- + expanded : DataArray + This object, but with additional dimension(s). + + See Also + -------- + Dataset.expand_dims + + Examples + -------- + >>> da = xr.DataArray(np.arange(5), dims=("x")) + >>> da + Size: 40B + array([0, 1, 2, 3, 4]) + Dimensions without coordinates: x + + Add new dimension of length 2: + + >>> da.expand_dims(dim={"y": 2}) + Size: 80B + array([[0, 1, 2, 3, 4], + [0, 1, 2, 3, 4]]) + Dimensions without coordinates: y, x + + >>> da.expand_dims(dim={"y": 2}, axis=1) + Size: 80B + array([[0, 0], + [1, 1], + [2, 2], + [3, 3], + [4, 4]]) + Dimensions without coordinates: x, y + + Add a new dimension with coordinates from array: + + >>> da.expand_dims(dim={"y": np.arange(5)}, axis=0) + Size: 200B + array([[0, 1, 2, 3, 4], + [0, 1, 2, 3, 4], + [0, 1, 2, 3, 4], + [0, 1, 2, 3, 4], + [0, 1, 2, 3, 4]]) + Coordinates: + * y (y) int64 40B 0 1 2 3 4 + Dimensions without coordinates: x + """ + if isinstance(dim, int): + raise TypeError("dim should be Hashable or sequence/mapping of Hashables") + elif isinstance(dim, Sequence) and not isinstance(dim, str): + if len(dim) != len(set(dim)): + raise ValueError("dims should not contain duplicate values.") + dim = dict.fromkeys(dim, 1) + elif dim is not None and not isinstance(dim, Mapping): + dim = {dim: 1} + + dim = either_dict_or_kwargs(dim, dim_kwargs, "expand_dims") + ds = self._to_temp_dataset().expand_dims( + dim, axis, create_index_for_new_dim=create_index_for_new_dim + ) + return self._from_temp_dataset(ds) + + def set_index( + self, + indexes: Mapping[Any, Hashable | Sequence[Hashable]] | None = None, + append: bool = False, + **indexes_kwargs: Hashable | Sequence[Hashable], + ) -> Self: + """Set DataArray (multi-)indexes using one or more existing + coordinates. + + This legacy method is limited to pandas (multi-)indexes and + 1-dimensional "dimension" coordinates. See + :py:meth:`~DataArray.set_xindex` for setting a pandas or a custom + Xarray-compatible index from one or more arbitrary coordinates. + + Parameters + ---------- + indexes : {dim: index, ...} + Mapping from names matching dimensions and values given + by (lists of) the names of existing coordinates or variables to set + as new (multi-)index. + append : bool, default: False + If True, append the supplied index(es) to the existing index(es). + Otherwise replace the existing index(es). + **indexes_kwargs : optional + The keyword arguments form of ``indexes``. + One of indexes or indexes_kwargs must be provided. + + Returns + ------- + obj : DataArray + Another DataArray, with this data but replaced coordinates. + + Examples + -------- + >>> arr = xr.DataArray( + ... data=np.ones((2, 3)), + ... dims=["x", "y"], + ... coords={"x": range(2), "y": range(3), "a": ("x", [3, 4])}, + ... ) + >>> arr + Size: 48B + array([[1., 1., 1.], + [1., 1., 1.]]) + Coordinates: + * x (x) int64 16B 0 1 + * y (y) int64 24B 0 1 2 + a (x) int64 16B 3 4 + >>> arr.set_index(x="a") + Size: 48B + array([[1., 1., 1.], + [1., 1., 1.]]) + Coordinates: + * x (x) int64 16B 3 4 + * y (y) int64 24B 0 1 2 + + See Also + -------- + DataArray.reset_index + DataArray.set_xindex + """ + ds = self._to_temp_dataset().set_index(indexes, append=append, **indexes_kwargs) + return self._from_temp_dataset(ds) + + def reset_index( + self, + dims_or_levels: Hashable | Sequence[Hashable], + drop: bool = False, + ) -> Self: + """Reset the specified index(es) or multi-index level(s). + + This legacy method is specific to pandas (multi-)indexes and + 1-dimensional "dimension" coordinates. See the more generic + :py:meth:`~DataArray.drop_indexes` and :py:meth:`~DataArray.set_xindex` + method to respectively drop and set pandas or custom indexes for + arbitrary coordinates. + + Parameters + ---------- + dims_or_levels : Hashable or sequence of Hashable + Name(s) of the dimension(s) and/or multi-index level(s) that will + be reset. + drop : bool, default: False + If True, remove the specified indexes and/or multi-index levels + instead of extracting them as new coordinates (default: False). + + Returns + ------- + obj : DataArray + Another dataarray, with this dataarray's data but replaced + coordinates. + + See Also + -------- + DataArray.set_index + DataArray.set_xindex + DataArray.drop_indexes + """ + ds = self._to_temp_dataset().reset_index(dims_or_levels, drop=drop) + return self._from_temp_dataset(ds) + + def set_xindex( + self, + coord_names: str | Sequence[Hashable], + index_cls: type[Index] | None = None, + **options, + ) -> Self: + """Set a new, Xarray-compatible index from one or more existing + coordinate(s). + + Parameters + ---------- + coord_names : str or list + Name(s) of the coordinate(s) used to build the index. + If several names are given, their order matters. + index_cls : subclass of :class:`~xarray.indexes.Index` + The type of index to create. By default, try setting + a pandas (multi-)index from the supplied coordinates. + **options + Options passed to the index constructor. + + Returns + ------- + obj : DataArray + Another dataarray, with this dataarray's data and with a new index. + + """ + ds = self._to_temp_dataset().set_xindex(coord_names, index_cls, **options) + return self._from_temp_dataset(ds) + + def reorder_levels( + self, + dim_order: Mapping[Any, Sequence[int | Hashable]] | None = None, + **dim_order_kwargs: Sequence[int | Hashable], + ) -> Self: + """Rearrange index levels using input order. + + Parameters + ---------- + dim_order dict-like of Hashable to int or Hashable: optional + Mapping from names matching dimensions and values given + by lists representing new level orders. Every given dimension + must have a multi-index. + **dim_order_kwargs : optional + The keyword arguments form of ``dim_order``. + One of dim_order or dim_order_kwargs must be provided. + + Returns + ------- + obj : DataArray + Another dataarray, with this dataarray's data but replaced + coordinates. + """ + ds = self._to_temp_dataset().reorder_levels(dim_order, **dim_order_kwargs) + return self._from_temp_dataset(ds) + + @partial(deprecate_dims, old_name="dimensions") + def stack( + self, + dim: Mapping[Any, Sequence[Hashable]] | None = None, + create_index: bool | None = True, + index_cls: type[Index] = PandasMultiIndex, + **dim_kwargs: Sequence[Hashable], + ) -> Self: + """ + Stack any number of existing dimensions into a single new dimension. + + New dimensions will be added at the end, and the corresponding + coordinate variables will be combined into a MultiIndex. + + Parameters + ---------- + dim : mapping of Hashable to sequence of Hashable + Mapping of the form `new_name=(dim1, dim2, ...)`. + Names of new dimensions, and the existing dimensions that they + replace. An ellipsis (`...`) will be replaced by all unlisted dimensions. + Passing a list containing an ellipsis (`stacked_dim=[...]`) will stack over + all dimensions. + create_index : bool or None, default: True + If True, create a multi-index for each of the stacked dimensions. + If False, don't create any index. + If None, create a multi-index only if exactly one single (1-d) coordinate + index is found for every dimension to stack. + index_cls: class, optional + Can be used to pass a custom multi-index type. Must be an Xarray index that + implements `.stack()`. By default, a pandas multi-index wrapper is used. + **dim_kwargs + The keyword arguments form of ``dim``. + One of dim or dim_kwargs must be provided. + + Returns + ------- + stacked : DataArray + DataArray with stacked data. + + Examples + -------- + >>> arr = xr.DataArray( + ... np.arange(6).reshape(2, 3), + ... coords=[("x", ["a", "b"]), ("y", [0, 1, 2])], + ... ) + >>> arr + Size: 48B + array([[0, 1, 2], + [3, 4, 5]]) + Coordinates: + * x (x) >> stacked = arr.stack(z=("x", "y")) + >>> stacked.indexes["z"] + MultiIndex([('a', 0), + ('a', 1), + ('a', 2), + ('b', 0), + ('b', 1), + ('b', 2)], + name='z') + + See Also + -------- + DataArray.unstack + """ + ds = self._to_temp_dataset().stack( + dim, + create_index=create_index, + index_cls=index_cls, + **dim_kwargs, + ) + return self._from_temp_dataset(ds) + + @_deprecate_positional_args("v2023.10.0") + def unstack( + self, + dim: Dims = None, + *, + fill_value: Any = dtypes.NA, + sparse: bool = False, + ) -> Self: + """ + Unstack existing dimensions corresponding to MultiIndexes into + multiple new dimensions. + + New dimensions will be added at the end. + + Parameters + ---------- + dim : str, Iterable of Hashable or None, optional + Dimension(s) over which to unstack. By default unstacks all + MultiIndexes. + fill_value : scalar or dict-like, default: nan + Value to be filled. If a dict-like, maps variable names to + fill values. Use the data array's name to refer to its + name. If not provided or if the dict-like does not contain + all variables, the dtype's NA value will be used. + sparse : bool, default: False + Use sparse-array if True + + Returns + ------- + unstacked : DataArray + Array with unstacked data. + + Examples + -------- + >>> arr = xr.DataArray( + ... np.arange(6).reshape(2, 3), + ... coords=[("x", ["a", "b"]), ("y", [0, 1, 2])], + ... ) + >>> arr + Size: 48B + array([[0, 1, 2], + [3, 4, 5]]) + Coordinates: + * x (x) >> stacked = arr.stack(z=("x", "y")) + >>> stacked.indexes["z"] + MultiIndex([('a', 0), + ('a', 1), + ('a', 2), + ('b', 0), + ('b', 1), + ('b', 2)], + name='z') + >>> roundtripped = stacked.unstack() + >>> arr.identical(roundtripped) + True + + See Also + -------- + DataArray.stack + """ + ds = self._to_temp_dataset().unstack(dim, fill_value=fill_value, sparse=sparse) + return self._from_temp_dataset(ds) + + def to_unstacked_dataset(self, dim: Hashable, level: int | Hashable = 0) -> Dataset: + """Unstack DataArray expanding to Dataset along a given level of a + stacked coordinate. + + This is the inverse operation of Dataset.to_stacked_array. + + Parameters + ---------- + dim : Hashable + Name of existing dimension to unstack + level : int or Hashable, default: 0 + The MultiIndex level to expand to a dataset along. Can either be + the integer index of the level or its name. + + Returns + ------- + unstacked: Dataset + + Examples + -------- + >>> arr = xr.DataArray( + ... np.arange(6).reshape(2, 3), + ... coords=[("x", ["a", "b"]), ("y", [0, 1, 2])], + ... ) + >>> data = xr.Dataset({"a": arr, "b": arr.isel(y=0)}) + >>> data + Size: 96B + Dimensions: (x: 2, y: 3) + Coordinates: + * x (x) >> stacked = data.to_stacked_array("z", ["x"]) + >>> stacked.indexes["z"] + MultiIndex([('a', 0), + ('a', 1), + ('a', 2), + ('b', nan)], + name='z') + >>> roundtripped = stacked.to_unstacked_dataset(dim="z") + >>> data.identical(roundtripped) + True + + See Also + -------- + Dataset.to_stacked_array + """ + idx = self._indexes[dim].to_pandas_index() + if not isinstance(idx, pd.MultiIndex): + raise ValueError(f"'{dim}' is not a stacked coordinate") + + level_number = idx._get_level_number(level) + variables = idx.levels[level_number] + variable_dim = idx.names[level_number] + + # pull variables out of datarray + data_dict = {} + for k in variables: + data_dict[k] = self.sel({variable_dim: k}, drop=True).squeeze(drop=True) + + # unstacked dataset + return Dataset(data_dict) + + @deprecate_dims + def transpose( + self, + *dim: Hashable, + transpose_coords: bool = True, + missing_dims: ErrorOptionsWithWarn = "raise", + ) -> Self: + """Return a new DataArray object with transposed dimensions. + + Parameters + ---------- + *dim : Hashable, optional + By default, reverse the dimensions. Otherwise, reorder the + dimensions to this order. + transpose_coords : bool, default: True + If True, also transpose the coordinates of this DataArray. + missing_dims : {"raise", "warn", "ignore"}, default: "raise" + What to do if dimensions that should be selected from are not present in the + DataArray: + - "raise": raise an exception + - "warn": raise a warning, and ignore the missing dimensions + - "ignore": ignore the missing dimensions + + Returns + ------- + transposed : DataArray + The returned DataArray's array is transposed. + + Notes + ----- + This operation returns a view of this array's data. It is + lazy for dask-backed DataArrays but not for numpy-backed DataArrays + -- the data will be fully loaded. + + See Also + -------- + numpy.transpose + Dataset.transpose + """ + if dim: + dim = tuple(infix_dims(dim, self.dims, missing_dims)) + variable = self.variable.transpose(*dim) + if transpose_coords: + coords: dict[Hashable, Variable] = {} + for name, coord in self.coords.items(): + coord_dims = tuple(d for d in dim if d in coord.dims) + coords[name] = coord.variable.transpose(*coord_dims) + return self._replace(variable, coords) + else: + return self._replace(variable) + + @property + def T(self) -> Self: + return self.transpose() + + def drop_vars( + self, + names: str | Iterable[Hashable] | Callable[[Self], str | Iterable[Hashable]], + *, + errors: ErrorOptions = "raise", + ) -> Self: + """Returns an array with dropped variables. + + Parameters + ---------- + names : Hashable or iterable of Hashable or Callable + Name(s) of variables to drop. If a Callable, this object is passed as its + only argument and its result is used. + errors : {"raise", "ignore"}, default: "raise" + If 'raise', raises a ValueError error if any of the variable + passed are not in the dataset. If 'ignore', any given names that are in the + DataArray are dropped and no error is raised. + + Returns + ------- + dropped : Dataset + New Dataset copied from `self` with variables removed. + + Examples + ------- + >>> data = np.arange(12).reshape(4, 3) + >>> da = xr.DataArray( + ... data=data, + ... dims=["x", "y"], + ... coords={"x": [10, 20, 30, 40], "y": [70, 80, 90]}, + ... ) + >>> da + Size: 96B + array([[ 0, 1, 2], + [ 3, 4, 5], + [ 6, 7, 8], + [ 9, 10, 11]]) + Coordinates: + * x (x) int64 32B 10 20 30 40 + * y (y) int64 24B 70 80 90 + + Removing a single variable: + + >>> da.drop_vars("x") + Size: 96B + array([[ 0, 1, 2], + [ 3, 4, 5], + [ 6, 7, 8], + [ 9, 10, 11]]) + Coordinates: + * y (y) int64 24B 70 80 90 + Dimensions without coordinates: x + + Removing a list of variables: + + >>> da.drop_vars(["x", "y"]) + Size: 96B + array([[ 0, 1, 2], + [ 3, 4, 5], + [ 6, 7, 8], + [ 9, 10, 11]]) + Dimensions without coordinates: x, y + + >>> da.drop_vars(lambda x: x.coords) + Size: 96B + array([[ 0, 1, 2], + [ 3, 4, 5], + [ 6, 7, 8], + [ 9, 10, 11]]) + Dimensions without coordinates: x, y + """ + if callable(names): + names = names(self) + ds = self._to_temp_dataset().drop_vars(names, errors=errors) + return self._from_temp_dataset(ds) + + def drop_indexes( + self, + coord_names: Hashable | Iterable[Hashable], + *, + errors: ErrorOptions = "raise", + ) -> Self: + """Drop the indexes assigned to the given coordinates. + + Parameters + ---------- + coord_names : hashable or iterable of hashable + Name(s) of the coordinate(s) for which to drop the index. + errors : {"raise", "ignore"}, default: "raise" + If 'raise', raises a ValueError error if any of the coordinates + passed have no index or are not in the dataset. + If 'ignore', no error is raised. + + Returns + ------- + dropped : DataArray + A new dataarray with dropped indexes. + """ + ds = self._to_temp_dataset().drop_indexes(coord_names, errors=errors) + return self._from_temp_dataset(ds) + + def drop( + self, + labels: Mapping[Any, Any] | None = None, + dim: Hashable | None = None, + *, + errors: ErrorOptions = "raise", + **labels_kwargs, + ) -> Self: + """Backward compatible method based on `drop_vars` and `drop_sel` + + Using either `drop_vars` or `drop_sel` is encouraged + + See Also + -------- + DataArray.drop_vars + DataArray.drop_sel + """ + ds = self._to_temp_dataset().drop(labels, dim, errors=errors, **labels_kwargs) + return self._from_temp_dataset(ds) + + def drop_sel( + self, + labels: Mapping[Any, Any] | None = None, + *, + errors: ErrorOptions = "raise", + **labels_kwargs, + ) -> Self: + """Drop index labels from this DataArray. + + Parameters + ---------- + labels : mapping of Hashable to Any + Index labels to drop + errors : {"raise", "ignore"}, default: "raise" + If 'raise', raises a ValueError error if + any of the index labels passed are not + in the dataset. If 'ignore', any given labels that are in the + dataset are dropped and no error is raised. + **labels_kwargs : {dim: label, ...}, optional + The keyword arguments form of ``dim`` and ``labels`` + + Returns + ------- + dropped : DataArray + + Examples + -------- + >>> da = xr.DataArray( + ... np.arange(25).reshape(5, 5), + ... coords={"x": np.arange(0, 9, 2), "y": np.arange(0, 13, 3)}, + ... dims=("x", "y"), + ... ) + >>> da + Size: 200B + array([[ 0, 1, 2, 3, 4], + [ 5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19], + [20, 21, 22, 23, 24]]) + Coordinates: + * x (x) int64 40B 0 2 4 6 8 + * y (y) int64 40B 0 3 6 9 12 + + >>> da.drop_sel(x=[0, 2], y=9) + Size: 96B + array([[10, 11, 12, 14], + [15, 16, 17, 19], + [20, 21, 22, 24]]) + Coordinates: + * x (x) int64 24B 4 6 8 + * y (y) int64 32B 0 3 6 12 + + >>> da.drop_sel({"x": 6, "y": [0, 3]}) + Size: 96B + array([[ 2, 3, 4], + [ 7, 8, 9], + [12, 13, 14], + [22, 23, 24]]) + Coordinates: + * x (x) int64 32B 0 2 4 8 + * y (y) int64 24B 6 9 12 + """ + if labels_kwargs or isinstance(labels, dict): + labels = either_dict_or_kwargs(labels, labels_kwargs, "drop") + + ds = self._to_temp_dataset().drop_sel(labels, errors=errors) + return self._from_temp_dataset(ds) + + def drop_isel( + self, indexers: Mapping[Any, Any] | None = None, **indexers_kwargs + ) -> Self: + """Drop index positions from this DataArray. + + Parameters + ---------- + indexers : mapping of Hashable to Any or None, default: None + Index locations to drop + **indexers_kwargs : {dim: position, ...}, optional + The keyword arguments form of ``dim`` and ``positions`` + + Returns + ------- + dropped : DataArray + + Raises + ------ + IndexError + + Examples + -------- + >>> da = xr.DataArray(np.arange(25).reshape(5, 5), dims=("X", "Y")) + >>> da + Size: 200B + array([[ 0, 1, 2, 3, 4], + [ 5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19], + [20, 21, 22, 23, 24]]) + Dimensions without coordinates: X, Y + + >>> da.drop_isel(X=[0, 4], Y=2) + Size: 96B + array([[ 5, 6, 8, 9], + [10, 11, 13, 14], + [15, 16, 18, 19]]) + Dimensions without coordinates: X, Y + + >>> da.drop_isel({"X": 3, "Y": 3}) + Size: 128B + array([[ 0, 1, 2, 4], + [ 5, 6, 7, 9], + [10, 11, 12, 14], + [20, 21, 22, 24]]) + Dimensions without coordinates: X, Y + """ + dataset = self._to_temp_dataset() + dataset = dataset.drop_isel(indexers=indexers, **indexers_kwargs) + return self._from_temp_dataset(dataset) + + @_deprecate_positional_args("v2023.10.0") + def dropna( + self, + dim: Hashable, + *, + how: Literal["any", "all"] = "any", + thresh: int | None = None, + ) -> Self: + """Returns a new array with dropped labels for missing values along + the provided dimension. + + Parameters + ---------- + dim : Hashable + Dimension along which to drop missing values. Dropping along + multiple dimensions simultaneously is not yet supported. + how : {"any", "all"}, default: "any" + - any : if any NA values are present, drop that label + - all : if all values are NA, drop that label + + thresh : int or None, default: None + If supplied, require this many non-NA values. + + Returns + ------- + dropped : DataArray + + Examples + -------- + >>> temperature = [ + ... [0, 4, 2, 9], + ... [np.nan, np.nan, np.nan, np.nan], + ... [np.nan, 4, 2, 0], + ... [3, 1, 0, 0], + ... ] + >>> da = xr.DataArray( + ... data=temperature, + ... dims=["Y", "X"], + ... coords=dict( + ... lat=("Y", np.array([-20.0, -20.25, -20.50, -20.75])), + ... lon=("X", np.array([10.0, 10.25, 10.5, 10.75])), + ... ), + ... ) + >>> da + Size: 128B + array([[ 0., 4., 2., 9.], + [nan, nan, nan, nan], + [nan, 4., 2., 0.], + [ 3., 1., 0., 0.]]) + Coordinates: + lat (Y) float64 32B -20.0 -20.25 -20.5 -20.75 + lon (X) float64 32B 10.0 10.25 10.5 10.75 + Dimensions without coordinates: Y, X + + >>> da.dropna(dim="Y", how="any") + Size: 64B + array([[0., 4., 2., 9.], + [3., 1., 0., 0.]]) + Coordinates: + lat (Y) float64 16B -20.0 -20.75 + lon (X) float64 32B 10.0 10.25 10.5 10.75 + Dimensions without coordinates: Y, X + + Drop values only if all values along the dimension are NaN: + + >>> da.dropna(dim="Y", how="all") + Size: 96B + array([[ 0., 4., 2., 9.], + [nan, 4., 2., 0.], + [ 3., 1., 0., 0.]]) + Coordinates: + lat (Y) float64 24B -20.0 -20.5 -20.75 + lon (X) float64 32B 10.0 10.25 10.5 10.75 + Dimensions without coordinates: Y, X + """ + ds = self._to_temp_dataset().dropna(dim, how=how, thresh=thresh) + return self._from_temp_dataset(ds) + + def fillna(self, value: Any) -> Self: + """Fill missing values in this object. + + This operation follows the normal broadcasting and alignment rules that + xarray uses for binary arithmetic, except the result is aligned to this + object (``join='left'``) instead of aligned to the intersection of + index coordinates (``join='inner'``). + + Parameters + ---------- + value : scalar, ndarray or DataArray + Used to fill all matching missing values in this array. If the + argument is a DataArray, it is first aligned with (reindexed to) + this array. + + Returns + ------- + filled : DataArray + + Examples + -------- + >>> da = xr.DataArray( + ... np.array([1, 4, np.nan, 0, 3, np.nan]), + ... dims="Z", + ... coords=dict( + ... Z=("Z", np.arange(6)), + ... height=("Z", np.array([0, 10, 20, 30, 40, 50])), + ... ), + ... ) + >>> da + Size: 48B + array([ 1., 4., nan, 0., 3., nan]) + Coordinates: + * Z (Z) int64 48B 0 1 2 3 4 5 + height (Z) int64 48B 0 10 20 30 40 50 + + Fill all NaN values with 0: + + >>> da.fillna(0) + Size: 48B + array([1., 4., 0., 0., 3., 0.]) + Coordinates: + * Z (Z) int64 48B 0 1 2 3 4 5 + height (Z) int64 48B 0 10 20 30 40 50 + + Fill NaN values with corresponding values in array: + + >>> da.fillna(np.array([2, 9, 4, 2, 8, 9])) + Size: 48B + array([1., 4., 4., 0., 3., 9.]) + Coordinates: + * Z (Z) int64 48B 0 1 2 3 4 5 + height (Z) int64 48B 0 10 20 30 40 50 + """ + if utils.is_dict_like(value): + raise TypeError( + "cannot provide fill value as a dictionary with " + "fillna on a DataArray" + ) + out = ops.fillna(self, value) + return out + + def interpolate_na( + self, + dim: Hashable | None = None, + method: InterpOptions = "linear", + limit: int | None = None, + use_coordinate: bool | str = True, + max_gap: ( + None + | int + | float + | str + | pd.Timedelta + | np.timedelta64 + | datetime.timedelta + ) = None, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """Fill in NaNs by interpolating according to different methods. + + Parameters + ---------- + dim : Hashable or None, optional + Specifies the dimension along which to interpolate. + method : {"linear", "nearest", "zero", "slinear", "quadratic", "cubic", "polynomial", \ + "barycentric", "krogh", "pchip", "spline", "akima"}, default: "linear" + String indicating which method to use for interpolation: + + - 'linear': linear interpolation. Additional keyword + arguments are passed to :py:func:`numpy.interp` + - 'nearest', 'zero', 'slinear', 'quadratic', 'cubic', 'polynomial': + are passed to :py:func:`scipy.interpolate.interp1d`. If + ``method='polynomial'``, the ``order`` keyword argument must also be + provided. + - 'barycentric', 'krogh', 'pchip', 'spline', 'akima': use their + respective :py:class:`scipy.interpolate` classes. + + use_coordinate : bool or str, default: True + Specifies which index to use as the x values in the interpolation + formulated as `y = f(x)`. If False, values are treated as if + equally-spaced along ``dim``. If True, the IndexVariable `dim` is + used. If ``use_coordinate`` is a string, it specifies the name of a + coordinate variable to use as the index. + limit : int or None, default: None + Maximum number of consecutive NaNs to fill. Must be greater than 0 + or None for no limit. This filling is done regardless of the size of + the gap in the data. To only interpolate over gaps less than a given length, + see ``max_gap``. + max_gap : int, float, str, pandas.Timedelta, numpy.timedelta64, datetime.timedelta, default: None + Maximum size of gap, a continuous sequence of NaNs, that will be filled. + Use None for no limit. When interpolating along a datetime64 dimension + and ``use_coordinate=True``, ``max_gap`` can be one of the following: + + - a string that is valid input for pandas.to_timedelta + - a :py:class:`numpy.timedelta64` object + - a :py:class:`pandas.Timedelta` object + - a :py:class:`datetime.timedelta` object + + Otherwise, ``max_gap`` must be an int or a float. Use of ``max_gap`` with unlabeled + dimensions has not been implemented yet. Gap length is defined as the difference + between coordinate values at the first data point after a gap and the last value + before a gap. For gaps at the beginning (end), gap length is defined as the difference + between coordinate values at the first (last) valid data point and the first (last) NaN. + For example, consider:: + + + array([nan, nan, nan, 1., nan, nan, 4., nan, nan]) + Coordinates: + * x (x) int64 0 1 2 3 4 5 6 7 8 + + The gap lengths are 3-0 = 3; 6-3 = 3; and 8-6 = 2 respectively + keep_attrs : bool or None, default: None + If True, the dataarray's attributes (`attrs`) will be copied from + the original object to the new one. If False, the new + object will be returned without attributes. + **kwargs : dict, optional + parameters passed verbatim to the underlying interpolation function + + Returns + ------- + interpolated: DataArray + Filled in DataArray. + + See Also + -------- + numpy.interp + scipy.interpolate + + Examples + -------- + >>> da = xr.DataArray( + ... [np.nan, 2, 3, np.nan, 0], dims="x", coords={"x": [0, 1, 2, 3, 4]} + ... ) + >>> da + Size: 40B + array([nan, 2., 3., nan, 0.]) + Coordinates: + * x (x) int64 40B 0 1 2 3 4 + + >>> da.interpolate_na(dim="x", method="linear") + Size: 40B + array([nan, 2. , 3. , 1.5, 0. ]) + Coordinates: + * x (x) int64 40B 0 1 2 3 4 + + >>> da.interpolate_na(dim="x", method="linear", fill_value="extrapolate") + Size: 40B + array([1. , 2. , 3. , 1.5, 0. ]) + Coordinates: + * x (x) int64 40B 0 1 2 3 4 + """ + from xarray.core.missing import interp_na + + return interp_na( + self, + dim=dim, + method=method, + limit=limit, + use_coordinate=use_coordinate, + max_gap=max_gap, + keep_attrs=keep_attrs, + **kwargs, + ) + + def ffill(self, dim: Hashable, limit: int | None = None) -> Self: + """Fill NaN values by propagating values forward + + *Requires bottleneck.* + + Parameters + ---------- + dim : Hashable + Specifies the dimension along which to propagate values when + filling. + limit : int or None, default: None + The maximum number of consecutive NaN values to forward fill. In + other words, if there is a gap with more than this number of + consecutive NaNs, it will only be partially filled. Must be greater + than 0 or None for no limit. Must be None or greater than or equal + to axis length if filling along chunked axes (dimensions). + + Returns + ------- + filled : DataArray + + Examples + -------- + >>> temperature = np.array( + ... [ + ... [np.nan, 1, 3], + ... [0, np.nan, 5], + ... [5, np.nan, np.nan], + ... [3, np.nan, np.nan], + ... [0, 2, 0], + ... ] + ... ) + >>> da = xr.DataArray( + ... data=temperature, + ... dims=["Y", "X"], + ... coords=dict( + ... lat=("Y", np.array([-20.0, -20.25, -20.50, -20.75, -21.0])), + ... lon=("X", np.array([10.0, 10.25, 10.5])), + ... ), + ... ) + >>> da + Size: 120B + array([[nan, 1., 3.], + [ 0., nan, 5.], + [ 5., nan, nan], + [ 3., nan, nan], + [ 0., 2., 0.]]) + Coordinates: + lat (Y) float64 40B -20.0 -20.25 -20.5 -20.75 -21.0 + lon (X) float64 24B 10.0 10.25 10.5 + Dimensions without coordinates: Y, X + + Fill all NaN values: + + >>> da.ffill(dim="Y", limit=None) + Size: 120B + array([[nan, 1., 3.], + [ 0., 1., 5.], + [ 5., 1., 5.], + [ 3., 1., 5.], + [ 0., 2., 0.]]) + Coordinates: + lat (Y) float64 40B -20.0 -20.25 -20.5 -20.75 -21.0 + lon (X) float64 24B 10.0 10.25 10.5 + Dimensions without coordinates: Y, X + + Fill only the first of consecutive NaN values: + + >>> da.ffill(dim="Y", limit=1) + Size: 120B + array([[nan, 1., 3.], + [ 0., 1., 5.], + [ 5., nan, 5.], + [ 3., nan, nan], + [ 0., 2., 0.]]) + Coordinates: + lat (Y) float64 40B -20.0 -20.25 -20.5 -20.75 -21.0 + lon (X) float64 24B 10.0 10.25 10.5 + Dimensions without coordinates: Y, X + """ + from xarray.core.missing import ffill + + return ffill(self, dim, limit=limit) + + def bfill(self, dim: Hashable, limit: int | None = None) -> Self: + """Fill NaN values by propagating values backward + + *Requires bottleneck.* + + Parameters + ---------- + dim : str + Specifies the dimension along which to propagate values when + filling. + limit : int or None, default: None + The maximum number of consecutive NaN values to backward fill. In + other words, if there is a gap with more than this number of + consecutive NaNs, it will only be partially filled. Must be greater + than 0 or None for no limit. Must be None or greater than or equal + to axis length if filling along chunked axes (dimensions). + + Returns + ------- + filled : DataArray + + Examples + -------- + >>> temperature = np.array( + ... [ + ... [0, 1, 3], + ... [0, np.nan, 5], + ... [5, np.nan, np.nan], + ... [3, np.nan, np.nan], + ... [np.nan, 2, 0], + ... ] + ... ) + >>> da = xr.DataArray( + ... data=temperature, + ... dims=["Y", "X"], + ... coords=dict( + ... lat=("Y", np.array([-20.0, -20.25, -20.50, -20.75, -21.0])), + ... lon=("X", np.array([10.0, 10.25, 10.5])), + ... ), + ... ) + >>> da + Size: 120B + array([[ 0., 1., 3.], + [ 0., nan, 5.], + [ 5., nan, nan], + [ 3., nan, nan], + [nan, 2., 0.]]) + Coordinates: + lat (Y) float64 40B -20.0 -20.25 -20.5 -20.75 -21.0 + lon (X) float64 24B 10.0 10.25 10.5 + Dimensions without coordinates: Y, X + + Fill all NaN values: + + >>> da.bfill(dim="Y", limit=None) + Size: 120B + array([[ 0., 1., 3.], + [ 0., 2., 5.], + [ 5., 2., 0.], + [ 3., 2., 0.], + [nan, 2., 0.]]) + Coordinates: + lat (Y) float64 40B -20.0 -20.25 -20.5 -20.75 -21.0 + lon (X) float64 24B 10.0 10.25 10.5 + Dimensions without coordinates: Y, X + + Fill only the first of consecutive NaN values: + + >>> da.bfill(dim="Y", limit=1) + Size: 120B + array([[ 0., 1., 3.], + [ 0., nan, 5.], + [ 5., nan, nan], + [ 3., 2., 0.], + [nan, 2., 0.]]) + Coordinates: + lat (Y) float64 40B -20.0 -20.25 -20.5 -20.75 -21.0 + lon (X) float64 24B 10.0 10.25 10.5 + Dimensions without coordinates: Y, X + """ + from xarray.core.missing import bfill + + return bfill(self, dim, limit=limit) + + def combine_first(self, other: Self) -> Self: + """Combine two DataArray objects, with union of coordinates. + + This operation follows the normal broadcasting and alignment rules of + ``join='outer'``. Default to non-null values of array calling the + method. Use np.nan to fill in vacant cells after alignment. + + Parameters + ---------- + other : DataArray + Used to fill all matching missing values in this array. + + Returns + ------- + DataArray + """ + return ops.fillna(self, other, join="outer") + + def reduce( + self, + func: Callable[..., Any], + dim: Dims = None, + *, + axis: int | Sequence[int] | None = None, + keep_attrs: bool | None = None, + keepdims: bool = False, + **kwargs: Any, + ) -> Self: + """Reduce this array by applying `func` along some dimension(s). + + Parameters + ---------- + func : callable + Function which can be called in the form + `f(x, axis=axis, **kwargs)` to return the result of reducing an + np.ndarray over an integer valued axis. + dim : "...", str, Iterable of Hashable or None, optional + Dimension(s) over which to apply `func`. By default `func` is + applied over all dimensions. + axis : int or sequence of int, optional + Axis(es) over which to repeatedly apply `func`. Only one of the + 'dim' and 'axis' arguments can be supplied. If neither are + supplied, then the reduction is calculated over the flattened array + (by calling `f(x)` without an axis argument). + keep_attrs : bool or None, optional + If True, the variable's attributes (`attrs`) will be copied from + the original object to the new one. If False (default), the new + object will be returned without attributes. + keepdims : bool, default: False + If True, the dimensions which are reduced are left in the result + as dimensions of size one. Coordinates that use these dimensions + are removed. + **kwargs : dict + Additional keyword arguments passed on to `func`. + + Returns + ------- + reduced : DataArray + DataArray with this object's array replaced with an array with + summarized data and the indicated dimension(s) removed. + """ + + var = self.variable.reduce(func, dim, axis, keep_attrs, keepdims, **kwargs) + return self._replace_maybe_drop_dims(var) + + def to_pandas(self) -> Self | pd.Series | pd.DataFrame: + """Convert this array into a pandas object with the same shape. + + The type of the returned object depends on the number of DataArray + dimensions: + + * 0D -> `xarray.DataArray` + * 1D -> `pandas.Series` + * 2D -> `pandas.DataFrame` + + Only works for arrays with 2 or fewer dimensions. + + The DataArray constructor performs the inverse transformation. + + Returns + ------- + result : DataArray | Series | DataFrame + DataArray, pandas Series or pandas DataFrame. + """ + # TODO: consolidate the info about pandas constructors and the + # attributes that correspond to their indexes into a separate module? + constructors = {0: lambda x: x, 1: pd.Series, 2: pd.DataFrame} + try: + constructor = constructors[self.ndim] + except KeyError: + raise ValueError( + f"Cannot convert arrays with {self.ndim} dimensions into " + "pandas objects. Requires 2 or fewer dimensions." + ) + indexes = [self.get_index(dim) for dim in self.dims] + return constructor(self.values, *indexes) + + def to_dataframe( + self, name: Hashable | None = None, dim_order: Sequence[Hashable] | None = None + ) -> pd.DataFrame: + """Convert this array and its coordinates into a tidy pandas.DataFrame. + + The DataFrame is indexed by the Cartesian product of index coordinates + (in the form of a :py:class:`pandas.MultiIndex`). Other coordinates are + included as columns in the DataFrame. + + For 1D and 2D DataArrays, see also :py:func:`DataArray.to_pandas` which + doesn't rely on a MultiIndex to build the DataFrame. + + Parameters + ---------- + name: Hashable or None, optional + Name to give to this array (required if unnamed). + dim_order: Sequence of Hashable or None, optional + Hierarchical dimension order for the resulting dataframe. + Array content is transposed to this order and then written out as flat + vectors in contiguous order, so the last dimension in this list + will be contiguous in the resulting DataFrame. This has a major + influence on which operations are efficient on the resulting + dataframe. + + If provided, must include all dimensions of this DataArray. By default, + dimensions are sorted according to the DataArray dimensions order. + + Returns + ------- + result: DataFrame + DataArray as a pandas DataFrame. + + See also + -------- + DataArray.to_pandas + DataArray.to_series + """ + if name is None: + name = self.name + if name is None: + raise ValueError( + "cannot convert an unnamed DataArray to a " + "DataFrame: use the ``name`` parameter" + ) + if self.ndim == 0: + raise ValueError("cannot convert a scalar to a DataFrame") + + # By using a unique name, we can convert a DataArray into a DataFrame + # even if it shares a name with one of its coordinates. + # I would normally use unique_name = object() but that results in a + # dataframe with columns in the wrong order, for reasons I have not + # been able to debug (possibly a pandas bug?). + unique_name = "__unique_name_identifier_z98xfz98xugfg73ho__" + ds = self._to_dataset_whole(name=unique_name) + + if dim_order is None: + ordered_dims = dict(zip(self.dims, self.shape)) + else: + ordered_dims = ds._normalize_dim_order(dim_order=dim_order) + + df = ds._to_dataframe(ordered_dims) + df.columns = [name if c == unique_name else c for c in df.columns] + return df + + def to_series(self) -> pd.Series: + """Convert this array into a pandas.Series. + + The Series is indexed by the Cartesian product of index coordinates + (in the form of a :py:class:`pandas.MultiIndex`). + + Returns + ------- + result : Series + DataArray as a pandas Series. + + See also + -------- + DataArray.to_pandas + DataArray.to_dataframe + """ + index = self.coords.to_index() + return pd.Series(self.values.reshape(-1), index=index, name=self.name) + + def to_masked_array(self, copy: bool = True) -> np.ma.MaskedArray: + """Convert this array into a numpy.ma.MaskedArray + + Parameters + ---------- + copy : bool, default: True + If True make a copy of the array in the result. If False, + a MaskedArray view of DataArray.values is returned. + + Returns + ------- + result : MaskedArray + Masked where invalid values (nan or inf) occur. + """ + values = self.to_numpy() # only compute lazy arrays once + isnull = pd.isnull(values) + return np.ma.MaskedArray(data=values, mask=isnull, copy=copy) + + # path=None writes to bytes + @overload + def to_netcdf( + self, + path: None = None, + mode: NetcdfWriteModes = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + compute: bool = True, + invalid_netcdf: bool = False, + ) -> bytes: ... + + # compute=False returns dask.Delayed + @overload + def to_netcdf( + self, + path: str | PathLike, + mode: NetcdfWriteModes = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + *, + compute: Literal[False], + invalid_netcdf: bool = False, + ) -> Delayed: ... + + # default return None + @overload + def to_netcdf( + self, + path: str | PathLike, + mode: NetcdfWriteModes = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + compute: Literal[True] = True, + invalid_netcdf: bool = False, + ) -> None: ... + + # if compute cannot be evaluated at type check time + # we may get back either Delayed or None + @overload + def to_netcdf( + self, + path: str | PathLike, + mode: NetcdfWriteModes = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + compute: bool = True, + invalid_netcdf: bool = False, + ) -> Delayed | None: ... + + def to_netcdf( + self, + path: str | PathLike | None = None, + mode: NetcdfWriteModes = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + compute: bool = True, + invalid_netcdf: bool = False, + ) -> bytes | Delayed | None: + """Write DataArray contents to a netCDF file. + + Parameters + ---------- + path : str, path-like or None, optional + Path to which to save this dataset. File-like objects are only + supported by the scipy engine. If no path is provided, this + function returns the resulting netCDF file as bytes; in this case, + we need to use scipy, which does not support netCDF version 4 (the + default format becomes NETCDF3_64BIT). + mode : {"w", "a"}, default: "w" + Write ('w') or append ('a') mode. If mode='w', any existing file at + this location will be overwritten. If mode='a', existing variables + will be overwritten. + format : {"NETCDF4", "NETCDF4_CLASSIC", "NETCDF3_64BIT", \ + "NETCDF3_CLASSIC"}, optional + File format for the resulting netCDF file: + + * NETCDF4: Data is stored in an HDF5 file, using netCDF4 API + features. + * NETCDF4_CLASSIC: Data is stored in an HDF5 file, using only + netCDF 3 compatible API features. + * NETCDF3_64BIT: 64-bit offset version of the netCDF 3 file format, + which fully supports 2+ GB files, but is only compatible with + clients linked against netCDF version 3.6.0 or later. + * NETCDF3_CLASSIC: The classic netCDF 3 file format. It does not + handle 2+ GB files very well. + + All formats are supported by the netCDF4-python library. + scipy.io.netcdf only supports the last two formats. + + The default format is NETCDF4 if you are saving a file to disk and + have the netCDF4-python library available. Otherwise, xarray falls + back to using scipy to write netCDF files and defaults to the + NETCDF3_64BIT format (scipy does not support netCDF4). + group : str, optional + Path to the netCDF4 group in the given file to open (only works for + format='NETCDF4'). The group(s) will be created if necessary. + engine : {"netcdf4", "scipy", "h5netcdf"}, optional + Engine to use when writing netCDF files. If not provided, the + default engine is chosen based on available dependencies, with a + preference for 'netcdf4' if writing to a file on disk. + encoding : dict, optional + Nested dictionary with variable names as keys and dictionaries of + variable specific encodings as values, e.g., + ``{"my_variable": {"dtype": "int16", "scale_factor": 0.1, + "zlib": True}, ...}`` + + The `h5netcdf` engine supports both the NetCDF4-style compression + encoding parameters ``{"zlib": True, "complevel": 9}`` and the h5py + ones ``{"compression": "gzip", "compression_opts": 9}``. + This allows using any compression plugin installed in the HDF5 + library, e.g. LZF. + + unlimited_dims : iterable of Hashable, optional + Dimension(s) that should be serialized as unlimited dimensions. + By default, no dimensions are treated as unlimited dimensions. + Note that unlimited_dims may also be set via + ``dataset.encoding["unlimited_dims"]``. + compute: bool, default: True + If true compute immediately, otherwise return a + ``dask.delayed.Delayed`` object that can be computed later. + invalid_netcdf: bool, default: False + Only valid along with ``engine="h5netcdf"``. If True, allow writing + hdf5 files which are invalid netcdf as described in + https://github.com/h5netcdf/h5netcdf. + + Returns + ------- + store: bytes or Delayed or None + * ``bytes`` if path is None + * ``dask.delayed.Delayed`` if compute is False + * None otherwise + + Notes + ----- + Only xarray.Dataset objects can be written to netCDF files, so + the xarray.DataArray is converted to a xarray.Dataset object + containing a single variable. If the DataArray has no name, or if the + name is the same as a coordinate name, then it is given the name + ``"__xarray_dataarray_variable__"``. + + [netCDF4 backend only] netCDF4 enums are decoded into the + dataarray dtype metadata. + + See Also + -------- + Dataset.to_netcdf + """ + from xarray.backends.api import DATAARRAY_NAME, DATAARRAY_VARIABLE, to_netcdf + + if self.name is None: + # If no name is set then use a generic xarray name + dataset = self.to_dataset(name=DATAARRAY_VARIABLE) + elif self.name in self.coords or self.name in self.dims: + # The name is the same as one of the coords names, which netCDF + # doesn't support, so rename it but keep track of the old name + dataset = self.to_dataset(name=DATAARRAY_VARIABLE) + dataset.attrs[DATAARRAY_NAME] = self.name + else: + # No problems with the name - so we're fine! + dataset = self.to_dataset() + + return to_netcdf( # type: ignore # mypy cannot resolve the overloads:( + dataset, + path, + mode=mode, + format=format, + group=group, + engine=engine, + encoding=encoding, + unlimited_dims=unlimited_dims, + compute=compute, + multifile=False, + invalid_netcdf=invalid_netcdf, + ) + + # compute=True (default) returns ZarrStore + @overload + def to_zarr( + self, + store: MutableMapping | str | PathLike[str] | None = None, + chunk_store: MutableMapping | str | PathLike | None = None, + mode: ZarrWriteModes | None = None, + synchronizer=None, + group: str | None = None, + *, + encoding: Mapping | None = None, + compute: Literal[True] = True, + consolidated: bool | None = None, + append_dim: Hashable | None = None, + region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None, + safe_chunks: bool = True, + storage_options: dict[str, str] | None = None, + zarr_version: int | None = None, + ) -> ZarrStore: ... + + # compute=False returns dask.Delayed + @overload + def to_zarr( + self, + store: MutableMapping | str | PathLike[str] | None = None, + chunk_store: MutableMapping | str | PathLike | None = None, + mode: ZarrWriteModes | None = None, + synchronizer=None, + group: str | None = None, + encoding: Mapping | None = None, + *, + compute: Literal[False], + consolidated: bool | None = None, + append_dim: Hashable | None = None, + region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None, + safe_chunks: bool = True, + storage_options: dict[str, str] | None = None, + zarr_version: int | None = None, + ) -> Delayed: ... + + def to_zarr( + self, + store: MutableMapping | str | PathLike[str] | None = None, + chunk_store: MutableMapping | str | PathLike | None = None, + mode: ZarrWriteModes | None = None, + synchronizer=None, + group: str | None = None, + encoding: Mapping | None = None, + *, + compute: bool = True, + consolidated: bool | None = None, + append_dim: Hashable | None = None, + region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None, + safe_chunks: bool = True, + storage_options: dict[str, str] | None = None, + zarr_version: int | None = None, + ) -> ZarrStore | Delayed: + """Write DataArray contents to a Zarr store + + Zarr chunks are determined in the following way: + + - From the ``chunks`` attribute in each variable's ``encoding`` + (can be set via `DataArray.chunk`). + - If the variable is a Dask array, from the dask chunks + - If neither Dask chunks nor encoding chunks are present, chunks will + be determined automatically by Zarr + - If both Dask chunks and encoding chunks are present, encoding chunks + will be used, provided that there is a many-to-one relationship between + encoding chunks and dask chunks (i.e. Dask chunks are bigger than and + evenly divide encoding chunks); otherwise raise a ``ValueError``. + This restriction ensures that no synchronization / locks are required + when writing. To disable this restriction, use ``safe_chunks=False``. + + Parameters + ---------- + store : MutableMapping, str or path-like, optional + Store or path to directory in local or remote file system. + chunk_store : MutableMapping, str or path-like, optional + Store or path to directory in local or remote file system only for Zarr + array chunks. Requires zarr-python v2.4.0 or later. + mode : {"w", "w-", "a", "a-", r+", None}, optional + Persistence mode: "w" means create (overwrite if exists); + "w-" means create (fail if exists); + "a" means override all existing variables including dimension coordinates (create if does not exist); + "a-" means only append those variables that have ``append_dim``. + "r+" means modify existing array *values* only (raise an error if + any metadata or shapes would change). + The default mode is "a" if ``append_dim`` is set. Otherwise, it is + "r+" if ``region`` is set and ``w-`` otherwise. + synchronizer : object, optional + Zarr array synchronizer. + group : str, optional + Group path. (a.k.a. `path` in zarr terminology.) + encoding : dict, optional + Nested dictionary with variable names as keys and dictionaries of + variable specific encodings as values, e.g., + ``{"my_variable": {"dtype": "int16", "scale_factor": 0.1,}, ...}`` + compute : bool, default: True + If True write array data immediately, otherwise return a + ``dask.delayed.Delayed`` object that can be computed to write + array data later. Metadata is always updated eagerly. + consolidated : bool, optional + If True, apply zarr's `consolidate_metadata` function to the store + after writing metadata and read existing stores with consolidated + metadata; if False, do not. The default (`consolidated=None`) means + write consolidated metadata and attempt to read consolidated + metadata for existing stores (falling back to non-consolidated). + + When the experimental ``zarr_version=3``, ``consolidated`` must be + either be ``None`` or ``False``. + append_dim : hashable, optional + If set, the dimension along which the data will be appended. All + other dimensions on overridden variables must remain the same size. + region : dict, optional + Optional mapping from dimension names to integer slices along + dataarray dimensions to indicate the region of existing zarr array(s) + in which to write this datarray's data. For example, + ``{'x': slice(0, 1000), 'y': slice(10000, 11000)}`` would indicate + that values should be written to the region ``0:1000`` along ``x`` + and ``10000:11000`` along ``y``. + + Two restrictions apply to the use of ``region``: + + - If ``region`` is set, _all_ variables in a dataarray must have at + least one dimension in common with the region. Other variables + should be written in a separate call to ``to_zarr()``. + - Dimensions cannot be included in both ``region`` and + ``append_dim`` at the same time. To create empty arrays to fill + in with ``region``, use a separate call to ``to_zarr()`` with + ``compute=False``. See "Appending to existing Zarr stores" in + the reference documentation for full details. + + Users are expected to ensure that the specified region aligns with + Zarr chunk boundaries, and that dask chunks are also aligned. + Xarray makes limited checks that these multiple chunk boundaries line up. + It is possible to write incomplete chunks and corrupt the data with this + option if you are not careful. + safe_chunks : bool, default: True + If True, only allow writes to when there is a many-to-one relationship + between Zarr chunks (specified in encoding) and Dask chunks. + Set False to override this restriction; however, data may become corrupted + if Zarr arrays are written in parallel. This option may be useful in combination + with ``compute=False`` to initialize a Zarr store from an existing + DataArray with arbitrary chunk structure. + storage_options : dict, optional + Any additional parameters for the storage backend (ignored for local + paths). + zarr_version : int or None, optional + The desired zarr spec version to target (currently 2 or 3). The + default of None will attempt to determine the zarr version from + ``store`` when possible, otherwise defaulting to 2. + + Returns + ------- + * ``dask.delayed.Delayed`` if compute is False + * ZarrStore otherwise + + References + ---------- + https://zarr.readthedocs.io/ + + Notes + ----- + Zarr chunking behavior: + If chunks are found in the encoding argument or attribute + corresponding to any DataArray, those chunks are used. + If a DataArray is a dask array, it is written with those chunks. + If not other chunks are found, Zarr uses its own heuristics to + choose automatic chunk sizes. + + encoding: + The encoding attribute (if exists) of the DataArray(s) will be + used. Override any existing encodings by providing the ``encoding`` kwarg. + + See Also + -------- + Dataset.to_zarr + :ref:`io.zarr` + The I/O user guide, with more details and examples. + """ + from xarray.backends.api import DATAARRAY_NAME, DATAARRAY_VARIABLE, to_zarr + + if self.name is None: + # If no name is set then use a generic xarray name + dataset = self.to_dataset(name=DATAARRAY_VARIABLE) + elif self.name in self.coords or self.name in self.dims: + # The name is the same as one of the coords names, which the netCDF data model + # does not support, so rename it but keep track of the old name + dataset = self.to_dataset(name=DATAARRAY_VARIABLE) + dataset.attrs[DATAARRAY_NAME] = self.name + else: + # No problems with the name - so we're fine! + dataset = self.to_dataset() + + return to_zarr( # type: ignore[call-overload,misc] + dataset, + store=store, + chunk_store=chunk_store, + mode=mode, + synchronizer=synchronizer, + group=group, + encoding=encoding, + compute=compute, + consolidated=consolidated, + append_dim=append_dim, + region=region, + safe_chunks=safe_chunks, + storage_options=storage_options, + zarr_version=zarr_version, + ) + + def to_dict( + self, data: bool | Literal["list", "array"] = "list", encoding: bool = False + ) -> dict[str, Any]: + """ + Convert this xarray.DataArray into a dictionary following xarray + naming conventions. + + Converts all variables and attributes to native Python objects. + Useful for converting to json. To avoid datetime incompatibility + use decode_times=False kwarg in xarray.open_dataset. + + Parameters + ---------- + data : bool or {"list", "array"}, default: "list" + Whether to include the actual data in the dictionary. When set to + False, returns just the schema. If set to "array", returns data as + underlying array type. If set to "list" (or True for backwards + compatibility), returns data in lists of Python data types. Note + that for obtaining the "list" output efficiently, use + `da.compute().to_dict(data="list")`. + + encoding : bool, default: False + Whether to include the Dataset's encoding in the dictionary. + + Returns + ------- + dict: dict + + See Also + -------- + DataArray.from_dict + Dataset.to_dict + """ + d = self.variable.to_dict(data=data) + d.update({"coords": {}, "name": self.name}) + for k, coord in self.coords.items(): + d["coords"][k] = coord.variable.to_dict(data=data) + if encoding: + d["encoding"] = dict(self.encoding) + return d + + @classmethod + def from_dict(cls, d: Mapping[str, Any]) -> Self: + """Convert a dictionary into an xarray.DataArray + + Parameters + ---------- + d : dict + Mapping with a minimum structure of {"dims": [...], "data": [...]} + + Returns + ------- + obj : xarray.DataArray + + See Also + -------- + DataArray.to_dict + Dataset.from_dict + + Examples + -------- + >>> d = {"dims": "t", "data": [1, 2, 3]} + >>> da = xr.DataArray.from_dict(d) + >>> da + Size: 24B + array([1, 2, 3]) + Dimensions without coordinates: t + + >>> d = { + ... "coords": { + ... "t": {"dims": "t", "data": [0, 1, 2], "attrs": {"units": "s"}} + ... }, + ... "attrs": {"title": "air temperature"}, + ... "dims": "t", + ... "data": [10, 20, 30], + ... "name": "a", + ... } + >>> da = xr.DataArray.from_dict(d) + >>> da + Size: 24B + array([10, 20, 30]) + Coordinates: + * t (t) int64 24B 0 1 2 + Attributes: + title: air temperature + """ + coords = None + if "coords" in d: + try: + coords = { + k: (v["dims"], v["data"], v.get("attrs")) + for k, v in d["coords"].items() + } + except KeyError as e: + raise ValueError( + "cannot convert dict when coords are missing the key " + f"'{str(e.args[0])}'" + ) + try: + data = d["data"] + except KeyError: + raise ValueError("cannot convert dict without the key 'data''") + else: + obj = cls(data, coords, d.get("dims"), d.get("name"), d.get("attrs")) + + obj.encoding.update(d.get("encoding", {})) + + return obj + + @classmethod + def from_series(cls, series: pd.Series, sparse: bool = False) -> DataArray: + """Convert a pandas.Series into an xarray.DataArray. + + If the series's index is a MultiIndex, it will be expanded into a + tensor product of one-dimensional coordinates (filling in missing + values with NaN). Thus this operation should be the inverse of the + `to_series` method. + + Parameters + ---------- + series : Series + Pandas Series object to convert. + sparse : bool, default: False + If sparse=True, creates a sparse array instead of a dense NumPy array. + Requires the pydata/sparse package. + + See Also + -------- + DataArray.to_series + Dataset.from_dataframe + """ + temp_name = "__temporary_name" + df = pd.DataFrame({temp_name: series}) + ds = Dataset.from_dataframe(df, sparse=sparse) + result = ds[temp_name] + result.name = series.name + return result + + def to_iris(self) -> iris_Cube: + """Convert this array into a iris.cube.Cube""" + from xarray.convert import to_iris + + return to_iris(self) + + @classmethod + def from_iris(cls, cube: iris_Cube) -> Self: + """Convert a iris.cube.Cube into an xarray.DataArray""" + from xarray.convert import from_iris + + return from_iris(cube) + + def _all_compat(self, other: Self, compat_str: str) -> bool: + """Helper function for equals, broadcast_equals, and identical""" + + def compat(x, y): + return getattr(x.variable, compat_str)(y.variable) + + return utils.dict_equiv(self.coords, other.coords, compat=compat) and compat( + self, other + ) + + def broadcast_equals(self, other: Self) -> bool: + """Two DataArrays are broadcast equal if they are equal after + broadcasting them against each other such that they have the same + dimensions. + + Parameters + ---------- + other : DataArray + DataArray to compare to. + + Returns + ---------- + equal : bool + True if the two DataArrays are broadcast equal. + + See Also + -------- + DataArray.equals + DataArray.identical + + Examples + -------- + >>> a = xr.DataArray([1, 2], dims="X") + >>> b = xr.DataArray([[1, 1], [2, 2]], dims=["X", "Y"]) + >>> a + Size: 16B + array([1, 2]) + Dimensions without coordinates: X + >>> b + Size: 32B + array([[1, 1], + [2, 2]]) + Dimensions without coordinates: X, Y + + .equals returns True if two DataArrays have the same values, dimensions, and coordinates. .broadcast_equals returns True if the results of broadcasting two DataArrays against each other have the same values, dimensions, and coordinates. + + >>> a.equals(b) + False + >>> a2, b2 = xr.broadcast(a, b) + >>> a2.equals(b2) + True + >>> a.broadcast_equals(b) + True + """ + try: + return self._all_compat(other, "broadcast_equals") + except (TypeError, AttributeError): + return False + + def equals(self, other: Self) -> bool: + """True if two DataArrays have the same dimensions, coordinates and + values; otherwise False. + + DataArrays can still be equal (like pandas objects) if they have NaN + values in the same locations. + + This method is necessary because `v1 == v2` for ``DataArray`` + does element-wise comparisons (like numpy.ndarrays). + + Parameters + ---------- + other : DataArray + DataArray to compare to. + + Returns + ---------- + equal : bool + True if the two DataArrays are equal. + + See Also + -------- + DataArray.broadcast_equals + DataArray.identical + + Examples + -------- + >>> a = xr.DataArray([1, 2, 3], dims="X") + >>> b = xr.DataArray([1, 2, 3], dims="X", attrs=dict(units="m")) + >>> c = xr.DataArray([1, 2, 3], dims="Y") + >>> d = xr.DataArray([3, 2, 1], dims="X") + >>> a + Size: 24B + array([1, 2, 3]) + Dimensions without coordinates: X + >>> b + Size: 24B + array([1, 2, 3]) + Dimensions without coordinates: X + Attributes: + units: m + >>> c + Size: 24B + array([1, 2, 3]) + Dimensions without coordinates: Y + >>> d + Size: 24B + array([3, 2, 1]) + Dimensions without coordinates: X + + >>> a.equals(b) + True + >>> a.equals(c) + False + >>> a.equals(d) + False + """ + try: + return self._all_compat(other, "equals") + except (TypeError, AttributeError): + return False + + def identical(self, other: Self) -> bool: + """Like equals, but also checks the array name and attributes, and + attributes on all coordinates. + + Parameters + ---------- + other : DataArray + DataArray to compare to. + + Returns + ---------- + equal : bool + True if the two DataArrays are identical. + + See Also + -------- + DataArray.broadcast_equals + DataArray.equals + + Examples + -------- + >>> a = xr.DataArray([1, 2, 3], dims="X", attrs=dict(units="m"), name="Width") + >>> b = xr.DataArray([1, 2, 3], dims="X", attrs=dict(units="m"), name="Width") + >>> c = xr.DataArray([1, 2, 3], dims="X", attrs=dict(units="ft"), name="Width") + >>> a + Size: 24B + array([1, 2, 3]) + Dimensions without coordinates: X + Attributes: + units: m + >>> b + Size: 24B + array([1, 2, 3]) + Dimensions without coordinates: X + Attributes: + units: m + >>> c + Size: 24B + array([1, 2, 3]) + Dimensions without coordinates: X + Attributes: + units: ft + + >>> a.equals(b) + True + >>> a.identical(b) + True + + >>> a.equals(c) + True + >>> a.identical(c) + False + """ + try: + return self.name == other.name and self._all_compat(other, "identical") + except (TypeError, AttributeError): + return False + + def _result_name(self, other: Any = None) -> Hashable | None: + # use the same naming heuristics as pandas: + # https://github.com/ContinuumIO/blaze/issues/458#issuecomment-51936356 + other_name = getattr(other, "name", _default) + if other_name is _default or other_name == self.name: + return self.name + else: + return None + + def __array_wrap__(self, obj, context=None) -> Self: + new_var = self.variable.__array_wrap__(obj, context) + return self._replace(new_var) + + def __matmul__(self, obj: T_Xarray) -> T_Xarray: + return self.dot(obj) + + def __rmatmul__(self, other: T_Xarray) -> T_Xarray: + # currently somewhat duplicative, as only other DataArrays are + # compatible with matmul + return computation.dot(other, self) + + def _unary_op(self, f: Callable, *args, **kwargs) -> Self: + keep_attrs = kwargs.pop("keep_attrs", None) + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=True) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered") + warnings.filterwarnings( + "ignore", r"Mean of empty slice", category=RuntimeWarning + ) + with np.errstate(all="ignore"): + da = self.__array_wrap__(f(self.variable.data, *args, **kwargs)) + if keep_attrs: + da.attrs = self.attrs + return da + + def _binary_op( + self, other: DaCompatible, f: Callable, reflexive: bool = False + ) -> Self: + from xarray.core.groupby import GroupBy + + if isinstance(other, (Dataset, GroupBy)): + return NotImplemented + if isinstance(other, DataArray): + align_type = OPTIONS["arithmetic_join"] + self, other = align(self, other, join=align_type, copy=False) + other_variable_or_arraylike: DaCompatible = getattr(other, "variable", other) + other_coords = getattr(other, "coords", None) + + variable = ( + f(self.variable, other_variable_or_arraylike) + if not reflexive + else f(other_variable_or_arraylike, self.variable) + ) + coords, indexes = self.coords._merge_raw(other_coords, reflexive) + name = self._result_name(other) + + return self._replace(variable, coords, name, indexes=indexes) + + def _inplace_binary_op(self, other: DaCompatible, f: Callable) -> Self: + from xarray.core.groupby import GroupBy + + if isinstance(other, GroupBy): + raise TypeError( + "in-place operations between a DataArray and " + "a grouped object are not permitted" + ) + # n.b. we can't align other to self (with other.reindex_like(self)) + # because `other` may be converted into floats, which would cause + # in-place arithmetic to fail unpredictably. Instead, we simply + # don't support automatic alignment with in-place arithmetic. + other_coords = getattr(other, "coords", None) + other_variable = getattr(other, "variable", other) + try: + with self.coords._merge_inplace(other_coords): + f(self.variable, other_variable) + except MergeError as exc: + raise MergeError( + "Automatic alignment is not supported for in-place operations.\n" + "Consider aligning the indices manually or using a not-in-place operation.\n" + "See https://github.com/pydata/xarray/issues/3910 for more explanations." + ) from exc + return self + + def _copy_attrs_from(self, other: DataArray | Dataset | Variable) -> None: + self.attrs = other.attrs + + plot = utils.UncachedAccessor(DataArrayPlotAccessor) + + def _title_for_slice(self, truncate: int = 50) -> str: + """ + If the dataarray has 1 dimensional coordinates or comes from a slice + we can show that info in the title + + Parameters + ---------- + truncate : int, default: 50 + maximum number of characters for title + + Returns + ------- + title : string + Can be used for plot titles + + """ + one_dims = [] + for dim, coord in self.coords.items(): + if coord.size == 1: + one_dims.append( + f"{dim} = {format_item(coord.values)}{_get_units_from_attrs(coord)}" + ) + + title = ", ".join(one_dims) + if len(title) > truncate: + title = title[: (truncate - 3)] + "..." + + return title + + @_deprecate_positional_args("v2023.10.0") + def diff( + self, + dim: Hashable, + n: int = 1, + *, + label: Literal["upper", "lower"] = "upper", + ) -> Self: + """Calculate the n-th order discrete difference along given axis. + + Parameters + ---------- + dim : Hashable + Dimension over which to calculate the finite difference. + n : int, default: 1 + The number of times values are differenced. + label : {"upper", "lower"}, default: "upper" + The new coordinate in dimension ``dim`` will have the + values of either the minuend's or subtrahend's coordinate + for values 'upper' and 'lower', respectively. + + Returns + ------- + difference : DataArray + The n-th order finite difference of this object. + + Notes + ----- + `n` matches numpy's behavior and is different from pandas' first argument named + `periods`. + + Examples + -------- + >>> arr = xr.DataArray([5, 5, 6, 6], [[1, 2, 3, 4]], ["x"]) + >>> arr.diff("x") + Size: 24B + array([0, 1, 0]) + Coordinates: + * x (x) int64 24B 2 3 4 + >>> arr.diff("x", 2) + Size: 16B + array([ 1, -1]) + Coordinates: + * x (x) int64 16B 3 4 + + See Also + -------- + DataArray.differentiate + """ + ds = self._to_temp_dataset().diff(n=n, dim=dim, label=label) + return self._from_temp_dataset(ds) + + def shift( + self, + shifts: Mapping[Any, int] | None = None, + fill_value: Any = dtypes.NA, + **shifts_kwargs: int, + ) -> Self: + """Shift this DataArray by an offset along one or more dimensions. + + Only the data is moved; coordinates stay in place. This is consistent + with the behavior of ``shift`` in pandas. + + Values shifted from beyond array bounds will appear at one end of + each dimension, which are filled according to `fill_value`. For periodic + offsets instead see `roll`. + + Parameters + ---------- + shifts : mapping of Hashable to int or None, optional + Integer offset to shift along each of the given dimensions. + Positive offsets shift to the right; negative offsets shift to the + left. + fill_value : scalar, optional + Value to use for newly missing values + **shifts_kwargs + The keyword arguments form of ``shifts``. + One of shifts or shifts_kwargs must be provided. + + Returns + ------- + shifted : DataArray + DataArray with the same coordinates and attributes but shifted + data. + + See Also + -------- + roll + + Examples + -------- + >>> arr = xr.DataArray([5, 6, 7], dims="x") + >>> arr.shift(x=1) + Size: 24B + array([nan, 5., 6.]) + Dimensions without coordinates: x + """ + variable = self.variable.shift( + shifts=shifts, fill_value=fill_value, **shifts_kwargs + ) + return self._replace(variable=variable) + + def roll( + self, + shifts: Mapping[Hashable, int] | None = None, + roll_coords: bool = False, + **shifts_kwargs: int, + ) -> Self: + """Roll this array by an offset along one or more dimensions. + + Unlike shift, roll treats the given dimensions as periodic, so will not + create any missing values to be filled. + + Unlike shift, roll may rotate all variables, including coordinates + if specified. The direction of rotation is consistent with + :py:func:`numpy.roll`. + + Parameters + ---------- + shifts : mapping of Hashable to int, optional + Integer offset to rotate each of the given dimensions. + Positive offsets roll to the right; negative offsets roll to the + left. + roll_coords : bool, default: False + Indicates whether to roll the coordinates by the offset too. + **shifts_kwargs : {dim: offset, ...}, optional + The keyword arguments form of ``shifts``. + One of shifts or shifts_kwargs must be provided. + + Returns + ------- + rolled : DataArray + DataArray with the same attributes but rolled data and coordinates. + + See Also + -------- + shift + + Examples + -------- + >>> arr = xr.DataArray([5, 6, 7], dims="x") + >>> arr.roll(x=1) + Size: 24B + array([7, 5, 6]) + Dimensions without coordinates: x + """ + ds = self._to_temp_dataset().roll( + shifts=shifts, roll_coords=roll_coords, **shifts_kwargs + ) + return self._from_temp_dataset(ds) + + @property + def real(self) -> Self: + """ + The real part of the array. + + See Also + -------- + numpy.ndarray.real + """ + return self._replace(self.variable.real) + + @property + def imag(self) -> Self: + """ + The imaginary part of the array. + + See Also + -------- + numpy.ndarray.imag + """ + return self._replace(self.variable.imag) + + @deprecate_dims + def dot( + self, + other: T_Xarray, + dim: Dims = None, + ) -> T_Xarray: + """Perform dot product of two DataArrays along their shared dims. + + Equivalent to taking taking tensordot over all shared dims. + + Parameters + ---------- + other : DataArray + The other array with which the dot product is performed. + dim : ..., str, Iterable of Hashable or None, optional + Which dimensions to sum over. Ellipsis (`...`) sums over all dimensions. + If not specified, then all the common dimensions are summed over. + + Returns + ------- + result : DataArray + Array resulting from the dot product over all shared dimensions. + + See Also + -------- + dot + numpy.tensordot + + Examples + -------- + >>> da_vals = np.arange(6 * 5 * 4).reshape((6, 5, 4)) + >>> da = xr.DataArray(da_vals, dims=["x", "y", "z"]) + >>> dm_vals = np.arange(4) + >>> dm = xr.DataArray(dm_vals, dims=["z"]) + + >>> dm.dims + ('z',) + + >>> da.dims + ('x', 'y', 'z') + + >>> dot_result = da.dot(dm) + >>> dot_result.dims + ('x', 'y') + + """ + if isinstance(other, Dataset): + raise NotImplementedError( + "dot products are not yet supported with Dataset objects." + ) + if not isinstance(other, DataArray): + raise TypeError("dot only operates on DataArrays.") + + return computation.dot(self, other, dim=dim) + + def sortby( + self, + variables: ( + Hashable + | DataArray + | Sequence[Hashable | DataArray] + | Callable[[Self], Hashable | DataArray | Sequence[Hashable | DataArray]] + ), + ascending: bool = True, + ) -> Self: + """Sort object by labels or values (along an axis). + + Sorts the dataarray, either along specified dimensions, + or according to values of 1-D dataarrays that share dimension + with calling object. + + If the input variables are dataarrays, then the dataarrays are aligned + (via left-join) to the calling object prior to sorting by cell values. + NaNs are sorted to the end, following Numpy convention. + + If multiple sorts along the same dimension is + given, numpy's lexsort is performed along that dimension: + https://numpy.org/doc/stable/reference/generated/numpy.lexsort.html + and the FIRST key in the sequence is used as the primary sort key, + followed by the 2nd key, etc. + + Parameters + ---------- + variables : Hashable, DataArray, sequence of Hashable or DataArray, or Callable + 1D DataArray objects or name(s) of 1D variable(s) in coords whose values are + used to sort this array. If a callable, the callable is passed this object, + and the result is used as the value for cond. + ascending : bool, default: True + Whether to sort by ascending or descending order. + + Returns + ------- + sorted : DataArray + A new dataarray where all the specified dims are sorted by dim + labels. + + See Also + -------- + Dataset.sortby + numpy.sort + pandas.sort_values + pandas.sort_index + + Examples + -------- + >>> da = xr.DataArray( + ... np.arange(5, 0, -1), + ... coords=[pd.date_range("1/1/2000", periods=5)], + ... dims="time", + ... ) + >>> da + Size: 40B + array([5, 4, 3, 2, 1]) + Coordinates: + * time (time) datetime64[ns] 40B 2000-01-01 2000-01-02 ... 2000-01-05 + + >>> da.sortby(da) + Size: 40B + array([1, 2, 3, 4, 5]) + Coordinates: + * time (time) datetime64[ns] 40B 2000-01-05 2000-01-04 ... 2000-01-01 + + >>> da.sortby(lambda x: x) + Size: 40B + array([1, 2, 3, 4, 5]) + Coordinates: + * time (time) datetime64[ns] 40B 2000-01-05 2000-01-04 ... 2000-01-01 + """ + # We need to convert the callable here rather than pass it through to the + # dataset method, since otherwise the dataset method would try to call the + # callable with the dataset as the object + if callable(variables): + variables = variables(self) + ds = self._to_temp_dataset().sortby(variables, ascending=ascending) + return self._from_temp_dataset(ds) + + @_deprecate_positional_args("v2023.10.0") + def quantile( + self, + q: ArrayLike, + dim: Dims = None, + *, + method: QuantileMethods = "linear", + keep_attrs: bool | None = None, + skipna: bool | None = None, + interpolation: QuantileMethods | None = None, + ) -> Self: + """Compute the qth quantile of the data along the specified dimension. + + Returns the qth quantiles(s) of the array elements. + + Parameters + ---------- + q : float or array-like of float + Quantile to compute, which must be between 0 and 1 inclusive. + dim : str or Iterable of Hashable, optional + Dimension(s) over which to apply quantile. + method : str, default: "linear" + This optional parameter specifies the interpolation method to use when the + desired quantile lies between two data points. The options sorted by their R + type as summarized in the H&F paper [1]_ are: + + 1. "inverted_cdf" + 2. "averaged_inverted_cdf" + 3. "closest_observation" + 4. "interpolated_inverted_cdf" + 5. "hazen" + 6. "weibull" + 7. "linear" (default) + 8. "median_unbiased" + 9. "normal_unbiased" + + The first three methods are discontiuous. The following discontinuous + variations of the default "linear" (7.) option are also available: + + * "lower" + * "higher" + * "midpoint" + * "nearest" + + See :py:func:`numpy.quantile` or [1]_ for details. The "method" argument + was previously called "interpolation", renamed in accordance with numpy + version 1.22.0. + + keep_attrs : bool or None, optional + If True, the dataset's attributes (`attrs`) will be copied from + the original object to the new one. If False (default), the new + object will be returned without attributes. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + + Returns + ------- + quantiles : DataArray + If `q` is a single quantile, then the result + is a scalar. If multiple percentiles are given, first axis of + the result corresponds to the quantile and a quantile dimension + is added to the return array. The other dimensions are the + dimensions that remain after the reduction of the array. + + See Also + -------- + numpy.nanquantile, numpy.quantile, pandas.Series.quantile, Dataset.quantile + + Examples + -------- + >>> da = xr.DataArray( + ... data=[[0.7, 4.2, 9.4, 1.5], [6.5, 7.3, 2.6, 1.9]], + ... coords={"x": [7, 9], "y": [1, 1.5, 2, 2.5]}, + ... dims=("x", "y"), + ... ) + >>> da.quantile(0) # or da.quantile(0, dim=...) + Size: 8B + array(0.7) + Coordinates: + quantile float64 8B 0.0 + >>> da.quantile(0, dim="x") + Size: 32B + array([0.7, 4.2, 2.6, 1.5]) + Coordinates: + * y (y) float64 32B 1.0 1.5 2.0 2.5 + quantile float64 8B 0.0 + >>> da.quantile([0, 0.5, 1]) + Size: 24B + array([0.7, 3.4, 9.4]) + Coordinates: + * quantile (quantile) float64 24B 0.0 0.5 1.0 + >>> da.quantile([0, 0.5, 1], dim="x") + Size: 96B + array([[0.7 , 4.2 , 2.6 , 1.5 ], + [3.6 , 5.75, 6. , 1.7 ], + [6.5 , 7.3 , 9.4 , 1.9 ]]) + Coordinates: + * y (y) float64 32B 1.0 1.5 2.0 2.5 + * quantile (quantile) float64 24B 0.0 0.5 1.0 + + References + ---------- + .. [1] R. J. Hyndman and Y. Fan, + "Sample quantiles in statistical packages," + The American Statistician, 50(4), pp. 361-365, 1996 + """ + + ds = self._to_temp_dataset().quantile( + q, + dim=dim, + keep_attrs=keep_attrs, + method=method, + skipna=skipna, + interpolation=interpolation, + ) + return self._from_temp_dataset(ds) + + @_deprecate_positional_args("v2023.10.0") + def rank( + self, + dim: Hashable, + *, + pct: bool = False, + keep_attrs: bool | None = None, + ) -> Self: + """Ranks the data. + + Equal values are assigned a rank that is the average of the ranks that + would have been otherwise assigned to all of the values within that + set. Ranks begin at 1, not 0. If pct, computes percentage ranks. + + NaNs in the input array are returned as NaNs. + + The `bottleneck` library is required. + + Parameters + ---------- + dim : Hashable + Dimension over which to compute rank. + pct : bool, default: False + If True, compute percentage ranks, otherwise compute integer ranks. + keep_attrs : bool or None, optional + If True, the dataset's attributes (`attrs`) will be copied from + the original object to the new one. If False (default), the new + object will be returned without attributes. + + Returns + ------- + ranked : DataArray + DataArray with the same coordinates and dtype 'float64'. + + Examples + -------- + >>> arr = xr.DataArray([5, 6, 7], dims="x") + >>> arr.rank("x") + Size: 24B + array([1., 2., 3.]) + Dimensions without coordinates: x + """ + + ds = self._to_temp_dataset().rank(dim, pct=pct, keep_attrs=keep_attrs) + return self._from_temp_dataset(ds) + + def differentiate( + self, + coord: Hashable, + edge_order: Literal[1, 2] = 1, + datetime_unit: DatetimeUnitOptions = None, + ) -> Self: + """Differentiate the array with the second order accurate central + differences. + + .. note:: + This feature is limited to simple cartesian geometry, i.e. coord + must be one dimensional. + + Parameters + ---------- + coord : Hashable + The coordinate to be used to compute the gradient. + edge_order : {1, 2}, default: 1 + N-th order accurate differences at the boundaries. + datetime_unit : {"W", "D", "h", "m", "s", "ms", \ + "us", "ns", "ps", "fs", "as", None}, optional + Unit to compute gradient. Only valid for datetime coordinate. "Y" and "M" are not available as + datetime_unit. + + Returns + ------- + differentiated: DataArray + + See also + -------- + numpy.gradient: corresponding numpy function + + Examples + -------- + + >>> da = xr.DataArray( + ... np.arange(12).reshape(4, 3), + ... dims=["x", "y"], + ... coords={"x": [0, 0.1, 1.1, 1.2]}, + ... ) + >>> da + Size: 96B + array([[ 0, 1, 2], + [ 3, 4, 5], + [ 6, 7, 8], + [ 9, 10, 11]]) + Coordinates: + * x (x) float64 32B 0.0 0.1 1.1 1.2 + Dimensions without coordinates: y + >>> + >>> da.differentiate("x") + Size: 96B + array([[30. , 30. , 30. ], + [27.54545455, 27.54545455, 27.54545455], + [27.54545455, 27.54545455, 27.54545455], + [30. , 30. , 30. ]]) + Coordinates: + * x (x) float64 32B 0.0 0.1 1.1 1.2 + Dimensions without coordinates: y + """ + ds = self._to_temp_dataset().differentiate(coord, edge_order, datetime_unit) + return self._from_temp_dataset(ds) + + def integrate( + self, + coord: Hashable | Sequence[Hashable] = None, + datetime_unit: DatetimeUnitOptions = None, + ) -> Self: + """Integrate along the given coordinate using the trapezoidal rule. + + .. note:: + This feature is limited to simple cartesian geometry, i.e. coord + must be one dimensional. + + Parameters + ---------- + coord : Hashable, or sequence of Hashable + Coordinate(s) used for the integration. + datetime_unit : {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', 'us', 'ns', \ + 'ps', 'fs', 'as', None}, optional + Specify the unit if a datetime coordinate is used. + + Returns + ------- + integrated : DataArray + + See also + -------- + Dataset.integrate + numpy.trapz : corresponding numpy function + + Examples + -------- + + >>> da = xr.DataArray( + ... np.arange(12).reshape(4, 3), + ... dims=["x", "y"], + ... coords={"x": [0, 0.1, 1.1, 1.2]}, + ... ) + >>> da + Size: 96B + array([[ 0, 1, 2], + [ 3, 4, 5], + [ 6, 7, 8], + [ 9, 10, 11]]) + Coordinates: + * x (x) float64 32B 0.0 0.1 1.1 1.2 + Dimensions without coordinates: y + >>> + >>> da.integrate("x") + Size: 24B + array([5.4, 6.6, 7.8]) + Dimensions without coordinates: y + """ + ds = self._to_temp_dataset().integrate(coord, datetime_unit) + return self._from_temp_dataset(ds) + + def cumulative_integrate( + self, + coord: Hashable | Sequence[Hashable] = None, + datetime_unit: DatetimeUnitOptions = None, + ) -> Self: + """Integrate cumulatively along the given coordinate using the trapezoidal rule. + + .. note:: + This feature is limited to simple cartesian geometry, i.e. coord + must be one dimensional. + + The first entry of the cumulative integral is always 0, in order to keep the + length of the dimension unchanged between input and output. + + Parameters + ---------- + coord : Hashable, or sequence of Hashable + Coordinate(s) used for the integration. + datetime_unit : {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', 'us', 'ns', \ + 'ps', 'fs', 'as', None}, optional + Specify the unit if a datetime coordinate is used. + + Returns + ------- + integrated : DataArray + + See also + -------- + Dataset.cumulative_integrate + scipy.integrate.cumulative_trapezoid : corresponding scipy function + + Examples + -------- + + >>> da = xr.DataArray( + ... np.arange(12).reshape(4, 3), + ... dims=["x", "y"], + ... coords={"x": [0, 0.1, 1.1, 1.2]}, + ... ) + >>> da + Size: 96B + array([[ 0, 1, 2], + [ 3, 4, 5], + [ 6, 7, 8], + [ 9, 10, 11]]) + Coordinates: + * x (x) float64 32B 0.0 0.1 1.1 1.2 + Dimensions without coordinates: y + >>> + >>> da.cumulative_integrate("x") + Size: 96B + array([[0. , 0. , 0. ], + [0.15, 0.25, 0.35], + [4.65, 5.75, 6.85], + [5.4 , 6.6 , 7.8 ]]) + Coordinates: + * x (x) float64 32B 0.0 0.1 1.1 1.2 + Dimensions without coordinates: y + """ + ds = self._to_temp_dataset().cumulative_integrate(coord, datetime_unit) + return self._from_temp_dataset(ds) + + def unify_chunks(self) -> Self: + """Unify chunk size along all chunked dimensions of this DataArray. + + Returns + ------- + DataArray with consistent chunk sizes for all dask-array variables + + See Also + -------- + dask.array.core.unify_chunks + """ + + return unify_chunks(self)[0] + + def map_blocks( + self, + func: Callable[..., T_Xarray], + args: Sequence[Any] = (), + kwargs: Mapping[str, Any] | None = None, + template: DataArray | Dataset | None = None, + ) -> T_Xarray: + """ + Apply a function to each block of this DataArray. + + .. warning:: + This method is experimental and its signature may change. + + Parameters + ---------- + func : callable + User-provided function that accepts a DataArray as its first + parameter. The function will receive a subset or 'block' of this DataArray (see below), + corresponding to one chunk along each chunked dimension. ``func`` will be + executed as ``func(subset_dataarray, *subset_args, **kwargs)``. + + This function must return either a single DataArray or a single Dataset. + + This function cannot add a new chunked dimension. + args : sequence + Passed to func after unpacking and subsetting any xarray objects by blocks. + xarray objects in args must be aligned with this object, otherwise an error is raised. + kwargs : mapping + Passed verbatim to func after unpacking. xarray objects, if any, will not be + subset to blocks. Passing dask collections in kwargs is not allowed. + template : DataArray or Dataset, optional + xarray object representing the final result after compute is called. If not provided, + the function will be first run on mocked-up data, that looks like this object but + has sizes 0, to determine properties of the returned object such as dtype, + variable names, attributes, new dimensions and new indexes (if any). + ``template`` must be provided if the function changes the size of existing dimensions. + When provided, ``attrs`` on variables in `template` are copied over to the result. Any + ``attrs`` set by ``func`` will be ignored. + + Returns + ------- + A single DataArray or Dataset with dask backend, reassembled from the outputs of the + function. + + Notes + ----- + This function is designed for when ``func`` needs to manipulate a whole xarray object + subset to each block. Each block is loaded into memory. In the more common case where + ``func`` can work on numpy arrays, it is recommended to use ``apply_ufunc``. + + If none of the variables in this object is backed by dask arrays, calling this function is + equivalent to calling ``func(obj, *args, **kwargs)``. + + See Also + -------- + dask.array.map_blocks, xarray.apply_ufunc, xarray.Dataset.map_blocks + xarray.DataArray.map_blocks + + :doc:`xarray-tutorial:advanced/map_blocks/map_blocks` + Advanced Tutorial on map_blocks with dask + + Examples + -------- + Calculate an anomaly from climatology using ``.groupby()``. Using + ``xr.map_blocks()`` allows for parallel operations with knowledge of ``xarray``, + its indices, and its methods like ``.groupby()``. + + >>> def calculate_anomaly(da, groupby_type="time.month"): + ... gb = da.groupby(groupby_type) + ... clim = gb.mean(dim="time") + ... return gb - clim + ... + >>> time = xr.cftime_range("1990-01", "1992-01", freq="ME") + >>> month = xr.DataArray(time.month, coords={"time": time}, dims=["time"]) + >>> np.random.seed(123) + >>> array = xr.DataArray( + ... np.random.rand(len(time)), + ... dims=["time"], + ... coords={"time": time, "month": month}, + ... ).chunk() + >>> array.map_blocks(calculate_anomaly, template=array).compute() + Size: 192B + array([ 0.12894847, 0.11323072, -0.0855964 , -0.09334032, 0.26848862, + 0.12382735, 0.22460641, 0.07650108, -0.07673453, -0.22865714, + -0.19063865, 0.0590131 , -0.12894847, -0.11323072, 0.0855964 , + 0.09334032, -0.26848862, -0.12382735, -0.22460641, -0.07650108, + 0.07673453, 0.22865714, 0.19063865, -0.0590131 ]) + Coordinates: + * time (time) object 192B 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 + month (time) int64 192B 1 2 3 4 5 6 7 8 9 10 ... 3 4 5 6 7 8 9 10 11 12 + + Note that one must explicitly use ``args=[]`` and ``kwargs={}`` to pass arguments + to the function being applied in ``xr.map_blocks()``: + + >>> array.map_blocks( + ... calculate_anomaly, kwargs={"groupby_type": "time.year"}, template=array + ... ) # doctest: +ELLIPSIS + Size: 192B + dask.array<-calculate_anomaly, shape=(24,), dtype=float64, chunksize=(24,), chunktype=numpy.ndarray> + Coordinates: + * time (time) object 192B 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 + month (time) int64 192B dask.array + """ + from xarray.core.parallel import map_blocks + + return map_blocks(func, self, args, kwargs, template) + + def polyfit( + self, + dim: Hashable, + deg: int, + skipna: bool | None = None, + rcond: float | None = None, + w: Hashable | Any | None = None, + full: bool = False, + cov: bool | Literal["unscaled"] = False, + ) -> Dataset: + """ + Least squares polynomial fit. + + This replicates the behaviour of `numpy.polyfit` but differs by skipping + invalid values when `skipna = True`. + + Parameters + ---------- + dim : Hashable + Coordinate along which to fit the polynomials. + deg : int + Degree of the fitting polynomial. + skipna : bool or None, optional + If True, removes all invalid values before fitting each 1D slices of the array. + Default is True if data is stored in a dask.array or if there is any + invalid values, False otherwise. + rcond : float or None, optional + Relative condition number to the fit. + w : Hashable, array-like or None, optional + Weights to apply to the y-coordinate of the sample points. + Can be an array-like object or the name of a coordinate in the dataset. + full : bool, default: False + Whether to return the residuals, matrix rank and singular values in addition + to the coefficients. + cov : bool or "unscaled", default: False + Whether to return to the covariance matrix in addition to the coefficients. + The matrix is not scaled if `cov='unscaled'`. + + Returns + ------- + polyfit_results : Dataset + A single dataset which contains: + + polyfit_coefficients + The coefficients of the best fit. + polyfit_residuals + The residuals of the least-square computation (only included if `full=True`). + When the matrix rank is deficient, np.nan is returned. + [dim]_matrix_rank + The effective rank of the scaled Vandermonde coefficient matrix (only included if `full=True`) + [dim]_singular_value + The singular values of the scaled Vandermonde coefficient matrix (only included if `full=True`) + polyfit_covariance + The covariance matrix of the polynomial coefficient estimates (only included if `full=False` and `cov=True`) + + See Also + -------- + numpy.polyfit + numpy.polyval + xarray.polyval + DataArray.curvefit + """ + return self._to_temp_dataset().polyfit( + dim, deg, skipna=skipna, rcond=rcond, w=w, full=full, cov=cov + ) + + def pad( + self, + pad_width: Mapping[Any, int | tuple[int, int]] | None = None, + mode: PadModeOptions = "constant", + stat_length: ( + int | tuple[int, int] | Mapping[Any, tuple[int, int]] | None + ) = None, + constant_values: ( + float | tuple[float, float] | Mapping[Any, tuple[float, float]] | None + ) = None, + end_values: int | tuple[int, int] | Mapping[Any, tuple[int, int]] | None = None, + reflect_type: PadReflectOptions = None, + keep_attrs: bool | None = None, + **pad_width_kwargs: Any, + ) -> Self: + """Pad this array along one or more dimensions. + + .. warning:: + This function is experimental and its behaviour is likely to change + especially regarding padding of dimension coordinates (or IndexVariables). + + When using one of the modes ("edge", "reflect", "symmetric", "wrap"), + coordinates will be padded with the same mode, otherwise coordinates + are padded using the "constant" mode with fill_value dtypes.NA. + + Parameters + ---------- + pad_width : mapping of Hashable to tuple of int + Mapping with the form of {dim: (pad_before, pad_after)} + describing the number of values padded along each dimension. + {dim: pad} is a shortcut for pad_before = pad_after = pad + mode : {"constant", "edge", "linear_ramp", "maximum", "mean", "median", \ + "minimum", "reflect", "symmetric", "wrap"}, default: "constant" + How to pad the DataArray (taken from numpy docs): + + - "constant": Pads with a constant value. + - "edge": Pads with the edge values of array. + - "linear_ramp": Pads with the linear ramp between end_value and the + array edge value. + - "maximum": Pads with the maximum value of all or part of the + vector along each axis. + - "mean": Pads with the mean value of all or part of the + vector along each axis. + - "median": Pads with the median value of all or part of the + vector along each axis. + - "minimum": Pads with the minimum value of all or part of the + vector along each axis. + - "reflect": Pads with the reflection of the vector mirrored on + the first and last values of the vector along each axis. + - "symmetric": Pads with the reflection of the vector mirrored + along the edge of the array. + - "wrap": Pads with the wrap of the vector along the axis. + The first values are used to pad the end and the + end values are used to pad the beginning. + + stat_length : int, tuple or mapping of Hashable to tuple, default: None + Used in 'maximum', 'mean', 'median', and 'minimum'. Number of + values at edge of each axis used to calculate the statistic value. + {dim_1: (before_1, after_1), ... dim_N: (before_N, after_N)} unique + statistic lengths along each dimension. + ((before, after),) yields same before and after statistic lengths + for each dimension. + (stat_length,) or int is a shortcut for before = after = statistic + length for all axes. + Default is ``None``, to use the entire axis. + constant_values : scalar, tuple or mapping of Hashable to tuple, default: 0 + Used in 'constant'. The values to set the padded values for each + axis. + ``{dim_1: (before_1, after_1), ... dim_N: (before_N, after_N)}`` unique + pad constants along each dimension. + ``((before, after),)`` yields same before and after constants for each + dimension. + ``(constant,)`` or ``constant`` is a shortcut for ``before = after = constant`` for + all dimensions. + Default is 0. + end_values : scalar, tuple or mapping of Hashable to tuple, default: 0 + Used in 'linear_ramp'. The values used for the ending value of the + linear_ramp and that will form the edge of the padded array. + ``{dim_1: (before_1, after_1), ... dim_N: (before_N, after_N)}`` unique + end values along each dimension. + ``((before, after),)`` yields same before and after end values for each + axis. + ``(constant,)`` or ``constant`` is a shortcut for ``before = after = constant`` for + all axes. + Default is 0. + reflect_type : {"even", "odd", None}, optional + Used in "reflect", and "symmetric". The "even" style is the + default with an unaltered reflection around the edge value. For + the "odd" style, the extended part of the array is created by + subtracting the reflected values from two times the edge value. + keep_attrs : bool or None, optional + If True, the attributes (``attrs``) will be copied from the + original object to the new one. If False, the new object + will be returned without attributes. + **pad_width_kwargs + The keyword arguments form of ``pad_width``. + One of ``pad_width`` or ``pad_width_kwargs`` must be provided. + + Returns + ------- + padded : DataArray + DataArray with the padded coordinates and data. + + See Also + -------- + DataArray.shift, DataArray.roll, DataArray.bfill, DataArray.ffill, numpy.pad, dask.array.pad + + Notes + ----- + For ``mode="constant"`` and ``constant_values=None``, integer types will be + promoted to ``float`` and padded with ``np.nan``. + + Padding coordinates will drop their corresponding index (if any) and will reset default + indexes for dimension coordinates. + + Examples + -------- + >>> arr = xr.DataArray([5, 6, 7], coords=[("x", [0, 1, 2])]) + >>> arr.pad(x=(1, 2), constant_values=0) + Size: 48B + array([0, 5, 6, 7, 0, 0]) + Coordinates: + * x (x) float64 48B nan 0.0 1.0 2.0 nan nan + + >>> da = xr.DataArray( + ... [[0, 1, 2, 3], [10, 11, 12, 13]], + ... dims=["x", "y"], + ... coords={"x": [0, 1], "y": [10, 20, 30, 40], "z": ("x", [100, 200])}, + ... ) + >>> da.pad(x=1) + Size: 128B + array([[nan, nan, nan, nan], + [ 0., 1., 2., 3.], + [10., 11., 12., 13.], + [nan, nan, nan, nan]]) + Coordinates: + * x (x) float64 32B nan 0.0 1.0 nan + * y (y) int64 32B 10 20 30 40 + z (x) float64 32B nan 100.0 200.0 nan + + Careful, ``constant_values`` are coerced to the data type of the array which may + lead to a loss of precision: + + >>> da.pad(x=1, constant_values=1.23456789) + Size: 128B + array([[ 1, 1, 1, 1], + [ 0, 1, 2, 3], + [10, 11, 12, 13], + [ 1, 1, 1, 1]]) + Coordinates: + * x (x) float64 32B nan 0.0 1.0 nan + * y (y) int64 32B 10 20 30 40 + z (x) float64 32B nan 100.0 200.0 nan + """ + ds = self._to_temp_dataset().pad( + pad_width=pad_width, + mode=mode, + stat_length=stat_length, + constant_values=constant_values, + end_values=end_values, + reflect_type=reflect_type, + keep_attrs=keep_attrs, + **pad_width_kwargs, + ) + return self._from_temp_dataset(ds) + + @_deprecate_positional_args("v2023.10.0") + def idxmin( + self, + dim: Hashable | None = None, + *, + skipna: bool | None = None, + fill_value: Any = dtypes.NA, + keep_attrs: bool | None = None, + ) -> Self: + """Return the coordinate label of the minimum value along a dimension. + + Returns a new `DataArray` named after the dimension with the values of + the coordinate labels along that dimension corresponding to minimum + values along that dimension. + + In comparison to :py:meth:`~DataArray.argmin`, this returns the + coordinate label while :py:meth:`~DataArray.argmin` returns the index. + + Parameters + ---------- + dim : str, optional + Dimension over which to apply `idxmin`. This is optional for 1D + arrays, but required for arrays with 2 or more dimensions. + skipna : bool or None, default: None + If True, skip missing values (as marked by NaN). By default, only + skips missing values for ``float``, ``complex``, and ``object`` + dtypes; other dtypes either do not have a sentinel missing value + (``int``) or ``skipna=True`` has not been implemented + (``datetime64`` or ``timedelta64``). + fill_value : Any, default: NaN + Value to be filled in case all of the values along a dimension are + null. By default this is NaN. The fill value and result are + automatically converted to a compatible dtype if possible. + Ignored if ``skipna`` is False. + keep_attrs : bool or None, optional + If True, the attributes (``attrs``) will be copied from the + original object to the new one. If False, the new object + will be returned without attributes. + + Returns + ------- + reduced : DataArray + New `DataArray` object with `idxmin` applied to its data and the + indicated dimension removed. + + See Also + -------- + Dataset.idxmin, DataArray.idxmax, DataArray.min, DataArray.argmin + + Examples + -------- + >>> array = xr.DataArray( + ... [0, 2, 1, 0, -2], dims="x", coords={"x": ["a", "b", "c", "d", "e"]} + ... ) + >>> array.min() + Size: 8B + array(-2) + >>> array.argmin(...) + {'x': Size: 8B + array(4)} + >>> array.idxmin() + Size: 4B + array('e', dtype='>> array = xr.DataArray( + ... [ + ... [2.0, 1.0, 2.0, 0.0, -2.0], + ... [-4.0, np.nan, 2.0, np.nan, -2.0], + ... [np.nan, np.nan, 1.0, np.nan, np.nan], + ... ], + ... dims=["y", "x"], + ... coords={"y": [-1, 0, 1], "x": np.arange(5.0) ** 2}, + ... ) + >>> array.min(dim="x") + Size: 24B + array([-2., -4., 1.]) + Coordinates: + * y (y) int64 24B -1 0 1 + >>> array.argmin(dim="x") + Size: 24B + array([4, 0, 2]) + Coordinates: + * y (y) int64 24B -1 0 1 + >>> array.idxmin(dim="x") + Size: 24B + array([16., 0., 4.]) + Coordinates: + * y (y) int64 24B -1 0 1 + """ + return computation._calc_idxminmax( + array=self, + func=lambda x, *args, **kwargs: x.argmin(*args, **kwargs), + dim=dim, + skipna=skipna, + fill_value=fill_value, + keep_attrs=keep_attrs, + ) + + @_deprecate_positional_args("v2023.10.0") + def idxmax( + self, + dim: Hashable = None, + *, + skipna: bool | None = None, + fill_value: Any = dtypes.NA, + keep_attrs: bool | None = None, + ) -> Self: + """Return the coordinate label of the maximum value along a dimension. + + Returns a new `DataArray` named after the dimension with the values of + the coordinate labels along that dimension corresponding to maximum + values along that dimension. + + In comparison to :py:meth:`~DataArray.argmax`, this returns the + coordinate label while :py:meth:`~DataArray.argmax` returns the index. + + Parameters + ---------- + dim : Hashable, optional + Dimension over which to apply `idxmax`. This is optional for 1D + arrays, but required for arrays with 2 or more dimensions. + skipna : bool or None, default: None + If True, skip missing values (as marked by NaN). By default, only + skips missing values for ``float``, ``complex``, and ``object`` + dtypes; other dtypes either do not have a sentinel missing value + (``int``) or ``skipna=True`` has not been implemented + (``datetime64`` or ``timedelta64``). + fill_value : Any, default: NaN + Value to be filled in case all of the values along a dimension are + null. By default this is NaN. The fill value and result are + automatically converted to a compatible dtype if possible. + Ignored if ``skipna`` is False. + keep_attrs : bool or None, optional + If True, the attributes (``attrs``) will be copied from the + original object to the new one. If False, the new object + will be returned without attributes. + + Returns + ------- + reduced : DataArray + New `DataArray` object with `idxmax` applied to its data and the + indicated dimension removed. + + See Also + -------- + Dataset.idxmax, DataArray.idxmin, DataArray.max, DataArray.argmax + + Examples + -------- + >>> array = xr.DataArray( + ... [0, 2, 1, 0, -2], dims="x", coords={"x": ["a", "b", "c", "d", "e"]} + ... ) + >>> array.max() + Size: 8B + array(2) + >>> array.argmax(...) + {'x': Size: 8B + array(1)} + >>> array.idxmax() + Size: 4B + array('b', dtype='>> array = xr.DataArray( + ... [ + ... [2.0, 1.0, 2.0, 0.0, -2.0], + ... [-4.0, np.nan, 2.0, np.nan, -2.0], + ... [np.nan, np.nan, 1.0, np.nan, np.nan], + ... ], + ... dims=["y", "x"], + ... coords={"y": [-1, 0, 1], "x": np.arange(5.0) ** 2}, + ... ) + >>> array.max(dim="x") + Size: 24B + array([2., 2., 1.]) + Coordinates: + * y (y) int64 24B -1 0 1 + >>> array.argmax(dim="x") + Size: 24B + array([0, 2, 2]) + Coordinates: + * y (y) int64 24B -1 0 1 + >>> array.idxmax(dim="x") + Size: 24B + array([0., 4., 4.]) + Coordinates: + * y (y) int64 24B -1 0 1 + """ + return computation._calc_idxminmax( + array=self, + func=lambda x, *args, **kwargs: x.argmax(*args, **kwargs), + dim=dim, + skipna=skipna, + fill_value=fill_value, + keep_attrs=keep_attrs, + ) + + @_deprecate_positional_args("v2023.10.0") + def argmin( + self, + dim: Dims = None, + *, + axis: int | None = None, + keep_attrs: bool | None = None, + skipna: bool | None = None, + ) -> Self | dict[Hashable, Self]: + """Index or indices of the minimum of the DataArray over one or more dimensions. + + If a sequence is passed to 'dim', then result returned as dict of DataArrays, + which can be passed directly to isel(). If a single str is passed to 'dim' then + returns a DataArray with dtype int. + + If there are multiple minima, the indices of the first one found will be + returned. + + Parameters + ---------- + dim : "...", str, Iterable of Hashable or None, optional + The dimensions over which to find the minimum. By default, finds minimum over + all dimensions - for now returning an int for backward compatibility, but + this is deprecated, in future will return a dict with indices for all + dimensions; to return a dict with all dimensions now, pass '...'. + axis : int or None, optional + Axis over which to apply `argmin`. Only one of the 'dim' and 'axis' arguments + can be supplied. + keep_attrs : bool or None, optional + If True, the attributes (`attrs`) will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + + Returns + ------- + result : DataArray or dict of DataArray + + See Also + -------- + Variable.argmin, DataArray.idxmin + + Examples + -------- + >>> array = xr.DataArray([0, 2, -1, 3], dims="x") + >>> array.min() + Size: 8B + array(-1) + >>> array.argmin(...) + {'x': Size: 8B + array(2)} + >>> array.isel(array.argmin(...)) + Size: 8B + array(-1) + + >>> array = xr.DataArray( + ... [[[3, 2, 1], [3, 1, 2], [2, 1, 3]], [[1, 3, 2], [2, -5, 1], [2, 3, 1]]], + ... dims=("x", "y", "z"), + ... ) + >>> array.min(dim="x") + Size: 72B + array([[ 1, 2, 1], + [ 2, -5, 1], + [ 2, 1, 1]]) + Dimensions without coordinates: y, z + >>> array.argmin(dim="x") + Size: 72B + array([[1, 0, 0], + [1, 1, 1], + [0, 0, 1]]) + Dimensions without coordinates: y, z + >>> array.argmin(dim=["x"]) + {'x': Size: 72B + array([[1, 0, 0], + [1, 1, 1], + [0, 0, 1]]) + Dimensions without coordinates: y, z} + >>> array.min(dim=("x", "z")) + Size: 24B + array([ 1, -5, 1]) + Dimensions without coordinates: y + >>> array.argmin(dim=["x", "z"]) + {'x': Size: 24B + array([0, 1, 0]) + Dimensions without coordinates: y, 'z': Size: 24B + array([2, 1, 1]) + Dimensions without coordinates: y} + >>> array.isel(array.argmin(dim=["x", "z"])) + Size: 24B + array([ 1, -5, 1]) + Dimensions without coordinates: y + """ + result = self.variable.argmin(dim, axis, keep_attrs, skipna) + if isinstance(result, dict): + return {k: self._replace_maybe_drop_dims(v) for k, v in result.items()} + else: + return self._replace_maybe_drop_dims(result) + + @_deprecate_positional_args("v2023.10.0") + def argmax( + self, + dim: Dims = None, + *, + axis: int | None = None, + keep_attrs: bool | None = None, + skipna: bool | None = None, + ) -> Self | dict[Hashable, Self]: + """Index or indices of the maximum of the DataArray over one or more dimensions. + + If a sequence is passed to 'dim', then result returned as dict of DataArrays, + which can be passed directly to isel(). If a single str is passed to 'dim' then + returns a DataArray with dtype int. + + If there are multiple maxima, the indices of the first one found will be + returned. + + Parameters + ---------- + dim : "...", str, Iterable of Hashable or None, optional + The dimensions over which to find the maximum. By default, finds maximum over + all dimensions - for now returning an int for backward compatibility, but + this is deprecated, in future will return a dict with indices for all + dimensions; to return a dict with all dimensions now, pass '...'. + axis : int or None, optional + Axis over which to apply `argmax`. Only one of the 'dim' and 'axis' arguments + can be supplied. + keep_attrs : bool or None, optional + If True, the attributes (`attrs`) will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + + Returns + ------- + result : DataArray or dict of DataArray + + See Also + -------- + Variable.argmax, DataArray.idxmax + + Examples + -------- + >>> array = xr.DataArray([0, 2, -1, 3], dims="x") + >>> array.max() + Size: 8B + array(3) + >>> array.argmax(...) + {'x': Size: 8B + array(3)} + >>> array.isel(array.argmax(...)) + Size: 8B + array(3) + + >>> array = xr.DataArray( + ... [[[3, 2, 1], [3, 1, 2], [2, 1, 3]], [[1, 3, 2], [2, 5, 1], [2, 3, 1]]], + ... dims=("x", "y", "z"), + ... ) + >>> array.max(dim="x") + Size: 72B + array([[3, 3, 2], + [3, 5, 2], + [2, 3, 3]]) + Dimensions without coordinates: y, z + >>> array.argmax(dim="x") + Size: 72B + array([[0, 1, 1], + [0, 1, 0], + [0, 1, 0]]) + Dimensions without coordinates: y, z + >>> array.argmax(dim=["x"]) + {'x': Size: 72B + array([[0, 1, 1], + [0, 1, 0], + [0, 1, 0]]) + Dimensions without coordinates: y, z} + >>> array.max(dim=("x", "z")) + Size: 24B + array([3, 5, 3]) + Dimensions without coordinates: y + >>> array.argmax(dim=["x", "z"]) + {'x': Size: 24B + array([0, 1, 0]) + Dimensions without coordinates: y, 'z': Size: 24B + array([0, 1, 2]) + Dimensions without coordinates: y} + >>> array.isel(array.argmax(dim=["x", "z"])) + Size: 24B + array([3, 5, 3]) + Dimensions without coordinates: y + """ + result = self.variable.argmax(dim, axis, keep_attrs, skipna) + if isinstance(result, dict): + return {k: self._replace_maybe_drop_dims(v) for k, v in result.items()} + else: + return self._replace_maybe_drop_dims(result) + + def query( + self, + queries: Mapping[Any, Any] | None = None, + parser: QueryParserOptions = "pandas", + engine: QueryEngineOptions = None, + missing_dims: ErrorOptionsWithWarn = "raise", + **queries_kwargs: Any, + ) -> DataArray: + """Return a new data array indexed along the specified + dimension(s), where the indexers are given as strings containing + Python expressions to be evaluated against the values in the array. + + Parameters + ---------- + queries : dict-like or None, optional + A dict-like with keys matching dimensions and values given by strings + containing Python expressions to be evaluated against the data variables + in the dataset. The expressions will be evaluated using the pandas + eval() function, and can contain any valid Python expressions but cannot + contain any Python statements. + parser : {"pandas", "python"}, default: "pandas" + The parser to use to construct the syntax tree from the expression. + The default of 'pandas' parses code slightly different than standard + Python. Alternatively, you can parse an expression using the 'python' + parser to retain strict Python semantics. + engine : {"python", "numexpr", None}, default: None + The engine used to evaluate the expression. Supported engines are: + + - None: tries to use numexpr, falls back to python + - "numexpr": evaluates expressions using numexpr + - "python": performs operations as if you had eval’d in top level python + + missing_dims : {"raise", "warn", "ignore"}, default: "raise" + What to do if dimensions that should be selected from are not present in the + DataArray: + + - "raise": raise an exception + - "warn": raise a warning, and ignore the missing dimensions + - "ignore": ignore the missing dimensions + + **queries_kwargs : {dim: query, ...}, optional + The keyword arguments form of ``queries``. + One of queries or queries_kwargs must be provided. + + Returns + ------- + obj : DataArray + A new DataArray with the same contents as this dataset, indexed by + the results of the appropriate queries. + + See Also + -------- + DataArray.isel + Dataset.query + pandas.eval + + Examples + -------- + >>> da = xr.DataArray(np.arange(0, 5, 1), dims="x", name="a") + >>> da + Size: 40B + array([0, 1, 2, 3, 4]) + Dimensions without coordinates: x + >>> da.query(x="a > 2") + Size: 16B + array([3, 4]) + Dimensions without coordinates: x + """ + + ds = self._to_dataset_whole(shallow_copy=True) + ds = ds.query( + queries=queries, + parser=parser, + engine=engine, + missing_dims=missing_dims, + **queries_kwargs, + ) + return ds[self.name] + + def curvefit( + self, + coords: str | DataArray | Iterable[str | DataArray], + func: Callable[..., Any], + reduce_dims: Dims = None, + skipna: bool = True, + p0: Mapping[str, float | DataArray] | None = None, + bounds: Mapping[str, tuple[float | DataArray, float | DataArray]] | None = None, + param_names: Sequence[str] | None = None, + errors: ErrorOptions = "raise", + kwargs: dict[str, Any] | None = None, + ) -> Dataset: + """ + Curve fitting optimization for arbitrary functions. + + Wraps `scipy.optimize.curve_fit` with `apply_ufunc`. + + Parameters + ---------- + coords : Hashable, DataArray, or sequence of DataArray or Hashable + Independent coordinate(s) over which to perform the curve fitting. Must share + at least one dimension with the calling object. When fitting multi-dimensional + functions, supply `coords` as a sequence in the same order as arguments in + `func`. To fit along existing dimensions of the calling object, `coords` can + also be specified as a str or sequence of strs. + func : callable + User specified function in the form `f(x, *params)` which returns a numpy + array of length `len(x)`. `params` are the fittable parameters which are optimized + by scipy curve_fit. `x` can also be specified as a sequence containing multiple + coordinates, e.g. `f((x0, x1), *params)`. + reduce_dims : str, Iterable of Hashable or None, optional + Additional dimension(s) over which to aggregate while fitting. For example, + calling `ds.curvefit(coords='time', reduce_dims=['lat', 'lon'], ...)` will + aggregate all lat and lon points and fit the specified function along the + time dimension. + skipna : bool, default: True + Whether to skip missing values when fitting. Default is True. + p0 : dict-like or None, optional + Optional dictionary of parameter names to initial guesses passed to the + `curve_fit` `p0` arg. If the values are DataArrays, they will be appropriately + broadcast to the coordinates of the array. If none or only some parameters are + passed, the rest will be assigned initial values following the default scipy + behavior. + bounds : dict-like, optional + Optional dictionary of parameter names to tuples of bounding values passed to the + `curve_fit` `bounds` arg. If any of the bounds are DataArrays, they will be + appropriately broadcast to the coordinates of the array. If none or only some + parameters are passed, the rest will be unbounded following the default scipy + behavior. + param_names : sequence of Hashable or None, optional + Sequence of names for the fittable parameters of `func`. If not supplied, + this will be automatically determined by arguments of `func`. `param_names` + should be manually supplied when fitting a function that takes a variable + number of parameters. + errors : {"raise", "ignore"}, default: "raise" + If 'raise', any errors from the `scipy.optimize_curve_fit` optimization will + raise an exception. If 'ignore', the coefficients and covariances for the + coordinates where the fitting failed will be NaN. + **kwargs : optional + Additional keyword arguments to passed to scipy curve_fit. + + Returns + ------- + curvefit_results : Dataset + A single dataset which contains: + + [var]_curvefit_coefficients + The coefficients of the best fit. + [var]_curvefit_covariance + The covariance matrix of the coefficient estimates. + + Examples + -------- + Generate some exponentially decaying data, where the decay constant and amplitude are + different for different values of the coordinate ``x``: + + >>> rng = np.random.default_rng(seed=0) + >>> def exp_decay(t, time_constant, amplitude): + ... return np.exp(-t / time_constant) * amplitude + ... + >>> t = np.arange(11) + >>> da = xr.DataArray( + ... np.stack( + ... [ + ... exp_decay(t, 1, 0.1), + ... exp_decay(t, 2, 0.2), + ... exp_decay(t, 3, 0.3), + ... ] + ... ) + ... + rng.normal(size=(3, t.size)) * 0.01, + ... coords={"x": [0, 1, 2], "time": t}, + ... ) + >>> da + Size: 264B + array([[ 0.1012573 , 0.0354669 , 0.01993775, 0.00602771, -0.00352513, + 0.00428975, 0.01328788, 0.009562 , -0.00700381, -0.01264187, + -0.0062282 ], + [ 0.20041326, 0.09805582, 0.07138797, 0.03216692, 0.01974438, + 0.01097441, 0.00679441, 0.01015578, 0.01408826, 0.00093645, + 0.01501222], + [ 0.29334805, 0.21847449, 0.16305984, 0.11130396, 0.07164415, + 0.04744543, 0.03602333, 0.03129354, 0.01074885, 0.01284436, + 0.00910995]]) + Coordinates: + * x (x) int64 24B 0 1 2 + * time (time) int64 88B 0 1 2 3 4 5 6 7 8 9 10 + + Fit the exponential decay function to the data along the ``time`` dimension: + + >>> fit_result = da.curvefit("time", exp_decay) + >>> fit_result["curvefit_coefficients"].sel( + ... param="time_constant" + ... ) # doctest: +NUMBER + Size: 24B + array([1.05692036, 1.73549638, 2.94215771]) + Coordinates: + * x (x) int64 24B 0 1 2 + param >> fit_result["curvefit_coefficients"].sel(param="amplitude") + Size: 24B + array([0.1005489 , 0.19631423, 0.30003579]) + Coordinates: + * x (x) int64 24B 0 1 2 + param >> fit_result = da.curvefit( + ... "time", + ... exp_decay, + ... p0={ + ... "amplitude": 0.2, + ... "time_constant": xr.DataArray([1, 2, 3], coords=[da.x]), + ... }, + ... ) + >>> fit_result["curvefit_coefficients"].sel(param="time_constant") + Size: 24B + array([1.0569213 , 1.73550052, 2.94215733]) + Coordinates: + * x (x) int64 24B 0 1 2 + param >> fit_result["curvefit_coefficients"].sel(param="amplitude") + Size: 24B + array([0.10054889, 0.1963141 , 0.3000358 ]) + Coordinates: + * x (x) int64 24B 0 1 2 + param Self: + """Returns a new DataArray with duplicate dimension values removed. + + Parameters + ---------- + dim : dimension label or labels + Pass `...` to drop duplicates along all dimensions. + keep : {"first", "last", False}, default: "first" + Determines which duplicates (if any) to keep. + + - ``"first"`` : Drop duplicates except for the first occurrence. + - ``"last"`` : Drop duplicates except for the last occurrence. + - False : Drop all duplicates. + + Returns + ------- + DataArray + + See Also + -------- + Dataset.drop_duplicates + + Examples + -------- + >>> da = xr.DataArray( + ... np.arange(25).reshape(5, 5), + ... dims=("x", "y"), + ... coords={"x": np.array([0, 0, 1, 2, 3]), "y": np.array([0, 1, 2, 3, 3])}, + ... ) + >>> da + Size: 200B + array([[ 0, 1, 2, 3, 4], + [ 5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19], + [20, 21, 22, 23, 24]]) + Coordinates: + * x (x) int64 40B 0 0 1 2 3 + * y (y) int64 40B 0 1 2 3 3 + + >>> da.drop_duplicates(dim="x") + Size: 160B + array([[ 0, 1, 2, 3, 4], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19], + [20, 21, 22, 23, 24]]) + Coordinates: + * x (x) int64 32B 0 1 2 3 + * y (y) int64 40B 0 1 2 3 3 + + >>> da.drop_duplicates(dim="x", keep="last") + Size: 160B + array([[ 5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19], + [20, 21, 22, 23, 24]]) + Coordinates: + * x (x) int64 32B 0 1 2 3 + * y (y) int64 40B 0 1 2 3 3 + + Drop all duplicate dimension values: + + >>> da.drop_duplicates(dim=...) + Size: 128B + array([[ 0, 1, 2, 3], + [10, 11, 12, 13], + [15, 16, 17, 18], + [20, 21, 22, 23]]) + Coordinates: + * x (x) int64 32B 0 1 2 3 + * y (y) int64 32B 0 1 2 3 + """ + deduplicated = self._to_temp_dataset().drop_duplicates(dim, keep=keep) + return self._from_temp_dataset(deduplicated) + + def convert_calendar( + self, + calendar: str, + dim: str = "time", + align_on: str | None = None, + missing: Any | None = None, + use_cftime: bool | None = None, + ) -> Self: + """Convert the DataArray to another calendar. + + Only converts the individual timestamps, does not modify any data except + in dropping invalid/surplus dates or inserting missing dates. + + If the source and target calendars are either no_leap, all_leap or a + standard type, only the type of the time array is modified. + When converting to a leap year from a non-leap year, the 29th of February + is removed from the array. In the other direction the 29th of February + will be missing in the output, unless `missing` is specified, + in which case that value is inserted. + + For conversions involving `360_day` calendars, see Notes. + + This method is safe to use with sub-daily data as it doesn't touch the + time part of the timestamps. + + Parameters + --------- + calendar : str + The target calendar name. + dim : str + Name of the time coordinate. + align_on : {None, 'date', 'year'} + Must be specified when either source or target is a `360_day` calendar, + ignored otherwise. See Notes. + missing : Optional[any] + By default, i.e. if the value is None, this method will simply attempt + to convert the dates in the source calendar to the same dates in the + target calendar, and drop any of those that are not possible to + represent. If a value is provided, a new time coordinate will be + created in the target calendar with the same frequency as the original + time coordinate; for any dates that are not present in the source, the + data will be filled with this value. Note that using this mode requires + that the source data have an inferable frequency; for more information + see :py:func:`xarray.infer_freq`. For certain frequency, source, and + target calendar combinations, this could result in many missing values, see notes. + use_cftime : boolean, optional + Whether to use cftime objects in the output, only used if `calendar` + is one of {"proleptic_gregorian", "gregorian" or "standard"}. + If True, the new time axis uses cftime objects. + If None (default), it uses :py:class:`numpy.datetime64` values if the + date range permits it, and :py:class:`cftime.datetime` objects if not. + If False, it uses :py:class:`numpy.datetime64` or fails. + + Returns + ------- + DataArray + Copy of the dataarray with the time coordinate converted to the + target calendar. If 'missing' was None (default), invalid dates in + the new calendar are dropped, but missing dates are not inserted. + If `missing` was given, the new data is reindexed to have a time axis + with the same frequency as the source, but in the new calendar; any + missing datapoints are filled with `missing`. + + Notes + ----- + Passing a value to `missing` is only usable if the source's time coordinate as an + inferable frequencies (see :py:func:`~xarray.infer_freq`) and is only appropriate + if the target coordinate, generated from this frequency, has dates equivalent to the + source. It is usually **not** appropriate to use this mode with: + + - Period-end frequencies : 'A', 'Y', 'Q' or 'M', in opposition to 'AS' 'YS', 'QS' and 'MS' + - Sub-monthly frequencies that do not divide a day evenly : 'W', 'nD' where `N != 1` + or 'mH' where 24 % m != 0). + + If one of the source or target calendars is `"360_day"`, `align_on` must + be specified and two options are offered. + + - "year" + The dates are translated according to their relative position in the year, + ignoring their original month and day information, meaning that the + missing/surplus days are added/removed at regular intervals. + + From a `360_day` to a standard calendar, the output will be missing the + following dates (day of year in parentheses): + + To a leap year: + January 31st (31), March 31st (91), June 1st (153), July 31st (213), + September 31st (275) and November 30th (335). + To a non-leap year: + February 6th (36), April 19th (109), July 2nd (183), + September 12th (255), November 25th (329). + + From a standard calendar to a `"360_day"`, the following dates in the + source array will be dropped: + + From a leap year: + January 31st (31), April 1st (92), June 1st (153), August 1st (214), + September 31st (275), December 1st (336) + From a non-leap year: + February 6th (37), April 20th (110), July 2nd (183), + September 13th (256), November 25th (329) + + This option is best used on daily and subdaily data. + + - "date" + The month/day information is conserved and invalid dates are dropped + from the output. This means that when converting from a `"360_day"` to a + standard calendar, all 31st (Jan, March, May, July, August, October and + December) will be missing as there is no equivalent dates in the + `"360_day"` calendar and the 29th (on non-leap years) and 30th of February + will be dropped as there are no equivalent dates in a standard calendar. + + This option is best used with data on a frequency coarser than daily. + """ + return convert_calendar( + self, + calendar, + dim=dim, + align_on=align_on, + missing=missing, + use_cftime=use_cftime, + ) + + def interp_calendar( + self, + target: pd.DatetimeIndex | CFTimeIndex | DataArray, + dim: str = "time", + ) -> Self: + """Interpolates the DataArray to another calendar based on decimal year measure. + + Each timestamp in `source` and `target` are first converted to their decimal + year equivalent then `source` is interpolated on the target coordinate. + The decimal year of a timestamp is its year plus its sub-year component + converted to the fraction of its year. For example "2000-03-01 12:00" is + 2000.1653 in a standard calendar or 2000.16301 in a `"noleap"` calendar. + + This method should only be used when the time (HH:MM:SS) information of + time coordinate is not important. + + Parameters + ---------- + target: DataArray or DatetimeIndex or CFTimeIndex + The target time coordinate of a valid dtype + (np.datetime64 or cftime objects) + dim : str + The time coordinate name. + + Return + ------ + DataArray + The source interpolated on the decimal years of target, + """ + return interp_calendar(self, target, dim=dim) + + def groupby( + self, + group: Hashable | DataArray | IndexVariable, + squeeze: bool | None = None, + restore_coord_dims: bool = False, + ) -> DataArrayGroupBy: + """Returns a DataArrayGroupBy object for performing grouped operations. + + Parameters + ---------- + group : Hashable, DataArray or IndexVariable + Array whose unique values should be used to group this array. If a + Hashable, must be the name of a coordinate contained in this dataarray. + squeeze : bool, default: True + If "group" is a dimension of any arrays in this dataset, `squeeze` + controls whether the subarrays have a dimension of length 1 along + that dimension or if the dimension is squeezed out. + restore_coord_dims : bool, default: False + If True, also restore the dimension order of multi-dimensional + coordinates. + + Returns + ------- + grouped : DataArrayGroupBy + A `DataArrayGroupBy` object patterned after `pandas.GroupBy` that can be + iterated over in the form of `(unique_value, grouped_array)` pairs. + + Examples + -------- + Calculate daily anomalies for daily data: + + >>> da = xr.DataArray( + ... np.linspace(0, 1826, num=1827), + ... coords=[pd.date_range("2000-01-01", "2004-12-31", freq="D")], + ... dims="time", + ... ) + >>> da + Size: 15kB + array([0.000e+00, 1.000e+00, 2.000e+00, ..., 1.824e+03, 1.825e+03, + 1.826e+03]) + Coordinates: + * time (time) datetime64[ns] 15kB 2000-01-01 2000-01-02 ... 2004-12-31 + >>> da.groupby("time.dayofyear") - da.groupby("time.dayofyear").mean("time") + Size: 15kB + array([-730.8, -730.8, -730.8, ..., 730.2, 730.2, 730.5]) + Coordinates: + * time (time) datetime64[ns] 15kB 2000-01-01 2000-01-02 ... 2004-12-31 + dayofyear (time) int64 15kB 1 2 3 4 5 6 7 8 ... 360 361 362 363 364 365 366 + + See Also + -------- + :ref:`groupby` + Users guide explanation of how to group and bin data. + + :doc:`xarray-tutorial:intermediate/01-high-level-computation-patterns` + Tutorial on :py:func:`~xarray.DataArray.Groupby` for windowed computation + + :doc:`xarray-tutorial:fundamentals/03.2_groupby_with_xarray` + Tutorial on :py:func:`~xarray.DataArray.Groupby` demonstrating reductions, transformation and comparison with :py:func:`~xarray.DataArray.resample` + + DataArray.groupby_bins + Dataset.groupby + core.groupby.DataArrayGroupBy + DataArray.coarsen + pandas.DataFrame.groupby + Dataset.resample + DataArray.resample + """ + from xarray.core.groupby import ( + DataArrayGroupBy, + ResolvedGrouper, + UniqueGrouper, + _validate_groupby_squeeze, + ) + + _validate_groupby_squeeze(squeeze) + rgrouper = ResolvedGrouper(UniqueGrouper(), group, self) + return DataArrayGroupBy( + self, + (rgrouper,), + squeeze=squeeze, + restore_coord_dims=restore_coord_dims, + ) + + def groupby_bins( + self, + group: Hashable | DataArray | IndexVariable, + bins: ArrayLike, + right: bool = True, + labels: ArrayLike | Literal[False] | None = None, + precision: int = 3, + include_lowest: bool = False, + squeeze: bool | None = None, + restore_coord_dims: bool = False, + ) -> DataArrayGroupBy: + """Returns a DataArrayGroupBy object for performing grouped operations. + + Rather than using all unique values of `group`, the values are discretized + first by applying `pandas.cut` [1]_ to `group`. + + Parameters + ---------- + group : Hashable, DataArray or IndexVariable + Array whose binned values should be used to group this array. If a + Hashable, must be the name of a coordinate contained in this dataarray. + bins : int or array-like + If bins is an int, it defines the number of equal-width bins in the + range of x. However, in this case, the range of x is extended by .1% + on each side to include the min or max values of x. If bins is a + sequence it defines the bin edges allowing for non-uniform bin + width. No extension of the range of x is done in this case. + right : bool, default: True + Indicates whether the bins include the rightmost edge or not. If + right == True (the default), then the bins [1,2,3,4] indicate + (1,2], (2,3], (3,4]. + labels : array-like, False or None, default: None + Used as labels for the resulting bins. Must be of the same length as + the resulting bins. If False, string bin labels are assigned by + `pandas.cut`. + precision : int, default: 3 + The precision at which to store and display the bins labels. + include_lowest : bool, default: False + Whether the first interval should be left-inclusive or not. + squeeze : bool, default: True + If "group" is a dimension of any arrays in this dataset, `squeeze` + controls whether the subarrays have a dimension of length 1 along + that dimension or if the dimension is squeezed out. + restore_coord_dims : bool, default: False + If True, also restore the dimension order of multi-dimensional + coordinates. + + Returns + ------- + grouped : DataArrayGroupBy + A `DataArrayGroupBy` object patterned after `pandas.GroupBy` that can be + iterated over in the form of `(unique_value, grouped_array)` pairs. + The name of the group has the added suffix `_bins` in order to + distinguish it from the original variable. + + See Also + -------- + :ref:`groupby` + Users guide explanation of how to group and bin data. + DataArray.groupby + Dataset.groupby_bins + core.groupby.DataArrayGroupBy + pandas.DataFrame.groupby + + References + ---------- + .. [1] http://pandas.pydata.org/pandas-docs/stable/generated/pandas.cut.html + """ + from xarray.core.groupby import ( + BinGrouper, + DataArrayGroupBy, + ResolvedGrouper, + _validate_groupby_squeeze, + ) + + _validate_groupby_squeeze(squeeze) + grouper = BinGrouper( + bins=bins, + cut_kwargs={ + "right": right, + "labels": labels, + "precision": precision, + "include_lowest": include_lowest, + }, + ) + rgrouper = ResolvedGrouper(grouper, group, self) + + return DataArrayGroupBy( + self, + (rgrouper,), + squeeze=squeeze, + restore_coord_dims=restore_coord_dims, + ) + + def weighted(self, weights: DataArray) -> DataArrayWeighted: + """ + Weighted DataArray operations. + + Parameters + ---------- + weights : DataArray + An array of weights associated with the values in this Dataset. + Each value in the data contributes to the reduction operation + according to its associated weight. + + Notes + ----- + ``weights`` must be a DataArray and cannot contain missing values. + Missing values can be replaced by ``weights.fillna(0)``. + + Returns + ------- + core.weighted.DataArrayWeighted + + See Also + -------- + Dataset.weighted + + :ref:`comput.weighted` + User guide on weighted array reduction using :py:func:`~xarray.DataArray.weighted` + + :doc:`xarray-tutorial:fundamentals/03.4_weighted` + Tutorial on Weighted Reduction using :py:func:`~xarray.DataArray.weighted` + + """ + from xarray.core.weighted import DataArrayWeighted + + return DataArrayWeighted(self, weights) + + def rolling( + self, + dim: Mapping[Any, int] | None = None, + min_periods: int | None = None, + center: bool | Mapping[Any, bool] = False, + **window_kwargs: int, + ) -> DataArrayRolling: + """ + Rolling window object for DataArrays. + + Parameters + ---------- + dim : dict, optional + Mapping from the dimension name to create the rolling iterator + along (e.g. `time`) to its moving window size. + min_periods : int or None, default: None + Minimum number of observations in window required to have a value + (otherwise result is NA). The default, None, is equivalent to + setting min_periods equal to the size of the window. + center : bool or Mapping to int, default: False + Set the labels at the center of the window. + **window_kwargs : optional + The keyword arguments form of ``dim``. + One of dim or window_kwargs must be provided. + + Returns + ------- + core.rolling.DataArrayRolling + + Examples + -------- + Create rolling seasonal average of monthly data e.g. DJF, JFM, ..., SON: + + >>> da = xr.DataArray( + ... np.linspace(0, 11, num=12), + ... coords=[ + ... pd.date_range( + ... "1999-12-15", + ... periods=12, + ... freq=pd.DateOffset(months=1), + ... ) + ... ], + ... dims="time", + ... ) + >>> da + Size: 96B + array([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.]) + Coordinates: + * time (time) datetime64[ns] 96B 1999-12-15 2000-01-15 ... 2000-11-15 + >>> da.rolling(time=3, center=True).mean() + Size: 96B + array([nan, 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., nan]) + Coordinates: + * time (time) datetime64[ns] 96B 1999-12-15 2000-01-15 ... 2000-11-15 + + Remove the NaNs using ``dropna()``: + + >>> da.rolling(time=3, center=True).mean().dropna("time") + Size: 80B + array([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]) + Coordinates: + * time (time) datetime64[ns] 80B 2000-01-15 2000-02-15 ... 2000-10-15 + + See Also + -------- + DataArray.cumulative + Dataset.rolling + core.rolling.DataArrayRolling + """ + from xarray.core.rolling import DataArrayRolling + + dim = either_dict_or_kwargs(dim, window_kwargs, "rolling") + return DataArrayRolling(self, dim, min_periods=min_periods, center=center) + + def cumulative( + self, + dim: str | Iterable[Hashable], + min_periods: int = 1, + ) -> DataArrayRolling: + """ + Accumulating object for DataArrays. + + Parameters + ---------- + dims : iterable of hashable + The name(s) of the dimensions to create the cumulative window along + min_periods : int, default: 1 + Minimum number of observations in window required to have a value + (otherwise result is NA). The default is 1 (note this is different + from ``Rolling``, whose default is the size of the window). + + Returns + ------- + core.rolling.DataArrayRolling + + Examples + -------- + Create rolling seasonal average of monthly data e.g. DJF, JFM, ..., SON: + + >>> da = xr.DataArray( + ... np.linspace(0, 11, num=12), + ... coords=[ + ... pd.date_range( + ... "1999-12-15", + ... periods=12, + ... freq=pd.DateOffset(months=1), + ... ) + ... ], + ... dims="time", + ... ) + + >>> da + Size: 96B + array([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.]) + Coordinates: + * time (time) datetime64[ns] 96B 1999-12-15 2000-01-15 ... 2000-11-15 + + >>> da.cumulative("time").sum() + Size: 96B + array([ 0., 1., 3., 6., 10., 15., 21., 28., 36., 45., 55., 66.]) + Coordinates: + * time (time) datetime64[ns] 96B 1999-12-15 2000-01-15 ... 2000-11-15 + + See Also + -------- + DataArray.rolling + Dataset.cumulative + core.rolling.DataArrayRolling + """ + from xarray.core.rolling import DataArrayRolling + + # Could we abstract this "normalize and check 'dim'" logic? It's currently shared + # with the same method in Dataset. + if isinstance(dim, str): + if dim not in self.dims: + raise ValueError( + f"Dimension {dim} not found in data dimensions: {self.dims}" + ) + dim = {dim: self.sizes[dim]} + else: + missing_dims = set(dim) - set(self.dims) + if missing_dims: + raise ValueError( + f"Dimensions {missing_dims} not found in data dimensions: {self.dims}" + ) + dim = {d: self.sizes[d] for d in dim} + + return DataArrayRolling(self, dim, min_periods=min_periods, center=False) + + def coarsen( + self, + dim: Mapping[Any, int] | None = None, + boundary: CoarsenBoundaryOptions = "exact", + side: SideOptions | Mapping[Any, SideOptions] = "left", + coord_func: str | Callable | Mapping[Any, str | Callable] = "mean", + **window_kwargs: int, + ) -> DataArrayCoarsen: + """ + Coarsen object for DataArrays. + + Parameters + ---------- + dim : mapping of hashable to int, optional + Mapping from the dimension name to the window size. + boundary : {"exact", "trim", "pad"}, default: "exact" + If 'exact', a ValueError will be raised if dimension size is not a + multiple of the window size. If 'trim', the excess entries are + dropped. If 'pad', NA will be padded. + side : {"left", "right"} or mapping of str to {"left", "right"}, default: "left" + coord_func : str or mapping of hashable to str, default: "mean" + function (name) that is applied to the coordinates, + or a mapping from coordinate name to function (name). + + Returns + ------- + core.rolling.DataArrayCoarsen + + Examples + -------- + Coarsen the long time series by averaging over every three days. + + >>> da = xr.DataArray( + ... np.linspace(0, 364, num=364), + ... dims="time", + ... coords={"time": pd.date_range("1999-12-15", periods=364)}, + ... ) + >>> da # +doctest: ELLIPSIS + Size: 3kB + array([ 0. , 1.00275482, 2.00550964, 3.00826446, + 4.01101928, 5.0137741 , 6.01652893, 7.01928375, + 8.02203857, 9.02479339, 10.02754821, 11.03030303, + 12.03305785, 13.03581267, 14.03856749, 15.04132231, + 16.04407713, 17.04683196, 18.04958678, 19.0523416 , + 20.05509642, 21.05785124, 22.06060606, 23.06336088, + 24.0661157 , 25.06887052, 26.07162534, 27.07438017, + 28.07713499, 29.07988981, 30.08264463, 31.08539945, + 32.08815427, 33.09090909, 34.09366391, 35.09641873, + 36.09917355, 37.10192837, 38.1046832 , 39.10743802, + 40.11019284, 41.11294766, 42.11570248, 43.1184573 , + 44.12121212, 45.12396694, 46.12672176, 47.12947658, + 48.1322314 , 49.13498623, 50.13774105, 51.14049587, + 52.14325069, 53.14600551, 54.14876033, 55.15151515, + 56.15426997, 57.15702479, 58.15977961, 59.16253444, + 60.16528926, 61.16804408, 62.1707989 , 63.17355372, + 64.17630854, 65.17906336, 66.18181818, 67.184573 , + 68.18732782, 69.19008264, 70.19283747, 71.19559229, + 72.19834711, 73.20110193, 74.20385675, 75.20661157, + 76.20936639, 77.21212121, 78.21487603, 79.21763085, + ... + 284.78236915, 285.78512397, 286.78787879, 287.79063361, + 288.79338843, 289.79614325, 290.79889807, 291.80165289, + 292.80440771, 293.80716253, 294.80991736, 295.81267218, + 296.815427 , 297.81818182, 298.82093664, 299.82369146, + 300.82644628, 301.8292011 , 302.83195592, 303.83471074, + 304.83746556, 305.84022039, 306.84297521, 307.84573003, + 308.84848485, 309.85123967, 310.85399449, 311.85674931, + 312.85950413, 313.86225895, 314.86501377, 315.8677686 , + 316.87052342, 317.87327824, 318.87603306, 319.87878788, + 320.8815427 , 321.88429752, 322.88705234, 323.88980716, + 324.89256198, 325.8953168 , 326.89807163, 327.90082645, + 328.90358127, 329.90633609, 330.90909091, 331.91184573, + 332.91460055, 333.91735537, 334.92011019, 335.92286501, + 336.92561983, 337.92837466, 338.93112948, 339.9338843 , + 340.93663912, 341.93939394, 342.94214876, 343.94490358, + 344.9476584 , 345.95041322, 346.95316804, 347.95592287, + 348.95867769, 349.96143251, 350.96418733, 351.96694215, + 352.96969697, 353.97245179, 354.97520661, 355.97796143, + 356.98071625, 357.98347107, 358.9862259 , 359.98898072, + 360.99173554, 361.99449036, 362.99724518, 364. ]) + Coordinates: + * time (time) datetime64[ns] 3kB 1999-12-15 1999-12-16 ... 2000-12-12 + >>> da.coarsen(time=3, boundary="trim").mean() # +doctest: ELLIPSIS + Size: 968B + array([ 1.00275482, 4.01101928, 7.01928375, 10.02754821, + 13.03581267, 16.04407713, 19.0523416 , 22.06060606, + 25.06887052, 28.07713499, 31.08539945, 34.09366391, + 37.10192837, 40.11019284, 43.1184573 , 46.12672176, + 49.13498623, 52.14325069, 55.15151515, 58.15977961, + 61.16804408, 64.17630854, 67.184573 , 70.19283747, + 73.20110193, 76.20936639, 79.21763085, 82.22589532, + 85.23415978, 88.24242424, 91.25068871, 94.25895317, + 97.26721763, 100.27548209, 103.28374656, 106.29201102, + 109.30027548, 112.30853994, 115.31680441, 118.32506887, + 121.33333333, 124.3415978 , 127.34986226, 130.35812672, + 133.36639118, 136.37465565, 139.38292011, 142.39118457, + 145.39944904, 148.4077135 , 151.41597796, 154.42424242, + 157.43250689, 160.44077135, 163.44903581, 166.45730028, + 169.46556474, 172.4738292 , 175.48209366, 178.49035813, + 181.49862259, 184.50688705, 187.51515152, 190.52341598, + 193.53168044, 196.5399449 , 199.54820937, 202.55647383, + 205.56473829, 208.57300275, 211.58126722, 214.58953168, + 217.59779614, 220.60606061, 223.61432507, 226.62258953, + 229.63085399, 232.63911846, 235.64738292, 238.65564738, + 241.66391185, 244.67217631, 247.68044077, 250.68870523, + 253.6969697 , 256.70523416, 259.71349862, 262.72176309, + 265.73002755, 268.73829201, 271.74655647, 274.75482094, + 277.7630854 , 280.77134986, 283.77961433, 286.78787879, + 289.79614325, 292.80440771, 295.81267218, 298.82093664, + 301.8292011 , 304.83746556, 307.84573003, 310.85399449, + 313.86225895, 316.87052342, 319.87878788, 322.88705234, + 325.8953168 , 328.90358127, 331.91184573, 334.92011019, + 337.92837466, 340.93663912, 343.94490358, 346.95316804, + 349.96143251, 352.96969697, 355.97796143, 358.9862259 , + 361.99449036]) + Coordinates: + * time (time) datetime64[ns] 968B 1999-12-16 1999-12-19 ... 2000-12-10 + >>> + + See Also + -------- + core.rolling.DataArrayCoarsen + Dataset.coarsen + + :ref:`reshape.coarsen` + User guide describing :py:func:`~xarray.DataArray.coarsen` + + :ref:`compute.coarsen` + User guide on block arrgragation :py:func:`~xarray.DataArray.coarsen` + + :doc:`xarray-tutorial:fundamentals/03.3_windowed` + Tutorial on windowed computation using :py:func:`~xarray.DataArray.coarsen` + + """ + from xarray.core.rolling import DataArrayCoarsen + + dim = either_dict_or_kwargs(dim, window_kwargs, "coarsen") + return DataArrayCoarsen( + self, + dim, + boundary=boundary, + side=side, + coord_func=coord_func, + ) + + def resample( + self, + indexer: Mapping[Any, str] | None = None, + skipna: bool | None = None, + closed: SideOptions | None = None, + label: SideOptions | None = None, + base: int | None = None, + offset: pd.Timedelta | datetime.timedelta | str | None = None, + origin: str | DatetimeLike = "start_day", + loffset: datetime.timedelta | str | None = None, + restore_coord_dims: bool | None = None, + **indexer_kwargs: str, + ) -> DataArrayResample: + """Returns a Resample object for performing resampling operations. + + Handles both downsampling and upsampling. The resampled + dimension must be a datetime-like coordinate. If any intervals + contain no values from the original object, they will be given + the value ``NaN``. + + Parameters + ---------- + indexer : Mapping of Hashable to str, optional + Mapping from the dimension name to resample frequency [1]_. The + dimension must be datetime-like. + skipna : bool, optional + Whether to skip missing values when aggregating in downsampling. + closed : {"left", "right"}, optional + Side of each interval to treat as closed. + label : {"left", "right"}, optional + Side of each interval to use for labeling. + base : int, optional + For frequencies that evenly subdivide 1 day, the "origin" of the + aggregated intervals. For example, for "24H" frequency, base could + range from 0 through 23. + origin : {'epoch', 'start', 'start_day', 'end', 'end_day'}, pd.Timestamp, datetime.datetime, np.datetime64, or cftime.datetime, default 'start_day' + The datetime on which to adjust the grouping. The timezone of origin + must match the timezone of the index. + + If a datetime is not used, these values are also supported: + - 'epoch': `origin` is 1970-01-01 + - 'start': `origin` is the first value of the timeseries + - 'start_day': `origin` is the first day at midnight of the timeseries + - 'end': `origin` is the last value of the timeseries + - 'end_day': `origin` is the ceiling midnight of the last day + offset : pd.Timedelta, datetime.timedelta, or str, default is None + An offset timedelta added to the origin. + loffset : timedelta or str, optional + Offset used to adjust the resampled time labels. Some pandas date + offset strings are supported. + + .. deprecated:: 2023.03.0 + Following pandas, the ``loffset`` parameter is deprecated in favor + of using time offset arithmetic, and will be removed in a future + version of xarray. + + restore_coord_dims : bool, optional + If True, also restore the dimension order of multi-dimensional + coordinates. + **indexer_kwargs : str + The keyword arguments form of ``indexer``. + One of indexer or indexer_kwargs must be provided. + + Returns + ------- + resampled : core.resample.DataArrayResample + This object resampled. + + Examples + -------- + Downsample monthly time-series data to seasonal data: + + >>> da = xr.DataArray( + ... np.linspace(0, 11, num=12), + ... coords=[ + ... pd.date_range( + ... "1999-12-15", + ... periods=12, + ... freq=pd.DateOffset(months=1), + ... ) + ... ], + ... dims="time", + ... ) + >>> da + Size: 96B + array([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.]) + Coordinates: + * time (time) datetime64[ns] 96B 1999-12-15 2000-01-15 ... 2000-11-15 + >>> da.resample(time="QS-DEC").mean() + Size: 32B + array([ 1., 4., 7., 10.]) + Coordinates: + * time (time) datetime64[ns] 32B 1999-12-01 2000-03-01 ... 2000-09-01 + + Upsample monthly time-series data to daily data: + + >>> da.resample(time="1D").interpolate("linear") # +doctest: ELLIPSIS + Size: 3kB + array([ 0. , 0.03225806, 0.06451613, 0.09677419, 0.12903226, + 0.16129032, 0.19354839, 0.22580645, 0.25806452, 0.29032258, + 0.32258065, 0.35483871, 0.38709677, 0.41935484, 0.4516129 , + 0.48387097, 0.51612903, 0.5483871 , 0.58064516, 0.61290323, + 0.64516129, 0.67741935, 0.70967742, 0.74193548, 0.77419355, + 0.80645161, 0.83870968, 0.87096774, 0.90322581, 0.93548387, + 0.96774194, 1. , 1.03225806, 1.06451613, 1.09677419, + 1.12903226, 1.16129032, 1.19354839, 1.22580645, 1.25806452, + 1.29032258, 1.32258065, 1.35483871, 1.38709677, 1.41935484, + 1.4516129 , 1.48387097, 1.51612903, 1.5483871 , 1.58064516, + 1.61290323, 1.64516129, 1.67741935, 1.70967742, 1.74193548, + 1.77419355, 1.80645161, 1.83870968, 1.87096774, 1.90322581, + 1.93548387, 1.96774194, 2. , 2.03448276, 2.06896552, + 2.10344828, 2.13793103, 2.17241379, 2.20689655, 2.24137931, + 2.27586207, 2.31034483, 2.34482759, 2.37931034, 2.4137931 , + 2.44827586, 2.48275862, 2.51724138, 2.55172414, 2.5862069 , + 2.62068966, 2.65517241, 2.68965517, 2.72413793, 2.75862069, + 2.79310345, 2.82758621, 2.86206897, 2.89655172, 2.93103448, + 2.96551724, 3. , 3.03225806, 3.06451613, 3.09677419, + 3.12903226, 3.16129032, 3.19354839, 3.22580645, 3.25806452, + ... + 7.87096774, 7.90322581, 7.93548387, 7.96774194, 8. , + 8.03225806, 8.06451613, 8.09677419, 8.12903226, 8.16129032, + 8.19354839, 8.22580645, 8.25806452, 8.29032258, 8.32258065, + 8.35483871, 8.38709677, 8.41935484, 8.4516129 , 8.48387097, + 8.51612903, 8.5483871 , 8.58064516, 8.61290323, 8.64516129, + 8.67741935, 8.70967742, 8.74193548, 8.77419355, 8.80645161, + 8.83870968, 8.87096774, 8.90322581, 8.93548387, 8.96774194, + 9. , 9.03333333, 9.06666667, 9.1 , 9.13333333, + 9.16666667, 9.2 , 9.23333333, 9.26666667, 9.3 , + 9.33333333, 9.36666667, 9.4 , 9.43333333, 9.46666667, + 9.5 , 9.53333333, 9.56666667, 9.6 , 9.63333333, + 9.66666667, 9.7 , 9.73333333, 9.76666667, 9.8 , + 9.83333333, 9.86666667, 9.9 , 9.93333333, 9.96666667, + 10. , 10.03225806, 10.06451613, 10.09677419, 10.12903226, + 10.16129032, 10.19354839, 10.22580645, 10.25806452, 10.29032258, + 10.32258065, 10.35483871, 10.38709677, 10.41935484, 10.4516129 , + 10.48387097, 10.51612903, 10.5483871 , 10.58064516, 10.61290323, + 10.64516129, 10.67741935, 10.70967742, 10.74193548, 10.77419355, + 10.80645161, 10.83870968, 10.87096774, 10.90322581, 10.93548387, + 10.96774194, 11. ]) + Coordinates: + * time (time) datetime64[ns] 3kB 1999-12-15 1999-12-16 ... 2000-11-15 + + Limit scope of upsampling method + + >>> da.resample(time="1D").nearest(tolerance="1D") + Size: 3kB + array([ 0., 0., nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, 1., 1., 1., nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, 2., 2., 2., nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, 3., + 3., 3., nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, 4., 4., 4., nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, 5., 5., 5., nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + 6., 6., 6., nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, 7., 7., 7., nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, 8., 8., 8., nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, 9., 9., 9., nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, 10., 10., 10., nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, 11., 11.]) + Coordinates: + * time (time) datetime64[ns] 3kB 1999-12-15 1999-12-16 ... 2000-11-15 + + See Also + -------- + Dataset.resample + pandas.Series.resample + pandas.DataFrame.resample + + References + ---------- + .. [1] http://pandas.pydata.org/pandas-docs/stable/timeseries.html#offset-aliases + """ + from xarray.core.resample import DataArrayResample + + return self._resample( + resample_cls=DataArrayResample, + indexer=indexer, + skipna=skipna, + closed=closed, + label=label, + base=base, + offset=offset, + origin=origin, + loffset=loffset, + restore_coord_dims=restore_coord_dims, + **indexer_kwargs, + ) + + def to_dask_dataframe( + self, + dim_order: Sequence[Hashable] | None = None, + set_index: bool = False, + ) -> DaskDataFrame: + """Convert this array into a dask.dataframe.DataFrame. + + Parameters + ---------- + dim_order : Sequence of Hashable or None , optional + Hierarchical dimension order for the resulting dataframe. + Array content is transposed to this order and then written out as flat + vectors in contiguous order, so the last dimension in this list + will be contiguous in the resulting DataFrame. This has a major influence + on which operations are efficient on the resulting dask dataframe. + set_index : bool, default: False + If set_index=True, the dask DataFrame is indexed by this dataset's + coordinate. Since dask DataFrames do not support multi-indexes, + set_index only works if the dataset only contains one dimension. + + Returns + ------- + dask.dataframe.DataFrame + + Examples + -------- + >>> da = xr.DataArray( + ... np.arange(4 * 2 * 2).reshape(4, 2, 2), + ... dims=("time", "lat", "lon"), + ... coords={ + ... "time": np.arange(4), + ... "lat": [-30, -20], + ... "lon": [120, 130], + ... }, + ... name="eg_dataarray", + ... attrs={"units": "Celsius", "description": "Random temperature data"}, + ... ) + >>> da.to_dask_dataframe(["lat", "lon", "time"]).compute() + lat lon time eg_dataarray + 0 -30 120 0 0 + 1 -30 120 1 4 + 2 -30 120 2 8 + 3 -30 120 3 12 + 4 -30 130 0 1 + 5 -30 130 1 5 + 6 -30 130 2 9 + 7 -30 130 3 13 + 8 -20 120 0 2 + 9 -20 120 1 6 + 10 -20 120 2 10 + 11 -20 120 3 14 + 12 -20 130 0 3 + 13 -20 130 1 7 + 14 -20 130 2 11 + 15 -20 130 3 15 + """ + if self.name is None: + raise ValueError( + "Cannot convert an unnamed DataArray to a " + "dask dataframe : use the ``.rename`` method to assign a name." + ) + name = self.name + ds = self._to_dataset_whole(name, shallow_copy=False) + return ds.to_dask_dataframe(dim_order, set_index) + + # this needs to be at the end, or mypy will confuse with `str` + # https://mypy.readthedocs.io/en/latest/common_issues.html#dealing-with-conflicting-names + str = utils.UncachedAccessor(StringAccessor["DataArray"]) diff --git a/test/fixtures/whole_applications/xarray/xarray/core/dataset.py b/test/fixtures/whole_applications/xarray/xarray/core/dataset.py new file mode 100644 index 0000000..872cb48 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/core/dataset.py @@ -0,0 +1,10680 @@ +from __future__ import annotations + +import copy +import datetime +import inspect +import itertools +import math +import sys +import warnings +from collections import defaultdict +from collections.abc import ( + Collection, + Hashable, + Iterable, + Iterator, + Mapping, + MutableMapping, + Sequence, +) +from functools import partial +from html import escape +from numbers import Number +from operator import methodcaller +from os import PathLike +from typing import IO, TYPE_CHECKING, Any, Callable, Generic, Literal, cast, overload + +import numpy as np +from pandas.api.types import is_extension_array_dtype + +# remove once numpy 2.0 is the oldest supported version +try: + from numpy.exceptions import RankWarning # type: ignore[attr-defined,unused-ignore] +except ImportError: + from numpy import RankWarning + +import pandas as pd + +from xarray.coding.calendar_ops import convert_calendar, interp_calendar +from xarray.coding.cftimeindex import CFTimeIndex, _parse_array_of_cftime_strings +from xarray.core import ( + alignment, + duck_array_ops, + formatting, + formatting_html, + ops, + utils, +) +from xarray.core import dtypes as xrdtypes +from xarray.core._aggregations import DatasetAggregations +from xarray.core.alignment import ( + _broadcast_helper, + _get_broadcast_dims_map_common_coords, + align, +) +from xarray.core.arithmetic import DatasetArithmetic +from xarray.core.common import ( + DataWithCoords, + _contains_datetime_like_objects, + get_chunksizes, +) +from xarray.core.computation import unify_chunks +from xarray.core.coordinates import ( + Coordinates, + DatasetCoordinates, + assert_coordinate_consistent, + create_coords_with_default_indexes, +) +from xarray.core.duck_array_ops import datetime_to_numeric +from xarray.core.indexes import ( + Index, + Indexes, + PandasIndex, + PandasMultiIndex, + assert_no_index_corrupted, + create_default_index_implicit, + filter_indexes_from_coords, + isel_indexes, + remove_unused_levels_categories, + roll_indexes, +) +from xarray.core.indexing import is_fancy_indexer, map_index_queries +from xarray.core.merge import ( + dataset_merge_method, + dataset_update_method, + merge_coordinates_without_align, + merge_core, +) +from xarray.core.missing import get_clean_interp_index +from xarray.core.options import OPTIONS, _get_keep_attrs +from xarray.core.types import ( + NetcdfWriteModes, + QuantileMethods, + Self, + T_ChunkDim, + T_Chunks, + T_DataArray, + T_DataArrayOrSet, + T_Dataset, + ZarrWriteModes, +) +from xarray.core.utils import ( + Default, + Frozen, + FrozenMappingWarningOnValuesAccess, + HybridMappingProxy, + OrderedSet, + _default, + decode_numpy_dict_values, + drop_dims_from_indexers, + either_dict_or_kwargs, + emit_user_level_warning, + infix_dims, + is_dict_like, + is_duck_array, + is_duck_dask_array, + is_scalar, + maybe_wrap_array, +) +from xarray.core.variable import ( + IndexVariable, + Variable, + as_variable, + broadcast_variables, + calculate_dimensions, +) +from xarray.namedarray.parallelcompat import get_chunked_array_type, guess_chunkmanager +from xarray.namedarray.pycompat import array_type, is_chunked_array +from xarray.plot.accessor import DatasetPlotAccessor +from xarray.util.deprecation_helpers import _deprecate_positional_args, deprecate_dims + +if TYPE_CHECKING: + from dask.dataframe import DataFrame as DaskDataFrame + from dask.delayed import Delayed + from numpy.typing import ArrayLike + + from xarray.backends import AbstractDataStore, ZarrStore + from xarray.backends.api import T_NetcdfEngine, T_NetcdfTypes + from xarray.core.dataarray import DataArray + from xarray.core.groupby import DatasetGroupBy + from xarray.core.merge import CoercibleMapping, CoercibleValue, _MergeResult + from xarray.core.resample import DatasetResample + from xarray.core.rolling import DatasetCoarsen, DatasetRolling + from xarray.core.types import ( + CFCalendar, + CoarsenBoundaryOptions, + CombineAttrsOptions, + CompatOptions, + DataVars, + DatetimeLike, + DatetimeUnitOptions, + Dims, + DsCompatible, + ErrorOptions, + ErrorOptionsWithWarn, + InterpOptions, + JoinOptions, + PadModeOptions, + PadReflectOptions, + QueryEngineOptions, + QueryParserOptions, + ReindexMethodOptions, + SideOptions, + T_Xarray, + ) + from xarray.core.weighted import DatasetWeighted + from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint + + +# list of attributes of pd.DatetimeIndex that are ndarrays of time info +_DATETIMEINDEX_COMPONENTS = [ + "year", + "month", + "day", + "hour", + "minute", + "second", + "microsecond", + "nanosecond", + "date", + "time", + "dayofyear", + "weekofyear", + "dayofweek", + "quarter", +] + + +def _get_virtual_variable( + variables, key: Hashable, dim_sizes: Mapping | None = None +) -> tuple[Hashable, Hashable, Variable]: + """Get a virtual variable (e.g., 'time.year') from a dict of xarray.Variable + objects (if possible) + + """ + from xarray.core.dataarray import DataArray + + if dim_sizes is None: + dim_sizes = {} + + if key in dim_sizes: + data = pd.Index(range(dim_sizes[key]), name=key) + variable = IndexVariable((key,), data) + return key, key, variable + + if not isinstance(key, str): + raise KeyError(key) + + split_key = key.split(".", 1) + if len(split_key) != 2: + raise KeyError(key) + + ref_name, var_name = split_key + ref_var = variables[ref_name] + + if _contains_datetime_like_objects(ref_var): + ref_var = DataArray(ref_var) + data = getattr(ref_var.dt, var_name).data + else: + data = getattr(ref_var, var_name).data + virtual_var = Variable(ref_var.dims, data) + + return ref_name, var_name, virtual_var + + +def _get_chunk(var: Variable, chunks, chunkmanager: ChunkManagerEntrypoint): + """ + Return map from each dim to chunk sizes, accounting for backend's preferred chunks. + """ + + if isinstance(var, IndexVariable): + return {} + dims = var.dims + shape = var.shape + + # Determine the explicit requested chunks. + preferred_chunks = var.encoding.get("preferred_chunks", {}) + preferred_chunk_shape = tuple( + preferred_chunks.get(dim, size) for dim, size in zip(dims, shape) + ) + if isinstance(chunks, Number) or (chunks == "auto"): + chunks = dict.fromkeys(dims, chunks) + chunk_shape = tuple( + chunks.get(dim, None) or preferred_chunk_sizes + for dim, preferred_chunk_sizes in zip(dims, preferred_chunk_shape) + ) + + chunk_shape = chunkmanager.normalize_chunks( + chunk_shape, shape=shape, dtype=var.dtype, previous_chunks=preferred_chunk_shape + ) + + # Warn where requested chunks break preferred chunks, provided that the variable + # contains data. + if var.size: + for dim, size, chunk_sizes in zip(dims, shape, chunk_shape): + try: + preferred_chunk_sizes = preferred_chunks[dim] + except KeyError: + continue + # Determine the stop indices of the preferred chunks, but omit the last stop + # (equal to the dim size). In particular, assume that when a sequence + # expresses the preferred chunks, the sequence sums to the size. + preferred_stops = ( + range(preferred_chunk_sizes, size, preferred_chunk_sizes) + if isinstance(preferred_chunk_sizes, int) + else itertools.accumulate(preferred_chunk_sizes[:-1]) + ) + # Gather any stop indices of the specified chunks that are not a stop index + # of a preferred chunk. Again, omit the last stop, assuming that it equals + # the dim size. + breaks = set(itertools.accumulate(chunk_sizes[:-1])).difference( + preferred_stops + ) + if breaks: + warnings.warn( + "The specified chunks separate the stored chunks along " + f'dimension "{dim}" starting at index {min(breaks)}. This could ' + "degrade performance. Instead, consider rechunking after loading." + ) + + return dict(zip(dims, chunk_shape)) + + +def _maybe_chunk( + name, + var, + chunks, + token=None, + lock=None, + name_prefix="xarray-", + overwrite_encoded_chunks=False, + inline_array=False, + chunked_array_type: str | ChunkManagerEntrypoint | None = None, + from_array_kwargs=None, +): + + from xarray.namedarray.daskmanager import DaskManager + + if chunks is not None: + chunks = {dim: chunks[dim] for dim in var.dims if dim in chunks} + + if var.ndim: + chunked_array_type = guess_chunkmanager( + chunked_array_type + ) # coerce string to ChunkManagerEntrypoint type + if isinstance(chunked_array_type, DaskManager): + from dask.base import tokenize + + # when rechunking by different amounts, make sure dask names change + # by providing chunks as an input to tokenize. + # subtle bugs result otherwise. see GH3350 + # we use str() for speed, and use the name for the final array name on the next line + token2 = tokenize(token if token else var._data, str(chunks)) + name2 = f"{name_prefix}{name}-{token2}" + + from_array_kwargs = utils.consolidate_dask_from_array_kwargs( + from_array_kwargs, + name=name2, + lock=lock, + inline_array=inline_array, + ) + + var = var.chunk( + chunks, + chunked_array_type=chunked_array_type, + from_array_kwargs=from_array_kwargs, + ) + + if overwrite_encoded_chunks and var.chunks is not None: + var.encoding["chunks"] = tuple(x[0] for x in var.chunks) + return var + else: + return var + + +def as_dataset(obj: Any) -> Dataset: + """Cast the given object to a Dataset. + + Handles Datasets, DataArrays and dictionaries of variables. A new Dataset + object is only created if the provided object is not already one. + """ + if hasattr(obj, "to_dataset"): + obj = obj.to_dataset() + if not isinstance(obj, Dataset): + obj = Dataset(obj) + return obj + + +def _get_func_args(func, param_names): + """Use `inspect.signature` to try accessing `func` args. Otherwise, ensure + they are provided by user. + """ + try: + func_args = inspect.signature(func).parameters + except ValueError: + func_args = {} + if not param_names: + raise ValueError( + "Unable to inspect `func` signature, and `param_names` was not provided." + ) + if param_names: + params = param_names + else: + params = list(func_args)[1:] + if any( + [(p.kind in [p.VAR_POSITIONAL, p.VAR_KEYWORD]) for p in func_args.values()] + ): + raise ValueError( + "`param_names` must be provided because `func` takes variable length arguments." + ) + return params, func_args + + +def _initialize_curvefit_params(params, p0, bounds, func_args): + """Set initial guess and bounds for curvefit. + Priority: 1) passed args 2) func signature 3) scipy defaults + """ + from xarray.core.computation import where + + def _initialize_feasible(lb, ub): + # Mimics functionality of scipy.optimize.minpack._initialize_feasible + lb_finite = np.isfinite(lb) + ub_finite = np.isfinite(ub) + p0 = where( + lb_finite, + where( + ub_finite, + 0.5 * (lb + ub), # both bounds finite + lb + 1, # lower bound finite, upper infinite + ), + where( + ub_finite, + ub - 1, # lower bound infinite, upper finite + 0, # both bounds infinite + ), + ) + return p0 + + param_defaults = {p: 1 for p in params} + bounds_defaults = {p: (-np.inf, np.inf) for p in params} + for p in params: + if p in func_args and func_args[p].default is not func_args[p].empty: + param_defaults[p] = func_args[p].default + if p in bounds: + lb, ub = bounds[p] + bounds_defaults[p] = (lb, ub) + param_defaults[p] = where( + (param_defaults[p] < lb) | (param_defaults[p] > ub), + _initialize_feasible(lb, ub), + param_defaults[p], + ) + if p in p0: + param_defaults[p] = p0[p] + return param_defaults, bounds_defaults + + +def merge_data_and_coords(data_vars: DataVars, coords) -> _MergeResult: + """Used in Dataset.__init__.""" + if isinstance(coords, Coordinates): + coords = coords.copy() + else: + coords = create_coords_with_default_indexes(coords, data_vars) + + # exclude coords from alignment (all variables in a Coordinates object should + # already be aligned together) and use coordinates' indexes to align data_vars + return merge_core( + [data_vars, coords], + compat="broadcast_equals", + join="outer", + explicit_coords=tuple(coords), + indexes=coords.xindexes, + priority_arg=1, + skip_align_args=[1], + ) + + +class DataVariables(Mapping[Any, "DataArray"]): + __slots__ = ("_dataset",) + + def __init__(self, dataset: Dataset): + self._dataset = dataset + + def __iter__(self) -> Iterator[Hashable]: + return ( + key + for key in self._dataset._variables + if key not in self._dataset._coord_names + ) + + def __len__(self) -> int: + length = len(self._dataset._variables) - len(self._dataset._coord_names) + assert length >= 0, "something is wrong with Dataset._coord_names" + return length + + def __contains__(self, key: Hashable) -> bool: + return key in self._dataset._variables and key not in self._dataset._coord_names + + def __getitem__(self, key: Hashable) -> DataArray: + if key not in self._dataset._coord_names: + return self._dataset[key] + raise KeyError(key) + + def __repr__(self) -> str: + return formatting.data_vars_repr(self) + + @property + def variables(self) -> Mapping[Hashable, Variable]: + all_variables = self._dataset.variables + return Frozen({k: all_variables[k] for k in self}) + + @property + def dtypes(self) -> Frozen[Hashable, np.dtype]: + """Mapping from data variable names to dtypes. + + Cannot be modified directly, but is updated when adding new variables. + + See Also + -------- + Dataset.dtype + """ + return self._dataset.dtypes + + def _ipython_key_completions_(self): + """Provide method for the key-autocompletions in IPython.""" + return [ + key + for key in self._dataset._ipython_key_completions_() + if key not in self._dataset._coord_names + ] + + +class _LocIndexer(Generic[T_Dataset]): + __slots__ = ("dataset",) + + def __init__(self, dataset: T_Dataset): + self.dataset = dataset + + def __getitem__(self, key: Mapping[Any, Any]) -> T_Dataset: + if not utils.is_dict_like(key): + raise TypeError("can only lookup dictionaries from Dataset.loc") + return self.dataset.sel(key) + + def __setitem__(self, key, value) -> None: + if not utils.is_dict_like(key): + raise TypeError( + "can only set locations defined by dictionaries from Dataset.loc." + f" Got: {key}" + ) + + # set new values + dim_indexers = map_index_queries(self.dataset, key).dim_indexers + self.dataset[dim_indexers] = value + + +class Dataset( + DataWithCoords, + DatasetAggregations, + DatasetArithmetic, + Mapping[Hashable, "DataArray"], +): + """A multi-dimensional, in memory, array database. + + A dataset resembles an in-memory representation of a NetCDF file, + and consists of variables, coordinates and attributes which + together form a self describing dataset. + + Dataset implements the mapping interface with keys given by variable + names and values given by DataArray objects for each variable name. + + By default, pandas indexes are created for one dimensional variables with + name equal to their dimension (i.e., :term:`Dimension coordinate`) so those + variables can be readily used as coordinates for label based indexing. When a + :py:class:`~xarray.Coordinates` object is passed to ``coords``, any existing + index(es) built from those coordinates will be added to the Dataset. + + To load data from a file or file-like object, use the `open_dataset` + function. + + Parameters + ---------- + data_vars : dict-like, optional + A mapping from variable names to :py:class:`~xarray.DataArray` + objects, :py:class:`~xarray.Variable` objects or to tuples of + the form ``(dims, data[, attrs])`` which can be used as + arguments to create a new ``Variable``. Each dimension must + have the same length in all variables in which it appears. + + The following notations are accepted: + + - mapping {var name: DataArray} + - mapping {var name: Variable} + - mapping {var name: (dimension name, array-like)} + - mapping {var name: (tuple of dimension names, array-like)} + - mapping {dimension name: array-like} + (if array-like is not a scalar it will be automatically moved to coords, + see below) + + Each dimension must have the same length in all variables in + which it appears. + coords : :py:class:`~xarray.Coordinates` or dict-like, optional + A :py:class:`~xarray.Coordinates` object or another mapping in + similar form as the `data_vars` argument, except that each item + is saved on the dataset as a "coordinate". + These variables have an associated meaning: they describe + constant/fixed/independent quantities, unlike the + varying/measured/dependent quantities that belong in + `variables`. + + The following notations are accepted for arbitrary mappings: + + - mapping {coord name: DataArray} + - mapping {coord name: Variable} + - mapping {coord name: (dimension name, array-like)} + - mapping {coord name: (tuple of dimension names, array-like)} + - mapping {dimension name: array-like} + (the dimension name is implicitly set to be the same as the + coord name) + + The last notation implies either that the coordinate value is a scalar + or that it is a 1-dimensional array and the coord name is the same as + the dimension name (i.e., a :term:`Dimension coordinate`). In the latter + case, the 1-dimensional array will be assumed to give index values + along the dimension with the same name. + + Alternatively, a :py:class:`~xarray.Coordinates` object may be used in + order to explicitly pass indexes (e.g., a multi-index or any custom + Xarray index) or to bypass the creation of a default index for any + :term:`Dimension coordinate` included in that object. + + attrs : dict-like, optional + Global attributes to save on this dataset. + + Examples + -------- + In this example dataset, we will represent measurements of the temperature + and pressure that were made under various conditions: + + * the measurements were made on four different days; + * they were made at two separate locations, which we will represent using + their latitude and longitude; and + * they were made using three instrument developed by three different + manufacturers, which we will refer to using the strings `'manufac1'`, + `'manufac2'`, and `'manufac3'`. + + >>> np.random.seed(0) + >>> temperature = 15 + 8 * np.random.randn(2, 3, 4) + >>> precipitation = 10 * np.random.rand(2, 3, 4) + >>> lon = [-99.83, -99.32] + >>> lat = [42.25, 42.21] + >>> instruments = ["manufac1", "manufac2", "manufac3"] + >>> time = pd.date_range("2014-09-06", periods=4) + >>> reference_time = pd.Timestamp("2014-09-05") + + Here, we initialize the dataset with multiple dimensions. We use the string + `"loc"` to represent the location dimension of the data, the string + `"instrument"` to represent the instrument manufacturer dimension, and the + string `"time"` for the time dimension. + + >>> ds = xr.Dataset( + ... data_vars=dict( + ... temperature=(["loc", "instrument", "time"], temperature), + ... precipitation=(["loc", "instrument", "time"], precipitation), + ... ), + ... coords=dict( + ... lon=("loc", lon), + ... lat=("loc", lat), + ... instrument=instruments, + ... time=time, + ... reference_time=reference_time, + ... ), + ... attrs=dict(description="Weather related data."), + ... ) + >>> ds + Size: 552B + Dimensions: (loc: 2, instrument: 3, time: 4) + Coordinates: + lon (loc) float64 16B -99.83 -99.32 + lat (loc) float64 16B 42.25 42.21 + * instrument (instrument) >> ds.isel(ds.temperature.argmin(...)) + Size: 80B + Dimensions: () + Coordinates: + lon float64 8B -99.32 + lat float64 8B 42.21 + instrument None: + if data_vars is None: + data_vars = {} + if coords is None: + coords = {} + + both_data_and_coords = set(data_vars) & set(coords) + if both_data_and_coords: + raise ValueError( + f"variables {both_data_and_coords!r} are found in both data_vars and coords" + ) + + if isinstance(coords, Dataset): + coords = coords._variables + + variables, coord_names, dims, indexes, _ = merge_data_and_coords( + data_vars, coords + ) + + self._attrs = dict(attrs) if attrs else None + self._close = None + self._encoding = None + self._variables = variables + self._coord_names = coord_names + self._dims = dims + self._indexes = indexes + + # TODO: dirty workaround for mypy 1.5 error with inherited DatasetOpsMixin vs. Mapping + # related to https://github.com/python/mypy/issues/9319? + def __eq__(self, other: DsCompatible) -> Self: # type: ignore[override] + return super().__eq__(other) + + @classmethod + def load_store(cls, store, decoder=None) -> Self: + """Create a new dataset from the contents of a backends.*DataStore + object + """ + variables, attributes = store.load() + if decoder: + variables, attributes = decoder(variables, attributes) + obj = cls(variables, attrs=attributes) + obj.set_close(store.close) + return obj + + @property + def variables(self) -> Frozen[Hashable, Variable]: + """Low level interface to Dataset contents as dict of Variable objects. + + This ordered dictionary is frozen to prevent mutation that could + violate Dataset invariants. It contains all variable objects + constituting the Dataset, including both data variables and + coordinates. + """ + return Frozen(self._variables) + + @property + def attrs(self) -> dict[Any, Any]: + """Dictionary of global attributes on this dataset""" + if self._attrs is None: + self._attrs = {} + return self._attrs + + @attrs.setter + def attrs(self, value: Mapping[Any, Any]) -> None: + self._attrs = dict(value) if value else None + + @property + def encoding(self) -> dict[Any, Any]: + """Dictionary of global encoding attributes on this dataset""" + if self._encoding is None: + self._encoding = {} + return self._encoding + + @encoding.setter + def encoding(self, value: Mapping[Any, Any]) -> None: + self._encoding = dict(value) + + def reset_encoding(self) -> Self: + warnings.warn( + "reset_encoding is deprecated since 2023.11, use `drop_encoding` instead" + ) + return self.drop_encoding() + + def drop_encoding(self) -> Self: + """Return a new Dataset without encoding on the dataset or any of its + variables/coords.""" + variables = {k: v.drop_encoding() for k, v in self.variables.items()} + return self._replace(variables=variables, encoding={}) + + @property + def dims(self) -> Frozen[Hashable, int]: + """Mapping from dimension names to lengths. + + Cannot be modified directly, but is updated when adding new variables. + + Note that type of this object differs from `DataArray.dims`. + See `Dataset.sizes` and `DataArray.sizes` for consistently named + properties. This property will be changed to return a type more consistent with + `DataArray.dims` in the future, i.e. a set of dimension names. + + See Also + -------- + Dataset.sizes + DataArray.dims + """ + return FrozenMappingWarningOnValuesAccess(self._dims) + + @property + def sizes(self) -> Frozen[Hashable, int]: + """Mapping from dimension names to lengths. + + Cannot be modified directly, but is updated when adding new variables. + + This is an alias for `Dataset.dims` provided for the benefit of + consistency with `DataArray.sizes`. + + See Also + -------- + DataArray.sizes + """ + return Frozen(self._dims) + + @property + def dtypes(self) -> Frozen[Hashable, np.dtype]: + """Mapping from data variable names to dtypes. + + Cannot be modified directly, but is updated when adding new variables. + + See Also + -------- + DataArray.dtype + """ + return Frozen( + { + n: v.dtype + for n, v in self._variables.items() + if n not in self._coord_names + } + ) + + def load(self, **kwargs) -> Self: + """Manually trigger loading and/or computation of this dataset's data + from disk or a remote source into memory and return this dataset. + Unlike compute, the original dataset is modified and returned. + + Normally, it should not be necessary to call this method in user code, + because all xarray functions should either work on deferred data or + load data automatically. However, this method can be necessary when + working with many file objects on disk. + + Parameters + ---------- + **kwargs : dict + Additional keyword arguments passed on to ``dask.compute``. + + See Also + -------- + dask.compute + """ + # access .data to coerce everything to numpy or dask arrays + lazy_data = { + k: v._data for k, v in self.variables.items() if is_chunked_array(v._data) + } + if lazy_data: + chunkmanager = get_chunked_array_type(*lazy_data.values()) + + # evaluate all the chunked arrays simultaneously + evaluated_data: tuple[np.ndarray[Any, Any], ...] = chunkmanager.compute( + *lazy_data.values(), **kwargs + ) + + for k, data in zip(lazy_data, evaluated_data): + self.variables[k].data = data + + # load everything else sequentially + for k, v in self.variables.items(): + if k not in lazy_data: + v.load() + + return self + + def __dask_tokenize__(self) -> object: + from dask.base import normalize_token + + return normalize_token( + (type(self), self._variables, self._coord_names, self._attrs or None) + ) + + def __dask_graph__(self): + graphs = {k: v.__dask_graph__() for k, v in self.variables.items()} + graphs = {k: v for k, v in graphs.items() if v is not None} + if not graphs: + return None + else: + try: + from dask.highlevelgraph import HighLevelGraph + + return HighLevelGraph.merge(*graphs.values()) + except ImportError: + from dask import sharedict + + return sharedict.merge(*graphs.values()) + + def __dask_keys__(self): + import dask + + return [ + v.__dask_keys__() + for v in self.variables.values() + if dask.is_dask_collection(v) + ] + + def __dask_layers__(self): + import dask + + return sum( + ( + v.__dask_layers__() + for v in self.variables.values() + if dask.is_dask_collection(v) + ), + (), + ) + + @property + def __dask_optimize__(self): + import dask.array as da + + return da.Array.__dask_optimize__ + + @property + def __dask_scheduler__(self): + import dask.array as da + + return da.Array.__dask_scheduler__ + + def __dask_postcompute__(self): + return self._dask_postcompute, () + + def __dask_postpersist__(self): + return self._dask_postpersist, () + + def _dask_postcompute(self, results: Iterable[Variable]) -> Self: + import dask + + variables = {} + results_iter = iter(results) + + for k, v in self._variables.items(): + if dask.is_dask_collection(v): + rebuild, args = v.__dask_postcompute__() + v = rebuild(next(results_iter), *args) + variables[k] = v + + return type(self)._construct_direct( + variables, + self._coord_names, + self._dims, + self._attrs, + self._indexes, + self._encoding, + self._close, + ) + + def _dask_postpersist( + self, dsk: Mapping, *, rename: Mapping[str, str] | None = None + ) -> Self: + from dask import is_dask_collection + from dask.highlevelgraph import HighLevelGraph + from dask.optimization import cull + + variables = {} + + for k, v in self._variables.items(): + if not is_dask_collection(v): + variables[k] = v + continue + + if isinstance(dsk, HighLevelGraph): + # dask >= 2021.3 + # __dask_postpersist__() was called by dask.highlevelgraph. + # Don't use dsk.cull(), as we need to prevent partial layers: + # https://github.com/dask/dask/issues/7137 + layers = v.__dask_layers__() + if rename: + layers = [rename.get(k, k) for k in layers] + dsk2 = dsk.cull_layers(layers) + elif rename: # pragma: nocover + # At the moment of writing, this is only for forward compatibility. + # replace_name_in_key requires dask >= 2021.3. + from dask.base import flatten, replace_name_in_key + + keys = [ + replace_name_in_key(k, rename) for k in flatten(v.__dask_keys__()) + ] + dsk2, _ = cull(dsk, keys) + else: + # __dask_postpersist__() was called by dask.optimize or dask.persist + dsk2, _ = cull(dsk, v.__dask_keys__()) + + rebuild, args = v.__dask_postpersist__() + # rename was added in dask 2021.3 + kwargs = {"rename": rename} if rename else {} + variables[k] = rebuild(dsk2, *args, **kwargs) + + return type(self)._construct_direct( + variables, + self._coord_names, + self._dims, + self._attrs, + self._indexes, + self._encoding, + self._close, + ) + + def compute(self, **kwargs) -> Self: + """Manually trigger loading and/or computation of this dataset's data + from disk or a remote source into memory and return a new dataset. + Unlike load, the original dataset is left unaltered. + + Normally, it should not be necessary to call this method in user code, + because all xarray functions should either work on deferred data or + load data automatically. However, this method can be necessary when + working with many file objects on disk. + + Parameters + ---------- + **kwargs : dict + Additional keyword arguments passed on to ``dask.compute``. + + Returns + ------- + object : Dataset + New object with lazy data variables and coordinates as in-memory arrays. + + See Also + -------- + dask.compute + """ + new = self.copy(deep=False) + return new.load(**kwargs) + + def _persist_inplace(self, **kwargs) -> Self: + """Persist all Dask arrays in memory""" + # access .data to coerce everything to numpy or dask arrays + lazy_data = { + k: v._data for k, v in self.variables.items() if is_duck_dask_array(v._data) + } + if lazy_data: + import dask + + # evaluate all the dask arrays simultaneously + evaluated_data = dask.persist(*lazy_data.values(), **kwargs) + + for k, data in zip(lazy_data, evaluated_data): + self.variables[k].data = data + + return self + + def persist(self, **kwargs) -> Self: + """Trigger computation, keeping data as dask arrays + + This operation can be used to trigger computation on underlying dask + arrays, similar to ``.compute()`` or ``.load()``. However this + operation keeps the data as dask arrays. This is particularly useful + when using the dask.distributed scheduler and you want to load a large + amount of data into distributed memory. + Like compute (but unlike load), the original dataset is left unaltered. + + Parameters + ---------- + **kwargs : dict + Additional keyword arguments passed on to ``dask.persist``. + + Returns + ------- + object : Dataset + New object with all dask-backed coordinates and data variables as persisted dask arrays. + + See Also + -------- + dask.persist + """ + new = self.copy(deep=False) + return new._persist_inplace(**kwargs) + + @classmethod + def _construct_direct( + cls, + variables: dict[Any, Variable], + coord_names: set[Hashable], + dims: dict[Any, int] | None = None, + attrs: dict | None = None, + indexes: dict[Any, Index] | None = None, + encoding: dict | None = None, + close: Callable[[], None] | None = None, + ) -> Self: + """Shortcut around __init__ for internal use when we want to skip + costly validation + """ + if dims is None: + dims = calculate_dimensions(variables) + if indexes is None: + indexes = {} + obj = object.__new__(cls) + obj._variables = variables + obj._coord_names = coord_names + obj._dims = dims + obj._indexes = indexes + obj._attrs = attrs + obj._close = close + obj._encoding = encoding + return obj + + def _replace( + self, + variables: dict[Hashable, Variable] | None = None, + coord_names: set[Hashable] | None = None, + dims: dict[Any, int] | None = None, + attrs: dict[Hashable, Any] | None | Default = _default, + indexes: dict[Hashable, Index] | None = None, + encoding: dict | None | Default = _default, + inplace: bool = False, + ) -> Self: + """Fastpath constructor for internal use. + + Returns an object with optionally with replaced attributes. + + Explicitly passed arguments are *not* copied when placed on the new + dataset. It is up to the caller to ensure that they have the right type + and are not used elsewhere. + """ + if inplace: + if variables is not None: + self._variables = variables + if coord_names is not None: + self._coord_names = coord_names + if dims is not None: + self._dims = dims + if attrs is not _default: + self._attrs = attrs + if indexes is not None: + self._indexes = indexes + if encoding is not _default: + self._encoding = encoding + obj = self + else: + if variables is None: + variables = self._variables.copy() + if coord_names is None: + coord_names = self._coord_names.copy() + if dims is None: + dims = self._dims.copy() + if attrs is _default: + attrs = copy.copy(self._attrs) + if indexes is None: + indexes = self._indexes.copy() + if encoding is _default: + encoding = copy.copy(self._encoding) + obj = self._construct_direct( + variables, coord_names, dims, attrs, indexes, encoding + ) + return obj + + def _replace_with_new_dims( + self, + variables: dict[Hashable, Variable], + coord_names: set | None = None, + attrs: dict[Hashable, Any] | None | Default = _default, + indexes: dict[Hashable, Index] | None = None, + inplace: bool = False, + ) -> Self: + """Replace variables with recalculated dimensions.""" + dims = calculate_dimensions(variables) + return self._replace( + variables, coord_names, dims, attrs, indexes, inplace=inplace + ) + + def _replace_vars_and_dims( + self, + variables: dict[Hashable, Variable], + coord_names: set | None = None, + dims: dict[Hashable, int] | None = None, + attrs: dict[Hashable, Any] | None | Default = _default, + inplace: bool = False, + ) -> Self: + """Deprecated version of _replace_with_new_dims(). + + Unlike _replace_with_new_dims(), this method always recalculates + indexes from variables. + """ + if dims is None: + dims = calculate_dimensions(variables) + return self._replace( + variables, coord_names, dims, attrs, indexes=None, inplace=inplace + ) + + def _overwrite_indexes( + self, + indexes: Mapping[Hashable, Index], + variables: Mapping[Hashable, Variable] | None = None, + drop_variables: list[Hashable] | None = None, + drop_indexes: list[Hashable] | None = None, + rename_dims: Mapping[Hashable, Hashable] | None = None, + ) -> Self: + """Maybe replace indexes. + + This function may do a lot more depending on index query + results. + + """ + if not indexes: + return self + + if variables is None: + variables = {} + if drop_variables is None: + drop_variables = [] + if drop_indexes is None: + drop_indexes = [] + + new_variables = self._variables.copy() + new_coord_names = self._coord_names.copy() + new_indexes = dict(self._indexes) + + index_variables = {} + no_index_variables = {} + for name, var in variables.items(): + old_var = self._variables.get(name) + if old_var is not None: + var.attrs.update(old_var.attrs) + var.encoding.update(old_var.encoding) + if name in indexes: + index_variables[name] = var + else: + no_index_variables[name] = var + + for name in indexes: + new_indexes[name] = indexes[name] + + for name, var in index_variables.items(): + new_coord_names.add(name) + new_variables[name] = var + + # append no-index variables at the end + for k in no_index_variables: + new_variables.pop(k) + new_variables.update(no_index_variables) + + for name in drop_indexes: + new_indexes.pop(name) + + for name in drop_variables: + new_variables.pop(name) + new_indexes.pop(name, None) + new_coord_names.remove(name) + + replaced = self._replace( + variables=new_variables, coord_names=new_coord_names, indexes=new_indexes + ) + + if rename_dims: + # skip rename indexes: they should already have the right name(s) + dims = replaced._rename_dims(rename_dims) + new_variables, new_coord_names = replaced._rename_vars({}, rename_dims) + return replaced._replace( + variables=new_variables, coord_names=new_coord_names, dims=dims + ) + else: + return replaced + + def copy(self, deep: bool = False, data: DataVars | None = None) -> Self: + """Returns a copy of this dataset. + + If `deep=True`, a deep copy is made of each of the component variables. + Otherwise, a shallow copy of each of the component variable is made, so + that the underlying memory region of the new dataset is the same as in + the original dataset. + + Use `data` to create a new object with the same structure as + original but entirely new data. + + Parameters + ---------- + deep : bool, default: False + Whether each component variable is loaded into memory and copied onto + the new object. Default is False. + data : dict-like or None, optional + Data to use in the new object. Each item in `data` must have same + shape as corresponding data variable in original. When `data` is + used, `deep` is ignored for the data variables and only used for + coords. + + Returns + ------- + object : Dataset + New object with dimensions, attributes, coordinates, name, encoding, + and optionally data copied from original. + + Examples + -------- + Shallow copy versus deep copy + + >>> da = xr.DataArray(np.random.randn(2, 3)) + >>> ds = xr.Dataset( + ... {"foo": da, "bar": ("x", [-1, 2])}, + ... coords={"x": ["one", "two"]}, + ... ) + >>> ds.copy() + Size: 88B + Dimensions: (dim_0: 2, dim_1: 3, x: 2) + Coordinates: + * x (x) >> ds_0 = ds.copy(deep=False) + >>> ds_0["foo"][0, 0] = 7 + >>> ds_0 + Size: 88B + Dimensions: (dim_0: 2, dim_1: 3, x: 2) + Coordinates: + * x (x) >> ds + Size: 88B + Dimensions: (dim_0: 2, dim_1: 3, x: 2) + Coordinates: + * x (x) >> ds.copy(data={"foo": np.arange(6).reshape(2, 3), "bar": ["a", "b"]}) + Size: 80B + Dimensions: (dim_0: 2, dim_1: 3, x: 2) + Coordinates: + * x (x) >> ds + Size: 88B + Dimensions: (dim_0: 2, dim_1: 3, x: 2) + Coordinates: + * x (x) Self: + if data is None: + data = {} + elif not utils.is_dict_like(data): + raise ValueError("Data must be dict-like") + + if data: + var_keys = set(self.data_vars.keys()) + data_keys = set(data.keys()) + keys_not_in_vars = data_keys - var_keys + if keys_not_in_vars: + raise ValueError( + "Data must only contain variables in original " + f"dataset. Extra variables: {keys_not_in_vars}" + ) + keys_missing_from_data = var_keys - data_keys + if keys_missing_from_data: + raise ValueError( + "Data must contain all variables in original " + f"dataset. Data is missing {keys_missing_from_data}" + ) + + indexes, index_vars = self.xindexes.copy_indexes(deep=deep) + + variables = {} + for k, v in self._variables.items(): + if k in index_vars: + variables[k] = index_vars[k] + else: + variables[k] = v._copy(deep=deep, data=data.get(k), memo=memo) + + attrs = copy.deepcopy(self._attrs, memo) if deep else copy.copy(self._attrs) + encoding = ( + copy.deepcopy(self._encoding, memo) if deep else copy.copy(self._encoding) + ) + + return self._replace(variables, indexes=indexes, attrs=attrs, encoding=encoding) + + def __copy__(self) -> Self: + return self._copy(deep=False) + + def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Self: + return self._copy(deep=True, memo=memo) + + def as_numpy(self) -> Self: + """ + Coerces wrapped data and coordinates into numpy arrays, returning a Dataset. + + See also + -------- + DataArray.as_numpy + DataArray.to_numpy : Returns only the data as a numpy.ndarray object. + """ + numpy_variables = {k: v.as_numpy() for k, v in self.variables.items()} + return self._replace(variables=numpy_variables) + + def _copy_listed(self, names: Iterable[Hashable]) -> Self: + """Create a new Dataset with the listed variables from this dataset and + the all relevant coordinates. Skips all validation. + """ + variables: dict[Hashable, Variable] = {} + coord_names = set() + indexes: dict[Hashable, Index] = {} + + for name in names: + try: + variables[name] = self._variables[name] + except KeyError: + ref_name, var_name, var = _get_virtual_variable( + self._variables, name, self.sizes + ) + variables[var_name] = var + if ref_name in self._coord_names or ref_name in self.dims: + coord_names.add(var_name) + if (var_name,) == var.dims: + index, index_vars = create_default_index_implicit(var, names) + indexes.update({k: index for k in index_vars}) + variables.update(index_vars) + coord_names.update(index_vars) + + needed_dims: OrderedSet[Hashable] = OrderedSet() + for v in variables.values(): + needed_dims.update(v.dims) + + dims = {k: self.sizes[k] for k in needed_dims} + + # preserves ordering of coordinates + for k in self._variables: + if k not in self._coord_names: + continue + + if set(self.variables[k].dims) <= needed_dims: + variables[k] = self._variables[k] + coord_names.add(k) + + indexes.update(filter_indexes_from_coords(self._indexes, coord_names)) + + return self._replace(variables, coord_names, dims, indexes=indexes) + + def _construct_dataarray(self, name: Hashable) -> DataArray: + """Construct a DataArray by indexing this dataset""" + from xarray.core.dataarray import DataArray + + try: + variable = self._variables[name] + except KeyError: + _, name, variable = _get_virtual_variable(self._variables, name, self.sizes) + + needed_dims = set(variable.dims) + + coords: dict[Hashable, Variable] = {} + # preserve ordering + for k in self._variables: + if k in self._coord_names and set(self._variables[k].dims) <= needed_dims: + coords[k] = self._variables[k] + + indexes = filter_indexes_from_coords(self._indexes, set(coords)) + + return DataArray(variable, coords, name=name, indexes=indexes, fastpath=True) + + @property + def _attr_sources(self) -> Iterable[Mapping[Hashable, Any]]: + """Places to look-up items for attribute-style access""" + yield from self._item_sources + yield self.attrs + + @property + def _item_sources(self) -> Iterable[Mapping[Hashable, Any]]: + """Places to look-up items for key-completion""" + yield self.data_vars + yield HybridMappingProxy(keys=self._coord_names, mapping=self.coords) + + # virtual coordinates + yield HybridMappingProxy(keys=self.sizes, mapping=self) + + def __contains__(self, key: object) -> bool: + """The 'in' operator will return true or false depending on whether + 'key' is an array in the dataset or not. + """ + return key in self._variables + + def __len__(self) -> int: + return len(self.data_vars) + + def __bool__(self) -> bool: + return bool(self.data_vars) + + def __iter__(self) -> Iterator[Hashable]: + return iter(self.data_vars) + + if TYPE_CHECKING: + # needed because __getattr__ is returning Any and otherwise + # this class counts as part of the SupportsArray Protocol + __array__ = None # type: ignore[var-annotated,unused-ignore] + + else: + + def __array__(self, dtype=None, copy=None): + raise TypeError( + "cannot directly convert an xarray.Dataset into a " + "numpy array. Instead, create an xarray.DataArray " + "first, either with indexing on the Dataset or by " + "invoking the `to_dataarray()` method." + ) + + @property + def nbytes(self) -> int: + """ + Total bytes consumed by the data arrays of all variables in this dataset. + + If the backend array for any variable does not include ``nbytes``, estimates + the total bytes for that array based on the ``size`` and ``dtype``. + """ + return sum(v.nbytes for v in self.variables.values()) + + @property + def loc(self) -> _LocIndexer[Self]: + """Attribute for location based indexing. Only supports __getitem__, + and only when the key is a dict of the form {dim: labels}. + """ + return _LocIndexer(self) + + @overload + def __getitem__(self, key: Hashable) -> DataArray: ... + + # Mapping is Iterable + @overload + def __getitem__(self, key: Iterable[Hashable]) -> Self: ... + + def __getitem__( + self, key: Mapping[Any, Any] | Hashable | Iterable[Hashable] + ) -> Self | DataArray: + """Access variables or coordinates of this dataset as a + :py:class:`~xarray.DataArray` or a subset of variables or a indexed dataset. + + Indexing with a list of names will return a new ``Dataset`` object. + """ + from xarray.core.formatting import shorten_list_repr + + if utils.is_dict_like(key): + return self.isel(**key) + if utils.hashable(key): + try: + return self._construct_dataarray(key) + except KeyError as e: + raise KeyError( + f"No variable named {key!r}. Variables on the dataset include {shorten_list_repr(list(self.variables.keys()), max_items=10)}" + ) from e + + if utils.iterable_of_hashable(key): + return self._copy_listed(key) + raise ValueError(f"Unsupported key-type {type(key)}") + + def __setitem__( + self, key: Hashable | Iterable[Hashable] | Mapping, value: Any + ) -> None: + """Add an array to this dataset. + Multiple arrays can be added at the same time, in which case each of + the following operations is applied to the respective value. + + If key is dict-like, update all variables in the dataset + one by one with the given value at the given location. + If the given value is also a dataset, select corresponding variables + in the given value and in the dataset to be changed. + + If value is a ` + from .dataarray import DataArray`, call its `select_vars()` method, rename it + to `key` and merge the contents of the resulting dataset into this + dataset. + + If value is a `Variable` object (or tuple of form + ``(dims, data[, attrs])``), add it to this dataset as a new + variable. + """ + from xarray.core.dataarray import DataArray + + if utils.is_dict_like(key): + # check for consistency and convert value to dataset + value = self._setitem_check(key, value) + # loop over dataset variables and set new values + processed = [] + for name, var in self.items(): + try: + var[key] = value[name] + processed.append(name) + except Exception as e: + if processed: + raise RuntimeError( + "An error occurred while setting values of the" + f" variable '{name}'. The following variables have" + f" been successfully updated:\n{processed}" + ) from e + else: + raise e + + elif utils.hashable(key): + if isinstance(value, Dataset): + raise TypeError( + "Cannot assign a Dataset to a single key - only a DataArray or Variable " + "object can be stored under a single key." + ) + self.update({key: value}) + + elif utils.iterable_of_hashable(key): + keylist = list(key) + if len(keylist) == 0: + raise ValueError("Empty list of variables to be set") + if len(keylist) == 1: + self.update({keylist[0]: value}) + else: + if len(keylist) != len(value): + raise ValueError( + f"Different lengths of variables to be set " + f"({len(keylist)}) and data used as input for " + f"setting ({len(value)})" + ) + if isinstance(value, Dataset): + self.update(dict(zip(keylist, value.data_vars.values()))) + elif isinstance(value, DataArray): + raise ValueError("Cannot assign single DataArray to multiple keys") + else: + self.update(dict(zip(keylist, value))) + + else: + raise ValueError(f"Unsupported key-type {type(key)}") + + def _setitem_check(self, key, value): + """Consistency check for __setitem__ + + When assigning values to a subset of a Dataset, do consistency check beforehand + to avoid leaving the dataset in a partially updated state when an error occurs. + """ + from xarray.core.alignment import align + from xarray.core.dataarray import DataArray + + if isinstance(value, Dataset): + missing_vars = [ + name for name in value.data_vars if name not in self.data_vars + ] + if missing_vars: + raise ValueError( + f"Variables {missing_vars} in new values" + f" not available in original dataset:\n{self}" + ) + elif not any([isinstance(value, t) for t in [DataArray, Number, str]]): + raise TypeError( + "Dataset assignment only accepts DataArrays, Datasets, and scalars." + ) + + new_value = Dataset() + for name, var in self.items(): + # test indexing + try: + var_k = var[key] + except Exception as e: + raise ValueError( + f"Variable '{name}': indexer {key} not available" + ) from e + + if isinstance(value, Dataset): + val = value[name] + else: + val = value + + if isinstance(val, DataArray): + # check consistency of dimensions + for dim in val.dims: + if dim not in var_k.dims: + raise KeyError( + f"Variable '{name}': dimension '{dim}' appears in new values " + f"but not in the indexed original data" + ) + dims = tuple(dim for dim in var_k.dims if dim in val.dims) + if dims != val.dims: + raise ValueError( + f"Variable '{name}': dimension order differs between" + f" original and new data:\n{dims}\nvs.\n{val.dims}" + ) + else: + val = np.array(val) + + # type conversion + new_value[name] = duck_array_ops.astype(val, dtype=var_k.dtype, copy=False) + + # check consistency of dimension sizes and dimension coordinates + if isinstance(value, DataArray) or isinstance(value, Dataset): + align(self[key], value, join="exact", copy=False) + + return new_value + + def __delitem__(self, key: Hashable) -> None: + """Remove a variable from this dataset.""" + assert_no_index_corrupted(self.xindexes, {key}) + + if key in self._indexes: + del self._indexes[key] + del self._variables[key] + self._coord_names.discard(key) + self._dims = calculate_dimensions(self._variables) + + # mutable objects should not be hashable + # https://github.com/python/mypy/issues/4266 + __hash__ = None # type: ignore[assignment] + + def _all_compat(self, other: Self, compat_str: str) -> bool: + """Helper function for equals and identical""" + + # some stores (e.g., scipy) do not seem to preserve order, so don't + # require matching order for equality + def compat(x: Variable, y: Variable) -> bool: + return getattr(x, compat_str)(y) + + return self._coord_names == other._coord_names and utils.dict_equiv( + self._variables, other._variables, compat=compat + ) + + def broadcast_equals(self, other: Self) -> bool: + """Two Datasets are broadcast equal if they are equal after + broadcasting all variables against each other. + + For example, variables that are scalar in one dataset but non-scalar in + the other dataset can still be broadcast equal if the the non-scalar + variable is a constant. + + Examples + -------- + + # 2D array with shape (1, 3) + + >>> data = np.array([[1, 2, 3]]) + >>> a = xr.Dataset( + ... {"variable_name": (("space", "time"), data)}, + ... coords={"space": [0], "time": [0, 1, 2]}, + ... ) + >>> a + Size: 56B + Dimensions: (space: 1, time: 3) + Coordinates: + * space (space) int64 8B 0 + * time (time) int64 24B 0 1 2 + Data variables: + variable_name (space, time) int64 24B 1 2 3 + + # 2D array with shape (3, 1) + + >>> data = np.array([[1], [2], [3]]) + >>> b = xr.Dataset( + ... {"variable_name": (("time", "space"), data)}, + ... coords={"time": [0, 1, 2], "space": [0]}, + ... ) + >>> b + Size: 56B + Dimensions: (time: 3, space: 1) + Coordinates: + * time (time) int64 24B 0 1 2 + * space (space) int64 8B 0 + Data variables: + variable_name (time, space) int64 24B 1 2 3 + + .equals returns True if two Datasets have the same values, dimensions, and coordinates. .broadcast_equals returns True if the + results of broadcasting two Datasets against each other have the same values, dimensions, and coordinates. + + >>> a.equals(b) + False + + >>> a.broadcast_equals(b) + True + + >>> a2, b2 = xr.broadcast(a, b) + >>> a2.equals(b2) + True + + See Also + -------- + Dataset.equals + Dataset.identical + Dataset.broadcast + """ + try: + return self._all_compat(other, "broadcast_equals") + except (TypeError, AttributeError): + return False + + def equals(self, other: Self) -> bool: + """Two Datasets are equal if they have matching variables and + coordinates, all of which are equal. + + Datasets can still be equal (like pandas objects) if they have NaN + values in the same locations. + + This method is necessary because `v1 == v2` for ``Dataset`` + does element-wise comparisons (like numpy.ndarrays). + + Examples + -------- + + # 2D array with shape (1, 3) + + >>> data = np.array([[1, 2, 3]]) + >>> dataset1 = xr.Dataset( + ... {"variable_name": (("space", "time"), data)}, + ... coords={"space": [0], "time": [0, 1, 2]}, + ... ) + >>> dataset1 + Size: 56B + Dimensions: (space: 1, time: 3) + Coordinates: + * space (space) int64 8B 0 + * time (time) int64 24B 0 1 2 + Data variables: + variable_name (space, time) int64 24B 1 2 3 + + # 2D array with shape (3, 1) + + >>> data = np.array([[1], [2], [3]]) + >>> dataset2 = xr.Dataset( + ... {"variable_name": (("time", "space"), data)}, + ... coords={"time": [0, 1, 2], "space": [0]}, + ... ) + >>> dataset2 + Size: 56B + Dimensions: (time: 3, space: 1) + Coordinates: + * time (time) int64 24B 0 1 2 + * space (space) int64 8B 0 + Data variables: + variable_name (time, space) int64 24B 1 2 3 + >>> dataset1.equals(dataset2) + False + + >>> dataset1.broadcast_equals(dataset2) + True + + .equals returns True if two Datasets have the same values, dimensions, and coordinates. .broadcast_equals returns True if the + results of broadcasting two Datasets against each other have the same values, dimensions, and coordinates. + + Similar for missing values too: + + >>> ds1 = xr.Dataset( + ... { + ... "temperature": (["x", "y"], [[1, np.nan], [3, 4]]), + ... }, + ... coords={"x": [0, 1], "y": [0, 1]}, + ... ) + + >>> ds2 = xr.Dataset( + ... { + ... "temperature": (["x", "y"], [[1, np.nan], [3, 4]]), + ... }, + ... coords={"x": [0, 1], "y": [0, 1]}, + ... ) + >>> ds1.equals(ds2) + True + + See Also + -------- + Dataset.broadcast_equals + Dataset.identical + """ + try: + return self._all_compat(other, "equals") + except (TypeError, AttributeError): + return False + + def identical(self, other: Self) -> bool: + """Like equals, but also checks all dataset attributes and the + attributes on all variables and coordinates. + + Example + ------- + + >>> a = xr.Dataset( + ... {"Width": ("X", [1, 2, 3])}, + ... coords={"X": [1, 2, 3]}, + ... attrs={"units": "m"}, + ... ) + >>> b = xr.Dataset( + ... {"Width": ("X", [1, 2, 3])}, + ... coords={"X": [1, 2, 3]}, + ... attrs={"units": "m"}, + ... ) + >>> c = xr.Dataset( + ... {"Width": ("X", [1, 2, 3])}, + ... coords={"X": [1, 2, 3]}, + ... attrs={"units": "ft"}, + ... ) + >>> a + Size: 48B + Dimensions: (X: 3) + Coordinates: + * X (X) int64 24B 1 2 3 + Data variables: + Width (X) int64 24B 1 2 3 + Attributes: + units: m + + >>> b + Size: 48B + Dimensions: (X: 3) + Coordinates: + * X (X) int64 24B 1 2 3 + Data variables: + Width (X) int64 24B 1 2 3 + Attributes: + units: m + + >>> c + Size: 48B + Dimensions: (X: 3) + Coordinates: + * X (X) int64 24B 1 2 3 + Data variables: + Width (X) int64 24B 1 2 3 + Attributes: + units: ft + + >>> a.equals(b) + True + + >>> a.identical(b) + True + + >>> a.equals(c) + True + + >>> a.identical(c) + False + + See Also + -------- + Dataset.broadcast_equals + Dataset.equals + """ + try: + return utils.dict_equiv(self.attrs, other.attrs) and self._all_compat( + other, "identical" + ) + except (TypeError, AttributeError): + return False + + @property + def indexes(self) -> Indexes[pd.Index]: + """Mapping of pandas.Index objects used for label based indexing. + + Raises an error if this Dataset has indexes that cannot be coerced + to pandas.Index objects. + + See Also + -------- + Dataset.xindexes + + """ + return self.xindexes.to_pandas_indexes() + + @property + def xindexes(self) -> Indexes[Index]: + """Mapping of :py:class:`~xarray.indexes.Index` objects + used for label based indexing. + """ + return Indexes(self._indexes, {k: self._variables[k] for k in self._indexes}) + + @property + def coords(self) -> DatasetCoordinates: + """Mapping of :py:class:`~xarray.DataArray` objects corresponding to + coordinate variables. + + See Also + -------- + Coordinates + """ + return DatasetCoordinates(self) + + @property + def data_vars(self) -> DataVariables: + """Dictionary of DataArray objects corresponding to data variables""" + return DataVariables(self) + + def set_coords(self, names: Hashable | Iterable[Hashable]) -> Self: + """Given names of one or more variables, set them as coordinates + + Parameters + ---------- + names : hashable or iterable of hashable + Name(s) of variables in this dataset to convert into coordinates. + + Examples + -------- + >>> dataset = xr.Dataset( + ... { + ... "pressure": ("time", [1.013, 1.2, 3.5]), + ... "time": pd.date_range("2023-01-01", periods=3), + ... } + ... ) + >>> dataset + Size: 48B + Dimensions: (time: 3) + Coordinates: + * time (time) datetime64[ns] 24B 2023-01-01 2023-01-02 2023-01-03 + Data variables: + pressure (time) float64 24B 1.013 1.2 3.5 + + >>> dataset.set_coords("pressure") + Size: 48B + Dimensions: (time: 3) + Coordinates: + pressure (time) float64 24B 1.013 1.2 3.5 + * time (time) datetime64[ns] 24B 2023-01-01 2023-01-02 2023-01-03 + Data variables: + *empty* + + On calling ``set_coords`` , these data variables are converted to coordinates, as shown in the final dataset. + + Returns + ------- + Dataset + + See Also + -------- + Dataset.swap_dims + Dataset.assign_coords + """ + # TODO: allow inserting new coordinates with this method, like + # DataFrame.set_index? + # nb. check in self._variables, not self.data_vars to insure that the + # operation is idempotent + if isinstance(names, str) or not isinstance(names, Iterable): + names = [names] + else: + names = list(names) + self._assert_all_in_dataset(names) + obj = self.copy() + obj._coord_names.update(names) + return obj + + def reset_coords( + self, + names: Dims = None, + drop: bool = False, + ) -> Self: + """Given names of coordinates, reset them to become variables + + Parameters + ---------- + names : str, Iterable of Hashable or None, optional + Name(s) of non-index coordinates in this dataset to reset into + variables. By default, all non-index coordinates are reset. + drop : bool, default: False + If True, remove coordinates instead of converting them into + variables. + + Examples + -------- + >>> dataset = xr.Dataset( + ... { + ... "temperature": ( + ... ["time", "lat", "lon"], + ... [[[25, 26], [27, 28]], [[29, 30], [31, 32]]], + ... ), + ... "precipitation": ( + ... ["time", "lat", "lon"], + ... [[[0.5, 0.8], [0.2, 0.4]], [[0.3, 0.6], [0.7, 0.9]]], + ... ), + ... }, + ... coords={ + ... "time": pd.date_range(start="2023-01-01", periods=2), + ... "lat": [40, 41], + ... "lon": [-80, -79], + ... "altitude": 1000, + ... }, + ... ) + + # Dataset before resetting coordinates + + >>> dataset + Size: 184B + Dimensions: (time: 2, lat: 2, lon: 2) + Coordinates: + * time (time) datetime64[ns] 16B 2023-01-01 2023-01-02 + * lat (lat) int64 16B 40 41 + * lon (lon) int64 16B -80 -79 + altitude int64 8B 1000 + Data variables: + temperature (time, lat, lon) int64 64B 25 26 27 28 29 30 31 32 + precipitation (time, lat, lon) float64 64B 0.5 0.8 0.2 0.4 0.3 0.6 0.7 0.9 + + # Reset the 'altitude' coordinate + + >>> dataset_reset = dataset.reset_coords("altitude") + + # Dataset after resetting coordinates + + >>> dataset_reset + Size: 184B + Dimensions: (time: 2, lat: 2, lon: 2) + Coordinates: + * time (time) datetime64[ns] 16B 2023-01-01 2023-01-02 + * lat (lat) int64 16B 40 41 + * lon (lon) int64 16B -80 -79 + Data variables: + temperature (time, lat, lon) int64 64B 25 26 27 28 29 30 31 32 + precipitation (time, lat, lon) float64 64B 0.5 0.8 0.2 0.4 0.3 0.6 0.7 0.9 + altitude int64 8B 1000 + + Returns + ------- + Dataset + + See Also + -------- + Dataset.set_coords + """ + if names is None: + names = self._coord_names - set(self._indexes) + else: + if isinstance(names, str) or not isinstance(names, Iterable): + names = [names] + else: + names = list(names) + self._assert_all_in_dataset(names) + bad_coords = set(names) & set(self._indexes) + if bad_coords: + raise ValueError( + f"cannot remove index coordinates with reset_coords: {bad_coords}" + ) + obj = self.copy() + obj._coord_names.difference_update(names) + if drop: + for name in names: + del obj._variables[name] + return obj + + def dump_to_store(self, store: AbstractDataStore, **kwargs) -> None: + """Store dataset contents to a backends.*DataStore object.""" + from xarray.backends.api import dump_to_store + + # TODO: rename and/or cleanup this method to make it more consistent + # with to_netcdf() + dump_to_store(self, store, **kwargs) + + # path=None writes to bytes + @overload + def to_netcdf( + self, + path: None = None, + mode: NetcdfWriteModes = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Any, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + compute: bool = True, + invalid_netcdf: bool = False, + ) -> bytes: ... + + # compute=False returns dask.Delayed + @overload + def to_netcdf( + self, + path: str | PathLike, + mode: NetcdfWriteModes = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Any, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + *, + compute: Literal[False], + invalid_netcdf: bool = False, + ) -> Delayed: ... + + # default return None + @overload + def to_netcdf( + self, + path: str | PathLike, + mode: NetcdfWriteModes = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Any, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + compute: Literal[True] = True, + invalid_netcdf: bool = False, + ) -> None: ... + + # if compute cannot be evaluated at type check time + # we may get back either Delayed or None + @overload + def to_netcdf( + self, + path: str | PathLike, + mode: NetcdfWriteModes = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Any, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + compute: bool = True, + invalid_netcdf: bool = False, + ) -> Delayed | None: ... + + def to_netcdf( + self, + path: str | PathLike | None = None, + mode: NetcdfWriteModes = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Any, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + compute: bool = True, + invalid_netcdf: bool = False, + ) -> bytes | Delayed | None: + """Write dataset contents to a netCDF file. + + Parameters + ---------- + path : str, path-like or file-like, optional + Path to which to save this dataset. File-like objects are only + supported by the scipy engine. If no path is provided, this + function returns the resulting netCDF file as bytes; in this case, + we need to use scipy, which does not support netCDF version 4 (the + default format becomes NETCDF3_64BIT). + mode : {"w", "a"}, default: "w" + Write ('w') or append ('a') mode. If mode='w', any existing file at + this location will be overwritten. If mode='a', existing variables + will be overwritten. + format : {"NETCDF4", "NETCDF4_CLASSIC", "NETCDF3_64BIT", \ + "NETCDF3_CLASSIC"}, optional + File format for the resulting netCDF file: + + * NETCDF4: Data is stored in an HDF5 file, using netCDF4 API + features. + * NETCDF4_CLASSIC: Data is stored in an HDF5 file, using only + netCDF 3 compatible API features. + * NETCDF3_64BIT: 64-bit offset version of the netCDF 3 file format, + which fully supports 2+ GB files, but is only compatible with + clients linked against netCDF version 3.6.0 or later. + * NETCDF3_CLASSIC: The classic netCDF 3 file format. It does not + handle 2+ GB files very well. + + All formats are supported by the netCDF4-python library. + scipy.io.netcdf only supports the last two formats. + + The default format is NETCDF4 if you are saving a file to disk and + have the netCDF4-python library available. Otherwise, xarray falls + back to using scipy to write netCDF files and defaults to the + NETCDF3_64BIT format (scipy does not support netCDF4). + group : str, optional + Path to the netCDF4 group in the given file to open (only works for + format='NETCDF4'). The group(s) will be created if necessary. + engine : {"netcdf4", "scipy", "h5netcdf"}, optional + Engine to use when writing netCDF files. If not provided, the + default engine is chosen based on available dependencies, with a + preference for 'netcdf4' if writing to a file on disk. + encoding : dict, optional + Nested dictionary with variable names as keys and dictionaries of + variable specific encodings as values, e.g., + ``{"my_variable": {"dtype": "int16", "scale_factor": 0.1, + "zlib": True}, ...}``. + If ``encoding`` is specified the original encoding of the variables of + the dataset is ignored. + + The `h5netcdf` engine supports both the NetCDF4-style compression + encoding parameters ``{"zlib": True, "complevel": 9}`` and the h5py + ones ``{"compression": "gzip", "compression_opts": 9}``. + This allows using any compression plugin installed in the HDF5 + library, e.g. LZF. + + unlimited_dims : iterable of hashable, optional + Dimension(s) that should be serialized as unlimited dimensions. + By default, no dimensions are treated as unlimited dimensions. + Note that unlimited_dims may also be set via + ``dataset.encoding["unlimited_dims"]``. + compute: bool, default: True + If true compute immediately, otherwise return a + ``dask.delayed.Delayed`` object that can be computed later. + invalid_netcdf: bool, default: False + Only valid along with ``engine="h5netcdf"``. If True, allow writing + hdf5 files which are invalid netcdf as described in + https://github.com/h5netcdf/h5netcdf. + + Returns + ------- + * ``bytes`` if path is None + * ``dask.delayed.Delayed`` if compute is False + * None otherwise + + See Also + -------- + DataArray.to_netcdf + """ + if encoding is None: + encoding = {} + from xarray.backends.api import to_netcdf + + return to_netcdf( # type: ignore # mypy cannot resolve the overloads:( + self, + path, + mode=mode, + format=format, + group=group, + engine=engine, + encoding=encoding, + unlimited_dims=unlimited_dims, + compute=compute, + multifile=False, + invalid_netcdf=invalid_netcdf, + ) + + # compute=True (default) returns ZarrStore + @overload + def to_zarr( + self, + store: MutableMapping | str | PathLike[str] | None = None, + chunk_store: MutableMapping | str | PathLike | None = None, + mode: ZarrWriteModes | None = None, + synchronizer=None, + group: str | None = None, + encoding: Mapping | None = None, + *, + compute: Literal[True] = True, + consolidated: bool | None = None, + append_dim: Hashable | None = None, + region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None, + safe_chunks: bool = True, + storage_options: dict[str, str] | None = None, + zarr_version: int | None = None, + write_empty_chunks: bool | None = None, + chunkmanager_store_kwargs: dict[str, Any] | None = None, + ) -> ZarrStore: ... + + # compute=False returns dask.Delayed + @overload + def to_zarr( + self, + store: MutableMapping | str | PathLike[str] | None = None, + chunk_store: MutableMapping | str | PathLike | None = None, + mode: ZarrWriteModes | None = None, + synchronizer=None, + group: str | None = None, + encoding: Mapping | None = None, + *, + compute: Literal[False], + consolidated: bool | None = None, + append_dim: Hashable | None = None, + region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None, + safe_chunks: bool = True, + storage_options: dict[str, str] | None = None, + zarr_version: int | None = None, + write_empty_chunks: bool | None = None, + chunkmanager_store_kwargs: dict[str, Any] | None = None, + ) -> Delayed: ... + + def to_zarr( + self, + store: MutableMapping | str | PathLike[str] | None = None, + chunk_store: MutableMapping | str | PathLike | None = None, + mode: ZarrWriteModes | None = None, + synchronizer=None, + group: str | None = None, + encoding: Mapping | None = None, + *, + compute: bool = True, + consolidated: bool | None = None, + append_dim: Hashable | None = None, + region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None, + safe_chunks: bool = True, + storage_options: dict[str, str] | None = None, + zarr_version: int | None = None, + write_empty_chunks: bool | None = None, + chunkmanager_store_kwargs: dict[str, Any] | None = None, + ) -> ZarrStore | Delayed: + """Write dataset contents to a zarr group. + + Zarr chunks are determined in the following way: + + - From the ``chunks`` attribute in each variable's ``encoding`` + (can be set via `Dataset.chunk`). + - If the variable is a Dask array, from the dask chunks + - If neither Dask chunks nor encoding chunks are present, chunks will + be determined automatically by Zarr + - If both Dask chunks and encoding chunks are present, encoding chunks + will be used, provided that there is a many-to-one relationship between + encoding chunks and dask chunks (i.e. Dask chunks are bigger than and + evenly divide encoding chunks); otherwise raise a ``ValueError``. + This restriction ensures that no synchronization / locks are required + when writing. To disable this restriction, use ``safe_chunks=False``. + + Parameters + ---------- + store : MutableMapping, str or path-like, optional + Store or path to directory in local or remote file system. + chunk_store : MutableMapping, str or path-like, optional + Store or path to directory in local or remote file system only for Zarr + array chunks. Requires zarr-python v2.4.0 or later. + mode : {"w", "w-", "a", "a-", r+", None}, optional + Persistence mode: "w" means create (overwrite if exists); + "w-" means create (fail if exists); + "a" means override all existing variables including dimension coordinates (create if does not exist); + "a-" means only append those variables that have ``append_dim``. + "r+" means modify existing array *values* only (raise an error if + any metadata or shapes would change). + The default mode is "a" if ``append_dim`` is set. Otherwise, it is + "r+" if ``region`` is set and ``w-`` otherwise. + synchronizer : object, optional + Zarr array synchronizer. + group : str, optional + Group path. (a.k.a. `path` in zarr terminology.) + encoding : dict, optional + Nested dictionary with variable names as keys and dictionaries of + variable specific encodings as values, e.g., + ``{"my_variable": {"dtype": "int16", "scale_factor": 0.1,}, ...}`` + compute : bool, default: True + If True write array data immediately, otherwise return a + ``dask.delayed.Delayed`` object that can be computed to write + array data later. Metadata is always updated eagerly. + consolidated : bool, optional + If True, apply :func:`zarr.convenience.consolidate_metadata` + after writing metadata and read existing stores with consolidated + metadata; if False, do not. The default (`consolidated=None`) means + write consolidated metadata and attempt to read consolidated + metadata for existing stores (falling back to non-consolidated). + + When the experimental ``zarr_version=3``, ``consolidated`` must be + either be ``None`` or ``False``. + append_dim : hashable, optional + If set, the dimension along which the data will be appended. All + other dimensions on overridden variables must remain the same size. + region : dict or "auto", optional + Optional mapping from dimension names to integer slices along + dataset dimensions to indicate the region of existing zarr array(s) + in which to write this dataset's data. For example, + ``{'x': slice(0, 1000), 'y': slice(10000, 11000)}`` would indicate + that values should be written to the region ``0:1000`` along ``x`` + and ``10000:11000`` along ``y``. + + Can also specify ``"auto"``, in which case the existing store will be + opened and the region inferred by matching the new data's coordinates. + ``"auto"`` can be used as a single string, which will automatically infer + the region for all dimensions, or as dictionary values for specific + dimensions mixed together with explicit slices for other dimensions. + + Two restrictions apply to the use of ``region``: + + - If ``region`` is set, _all_ variables in a dataset must have at + least one dimension in common with the region. Other variables + should be written in a separate call to ``to_zarr()``. + - Dimensions cannot be included in both ``region`` and + ``append_dim`` at the same time. To create empty arrays to fill + in with ``region``, use a separate call to ``to_zarr()`` with + ``compute=False``. See "Appending to existing Zarr stores" in + the reference documentation for full details. + + Users are expected to ensure that the specified region aligns with + Zarr chunk boundaries, and that dask chunks are also aligned. + Xarray makes limited checks that these multiple chunk boundaries line up. + It is possible to write incomplete chunks and corrupt the data with this + option if you are not careful. + safe_chunks : bool, default: True + If True, only allow writes to when there is a many-to-one relationship + between Zarr chunks (specified in encoding) and Dask chunks. + Set False to override this restriction; however, data may become corrupted + if Zarr arrays are written in parallel. This option may be useful in combination + with ``compute=False`` to initialize a Zarr from an existing + Dataset with arbitrary chunk structure. + storage_options : dict, optional + Any additional parameters for the storage backend (ignored for local + paths). + zarr_version : int or None, optional + The desired zarr spec version to target (currently 2 or 3). The + default of None will attempt to determine the zarr version from + ``store`` when possible, otherwise defaulting to 2. + write_empty_chunks : bool or None, optional + If True, all chunks will be stored regardless of their + contents. If False, each chunk is compared to the array's fill value + prior to storing. If a chunk is uniformly equal to the fill value, then + that chunk is not be stored, and the store entry for that chunk's key + is deleted. This setting enables sparser storage, as only chunks with + non-fill-value data are stored, at the expense of overhead associated + with checking the data of each chunk. If None (default) fall back to + specification(s) in ``encoding`` or Zarr defaults. A ``ValueError`` + will be raised if the value of this (if not None) differs with + ``encoding``. + chunkmanager_store_kwargs : dict, optional + Additional keyword arguments passed on to the `ChunkManager.store` method used to store + chunked arrays. For example for a dask array additional kwargs will be passed eventually to + :py:func:`dask.array.store()`. Experimental API that should not be relied upon. + + Returns + ------- + * ``dask.delayed.Delayed`` if compute is False + * ZarrStore otherwise + + References + ---------- + https://zarr.readthedocs.io/ + + Notes + ----- + Zarr chunking behavior: + If chunks are found in the encoding argument or attribute + corresponding to any DataArray, those chunks are used. + If a DataArray is a dask array, it is written with those chunks. + If not other chunks are found, Zarr uses its own heuristics to + choose automatic chunk sizes. + + encoding: + The encoding attribute (if exists) of the DataArray(s) will be + used. Override any existing encodings by providing the ``encoding`` kwarg. + + See Also + -------- + :ref:`io.zarr` + The I/O user guide, with more details and examples. + """ + from xarray.backends.api import to_zarr + + return to_zarr( # type: ignore[call-overload,misc] + self, + store=store, + chunk_store=chunk_store, + storage_options=storage_options, + mode=mode, + synchronizer=synchronizer, + group=group, + encoding=encoding, + compute=compute, + consolidated=consolidated, + append_dim=append_dim, + region=region, + safe_chunks=safe_chunks, + zarr_version=zarr_version, + write_empty_chunks=write_empty_chunks, + chunkmanager_store_kwargs=chunkmanager_store_kwargs, + ) + + def __repr__(self) -> str: + return formatting.dataset_repr(self) + + def _repr_html_(self) -> str: + if OPTIONS["display_style"] == "text": + return f"
{escape(repr(self))}
" + return formatting_html.dataset_repr(self) + + def info(self, buf: IO | None = None) -> None: + """ + Concise summary of a Dataset variables and attributes. + + Parameters + ---------- + buf : file-like, default: sys.stdout + writable buffer + + See Also + -------- + pandas.DataFrame.assign + ncdump : netCDF's ncdump + """ + if buf is None: # pragma: no cover + buf = sys.stdout + + lines = [] + lines.append("xarray.Dataset {") + lines.append("dimensions:") + for name, size in self.sizes.items(): + lines.append(f"\t{name} = {size} ;") + lines.append("\nvariables:") + for name, da in self.variables.items(): + dims = ", ".join(map(str, da.dims)) + lines.append(f"\t{da.dtype} {name}({dims}) ;") + for k, v in da.attrs.items(): + lines.append(f"\t\t{name}:{k} = {v} ;") + lines.append("\n// global attributes:") + for k, v in self.attrs.items(): + lines.append(f"\t:{k} = {v} ;") + lines.append("}") + + buf.write("\n".join(lines)) + + @property + def chunks(self) -> Mapping[Hashable, tuple[int, ...]]: + """ + Mapping from dimension names to block lengths for this dataset's data, or None if + the underlying data is not a dask array. + Cannot be modified directly, but can be modified by calling .chunk(). + + Same as Dataset.chunksizes, but maintained for backwards compatibility. + + See Also + -------- + Dataset.chunk + Dataset.chunksizes + xarray.unify_chunks + """ + return get_chunksizes(self.variables.values()) + + @property + def chunksizes(self) -> Mapping[Hashable, tuple[int, ...]]: + """ + Mapping from dimension names to block lengths for this dataset's data, or None if + the underlying data is not a dask array. + Cannot be modified directly, but can be modified by calling .chunk(). + + Same as Dataset.chunks. + + See Also + -------- + Dataset.chunk + Dataset.chunks + xarray.unify_chunks + """ + return get_chunksizes(self.variables.values()) + + def chunk( + self, + chunks: T_Chunks = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667) + name_prefix: str = "xarray-", + token: str | None = None, + lock: bool = False, + inline_array: bool = False, + chunked_array_type: str | ChunkManagerEntrypoint | None = None, + from_array_kwargs=None, + **chunks_kwargs: T_ChunkDim, + ) -> Self: + """Coerce all arrays in this dataset into dask arrays with the given + chunks. + + Non-dask arrays in this dataset will be converted to dask arrays. Dask + arrays will be rechunked to the given chunk sizes. + + If neither chunks is not provided for one or more dimensions, chunk + sizes along that dimension will not be updated; non-dask arrays will be + converted into dask arrays with a single block. + + Parameters + ---------- + chunks : int, tuple of int, "auto" or mapping of hashable to int, optional + Chunk sizes along each dimension, e.g., ``5``, ``"auto"``, or + ``{"x": 5, "y": 5}``. + name_prefix : str, default: "xarray-" + Prefix for the name of any new dask arrays. + token : str, optional + Token uniquely identifying this dataset. + lock : bool, default: False + Passed on to :py:func:`dask.array.from_array`, if the array is not + already as dask array. + inline_array: bool, default: False + Passed on to :py:func:`dask.array.from_array`, if the array is not + already as dask array. + chunked_array_type: str, optional + Which chunked array type to coerce this datasets' arrays to. + Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEntryPoint` system. + Experimental API that should not be relied upon. + from_array_kwargs: dict, optional + Additional keyword arguments passed on to the `ChunkManagerEntrypoint.from_array` method used to create + chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg. + For example, with dask as the default chunked array type, this method would pass additional kwargs + to :py:func:`dask.array.from_array`. Experimental API that should not be relied upon. + **chunks_kwargs : {dim: chunks, ...}, optional + The keyword arguments form of ``chunks``. + One of chunks or chunks_kwargs must be provided + + Returns + ------- + chunked : xarray.Dataset + + See Also + -------- + Dataset.chunks + Dataset.chunksizes + xarray.unify_chunks + dask.array.from_array + """ + if chunks is None and not chunks_kwargs: + warnings.warn( + "None value for 'chunks' is deprecated. " + "It will raise an error in the future. Use instead '{}'", + category=DeprecationWarning, + ) + chunks = {} + chunks_mapping: Mapping[Any, Any] + if not isinstance(chunks, Mapping) and chunks is not None: + if isinstance(chunks, (tuple, list)): + utils.emit_user_level_warning( + "Supplying chunks as dimension-order tuples is deprecated. " + "It will raise an error in the future. Instead use a dict with dimensions as keys.", + category=DeprecationWarning, + ) + chunks_mapping = dict.fromkeys(self.dims, chunks) + else: + chunks_mapping = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk") + + bad_dims = chunks_mapping.keys() - self.sizes.keys() + if bad_dims: + raise ValueError( + f"chunks keys {tuple(bad_dims)} not found in data dimensions {tuple(self.sizes.keys())}" + ) + + chunkmanager = guess_chunkmanager(chunked_array_type) + if from_array_kwargs is None: + from_array_kwargs = {} + + variables = { + k: _maybe_chunk( + k, + v, + chunks_mapping, + token, + lock, + name_prefix, + inline_array=inline_array, + chunked_array_type=chunkmanager, + from_array_kwargs=from_array_kwargs.copy(), + ) + for k, v in self.variables.items() + } + return self._replace(variables) + + def _validate_indexers( + self, indexers: Mapping[Any, Any], missing_dims: ErrorOptionsWithWarn = "raise" + ) -> Iterator[tuple[Hashable, int | slice | np.ndarray | Variable]]: + """Here we make sure + + indexer has a valid keys + + indexer is in a valid data type + + string indexers are cast to the appropriate date type if the + associated index is a DatetimeIndex or CFTimeIndex + """ + from xarray.coding.cftimeindex import CFTimeIndex + from xarray.core.dataarray import DataArray + + indexers = drop_dims_from_indexers(indexers, self.dims, missing_dims) + + # all indexers should be int, slice, np.ndarrays, or Variable + for k, v in indexers.items(): + if isinstance(v, (int, slice, Variable)): + yield k, v + elif isinstance(v, DataArray): + yield k, v.variable + elif isinstance(v, tuple): + yield k, as_variable(v) + elif isinstance(v, Dataset): + raise TypeError("cannot use a Dataset as an indexer") + elif isinstance(v, Sequence) and len(v) == 0: + yield k, np.empty((0,), dtype="int64") + else: + if not is_duck_array(v): + v = np.asarray(v) + + if v.dtype.kind in "US": + index = self._indexes[k].to_pandas_index() + if isinstance(index, pd.DatetimeIndex): + v = duck_array_ops.astype(v, dtype="datetime64[ns]") + elif isinstance(index, CFTimeIndex): + v = _parse_array_of_cftime_strings(v, index.date_type) + + if v.ndim > 1: + raise IndexError( + "Unlabeled multi-dimensional array cannot be " + f"used for indexing: {k}" + ) + yield k, v + + def _validate_interp_indexers( + self, indexers: Mapping[Any, Any] + ) -> Iterator[tuple[Hashable, Variable]]: + """Variant of _validate_indexers to be used for interpolation""" + for k, v in self._validate_indexers(indexers): + if isinstance(v, Variable): + if v.ndim == 1: + yield k, v.to_index_variable() + else: + yield k, v + elif isinstance(v, int): + yield k, Variable((), v, attrs=self.coords[k].attrs) + elif isinstance(v, np.ndarray): + if v.ndim == 0: + yield k, Variable((), v, attrs=self.coords[k].attrs) + elif v.ndim == 1: + yield k, IndexVariable((k,), v, attrs=self.coords[k].attrs) + else: + raise AssertionError() # Already tested by _validate_indexers + else: + raise TypeError(type(v)) + + def _get_indexers_coords_and_indexes(self, indexers): + """Extract coordinates and indexes from indexers. + + Only coordinate with a name different from any of self.variables will + be attached. + """ + from xarray.core.dataarray import DataArray + + coords_list = [] + for k, v in indexers.items(): + if isinstance(v, DataArray): + if v.dtype.kind == "b": + if v.ndim != 1: # we only support 1-d boolean array + raise ValueError( + f"{v.ndim:d}d-boolean array is used for indexing along " + f"dimension {k!r}, but only 1d boolean arrays are " + "supported." + ) + # Make sure in case of boolean DataArray, its + # coordinate also should be indexed. + v_coords = v[v.values.nonzero()[0]].coords + else: + v_coords = v.coords + coords_list.append(v_coords) + + # we don't need to call align() explicitly or check indexes for + # alignment, because merge_variables already checks for exact alignment + # between dimension coordinates + coords, indexes = merge_coordinates_without_align(coords_list) + assert_coordinate_consistent(self, coords) + + # silently drop the conflicted variables. + attached_coords = {k: v for k, v in coords.items() if k not in self._variables} + attached_indexes = { + k: v for k, v in indexes.items() if k not in self._variables + } + return attached_coords, attached_indexes + + def isel( + self, + indexers: Mapping[Any, Any] | None = None, + drop: bool = False, + missing_dims: ErrorOptionsWithWarn = "raise", + **indexers_kwargs: Any, + ) -> Self: + """Returns a new dataset with each array indexed along the specified + dimension(s). + + This method selects values from each array using its `__getitem__` + method, except this method does not require knowing the order of + each array's dimensions. + + Parameters + ---------- + indexers : dict, optional + A dict with keys matching dimensions and values given + by integers, slice objects or arrays. + indexer can be a integer, slice, array-like or DataArray. + If DataArrays are passed as indexers, xarray-style indexing will be + carried out. See :ref:`indexing` for the details. + One of indexers or indexers_kwargs must be provided. + drop : bool, default: False + If ``drop=True``, drop coordinates variables indexed by integers + instead of making them scalar. + missing_dims : {"raise", "warn", "ignore"}, default: "raise" + What to do if dimensions that should be selected from are not present in the + Dataset: + - "raise": raise an exception + - "warn": raise a warning, and ignore the missing dimensions + - "ignore": ignore the missing dimensions + + **indexers_kwargs : {dim: indexer, ...}, optional + The keyword arguments form of ``indexers``. + One of indexers or indexers_kwargs must be provided. + + Returns + ------- + obj : Dataset + A new Dataset with the same contents as this dataset, except each + array and dimension is indexed by the appropriate indexers. + If indexer DataArrays have coordinates that do not conflict with + this object, then these coordinates will be attached. + In general, each array's data will be a view of the array's data + in this dataset, unless vectorized indexing was triggered by using + an array indexer, in which case the data will be a copy. + + Examples + -------- + + >>> dataset = xr.Dataset( + ... { + ... "math_scores": ( + ... ["student", "test"], + ... [[90, 85, 92], [78, 80, 85], [95, 92, 98]], + ... ), + ... "english_scores": ( + ... ["student", "test"], + ... [[88, 90, 92], [75, 82, 79], [93, 96, 91]], + ... ), + ... }, + ... coords={ + ... "student": ["Alice", "Bob", "Charlie"], + ... "test": ["Test 1", "Test 2", "Test 3"], + ... }, + ... ) + + # A specific element from the dataset is selected + + >>> dataset.isel(student=1, test=0) + Size: 68B + Dimensions: () + Coordinates: + student >> slice_of_data = dataset.isel(student=slice(0, 2), test=slice(0, 2)) + >>> slice_of_data + Size: 168B + Dimensions: (student: 2, test: 2) + Coordinates: + * student (student) >> index_array = xr.DataArray([0, 2], dims="student") + >>> indexed_data = dataset.isel(student=index_array) + >>> indexed_data + Size: 224B + Dimensions: (student: 2, test: 3) + Coordinates: + * student (student) Self: + valid_indexers = dict(self._validate_indexers(indexers, missing_dims)) + + variables: dict[Hashable, Variable] = {} + indexes, index_variables = isel_indexes(self.xindexes, valid_indexers) + + for name, var in self.variables.items(): + if name in index_variables: + new_var = index_variables[name] + else: + var_indexers = { + k: v for k, v in valid_indexers.items() if k in var.dims + } + if var_indexers: + new_var = var.isel(indexers=var_indexers) + # drop scalar coordinates + # https://github.com/pydata/xarray/issues/6554 + if name in self.coords and drop and new_var.ndim == 0: + continue + else: + new_var = var.copy(deep=False) + if name not in indexes: + new_var = new_var.to_base_variable() + variables[name] = new_var + + coord_names = self._coord_names & variables.keys() + selected = self._replace_with_new_dims(variables, coord_names, indexes) + + # Extract coordinates from indexers + coord_vars, new_indexes = selected._get_indexers_coords_and_indexes(indexers) + variables.update(coord_vars) + indexes.update(new_indexes) + coord_names = self._coord_names & variables.keys() | coord_vars.keys() + return self._replace_with_new_dims(variables, coord_names, indexes=indexes) + + def sel( + self, + indexers: Mapping[Any, Any] | None = None, + method: str | None = None, + tolerance: int | float | Iterable[int | float] | None = None, + drop: bool = False, + **indexers_kwargs: Any, + ) -> Self: + """Returns a new dataset with each array indexed by tick labels + along the specified dimension(s). + + In contrast to `Dataset.isel`, indexers for this method should use + labels instead of integers. + + Under the hood, this method is powered by using pandas's powerful Index + objects. This makes label based indexing essentially just as fast as + using integer indexing. + + It also means this method uses pandas's (well documented) logic for + indexing. This means you can use string shortcuts for datetime indexes + (e.g., '2000-01' to select all values in January 2000). It also means + that slices are treated as inclusive of both the start and stop values, + unlike normal Python indexing. + + Parameters + ---------- + indexers : dict, optional + A dict with keys matching dimensions and values given + by scalars, slices or arrays of tick labels. For dimensions with + multi-index, the indexer may also be a dict-like object with keys + matching index level names. + If DataArrays are passed as indexers, xarray-style indexing will be + carried out. See :ref:`indexing` for the details. + One of indexers or indexers_kwargs must be provided. + method : {None, "nearest", "pad", "ffill", "backfill", "bfill"}, optional + Method to use for inexact matches: + + * None (default): only exact matches + * pad / ffill: propagate last valid index value forward + * backfill / bfill: propagate next valid index value backward + * nearest: use nearest valid index value + tolerance : optional + Maximum distance between original and new labels for inexact + matches. The values of the index at the matching locations must + satisfy the equation ``abs(index[indexer] - target) <= tolerance``. + drop : bool, optional + If ``drop=True``, drop coordinates variables in `indexers` instead + of making them scalar. + **indexers_kwargs : {dim: indexer, ...}, optional + The keyword arguments form of ``indexers``. + One of indexers or indexers_kwargs must be provided. + + Returns + ------- + obj : Dataset + A new Dataset with the same contents as this dataset, except each + variable and dimension is indexed by the appropriate indexers. + If indexer DataArrays have coordinates that do not conflict with + this object, then these coordinates will be attached. + In general, each array's data will be a view of the array's data + in this dataset, unless vectorized indexing was triggered by using + an array indexer, in which case the data will be a copy. + + See Also + -------- + Dataset.isel + DataArray.sel + + :doc:`xarray-tutorial:intermediate/indexing/indexing` + Tutorial material on indexing with Xarray objects + + :doc:`xarray-tutorial:fundamentals/02.1_indexing_Basic` + Tutorial material on basics of indexing + + """ + indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "sel") + query_results = map_index_queries( + self, indexers=indexers, method=method, tolerance=tolerance + ) + + if drop: + no_scalar_variables = {} + for k, v in query_results.variables.items(): + if v.dims: + no_scalar_variables[k] = v + else: + if k in self._coord_names: + query_results.drop_coords.append(k) + query_results.variables = no_scalar_variables + + result = self.isel(indexers=query_results.dim_indexers, drop=drop) + return result._overwrite_indexes(*query_results.as_tuple()[1:]) + + def head( + self, + indexers: Mapping[Any, int] | int | None = None, + **indexers_kwargs: Any, + ) -> Self: + """Returns a new dataset with the first `n` values of each array + for the specified dimension(s). + + Parameters + ---------- + indexers : dict or int, default: 5 + A dict with keys matching dimensions and integer values `n` + or a single integer `n` applied over all dimensions. + One of indexers or indexers_kwargs must be provided. + **indexers_kwargs : {dim: n, ...}, optional + The keyword arguments form of ``indexers``. + One of indexers or indexers_kwargs must be provided. + + Examples + -------- + >>> dates = pd.date_range(start="2023-01-01", periods=5) + >>> pageviews = [1200, 1500, 900, 1800, 2000] + >>> visitors = [800, 1000, 600, 1200, 1500] + >>> dataset = xr.Dataset( + ... { + ... "pageviews": (("date"), pageviews), + ... "visitors": (("date"), visitors), + ... }, + ... coords={"date": dates}, + ... ) + >>> busiest_days = dataset.sortby("pageviews", ascending=False) + >>> busiest_days.head() + Size: 120B + Dimensions: (date: 5) + Coordinates: + * date (date) datetime64[ns] 40B 2023-01-05 2023-01-04 ... 2023-01-03 + Data variables: + pageviews (date) int64 40B 2000 1800 1500 1200 900 + visitors (date) int64 40B 1500 1200 1000 800 600 + + # Retrieve the 3 most busiest days in terms of pageviews + + >>> busiest_days.head(3) + Size: 72B + Dimensions: (date: 3) + Coordinates: + * date (date) datetime64[ns] 24B 2023-01-05 2023-01-04 2023-01-02 + Data variables: + pageviews (date) int64 24B 2000 1800 1500 + visitors (date) int64 24B 1500 1200 1000 + + # Using a dictionary to specify the number of elements for specific dimensions + + >>> busiest_days.head({"date": 3}) + Size: 72B + Dimensions: (date: 3) + Coordinates: + * date (date) datetime64[ns] 24B 2023-01-05 2023-01-04 2023-01-02 + Data variables: + pageviews (date) int64 24B 2000 1800 1500 + visitors (date) int64 24B 1500 1200 1000 + + See Also + -------- + Dataset.tail + Dataset.thin + DataArray.head + """ + if not indexers_kwargs: + if indexers is None: + indexers = 5 + if not isinstance(indexers, int) and not is_dict_like(indexers): + raise TypeError("indexers must be either dict-like or a single integer") + if isinstance(indexers, int): + indexers = {dim: indexers for dim in self.dims} + indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "head") + for k, v in indexers.items(): + if not isinstance(v, int): + raise TypeError( + "expected integer type indexer for " + f"dimension {k!r}, found {type(v)!r}" + ) + elif v < 0: + raise ValueError( + "expected positive integer as indexer " + f"for dimension {k!r}, found {v}" + ) + indexers_slices = {k: slice(val) for k, val in indexers.items()} + return self.isel(indexers_slices) + + def tail( + self, + indexers: Mapping[Any, int] | int | None = None, + **indexers_kwargs: Any, + ) -> Self: + """Returns a new dataset with the last `n` values of each array + for the specified dimension(s). + + Parameters + ---------- + indexers : dict or int, default: 5 + A dict with keys matching dimensions and integer values `n` + or a single integer `n` applied over all dimensions. + One of indexers or indexers_kwargs must be provided. + **indexers_kwargs : {dim: n, ...}, optional + The keyword arguments form of ``indexers``. + One of indexers or indexers_kwargs must be provided. + + Examples + -------- + >>> activity_names = ["Walking", "Running", "Cycling", "Swimming", "Yoga"] + >>> durations = [30, 45, 60, 45, 60] # in minutes + >>> energies = [150, 300, 250, 400, 100] # in calories + >>> dataset = xr.Dataset( + ... { + ... "duration": (["activity"], durations), + ... "energy_expenditure": (["activity"], energies), + ... }, + ... coords={"activity": activity_names}, + ... ) + >>> sorted_dataset = dataset.sortby("energy_expenditure", ascending=False) + >>> sorted_dataset + Size: 240B + Dimensions: (activity: 5) + Coordinates: + * activity (activity) >> sorted_dataset.tail(3) + Size: 144B + Dimensions: (activity: 3) + Coordinates: + * activity (activity) >> sorted_dataset.tail({"activity": 3}) + Size: 144B + Dimensions: (activity: 3) + Coordinates: + * activity (activity) Self: + """Returns a new dataset with each array indexed along every `n`-th + value for the specified dimension(s) + + Parameters + ---------- + indexers : dict or int + A dict with keys matching dimensions and integer values `n` + or a single integer `n` applied over all dimensions. + One of indexers or indexers_kwargs must be provided. + **indexers_kwargs : {dim: n, ...}, optional + The keyword arguments form of ``indexers``. + One of indexers or indexers_kwargs must be provided. + + Examples + -------- + >>> x_arr = np.arange(0, 26) + >>> x_arr + array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25]) + >>> x = xr.DataArray( + ... np.reshape(x_arr, (2, 13)), + ... dims=("x", "y"), + ... coords={"x": [0, 1], "y": np.arange(0, 13)}, + ... ) + >>> x_ds = xr.Dataset({"foo": x}) + >>> x_ds + Size: 328B + Dimensions: (x: 2, y: 13) + Coordinates: + * x (x) int64 16B 0 1 + * y (y) int64 104B 0 1 2 3 4 5 6 7 8 9 10 11 12 + Data variables: + foo (x, y) int64 208B 0 1 2 3 4 5 6 7 8 ... 17 18 19 20 21 22 23 24 25 + + >>> x_ds.thin(3) + Size: 88B + Dimensions: (x: 1, y: 5) + Coordinates: + * x (x) int64 8B 0 + * y (y) int64 40B 0 3 6 9 12 + Data variables: + foo (x, y) int64 40B 0 3 6 9 12 + >>> x.thin({"x": 2, "y": 5}) + Size: 24B + array([[ 0, 5, 10]]) + Coordinates: + * x (x) int64 8B 0 + * y (y) int64 24B 0 5 10 + + See Also + -------- + Dataset.head + Dataset.tail + DataArray.thin + """ + if ( + not indexers_kwargs + and not isinstance(indexers, int) + and not is_dict_like(indexers) + ): + raise TypeError("indexers must be either dict-like or a single integer") + if isinstance(indexers, int): + indexers = {dim: indexers for dim in self.dims} + indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "thin") + for k, v in indexers.items(): + if not isinstance(v, int): + raise TypeError( + "expected integer type indexer for " + f"dimension {k!r}, found {type(v)!r}" + ) + elif v < 0: + raise ValueError( + "expected positive integer as indexer " + f"for dimension {k!r}, found {v}" + ) + elif v == 0: + raise ValueError("step cannot be zero") + indexers_slices = {k: slice(None, None, val) for k, val in indexers.items()} + return self.isel(indexers_slices) + + def broadcast_like( + self, + other: T_DataArrayOrSet, + exclude: Iterable[Hashable] | None = None, + ) -> Self: + """Broadcast this DataArray against another Dataset or DataArray. + This is equivalent to xr.broadcast(other, self)[1] + + Parameters + ---------- + other : Dataset or DataArray + Object against which to broadcast this array. + exclude : iterable of hashable, optional + Dimensions that must not be broadcasted + + """ + if exclude is None: + exclude = set() + else: + exclude = set(exclude) + args = align(other, self, join="outer", copy=False, exclude=exclude) + + dims_map, common_coords = _get_broadcast_dims_map_common_coords(args, exclude) + + return _broadcast_helper(args[1], exclude, dims_map, common_coords) + + def _reindex_callback( + self, + aligner: alignment.Aligner, + dim_pos_indexers: dict[Hashable, Any], + variables: dict[Hashable, Variable], + indexes: dict[Hashable, Index], + fill_value: Any, + exclude_dims: frozenset[Hashable], + exclude_vars: frozenset[Hashable], + ) -> Self: + """Callback called from ``Aligner`` to create a new reindexed Dataset.""" + + new_variables = variables.copy() + new_indexes = indexes.copy() + + # re-assign variable metadata + for name, new_var in new_variables.items(): + var = self._variables.get(name) + if var is not None: + new_var.attrs = var.attrs + new_var.encoding = var.encoding + + # pass through indexes from excluded dimensions + # no extra check needed for multi-coordinate indexes, potential conflicts + # should already have been detected when aligning the indexes + for name, idx in self._indexes.items(): + var = self._variables[name] + if set(var.dims) <= exclude_dims: + new_indexes[name] = idx + new_variables[name] = var + + if not dim_pos_indexers: + # fast path for no reindexing necessary + if set(new_indexes) - set(self._indexes): + # this only adds new indexes and their coordinate variables + reindexed = self._overwrite_indexes(new_indexes, new_variables) + else: + reindexed = self.copy(deep=aligner.copy) + else: + to_reindex = { + k: v + for k, v in self.variables.items() + if k not in variables and k not in exclude_vars + } + reindexed_vars = alignment.reindex_variables( + to_reindex, + dim_pos_indexers, + copy=aligner.copy, + fill_value=fill_value, + sparse=aligner.sparse, + ) + new_variables.update(reindexed_vars) + new_coord_names = self._coord_names | set(new_indexes) + reindexed = self._replace_with_new_dims( + new_variables, new_coord_names, indexes=new_indexes + ) + + reindexed.encoding = self.encoding + + return reindexed + + def reindex_like( + self, + other: T_Xarray, + method: ReindexMethodOptions = None, + tolerance: int | float | Iterable[int | float] | None = None, + copy: bool = True, + fill_value: Any = xrdtypes.NA, + ) -> Self: + """ + Conform this object onto the indexes of another object, for indexes which the + objects share. Missing values are filled with ``fill_value``. The default fill + value is NaN. + + Parameters + ---------- + other : Dataset or DataArray + Object with an 'indexes' attribute giving a mapping from dimension + names to pandas.Index objects, which provides coordinates upon + which to index the variables in this dataset. The indexes on this + other object need not be the same as the indexes on this + dataset. Any mis-matched index values will be filled in with + NaN, and any mis-matched dimension names will simply be ignored. + method : {None, "nearest", "pad", "ffill", "backfill", "bfill", None}, optional + Method to use for filling index values from other not found in this + dataset: + + - None (default): don't fill gaps + - "pad" / "ffill": propagate last valid index value forward + - "backfill" / "bfill": propagate next valid index value backward + - "nearest": use nearest valid index value + + tolerance : optional + Maximum distance between original and new labels for inexact + matches. The values of the index at the matching locations must + satisfy the equation ``abs(index[indexer] - target) <= tolerance``. + Tolerance may be a scalar value, which applies the same tolerance + to all values, or list-like, which applies variable tolerance per + element. List-like must be the same size as the index and its dtype + must exactly match the index’s type. + copy : bool, default: True + If ``copy=True``, data in the return value is always copied. If + ``copy=False`` and reindexing is unnecessary, or can be performed + with only slice operations, then the output may share memory with + the input. In either case, a new xarray object is always returned. + fill_value : scalar or dict-like, optional + Value to use for newly missing values. If a dict-like maps + variable names to fill values. + + Returns + ------- + reindexed : Dataset + Another dataset, with this dataset's data but coordinates from the + other object. + + See Also + -------- + Dataset.reindex + DataArray.reindex_like + align + + """ + return alignment.reindex_like( + self, + other=other, + method=method, + tolerance=tolerance, + copy=copy, + fill_value=fill_value, + ) + + def reindex( + self, + indexers: Mapping[Any, Any] | None = None, + method: ReindexMethodOptions = None, + tolerance: int | float | Iterable[int | float] | None = None, + copy: bool = True, + fill_value: Any = xrdtypes.NA, + **indexers_kwargs: Any, + ) -> Self: + """Conform this object onto a new set of indexes, filling in + missing values with ``fill_value``. The default fill value is NaN. + + Parameters + ---------- + indexers : dict, optional + Dictionary with keys given by dimension names and values given by + arrays of coordinates tick labels. Any mis-matched coordinate + values will be filled in with NaN, and any mis-matched dimension + names will simply be ignored. + One of indexers or indexers_kwargs must be provided. + method : {None, "nearest", "pad", "ffill", "backfill", "bfill", None}, optional + Method to use for filling index values in ``indexers`` not found in + this dataset: + + - None (default): don't fill gaps + - "pad" / "ffill": propagate last valid index value forward + - "backfill" / "bfill": propagate next valid index value backward + - "nearest": use nearest valid index value + + tolerance : optional + Maximum distance between original and new labels for inexact + matches. The values of the index at the matching locations must + satisfy the equation ``abs(index[indexer] - target) <= tolerance``. + Tolerance may be a scalar value, which applies the same tolerance + to all values, or list-like, which applies variable tolerance per + element. List-like must be the same size as the index and its dtype + must exactly match the index’s type. + copy : bool, default: True + If ``copy=True``, data in the return value is always copied. If + ``copy=False`` and reindexing is unnecessary, or can be performed + with only slice operations, then the output may share memory with + the input. In either case, a new xarray object is always returned. + fill_value : scalar or dict-like, optional + Value to use for newly missing values. If a dict-like, + maps variable names (including coordinates) to fill values. + sparse : bool, default: False + use sparse-array. + **indexers_kwargs : {dim: indexer, ...}, optional + Keyword arguments in the same form as ``indexers``. + One of indexers or indexers_kwargs must be provided. + + Returns + ------- + reindexed : Dataset + Another dataset, with this dataset's data but replaced coordinates. + + See Also + -------- + Dataset.reindex_like + align + pandas.Index.get_indexer + + Examples + -------- + Create a dataset with some fictional data. + + >>> x = xr.Dataset( + ... { + ... "temperature": ("station", 20 * np.random.rand(4)), + ... "pressure": ("station", 500 * np.random.rand(4)), + ... }, + ... coords={"station": ["boston", "nyc", "seattle", "denver"]}, + ... ) + >>> x + Size: 176B + Dimensions: (station: 4) + Coordinates: + * station (station) >> x.indexes + Indexes: + station Index(['boston', 'nyc', 'seattle', 'denver'], dtype='object', name='station') + + Create a new index and reindex the dataset. By default values in the new index that + do not have corresponding records in the dataset are assigned `NaN`. + + >>> new_index = ["boston", "austin", "seattle", "lincoln"] + >>> x.reindex({"station": new_index}) + Size: 176B + Dimensions: (station: 4) + Coordinates: + * station (station) >> x.reindex({"station": new_index}, fill_value=0) + Size: 176B + Dimensions: (station: 4) + Coordinates: + * station (station) >> x.reindex( + ... {"station": new_index}, fill_value={"temperature": 0, "pressure": 100} + ... ) + Size: 176B + Dimensions: (station: 4) + Coordinates: + * station (station) >> x.reindex({"station": new_index}, method="nearest") + Traceback (most recent call last): + ... + raise ValueError('index must be monotonic increasing or decreasing') + ValueError: index must be monotonic increasing or decreasing + + To further illustrate the filling functionality in reindex, we will create a + dataset with a monotonically increasing index (for example, a sequence of dates). + + >>> x2 = xr.Dataset( + ... { + ... "temperature": ( + ... "time", + ... [15.57, 12.77, np.nan, 0.3081, 16.59, 15.12], + ... ), + ... "pressure": ("time", 500 * np.random.rand(6)), + ... }, + ... coords={"time": pd.date_range("01/01/2019", periods=6, freq="D")}, + ... ) + >>> x2 + Size: 144B + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2019-01-01 2019-01-02 ... 2019-01-06 + Data variables: + temperature (time) float64 48B 15.57 12.77 nan 0.3081 16.59 15.12 + pressure (time) float64 48B 481.8 191.7 395.9 264.4 284.0 462.8 + + Suppose we decide to expand the dataset to cover a wider date range. + + >>> time_index2 = pd.date_range("12/29/2018", periods=10, freq="D") + >>> x2.reindex({"time": time_index2}) + Size: 240B + Dimensions: (time: 10) + Coordinates: + * time (time) datetime64[ns] 80B 2018-12-29 2018-12-30 ... 2019-01-07 + Data variables: + temperature (time) float64 80B nan nan nan 15.57 ... 0.3081 16.59 15.12 nan + pressure (time) float64 80B nan nan nan 481.8 ... 264.4 284.0 462.8 nan + + The index entries that did not have a value in the original data frame (for example, `2018-12-29`) + are by default filled with NaN. If desired, we can fill in the missing values using one of several options. + + For example, to back-propagate the last valid value to fill the `NaN` values, + pass `bfill` as an argument to the `method` keyword. + + >>> x3 = x2.reindex({"time": time_index2}, method="bfill") + >>> x3 + Size: 240B + Dimensions: (time: 10) + Coordinates: + * time (time) datetime64[ns] 80B 2018-12-29 2018-12-30 ... 2019-01-07 + Data variables: + temperature (time) float64 80B 15.57 15.57 15.57 15.57 ... 16.59 15.12 nan + pressure (time) float64 80B 481.8 481.8 481.8 481.8 ... 284.0 462.8 nan + + Please note that the `NaN` value present in the original dataset (at index value `2019-01-03`) + will not be filled by any of the value propagation schemes. + + >>> x2.where(x2.temperature.isnull(), drop=True) + Size: 24B + Dimensions: (time: 1) + Coordinates: + * time (time) datetime64[ns] 8B 2019-01-03 + Data variables: + temperature (time) float64 8B nan + pressure (time) float64 8B 395.9 + >>> x3.where(x3.temperature.isnull(), drop=True) + Size: 48B + Dimensions: (time: 2) + Coordinates: + * time (time) datetime64[ns] 16B 2019-01-03 2019-01-07 + Data variables: + temperature (time) float64 16B nan nan + pressure (time) float64 16B 395.9 nan + + This is because filling while reindexing does not look at dataset values, but only compares + the original and desired indexes. If you do want to fill in the `NaN` values present in the + original dataset, use the :py:meth:`~Dataset.fillna()` method. + + """ + indexers = utils.either_dict_or_kwargs(indexers, indexers_kwargs, "reindex") + return alignment.reindex( + self, + indexers=indexers, + method=method, + tolerance=tolerance, + copy=copy, + fill_value=fill_value, + ) + + def _reindex( + self, + indexers: Mapping[Any, Any] | None = None, + method: str | None = None, + tolerance: int | float | Iterable[int | float] | None = None, + copy: bool = True, + fill_value: Any = xrdtypes.NA, + sparse: bool = False, + **indexers_kwargs: Any, + ) -> Self: + """ + Same as reindex but supports sparse option. + """ + indexers = utils.either_dict_or_kwargs(indexers, indexers_kwargs, "reindex") + return alignment.reindex( + self, + indexers=indexers, + method=method, + tolerance=tolerance, + copy=copy, + fill_value=fill_value, + sparse=sparse, + ) + + def interp( + self, + coords: Mapping[Any, Any] | None = None, + method: InterpOptions = "linear", + assume_sorted: bool = False, + kwargs: Mapping[str, Any] | None = None, + method_non_numeric: str = "nearest", + **coords_kwargs: Any, + ) -> Self: + """Interpolate a Dataset onto new coordinates + + Performs univariate or multivariate interpolation of a Dataset onto + new coordinates using scipy's interpolation routines. If interpolating + along an existing dimension, :py:class:`scipy.interpolate.interp1d` is + called. When interpolating along multiple existing dimensions, an + attempt is made to decompose the interpolation into multiple + 1-dimensional interpolations. If this is possible, + :py:class:`scipy.interpolate.interp1d` is called. Otherwise, + :py:func:`scipy.interpolate.interpn` is called. + + Parameters + ---------- + coords : dict, optional + Mapping from dimension names to the new coordinates. + New coordinate can be a scalar, array-like or DataArray. + If DataArrays are passed as new coordinates, their dimensions are + used for the broadcasting. Missing values are skipped. + method : {"linear", "nearest", "zero", "slinear", "quadratic", "cubic", "polynomial", \ + "barycentric", "krogh", "pchip", "spline", "akima"}, default: "linear" + String indicating which method to use for interpolation: + + - 'linear': linear interpolation. Additional keyword + arguments are passed to :py:func:`numpy.interp` + - 'nearest', 'zero', 'slinear', 'quadratic', 'cubic', 'polynomial': + are passed to :py:func:`scipy.interpolate.interp1d`. If + ``method='polynomial'``, the ``order`` keyword argument must also be + provided. + - 'barycentric', 'krogh', 'pchip', 'spline', 'akima': use their + respective :py:class:`scipy.interpolate` classes. + + assume_sorted : bool, default: False + If False, values of coordinates that are interpolated over can be + in any order and they are sorted first. If True, interpolated + coordinates are assumed to be an array of monotonically increasing + values. + kwargs : dict, optional + Additional keyword arguments passed to scipy's interpolator. Valid + options and their behavior depend whether ``interp1d`` or + ``interpn`` is used. + method_non_numeric : {"nearest", "pad", "ffill", "backfill", "bfill"}, optional + Method for non-numeric types. Passed on to :py:meth:`Dataset.reindex`. + ``"nearest"`` is used by default. + **coords_kwargs : {dim: coordinate, ...}, optional + The keyword arguments form of ``coords``. + One of coords or coords_kwargs must be provided. + + Returns + ------- + interpolated : Dataset + New dataset on the new coordinates. + + Notes + ----- + scipy is required. + + See Also + -------- + scipy.interpolate.interp1d + scipy.interpolate.interpn + + :doc:`xarray-tutorial:fundamentals/02.2_manipulating_dimensions` + Tutorial material on manipulating data resolution using :py:func:`~xarray.Dataset.interp` + + Examples + -------- + >>> ds = xr.Dataset( + ... data_vars={ + ... "a": ("x", [5, 7, 4]), + ... "b": ( + ... ("x", "y"), + ... [[1, 4, 2, 9], [2, 7, 6, np.nan], [6, np.nan, 5, 8]], + ... ), + ... }, + ... coords={"x": [0, 1, 2], "y": [10, 12, 14, 16]}, + ... ) + >>> ds + Size: 176B + Dimensions: (x: 3, y: 4) + Coordinates: + * x (x) int64 24B 0 1 2 + * y (y) int64 32B 10 12 14 16 + Data variables: + a (x) int64 24B 5 7 4 + b (x, y) float64 96B 1.0 4.0 2.0 9.0 2.0 7.0 6.0 nan 6.0 nan 5.0 8.0 + + 1D interpolation with the default method (linear): + + >>> ds.interp(x=[0, 0.75, 1.25, 1.75]) + Size: 224B + Dimensions: (x: 4, y: 4) + Coordinates: + * y (y) int64 32B 10 12 14 16 + * x (x) float64 32B 0.0 0.75 1.25 1.75 + Data variables: + a (x) float64 32B 5.0 6.5 6.25 4.75 + b (x, y) float64 128B 1.0 4.0 2.0 nan 1.75 ... nan 5.0 nan 5.25 nan + + 1D interpolation with a different method: + + >>> ds.interp(x=[0, 0.75, 1.25, 1.75], method="nearest") + Size: 224B + Dimensions: (x: 4, y: 4) + Coordinates: + * y (y) int64 32B 10 12 14 16 + * x (x) float64 32B 0.0 0.75 1.25 1.75 + Data variables: + a (x) float64 32B 5.0 7.0 7.0 4.0 + b (x, y) float64 128B 1.0 4.0 2.0 9.0 2.0 7.0 ... nan 6.0 nan 5.0 8.0 + + 1D extrapolation: + + >>> ds.interp( + ... x=[1, 1.5, 2.5, 3.5], + ... method="linear", + ... kwargs={"fill_value": "extrapolate"}, + ... ) + Size: 224B + Dimensions: (x: 4, y: 4) + Coordinates: + * y (y) int64 32B 10 12 14 16 + * x (x) float64 32B 1.0 1.5 2.5 3.5 + Data variables: + a (x) float64 32B 7.0 5.5 2.5 -0.5 + b (x, y) float64 128B 2.0 7.0 6.0 nan 4.0 ... nan 12.0 nan 3.5 nan + + 2D interpolation: + + >>> ds.interp(x=[0, 0.75, 1.25, 1.75], y=[11, 13, 15], method="linear") + Size: 184B + Dimensions: (x: 4, y: 3) + Coordinates: + * x (x) float64 32B 0.0 0.75 1.25 1.75 + * y (y) int64 24B 11 13 15 + Data variables: + a (x) float64 32B 5.0 6.5 6.25 4.75 + b (x, y) float64 96B 2.5 3.0 nan 4.0 5.625 ... nan nan nan nan nan + """ + from xarray.core import missing + + if kwargs is None: + kwargs = {} + + coords = either_dict_or_kwargs(coords, coords_kwargs, "interp") + indexers = dict(self._validate_interp_indexers(coords)) + + if coords: + # This avoids broadcasting over coordinates that are both in + # the original array AND in the indexing array. It essentially + # forces interpolation along the shared coordinates. + sdims = ( + set(self.dims) + .intersection(*[set(nx.dims) for nx in indexers.values()]) + .difference(coords.keys()) + ) + indexers.update({d: self.variables[d] for d in sdims}) + + obj = self if assume_sorted else self.sortby([k for k in coords]) + + def maybe_variable(obj, k): + # workaround to get variable for dimension without coordinate. + try: + return obj._variables[k] + except KeyError: + return as_variable((k, range(obj.sizes[k]))) + + def _validate_interp_indexer(x, new_x): + # In the case of datetimes, the restrictions placed on indexers + # used with interp are stronger than those which are placed on + # isel, so we need an additional check after _validate_indexers. + if _contains_datetime_like_objects( + x + ) and not _contains_datetime_like_objects(new_x): + raise TypeError( + "When interpolating over a datetime-like " + "coordinate, the coordinates to " + "interpolate to must be either datetime " + "strings or datetimes. " + f"Instead got\n{new_x}" + ) + return x, new_x + + validated_indexers = { + k: _validate_interp_indexer(maybe_variable(obj, k), v) + for k, v in indexers.items() + } + + # optimization: subset to coordinate range of the target index + if method in ["linear", "nearest"]: + for k, v in validated_indexers.items(): + obj, newidx = missing._localize(obj, {k: v}) + validated_indexers[k] = newidx[k] + + # optimization: create dask coordinate arrays once per Dataset + # rather than once per Variable when dask.array.unify_chunks is called later + # GH4739 + if obj.__dask_graph__(): + dask_indexers = { + k: (index.to_base_variable().chunk(), dest.to_base_variable().chunk()) + for k, (index, dest) in validated_indexers.items() + } + + variables: dict[Hashable, Variable] = {} + reindex: bool = False + for name, var in obj._variables.items(): + if name in indexers: + continue + + if is_duck_dask_array(var.data): + use_indexers = dask_indexers + else: + use_indexers = validated_indexers + + dtype_kind = var.dtype.kind + if dtype_kind in "uifc": + # For normal number types do the interpolation: + var_indexers = {k: v for k, v in use_indexers.items() if k in var.dims} + variables[name] = missing.interp(var, var_indexers, method, **kwargs) + elif dtype_kind in "ObU" and (use_indexers.keys() & var.dims): + # For types that we do not understand do stepwise + # interpolation to avoid modifying the elements. + # reindex the variable instead because it supports + # booleans and objects and retains the dtype but inside + # this loop there might be some duplicate code that slows it + # down, therefore collect these signals and run it later: + reindex = True + elif all(d not in indexers for d in var.dims): + # For anything else we can only keep variables if they + # are not dependent on any coords that are being + # interpolated along: + variables[name] = var + + if reindex: + reindex_indexers = { + k: v for k, (_, v) in validated_indexers.items() if v.dims == (k,) + } + reindexed = alignment.reindex( + obj, + indexers=reindex_indexers, + method=method_non_numeric, + exclude_vars=variables.keys(), + ) + indexes = dict(reindexed._indexes) + variables.update(reindexed.variables) + else: + # Get the indexes that are not being interpolated along + indexes = {k: v for k, v in obj._indexes.items() if k not in indexers} + + # Get the coords that also exist in the variables: + coord_names = obj._coord_names & variables.keys() + selected = self._replace_with_new_dims( + variables.copy(), coord_names, indexes=indexes + ) + + # Attach indexer as coordinate + for k, v in indexers.items(): + assert isinstance(v, Variable) + if v.dims == (k,): + index = PandasIndex(v, k, coord_dtype=v.dtype) + index_vars = index.create_variables({k: v}) + indexes[k] = index + variables.update(index_vars) + else: + variables[k] = v + + # Extract coordinates from indexers + coord_vars, new_indexes = selected._get_indexers_coords_and_indexes(coords) + variables.update(coord_vars) + indexes.update(new_indexes) + + coord_names = obj._coord_names & variables.keys() | coord_vars.keys() + return self._replace_with_new_dims(variables, coord_names, indexes=indexes) + + def interp_like( + self, + other: T_Xarray, + method: InterpOptions = "linear", + assume_sorted: bool = False, + kwargs: Mapping[str, Any] | None = None, + method_non_numeric: str = "nearest", + ) -> Self: + """Interpolate this object onto the coordinates of another object, + filling the out of range values with NaN. + + If interpolating along a single existing dimension, + :py:class:`scipy.interpolate.interp1d` is called. When interpolating + along multiple existing dimensions, an attempt is made to decompose the + interpolation into multiple 1-dimensional interpolations. If this is + possible, :py:class:`scipy.interpolate.interp1d` is called. Otherwise, + :py:func:`scipy.interpolate.interpn` is called. + + Parameters + ---------- + other : Dataset or DataArray + Object with an 'indexes' attribute giving a mapping from dimension + names to an 1d array-like, which provides coordinates upon + which to index the variables in this dataset. Missing values are skipped. + method : {"linear", "nearest", "zero", "slinear", "quadratic", "cubic", "polynomial", \ + "barycentric", "krogh", "pchip", "spline", "akima"}, default: "linear" + String indicating which method to use for interpolation: + + - 'linear': linear interpolation. Additional keyword + arguments are passed to :py:func:`numpy.interp` + - 'nearest', 'zero', 'slinear', 'quadratic', 'cubic', 'polynomial': + are passed to :py:func:`scipy.interpolate.interp1d`. If + ``method='polynomial'``, the ``order`` keyword argument must also be + provided. + - 'barycentric', 'krogh', 'pchip', 'spline', 'akima': use their + respective :py:class:`scipy.interpolate` classes. + + assume_sorted : bool, default: False + If False, values of coordinates that are interpolated over can be + in any order and they are sorted first. If True, interpolated + coordinates are assumed to be an array of monotonically increasing + values. + kwargs : dict, optional + Additional keyword passed to scipy's interpolator. + method_non_numeric : {"nearest", "pad", "ffill", "backfill", "bfill"}, optional + Method for non-numeric types. Passed on to :py:meth:`Dataset.reindex`. + ``"nearest"`` is used by default. + + Returns + ------- + interpolated : Dataset + Another dataset by interpolating this dataset's data along the + coordinates of the other object. + + Notes + ----- + scipy is required. + If the dataset has object-type coordinates, reindex is used for these + coordinates instead of the interpolation. + + See Also + -------- + Dataset.interp + Dataset.reindex_like + """ + if kwargs is None: + kwargs = {} + + # pick only dimension coordinates with a single index + coords = {} + other_indexes = other.xindexes + for dim in self.dims: + other_dim_coords = other_indexes.get_all_coords(dim, errors="ignore") + if len(other_dim_coords) == 1: + coords[dim] = other_dim_coords[dim] + + numeric_coords: dict[Hashable, pd.Index] = {} + object_coords: dict[Hashable, pd.Index] = {} + for k, v in coords.items(): + if v.dtype.kind in "uifcMm": + numeric_coords[k] = v + else: + object_coords[k] = v + + ds = self + if object_coords: + # We do not support interpolation along object coordinate. + # reindex instead. + ds = self.reindex(object_coords) + return ds.interp( + coords=numeric_coords, + method=method, + assume_sorted=assume_sorted, + kwargs=kwargs, + method_non_numeric=method_non_numeric, + ) + + # Helper methods for rename() + def _rename_vars( + self, name_dict, dims_dict + ) -> tuple[dict[Hashable, Variable], set[Hashable]]: + variables = {} + coord_names = set() + for k, v in self.variables.items(): + var = v.copy(deep=False) + var.dims = tuple(dims_dict.get(dim, dim) for dim in v.dims) + name = name_dict.get(k, k) + if name in variables: + raise ValueError(f"the new name {name!r} conflicts") + variables[name] = var + if k in self._coord_names: + coord_names.add(name) + return variables, coord_names + + def _rename_dims(self, name_dict: Mapping[Any, Hashable]) -> dict[Hashable, int]: + return {name_dict.get(k, k): v for k, v in self.sizes.items()} + + def _rename_indexes( + self, name_dict: Mapping[Any, Hashable], dims_dict: Mapping[Any, Hashable] + ) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]: + if not self._indexes: + return {}, {} + + indexes = {} + variables = {} + + for index, coord_names in self.xindexes.group_by_index(): + new_index = index.rename(name_dict, dims_dict) + new_coord_names = [name_dict.get(k, k) for k in coord_names] + indexes.update({k: new_index for k in new_coord_names}) + new_index_vars = new_index.create_variables( + { + new: self._variables[old] + for old, new in zip(coord_names, new_coord_names) + } + ) + variables.update(new_index_vars) + + return indexes, variables + + def _rename_all( + self, name_dict: Mapping[Any, Hashable], dims_dict: Mapping[Any, Hashable] + ) -> tuple[ + dict[Hashable, Variable], + set[Hashable], + dict[Hashable, int], + dict[Hashable, Index], + ]: + variables, coord_names = self._rename_vars(name_dict, dims_dict) + dims = self._rename_dims(dims_dict) + + indexes, index_vars = self._rename_indexes(name_dict, dims_dict) + variables = {k: index_vars.get(k, v) for k, v in variables.items()} + + return variables, coord_names, dims, indexes + + def _rename( + self, + name_dict: Mapping[Any, Hashable] | None = None, + **names: Hashable, + ) -> Self: + """Also used internally by DataArray so that the warning (if any) + is raised at the right stack level. + """ + name_dict = either_dict_or_kwargs(name_dict, names, "rename") + for k in name_dict.keys(): + if k not in self and k not in self.dims: + raise ValueError( + f"cannot rename {k!r} because it is not a " + "variable or dimension in this dataset" + ) + + create_dim_coord = False + new_k = name_dict[k] + + if k == new_k: + continue # Same name, nothing to do + + if k in self.dims and new_k in self._coord_names: + coord_dims = self._variables[name_dict[k]].dims + if coord_dims == (k,): + create_dim_coord = True + elif k in self._coord_names and new_k in self.dims: + coord_dims = self._variables[k].dims + if coord_dims == (new_k,): + create_dim_coord = True + + if create_dim_coord: + warnings.warn( + f"rename {k!r} to {name_dict[k]!r} does not create an index " + "anymore. Try using swap_dims instead or use set_index " + "after rename to create an indexed coordinate.", + UserWarning, + stacklevel=3, + ) + + variables, coord_names, dims, indexes = self._rename_all( + name_dict=name_dict, dims_dict=name_dict + ) + return self._replace(variables, coord_names, dims=dims, indexes=indexes) + + def rename( + self, + name_dict: Mapping[Any, Hashable] | None = None, + **names: Hashable, + ) -> Self: + """Returns a new object with renamed variables, coordinates and dimensions. + + Parameters + ---------- + name_dict : dict-like, optional + Dictionary whose keys are current variable, coordinate or dimension names and + whose values are the desired names. + **names : optional + Keyword form of ``name_dict``. + One of name_dict or names must be provided. + + Returns + ------- + renamed : Dataset + Dataset with renamed variables, coordinates and dimensions. + + See Also + -------- + Dataset.swap_dims + Dataset.rename_vars + Dataset.rename_dims + DataArray.rename + """ + return self._rename(name_dict=name_dict, **names) + + def rename_dims( + self, + dims_dict: Mapping[Any, Hashable] | None = None, + **dims: Hashable, + ) -> Self: + """Returns a new object with renamed dimensions only. + + Parameters + ---------- + dims_dict : dict-like, optional + Dictionary whose keys are current dimension names and + whose values are the desired names. The desired names must + not be the name of an existing dimension or Variable in the Dataset. + **dims : optional + Keyword form of ``dims_dict``. + One of dims_dict or dims must be provided. + + Returns + ------- + renamed : Dataset + Dataset with renamed dimensions. + + See Also + -------- + Dataset.swap_dims + Dataset.rename + Dataset.rename_vars + DataArray.rename + """ + dims_dict = either_dict_or_kwargs(dims_dict, dims, "rename_dims") + for k, v in dims_dict.items(): + if k not in self.dims: + raise ValueError( + f"cannot rename {k!r} because it is not found " + f"in the dimensions of this dataset {tuple(self.dims)}" + ) + if v in self.dims or v in self: + raise ValueError( + f"Cannot rename {k} to {v} because {v} already exists. " + "Try using swap_dims instead." + ) + + variables, coord_names, sizes, indexes = self._rename_all( + name_dict={}, dims_dict=dims_dict + ) + return self._replace(variables, coord_names, dims=sizes, indexes=indexes) + + def rename_vars( + self, + name_dict: Mapping[Any, Hashable] | None = None, + **names: Hashable, + ) -> Self: + """Returns a new object with renamed variables including coordinates + + Parameters + ---------- + name_dict : dict-like, optional + Dictionary whose keys are current variable or coordinate names and + whose values are the desired names. + **names : optional + Keyword form of ``name_dict``. + One of name_dict or names must be provided. + + Returns + ------- + renamed : Dataset + Dataset with renamed variables including coordinates + + See Also + -------- + Dataset.swap_dims + Dataset.rename + Dataset.rename_dims + DataArray.rename + """ + name_dict = either_dict_or_kwargs(name_dict, names, "rename_vars") + for k in name_dict: + if k not in self: + raise ValueError( + f"cannot rename {k!r} because it is not a " + "variable or coordinate in this dataset" + ) + variables, coord_names, dims, indexes = self._rename_all( + name_dict=name_dict, dims_dict={} + ) + return self._replace(variables, coord_names, dims=dims, indexes=indexes) + + def swap_dims( + self, dims_dict: Mapping[Any, Hashable] | None = None, **dims_kwargs + ) -> Self: + """Returns a new object with swapped dimensions. + + Parameters + ---------- + dims_dict : dict-like + Dictionary whose keys are current dimension names and whose values + are new names. + **dims_kwargs : {existing_dim: new_dim, ...}, optional + The keyword arguments form of ``dims_dict``. + One of dims_dict or dims_kwargs must be provided. + + Returns + ------- + swapped : Dataset + Dataset with swapped dimensions. + + Examples + -------- + >>> ds = xr.Dataset( + ... data_vars={"a": ("x", [5, 7]), "b": ("x", [0.1, 2.4])}, + ... coords={"x": ["a", "b"], "y": ("x", [0, 1])}, + ... ) + >>> ds + Size: 56B + Dimensions: (x: 2) + Coordinates: + * x (x) >> ds.swap_dims({"x": "y"}) + Size: 56B + Dimensions: (y: 2) + Coordinates: + x (y) >> ds.swap_dims({"x": "z"}) + Size: 56B + Dimensions: (z: 2) + Coordinates: + x (z) Self: + """Return a new object with an additional axis (or axes) inserted at + the corresponding position in the array shape. The new object is a + view into the underlying array, not a copy. + + If dim is already a scalar coordinate, it will be promoted to a 1D + coordinate consisting of a single value. + + The automatic creation of indexes to back new 1D coordinate variables + controlled by the create_index_for_new_dim kwarg. + + Parameters + ---------- + dim : hashable, sequence of hashable, mapping, or None + Dimensions to include on the new variable. If provided as hashable + or sequence of hashable, then dimensions are inserted with length + 1. If provided as a mapping, then the keys are the new dimensions + and the values are either integers (giving the length of the new + dimensions) or array-like (giving the coordinates of the new + dimensions). + axis : int, sequence of int, or None, default: None + Axis position(s) where new axis is to be inserted (position(s) on + the result array). If a sequence of integers is passed, + multiple axes are inserted. In this case, dim arguments should be + same length list. If axis=None is passed, all the axes will be + inserted to the start of the result array. + create_index_for_new_dim : bool, default: True + Whether to create new ``PandasIndex`` objects when the object being expanded contains scalar variables with names in ``dim``. + **dim_kwargs : int or sequence or ndarray + The keywords are arbitrary dimensions being inserted and the values + are either the lengths of the new dims (if int is given), or their + coordinates. Note, this is an alternative to passing a dict to the + dim kwarg and will only be used if dim is None. + + Returns + ------- + expanded : Dataset + This object, but with additional dimension(s). + + Examples + -------- + >>> dataset = xr.Dataset({"temperature": ([], 25.0)}) + >>> dataset + Size: 8B + Dimensions: () + Data variables: + temperature float64 8B 25.0 + + # Expand the dataset with a new dimension called "time" + + >>> dataset.expand_dims(dim="time") + Size: 8B + Dimensions: (time: 1) + Dimensions without coordinates: time + Data variables: + temperature (time) float64 8B 25.0 + + # 1D data + + >>> temperature_1d = xr.DataArray([25.0, 26.5, 24.8], dims="x") + >>> dataset_1d = xr.Dataset({"temperature": temperature_1d}) + >>> dataset_1d + Size: 24B + Dimensions: (x: 3) + Dimensions without coordinates: x + Data variables: + temperature (x) float64 24B 25.0 26.5 24.8 + + # Expand the dataset with a new dimension called "time" using axis argument + + >>> dataset_1d.expand_dims(dim="time", axis=0) + Size: 24B + Dimensions: (time: 1, x: 3) + Dimensions without coordinates: time, x + Data variables: + temperature (time, x) float64 24B 25.0 26.5 24.8 + + # 2D data + + >>> temperature_2d = xr.DataArray(np.random.rand(3, 4), dims=("y", "x")) + >>> dataset_2d = xr.Dataset({"temperature": temperature_2d}) + >>> dataset_2d + Size: 96B + Dimensions: (y: 3, x: 4) + Dimensions without coordinates: y, x + Data variables: + temperature (y, x) float64 96B 0.5488 0.7152 0.6028 ... 0.7917 0.5289 + + # Expand the dataset with a new dimension called "time" using axis argument + + >>> dataset_2d.expand_dims(dim="time", axis=2) + Size: 96B + Dimensions: (y: 3, x: 4, time: 1) + Dimensions without coordinates: y, x, time + Data variables: + temperature (y, x, time) float64 96B 0.5488 0.7152 0.6028 ... 0.7917 0.5289 + + # Expand a scalar variable along a new dimension of the same name with and without creating a new index + + >>> ds = xr.Dataset(coords={"x": 0}) + >>> ds + Size: 8B + Dimensions: () + Coordinates: + x int64 8B 0 + Data variables: + *empty* + + >>> ds.expand_dims("x") + Size: 8B + Dimensions: (x: 1) + Coordinates: + * x (x) int64 8B 0 + Data variables: + *empty* + + >>> ds.expand_dims("x").indexes + Indexes: + x Index([0], dtype='int64', name='x') + + >>> ds.expand_dims("x", create_index_for_new_dim=False).indexes + Indexes: + *empty* + + See Also + -------- + DataArray.expand_dims + """ + if dim is None: + pass + elif isinstance(dim, Mapping): + # We're later going to modify dim in place; don't tamper with + # the input + dim = dict(dim) + elif isinstance(dim, int): + raise TypeError( + "dim should be hashable or sequence of hashables or mapping" + ) + elif isinstance(dim, str) or not isinstance(dim, Sequence): + dim = {dim: 1} + elif isinstance(dim, Sequence): + if len(dim) != len(set(dim)): + raise ValueError("dims should not contain duplicate values.") + dim = {d: 1 for d in dim} + + dim = either_dict_or_kwargs(dim, dim_kwargs, "expand_dims") + assert isinstance(dim, MutableMapping) + + if axis is None: + axis = list(range(len(dim))) + elif not isinstance(axis, Sequence): + axis = [axis] + + if len(dim) != len(axis): + raise ValueError("lengths of dim and axis should be identical.") + for d in dim: + if d in self.dims: + raise ValueError(f"Dimension {d} already exists.") + if d in self._variables and not utils.is_scalar(self._variables[d]): + raise ValueError( + f"{d} already exists as coordinate or" " variable name." + ) + + variables: dict[Hashable, Variable] = {} + indexes: dict[Hashable, Index] = dict(self._indexes) + coord_names = self._coord_names.copy() + # If dim is a dict, then ensure that the values are either integers + # or iterables. + for k, v in dim.items(): + if hasattr(v, "__iter__"): + # If the value for the new dimension is an iterable, then + # save the coordinates to the variables dict, and set the + # value within the dim dict to the length of the iterable + # for later use. + + if create_index_for_new_dim: + index = PandasIndex(v, k) + indexes[k] = index + name_and_new_1d_var = index.create_variables() + else: + name_and_new_1d_var = {k: Variable(data=v, dims=k)} + variables.update(name_and_new_1d_var) + coord_names.add(k) + dim[k] = variables[k].size + elif isinstance(v, int): + pass # Do nothing if the dimensions value is just an int + else: + raise TypeError( + f"The value of new dimension {k} must be " "an iterable or an int" + ) + + for k, v in self._variables.items(): + if k not in dim: + if k in coord_names: # Do not change coordinates + variables[k] = v + else: + result_ndim = len(v.dims) + len(axis) + for a in axis: + if a < -result_ndim or result_ndim - 1 < a: + raise IndexError( + f"Axis {a} of variable {k} is out of bounds of the " + f"expanded dimension size {result_ndim}" + ) + + axis_pos = [a if a >= 0 else result_ndim + a for a in axis] + if len(axis_pos) != len(set(axis_pos)): + raise ValueError("axis should not contain duplicate values") + # We need to sort them to make sure `axis` equals to the + # axis positions of the result array. + zip_axis_dim = sorted(zip(axis_pos, dim.items())) + + all_dims = list(zip(v.dims, v.shape)) + for d, c in zip_axis_dim: + all_dims.insert(d, c) + variables[k] = v.set_dims(dict(all_dims)) + else: + if k not in variables: + if k in coord_names and create_index_for_new_dim: + # If dims includes a label of a non-dimension coordinate, + # it will be promoted to a 1D coordinate with a single value. + index, index_vars = create_default_index_implicit(v.set_dims(k)) + indexes[k] = index + variables.update(index_vars) + else: + if create_index_for_new_dim: + warnings.warn( + f"No index created for dimension {k} because variable {k} is not a coordinate. " + f"To create an index for {k}, please first call `.set_coords('{k}')` on this object.", + UserWarning, + ) + + # create 1D variable without creating a new index + new_1d_var = v.set_dims(k) + variables.update({k: new_1d_var}) + + return self._replace_with_new_dims( + variables, coord_names=coord_names, indexes=indexes + ) + + def set_index( + self, + indexes: Mapping[Any, Hashable | Sequence[Hashable]] | None = None, + append: bool = False, + **indexes_kwargs: Hashable | Sequence[Hashable], + ) -> Self: + """Set Dataset (multi-)indexes using one or more existing coordinates + or variables. + + This legacy method is limited to pandas (multi-)indexes and + 1-dimensional "dimension" coordinates. See + :py:meth:`~Dataset.set_xindex` for setting a pandas or a custom + Xarray-compatible index from one or more arbitrary coordinates. + + Parameters + ---------- + indexes : {dim: index, ...} + Mapping from names matching dimensions and values given + by (lists of) the names of existing coordinates or variables to set + as new (multi-)index. + append : bool, default: False + If True, append the supplied index(es) to the existing index(es). + Otherwise replace the existing index(es) (default). + **indexes_kwargs : optional + The keyword arguments form of ``indexes``. + One of indexes or indexes_kwargs must be provided. + + Returns + ------- + obj : Dataset + Another dataset, with this dataset's data but replaced coordinates. + + Examples + -------- + >>> arr = xr.DataArray( + ... data=np.ones((2, 3)), + ... dims=["x", "y"], + ... coords={"x": range(2), "y": range(3), "a": ("x", [3, 4])}, + ... ) + >>> ds = xr.Dataset({"v": arr}) + >>> ds + Size: 104B + Dimensions: (x: 2, y: 3) + Coordinates: + * x (x) int64 16B 0 1 + * y (y) int64 24B 0 1 2 + a (x) int64 16B 3 4 + Data variables: + v (x, y) float64 48B 1.0 1.0 1.0 1.0 1.0 1.0 + >>> ds.set_index(x="a") + Size: 88B + Dimensions: (x: 2, y: 3) + Coordinates: + * x (x) int64 16B 3 4 + * y (y) int64 24B 0 1 2 + Data variables: + v (x, y) float64 48B 1.0 1.0 1.0 1.0 1.0 1.0 + + See Also + -------- + Dataset.reset_index + Dataset.set_xindex + Dataset.swap_dims + """ + dim_coords = either_dict_or_kwargs(indexes, indexes_kwargs, "set_index") + + new_indexes: dict[Hashable, Index] = {} + new_variables: dict[Hashable, Variable] = {} + drop_indexes: set[Hashable] = set() + drop_variables: set[Hashable] = set() + replace_dims: dict[Hashable, Hashable] = {} + all_var_names: set[Hashable] = set() + + for dim, _var_names in dim_coords.items(): + if isinstance(_var_names, str) or not isinstance(_var_names, Sequence): + var_names = [_var_names] + else: + var_names = list(_var_names) + + invalid_vars = set(var_names) - set(self._variables) + if invalid_vars: + raise ValueError( + ", ".join([str(v) for v in invalid_vars]) + + " variable(s) do not exist" + ) + + all_var_names.update(var_names) + drop_variables.update(var_names) + + # drop any pre-existing index involved and its corresponding coordinates + index_coord_names = self.xindexes.get_all_coords(dim, errors="ignore") + all_index_coord_names = set(index_coord_names) + for k in var_names: + all_index_coord_names.update( + self.xindexes.get_all_coords(k, errors="ignore") + ) + + drop_indexes.update(all_index_coord_names) + drop_variables.update(all_index_coord_names) + + if len(var_names) == 1 and (not append or dim not in self._indexes): + var_name = var_names[0] + var = self._variables[var_name] + # an error with a better message will be raised for scalar variables + # when creating the PandasIndex + if var.ndim > 0 and var.dims != (dim,): + raise ValueError( + f"dimension mismatch: try setting an index for dimension {dim!r} with " + f"variable {var_name!r} that has dimensions {var.dims}" + ) + idx = PandasIndex.from_variables({dim: var}, options={}) + idx_vars = idx.create_variables({var_name: var}) + + # trick to preserve coordinate order in this case + if dim in self._coord_names: + drop_variables.remove(dim) + else: + if append: + current_variables = { + k: self._variables[k] for k in index_coord_names + } + else: + current_variables = {} + idx, idx_vars = PandasMultiIndex.from_variables_maybe_expand( + dim, + current_variables, + {k: self._variables[k] for k in var_names}, + ) + for n in idx.index.names: + replace_dims[n] = dim + + new_indexes.update({k: idx for k in idx_vars}) + new_variables.update(idx_vars) + + # re-add deindexed coordinates (convert to base variables) + for k in drop_variables: + if ( + k not in new_variables + and k not in all_var_names + and k in self._coord_names + ): + new_variables[k] = self._variables[k].to_base_variable() + + indexes_: dict[Any, Index] = { + k: v for k, v in self._indexes.items() if k not in drop_indexes + } + indexes_.update(new_indexes) + + variables = { + k: v for k, v in self._variables.items() if k not in drop_variables + } + variables.update(new_variables) + + # update dimensions if necessary, GH: 3512 + for k, v in variables.items(): + if any(d in replace_dims for d in v.dims): + new_dims = [replace_dims.get(d, d) for d in v.dims] + variables[k] = v._replace(dims=new_dims) + + coord_names = self._coord_names - drop_variables | set(new_variables) + + return self._replace_with_new_dims( + variables, coord_names=coord_names, indexes=indexes_ + ) + + @_deprecate_positional_args("v2023.10.0") + def reset_index( + self, + dims_or_levels: Hashable | Sequence[Hashable], + *, + drop: bool = False, + ) -> Self: + """Reset the specified index(es) or multi-index level(s). + + This legacy method is specific to pandas (multi-)indexes and + 1-dimensional "dimension" coordinates. See the more generic + :py:meth:`~Dataset.drop_indexes` and :py:meth:`~Dataset.set_xindex` + method to respectively drop and set pandas or custom indexes for + arbitrary coordinates. + + Parameters + ---------- + dims_or_levels : Hashable or Sequence of Hashable + Name(s) of the dimension(s) and/or multi-index level(s) that will + be reset. + drop : bool, default: False + If True, remove the specified indexes and/or multi-index levels + instead of extracting them as new coordinates (default: False). + + Returns + ------- + obj : Dataset + Another dataset, with this dataset's data but replaced coordinates. + + See Also + -------- + Dataset.set_index + Dataset.set_xindex + Dataset.drop_indexes + """ + if isinstance(dims_or_levels, str) or not isinstance(dims_or_levels, Sequence): + dims_or_levels = [dims_or_levels] + + invalid_coords = set(dims_or_levels) - set(self._indexes) + if invalid_coords: + raise ValueError( + f"{tuple(invalid_coords)} are not coordinates with an index" + ) + + drop_indexes: set[Hashable] = set() + drop_variables: set[Hashable] = set() + seen: set[Index] = set() + new_indexes: dict[Hashable, Index] = {} + new_variables: dict[Hashable, Variable] = {} + + def drop_or_convert(var_names): + if drop: + drop_variables.update(var_names) + else: + base_vars = { + k: self._variables[k].to_base_variable() for k in var_names + } + new_variables.update(base_vars) + + for name in dims_or_levels: + index = self._indexes[name] + + if index in seen: + continue + seen.add(index) + + idx_var_names = set(self.xindexes.get_all_coords(name)) + drop_indexes.update(idx_var_names) + + if isinstance(index, PandasMultiIndex): + # special case for pd.MultiIndex + level_names = index.index.names + keep_level_vars = { + k: self._variables[k] + for k in level_names + if k not in dims_or_levels + } + + if index.dim not in dims_or_levels and keep_level_vars: + # do not drop the multi-index completely + # instead replace it by a new (multi-)index with dropped level(s) + idx = index.keep_levels(keep_level_vars) + idx_vars = idx.create_variables(keep_level_vars) + new_indexes.update({k: idx for k in idx_vars}) + new_variables.update(idx_vars) + if not isinstance(idx, PandasMultiIndex): + # multi-index reduced to single index + # backward compatibility: unique level coordinate renamed to dimension + drop_variables.update(keep_level_vars) + drop_or_convert( + [k for k in level_names if k not in keep_level_vars] + ) + else: + # always drop the multi-index dimension variable + drop_variables.add(index.dim) + drop_or_convert(level_names) + else: + drop_or_convert(idx_var_names) + + indexes = {k: v for k, v in self._indexes.items() if k not in drop_indexes} + indexes.update(new_indexes) + + variables = { + k: v for k, v in self._variables.items() if k not in drop_variables + } + variables.update(new_variables) + + coord_names = self._coord_names - drop_variables + + return self._replace_with_new_dims( + variables, coord_names=coord_names, indexes=indexes + ) + + def set_xindex( + self, + coord_names: str | Sequence[Hashable], + index_cls: type[Index] | None = None, + **options, + ) -> Self: + """Set a new, Xarray-compatible index from one or more existing + coordinate(s). + + Parameters + ---------- + coord_names : str or list + Name(s) of the coordinate(s) used to build the index. + If several names are given, their order matters. + index_cls : subclass of :class:`~xarray.indexes.Index`, optional + The type of index to create. By default, try setting + a ``PandasIndex`` if ``len(coord_names) == 1``, + otherwise a ``PandasMultiIndex``. + **options + Options passed to the index constructor. + + Returns + ------- + obj : Dataset + Another dataset, with this dataset's data and with a new index. + + """ + # the Sequence check is required for mypy + if is_scalar(coord_names) or not isinstance(coord_names, Sequence): + coord_names = [coord_names] + + if index_cls is None: + if len(coord_names) == 1: + index_cls = PandasIndex + else: + index_cls = PandasMultiIndex + else: + if not issubclass(index_cls, Index): + raise TypeError(f"{index_cls} is not a subclass of xarray.Index") + + invalid_coords = set(coord_names) - self._coord_names + + if invalid_coords: + msg = ["invalid coordinate(s)"] + no_vars = invalid_coords - set(self._variables) + data_vars = invalid_coords - no_vars + if no_vars: + msg.append(f"those variables don't exist: {no_vars}") + if data_vars: + msg.append( + f"those variables are data variables: {data_vars}, use `set_coords` first" + ) + raise ValueError("\n".join(msg)) + + # we could be more clever here (e.g., drop-in index replacement if index + # coordinates do not conflict), but let's not allow this for now + indexed_coords = set(coord_names) & set(self._indexes) + + if indexed_coords: + raise ValueError( + f"those coordinates already have an index: {indexed_coords}" + ) + + coord_vars = {name: self._variables[name] for name in coord_names} + + index = index_cls.from_variables(coord_vars, options=options) + + new_coord_vars = index.create_variables(coord_vars) + + # special case for setting a pandas multi-index from level coordinates + # TODO: remove it once we depreciate pandas multi-index dimension (tuple + # elements) coordinate + if isinstance(index, PandasMultiIndex): + coord_names = [index.dim] + list(coord_names) + + variables: dict[Hashable, Variable] + indexes: dict[Hashable, Index] + + if len(coord_names) == 1: + variables = self._variables.copy() + indexes = self._indexes.copy() + + name = list(coord_names).pop() + if name in new_coord_vars: + variables[name] = new_coord_vars[name] + indexes[name] = index + else: + # reorder variables and indexes so that coordinates having the same + # index are next to each other + variables = {} + for name, var in self._variables.items(): + if name not in coord_names: + variables[name] = var + + indexes = {} + for name, idx in self._indexes.items(): + if name not in coord_names: + indexes[name] = idx + + for name in coord_names: + try: + variables[name] = new_coord_vars[name] + except KeyError: + variables[name] = self._variables[name] + indexes[name] = index + + return self._replace( + variables=variables, + coord_names=self._coord_names | set(coord_names), + indexes=indexes, + ) + + def reorder_levels( + self, + dim_order: Mapping[Any, Sequence[int | Hashable]] | None = None, + **dim_order_kwargs: Sequence[int | Hashable], + ) -> Self: + """Rearrange index levels using input order. + + Parameters + ---------- + dim_order : dict-like of Hashable to Sequence of int or Hashable, optional + Mapping from names matching dimensions and values given + by lists representing new level orders. Every given dimension + must have a multi-index. + **dim_order_kwargs : Sequence of int or Hashable, optional + The keyword arguments form of ``dim_order``. + One of dim_order or dim_order_kwargs must be provided. + + Returns + ------- + obj : Dataset + Another dataset, with this dataset's data but replaced + coordinates. + """ + dim_order = either_dict_or_kwargs(dim_order, dim_order_kwargs, "reorder_levels") + variables = self._variables.copy() + indexes = dict(self._indexes) + new_indexes: dict[Hashable, Index] = {} + new_variables: dict[Hashable, IndexVariable] = {} + + for dim, order in dim_order.items(): + index = self._indexes[dim] + + if not isinstance(index, PandasMultiIndex): + raise ValueError(f"coordinate {dim} has no MultiIndex") + + level_vars = {k: self._variables[k] for k in order} + idx = index.reorder_levels(level_vars) + idx_vars = idx.create_variables(level_vars) + new_indexes.update({k: idx for k in idx_vars}) + new_variables.update(idx_vars) + + indexes = {k: v for k, v in self._indexes.items() if k not in new_indexes} + indexes.update(new_indexes) + + variables = {k: v for k, v in self._variables.items() if k not in new_variables} + variables.update(new_variables) + + return self._replace(variables, indexes=indexes) + + def _get_stack_index( + self, + dim, + multi=False, + create_index=False, + ) -> tuple[Index | None, dict[Hashable, Variable]]: + """Used by stack and unstack to get one pandas (multi-)index among + the indexed coordinates along dimension `dim`. + + If exactly one index is found, return it with its corresponding + coordinate variables(s), otherwise return None and an empty dict. + + If `create_index=True`, create a new index if none is found or raise + an error if multiple indexes are found. + + """ + stack_index: Index | None = None + stack_coords: dict[Hashable, Variable] = {} + + for name, index in self._indexes.items(): + var = self._variables[name] + if ( + var.ndim == 1 + and var.dims[0] == dim + and ( + # stack: must be a single coordinate index + not multi + and not self.xindexes.is_multi(name) + # unstack: must be an index that implements .unstack + or multi + and type(index).unstack is not Index.unstack + ) + ): + if stack_index is not None and index is not stack_index: + # more than one index found, stop + if create_index: + raise ValueError( + f"cannot stack dimension {dim!r} with `create_index=True` " + "and with more than one index found along that dimension" + ) + return None, {} + stack_index = index + stack_coords[name] = var + + if create_index and stack_index is None: + if dim in self._variables: + var = self._variables[dim] + else: + _, _, var = _get_virtual_variable(self._variables, dim, self.sizes) + # dummy index (only `stack_coords` will be used to construct the multi-index) + stack_index = PandasIndex([0], dim) + stack_coords = {dim: var} + + return stack_index, stack_coords + + def _stack_once( + self, + dims: Sequence[Hashable | ellipsis], + new_dim: Hashable, + index_cls: type[Index], + create_index: bool | None = True, + ) -> Self: + if dims == ...: + raise ValueError("Please use [...] for dims, rather than just ...") + if ... in dims: + dims = list(infix_dims(dims, self.dims)) + + new_variables: dict[Hashable, Variable] = {} + stacked_var_names: list[Hashable] = [] + drop_indexes: list[Hashable] = [] + + for name, var in self.variables.items(): + if any(d in var.dims for d in dims): + add_dims = [d for d in dims if d not in var.dims] + vdims = list(var.dims) + add_dims + shape = [self.sizes[d] for d in vdims] + exp_var = var.set_dims(vdims, shape) + stacked_var = exp_var.stack(**{new_dim: dims}) + new_variables[name] = stacked_var + stacked_var_names.append(name) + else: + new_variables[name] = var.copy(deep=False) + + # drop indexes of stacked coordinates (if any) + for name in stacked_var_names: + drop_indexes += list(self.xindexes.get_all_coords(name, errors="ignore")) + + new_indexes = {} + new_coord_names = set(self._coord_names) + if create_index or create_index is None: + product_vars: dict[Any, Variable] = {} + for dim in dims: + idx, idx_vars = self._get_stack_index(dim, create_index=create_index) + if idx is not None: + product_vars.update(idx_vars) + + if len(product_vars) == len(dims): + idx = index_cls.stack(product_vars, new_dim) + new_indexes[new_dim] = idx + new_indexes.update({k: idx for k in product_vars}) + idx_vars = idx.create_variables(product_vars) + # keep consistent multi-index coordinate order + for k in idx_vars: + new_variables.pop(k, None) + new_variables.update(idx_vars) + new_coord_names.update(idx_vars) + + indexes = {k: v for k, v in self._indexes.items() if k not in drop_indexes} + indexes.update(new_indexes) + + return self._replace_with_new_dims( + new_variables, coord_names=new_coord_names, indexes=indexes + ) + + @partial(deprecate_dims, old_name="dimensions") + def stack( + self, + dim: Mapping[Any, Sequence[Hashable | ellipsis]] | None = None, + create_index: bool | None = True, + index_cls: type[Index] = PandasMultiIndex, + **dim_kwargs: Sequence[Hashable | ellipsis], + ) -> Self: + """ + Stack any number of existing dimensions into a single new dimension. + + New dimensions will be added at the end, and by default the corresponding + coordinate variables will be combined into a MultiIndex. + + Parameters + ---------- + dim : mapping of hashable to sequence of hashable + Mapping of the form `new_name=(dim1, dim2, ...)`. Names of new + dimensions, and the existing dimensions that they replace. An + ellipsis (`...`) will be replaced by all unlisted dimensions. + Passing a list containing an ellipsis (`stacked_dim=[...]`) will stack over + all dimensions. + create_index : bool or None, default: True + + - True: create a multi-index for each of the stacked dimensions. + - False: don't create any index. + - None. create a multi-index only if exactly one single (1-d) coordinate + index is found for every dimension to stack. + + index_cls: Index-class, default: PandasMultiIndex + Can be used to pass a custom multi-index type (must be an Xarray index that + implements `.stack()`). By default, a pandas multi-index wrapper is used. + **dim_kwargs + The keyword arguments form of ``dim``. + One of dim or dim_kwargs must be provided. + + Returns + ------- + stacked : Dataset + Dataset with stacked data. + + See Also + -------- + Dataset.unstack + """ + dim = either_dict_or_kwargs(dim, dim_kwargs, "stack") + result = self + for new_dim, dims in dim.items(): + result = result._stack_once(dims, new_dim, index_cls, create_index) + return result + + def to_stacked_array( + self, + new_dim: Hashable, + sample_dims: Collection[Hashable], + variable_dim: Hashable = "variable", + name: Hashable | None = None, + ) -> DataArray: + """Combine variables of differing dimensionality into a DataArray + without broadcasting. + + This method is similar to Dataset.to_dataarray but does not broadcast the + variables. + + Parameters + ---------- + new_dim : hashable + Name of the new stacked coordinate + sample_dims : Collection of hashables + List of dimensions that **will not** be stacked. Each array in the + dataset must share these dimensions. For machine learning + applications, these define the dimensions over which samples are + drawn. + variable_dim : hashable, default: "variable" + Name of the level in the stacked coordinate which corresponds to + the variables. + name : hashable, optional + Name of the new data array. + + Returns + ------- + stacked : DataArray + DataArray with the specified dimensions and data variables + stacked together. The stacked coordinate is named ``new_dim`` + and represented by a MultiIndex object with a level containing the + data variable names. The name of this level is controlled using + the ``variable_dim`` argument. + + See Also + -------- + Dataset.to_dataarray + Dataset.stack + DataArray.to_unstacked_dataset + + Examples + -------- + >>> data = xr.Dataset( + ... data_vars={ + ... "a": (("x", "y"), [[0, 1, 2], [3, 4, 5]]), + ... "b": ("x", [6, 7]), + ... }, + ... coords={"y": ["u", "v", "w"]}, + ... ) + + >>> data + Size: 76B + Dimensions: (x: 2, y: 3) + Coordinates: + * y (y) >> data.to_stacked_array("z", sample_dims=["x"]) + Size: 64B + array([[0, 1, 2, 6], + [3, 4, 5, 7]]) + Coordinates: + * z (z) object 32B MultiIndex + * variable (z) Self: + index, index_vars = index_and_vars + variables: dict[Hashable, Variable] = {} + indexes = {k: v for k, v in self._indexes.items() if k != dim} + + new_indexes, clean_index = index.unstack() + indexes.update(new_indexes) + + for name, idx in new_indexes.items(): + variables.update(idx.create_variables(index_vars)) + + for name, var in self.variables.items(): + if name not in index_vars: + if dim in var.dims: + if isinstance(fill_value, Mapping): + fill_value_ = fill_value[name] + else: + fill_value_ = fill_value + + variables[name] = var._unstack_once( + index=clean_index, + dim=dim, + fill_value=fill_value_, + sparse=sparse, + ) + else: + variables[name] = var + + coord_names = set(self._coord_names) - {dim} | set(new_indexes) + + return self._replace_with_new_dims( + variables, coord_names=coord_names, indexes=indexes + ) + + def _unstack_full_reindex( + self, + dim: Hashable, + index_and_vars: tuple[Index, dict[Hashable, Variable]], + fill_value, + sparse: bool, + ) -> Self: + index, index_vars = index_and_vars + variables: dict[Hashable, Variable] = {} + indexes = {k: v for k, v in self._indexes.items() if k != dim} + + new_indexes, clean_index = index.unstack() + indexes.update(new_indexes) + + new_index_variables = {} + for name, idx in new_indexes.items(): + new_index_variables.update(idx.create_variables(index_vars)) + + new_dim_sizes = {k: v.size for k, v in new_index_variables.items()} + variables.update(new_index_variables) + + # take a shortcut in case the MultiIndex was not modified. + full_idx = pd.MultiIndex.from_product( + clean_index.levels, names=clean_index.names + ) + if clean_index.equals(full_idx): + obj = self + else: + # TODO: we may depreciate implicit re-indexing with a pandas.MultiIndex + xr_full_idx = PandasMultiIndex(full_idx, dim) + indexers = Indexes( + {k: xr_full_idx for k in index_vars}, + xr_full_idx.create_variables(index_vars), + ) + obj = self._reindex( + indexers, copy=False, fill_value=fill_value, sparse=sparse + ) + + for name, var in obj.variables.items(): + if name not in index_vars: + if dim in var.dims: + variables[name] = var.unstack({dim: new_dim_sizes}) + else: + variables[name] = var + + coord_names = set(self._coord_names) - {dim} | set(new_dim_sizes) + + return self._replace_with_new_dims( + variables, coord_names=coord_names, indexes=indexes + ) + + @_deprecate_positional_args("v2023.10.0") + def unstack( + self, + dim: Dims = None, + *, + fill_value: Any = xrdtypes.NA, + sparse: bool = False, + ) -> Self: + """ + Unstack existing dimensions corresponding to MultiIndexes into + multiple new dimensions. + + New dimensions will be added at the end. + + Parameters + ---------- + dim : str, Iterable of Hashable or None, optional + Dimension(s) over which to unstack. By default unstacks all + MultiIndexes. + fill_value : scalar or dict-like, default: nan + value to be filled. If a dict-like, maps variable names to + fill values. If not provided or if the dict-like does not + contain all variables, the dtype's NA value will be used. + sparse : bool, default: False + use sparse-array if True + + Returns + ------- + unstacked : Dataset + Dataset with unstacked data. + + See Also + -------- + Dataset.stack + """ + + if dim is None: + dims = list(self.dims) + else: + if isinstance(dim, str) or not isinstance(dim, Iterable): + dims = [dim] + else: + dims = list(dim) + + missing_dims = set(dims) - set(self.dims) + if missing_dims: + raise ValueError( + f"Dimensions {tuple(missing_dims)} not found in data dimensions {tuple(self.dims)}" + ) + + # each specified dimension must have exactly one multi-index + stacked_indexes: dict[Any, tuple[Index, dict[Hashable, Variable]]] = {} + for d in dims: + idx, idx_vars = self._get_stack_index(d, multi=True) + if idx is not None: + stacked_indexes[d] = idx, idx_vars + + if dim is None: + dims = list(stacked_indexes) + else: + non_multi_dims = set(dims) - set(stacked_indexes) + if non_multi_dims: + raise ValueError( + "cannot unstack dimensions that do not " + f"have exactly one multi-index: {tuple(non_multi_dims)}" + ) + + result = self.copy(deep=False) + + # we want to avoid allocating an object-dtype ndarray for a MultiIndex, + # so we can't just access self.variables[v].data for every variable. + # We only check the non-index variables. + # https://github.com/pydata/xarray/issues/5902 + nonindexes = [ + self.variables[k] for k in set(self.variables) - set(self._indexes) + ] + # Notes for each of these cases: + # 1. Dask arrays don't support assignment by index, which the fast unstack + # function requires. + # https://github.com/pydata/xarray/pull/4746#issuecomment-753282125 + # 2. Sparse doesn't currently support (though we could special-case it) + # https://github.com/pydata/sparse/issues/422 + # 3. pint requires checking if it's a NumPy array until + # https://github.com/pydata/xarray/pull/4751 is resolved, + # Once that is resolved, explicitly exclude pint arrays. + # pint doesn't implement `np.full_like` in a way that's + # currently compatible. + sparse_array_type = array_type("sparse") + needs_full_reindex = any( + is_duck_dask_array(v.data) + or isinstance(v.data, sparse_array_type) + or not isinstance(v.data, np.ndarray) + for v in nonindexes + ) + + for d in dims: + if needs_full_reindex: + result = result._unstack_full_reindex( + d, stacked_indexes[d], fill_value, sparse + ) + else: + result = result._unstack_once(d, stacked_indexes[d], fill_value, sparse) + return result + + def update(self, other: CoercibleMapping) -> Self: + """Update this dataset's variables with those from another dataset. + + Just like :py:meth:`dict.update` this is a in-place operation. + For a non-inplace version, see :py:meth:`Dataset.merge`. + + Parameters + ---------- + other : Dataset or mapping + Variables with which to update this dataset. One of: + + - Dataset + - mapping {var name: DataArray} + - mapping {var name: Variable} + - mapping {var name: (dimension name, array-like)} + - mapping {var name: (tuple of dimension names, array-like)} + + Returns + ------- + updated : Dataset + Updated dataset. Note that since the update is in-place this is the input + dataset. + + It is deprecated since version 0.17 and scheduled to be removed in 0.21. + + Raises + ------ + ValueError + If any dimensions would have inconsistent sizes in the updated + dataset. + + See Also + -------- + Dataset.assign + Dataset.merge + """ + merge_result = dataset_update_method(self, other) + return self._replace(inplace=True, **merge_result._asdict()) + + def merge( + self, + other: CoercibleMapping | DataArray, + overwrite_vars: Hashable | Iterable[Hashable] = frozenset(), + compat: CompatOptions = "no_conflicts", + join: JoinOptions = "outer", + fill_value: Any = xrdtypes.NA, + combine_attrs: CombineAttrsOptions = "override", + ) -> Self: + """Merge the arrays of two datasets into a single dataset. + + This method generally does not allow for overriding data, with the + exception of attributes, which are ignored on the second dataset. + Variables with the same name are checked for conflicts via the equals + or identical methods. + + Parameters + ---------- + other : Dataset or mapping + Dataset or variables to merge with this dataset. + overwrite_vars : hashable or iterable of hashable, optional + If provided, update variables of these name(s) without checking for + conflicts in this dataset. + compat : {"identical", "equals", "broadcast_equals", \ + "no_conflicts", "override", "minimal"}, default: "no_conflicts" + String indicating how to compare variables of the same name for + potential conflicts: + + - 'identical': all values, dimensions and attributes must be the + same. + - 'equals': all values and dimensions must be the same. + - 'broadcast_equals': all values must be equal when variables are + broadcast against each other to ensure common dimensions. + - 'no_conflicts': only values which are not null in both datasets + must be equal. The returned dataset then contains the combination + of all non-null values. + - 'override': skip comparing and pick variable from first dataset + - 'minimal': drop conflicting coordinates + + join : {"outer", "inner", "left", "right", "exact", "override"}, \ + default: "outer" + Method for joining ``self`` and ``other`` along shared dimensions: + + - 'outer': use the union of the indexes + - 'inner': use the intersection of the indexes + - 'left': use indexes from ``self`` + - 'right': use indexes from ``other`` + - 'exact': error instead of aligning non-equal indexes + - 'override': use indexes from ``self`` that are the same size + as those of ``other`` in that dimension + + fill_value : scalar or dict-like, optional + Value to use for newly missing values. If a dict-like, maps + variable names (including coordinates) to fill values. + combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \ + "override"} or callable, default: "override" + A callable or a string indicating how to combine attrs of the objects being + merged: + + - "drop": empty attrs on returned Dataset. + - "identical": all attrs must be the same on every object. + - "no_conflicts": attrs from all objects are combined, any that have + the same name must also have the same value. + - "drop_conflicts": attrs from all objects are combined, any that have + the same name but different values are dropped. + - "override": skip comparing and copy attrs from the first dataset to + the result. + + If a callable, it must expect a sequence of ``attrs`` dicts and a context object + as its only parameters. + + Returns + ------- + merged : Dataset + Merged dataset. + + Raises + ------ + MergeError + If any variables conflict (see ``compat``). + + See Also + -------- + Dataset.update + """ + from xarray.core.dataarray import DataArray + + other = other.to_dataset() if isinstance(other, DataArray) else other + merge_result = dataset_merge_method( + self, + other, + overwrite_vars=overwrite_vars, + compat=compat, + join=join, + fill_value=fill_value, + combine_attrs=combine_attrs, + ) + return self._replace(**merge_result._asdict()) + + def _assert_all_in_dataset( + self, names: Iterable[Hashable], virtual_okay: bool = False + ) -> None: + bad_names = set(names) - set(self._variables) + if virtual_okay: + bad_names -= self.virtual_variables + if bad_names: + ordered_bad_names = [name for name in names if name in bad_names] + raise ValueError( + f"These variables cannot be found in this dataset: {ordered_bad_names}" + ) + + def drop_vars( + self, + names: str | Iterable[Hashable] | Callable[[Self], str | Iterable[Hashable]], + *, + errors: ErrorOptions = "raise", + ) -> Self: + """Drop variables from this dataset. + + Parameters + ---------- + names : Hashable or iterable of Hashable or Callable + Name(s) of variables to drop. If a Callable, this object is passed as its + only argument and its result is used. + errors : {"raise", "ignore"}, default: "raise" + If 'raise', raises a ValueError error if any of the variable + passed are not in the dataset. If 'ignore', any given names that are in the + dataset are dropped and no error is raised. + + Examples + -------- + + >>> dataset = xr.Dataset( + ... { + ... "temperature": ( + ... ["time", "latitude", "longitude"], + ... [[[25.5, 26.3], [27.1, 28.0]]], + ... ), + ... "humidity": ( + ... ["time", "latitude", "longitude"], + ... [[[65.0, 63.8], [58.2, 59.6]]], + ... ), + ... "wind_speed": ( + ... ["time", "latitude", "longitude"], + ... [[[10.2, 8.5], [12.1, 9.8]]], + ... ), + ... }, + ... coords={ + ... "time": pd.date_range("2023-07-01", periods=1), + ... "latitude": [40.0, 40.2], + ... "longitude": [-75.0, -74.8], + ... }, + ... ) + >>> dataset + Size: 136B + Dimensions: (time: 1, latitude: 2, longitude: 2) + Coordinates: + * time (time) datetime64[ns] 8B 2023-07-01 + * latitude (latitude) float64 16B 40.0 40.2 + * longitude (longitude) float64 16B -75.0 -74.8 + Data variables: + temperature (time, latitude, longitude) float64 32B 25.5 26.3 27.1 28.0 + humidity (time, latitude, longitude) float64 32B 65.0 63.8 58.2 59.6 + wind_speed (time, latitude, longitude) float64 32B 10.2 8.5 12.1 9.8 + + Drop the 'humidity' variable + + >>> dataset.drop_vars(["humidity"]) + Size: 104B + Dimensions: (time: 1, latitude: 2, longitude: 2) + Coordinates: + * time (time) datetime64[ns] 8B 2023-07-01 + * latitude (latitude) float64 16B 40.0 40.2 + * longitude (longitude) float64 16B -75.0 -74.8 + Data variables: + temperature (time, latitude, longitude) float64 32B 25.5 26.3 27.1 28.0 + wind_speed (time, latitude, longitude) float64 32B 10.2 8.5 12.1 9.8 + + Drop the 'humidity', 'temperature' variables + + >>> dataset.drop_vars(["humidity", "temperature"]) + Size: 72B + Dimensions: (time: 1, latitude: 2, longitude: 2) + Coordinates: + * time (time) datetime64[ns] 8B 2023-07-01 + * latitude (latitude) float64 16B 40.0 40.2 + * longitude (longitude) float64 16B -75.0 -74.8 + Data variables: + wind_speed (time, latitude, longitude) float64 32B 10.2 8.5 12.1 9.8 + + Drop all indexes + + >>> dataset.drop_vars(lambda x: x.indexes) + Size: 96B + Dimensions: (time: 1, latitude: 2, longitude: 2) + Dimensions without coordinates: time, latitude, longitude + Data variables: + temperature (time, latitude, longitude) float64 32B 25.5 26.3 27.1 28.0 + humidity (time, latitude, longitude) float64 32B 65.0 63.8 58.2 59.6 + wind_speed (time, latitude, longitude) float64 32B 10.2 8.5 12.1 9.8 + + Attempt to drop non-existent variable with errors="ignore" + + >>> dataset.drop_vars(["pressure"], errors="ignore") + Size: 136B + Dimensions: (time: 1, latitude: 2, longitude: 2) + Coordinates: + * time (time) datetime64[ns] 8B 2023-07-01 + * latitude (latitude) float64 16B 40.0 40.2 + * longitude (longitude) float64 16B -75.0 -74.8 + Data variables: + temperature (time, latitude, longitude) float64 32B 25.5 26.3 27.1 28.0 + humidity (time, latitude, longitude) float64 32B 65.0 63.8 58.2 59.6 + wind_speed (time, latitude, longitude) float64 32B 10.2 8.5 12.1 9.8 + + Attempt to drop non-existent variable with errors="raise" + + >>> dataset.drop_vars(["pressure"], errors="raise") + Traceback (most recent call last): + ValueError: These variables cannot be found in this dataset: ['pressure'] + + Raises + ------ + ValueError + Raised if you attempt to drop a variable which is not present, and the kwarg ``errors='raise'``. + + Returns + ------- + dropped : Dataset + + See Also + -------- + DataArray.drop_vars + + """ + if callable(names): + names = names(self) + # the Iterable check is required for mypy + if is_scalar(names) or not isinstance(names, Iterable): + names_set = {names} + else: + names_set = set(names) + if errors == "raise": + self._assert_all_in_dataset(names_set) + + # GH6505 + other_names = set() + for var in names_set: + maybe_midx = self._indexes.get(var, None) + if isinstance(maybe_midx, PandasMultiIndex): + idx_coord_names = set(list(maybe_midx.index.names) + [maybe_midx.dim]) + idx_other_names = idx_coord_names - set(names_set) + other_names.update(idx_other_names) + if other_names: + names_set |= set(other_names) + warnings.warn( + f"Deleting a single level of a MultiIndex is deprecated. Previously, this deleted all levels of a MultiIndex. " + f"Please also drop the following variables: {other_names!r} to avoid an error in the future.", + DeprecationWarning, + stacklevel=2, + ) + + assert_no_index_corrupted(self.xindexes, names_set) + + variables = {k: v for k, v in self._variables.items() if k not in names_set} + coord_names = {k for k in self._coord_names if k in variables} + indexes = {k: v for k, v in self._indexes.items() if k not in names_set} + return self._replace_with_new_dims( + variables, coord_names=coord_names, indexes=indexes + ) + + def drop_indexes( + self, + coord_names: Hashable | Iterable[Hashable], + *, + errors: ErrorOptions = "raise", + ) -> Self: + """Drop the indexes assigned to the given coordinates. + + Parameters + ---------- + coord_names : hashable or iterable of hashable + Name(s) of the coordinate(s) for which to drop the index. + errors : {"raise", "ignore"}, default: "raise" + If 'raise', raises a ValueError error if any of the coordinates + passed have no index or are not in the dataset. + If 'ignore', no error is raised. + + Returns + ------- + dropped : Dataset + A new dataset with dropped indexes. + + """ + # the Iterable check is required for mypy + if is_scalar(coord_names) or not isinstance(coord_names, Iterable): + coord_names = {coord_names} + else: + coord_names = set(coord_names) + + if errors == "raise": + invalid_coords = coord_names - self._coord_names + if invalid_coords: + raise ValueError( + f"The coordinates {tuple(invalid_coords)} are not found in the " + f"dataset coordinates {tuple(self.coords.keys())}" + ) + + unindexed_coords = set(coord_names) - set(self._indexes) + if unindexed_coords: + raise ValueError( + f"those coordinates do not have an index: {unindexed_coords}" + ) + + assert_no_index_corrupted(self.xindexes, coord_names, action="remove index(es)") + + variables = {} + for name, var in self._variables.items(): + if name in coord_names: + variables[name] = var.to_base_variable() + else: + variables[name] = var + + indexes = {k: v for k, v in self._indexes.items() if k not in coord_names} + + return self._replace(variables=variables, indexes=indexes) + + def drop( + self, + labels=None, + dim=None, + *, + errors: ErrorOptions = "raise", + **labels_kwargs, + ) -> Self: + """Backward compatible method based on `drop_vars` and `drop_sel` + + Using either `drop_vars` or `drop_sel` is encouraged + + See Also + -------- + Dataset.drop_vars + Dataset.drop_sel + """ + if errors not in ["raise", "ignore"]: + raise ValueError('errors must be either "raise" or "ignore"') + + if is_dict_like(labels) and not isinstance(labels, dict): + emit_user_level_warning( + "dropping coordinates using `drop` is deprecated; use drop_vars.", + DeprecationWarning, + ) + return self.drop_vars(labels, errors=errors) + + if labels_kwargs or isinstance(labels, dict): + if dim is not None: + raise ValueError("cannot specify dim and dict-like arguments.") + labels = either_dict_or_kwargs(labels, labels_kwargs, "drop") + + if dim is None and (is_scalar(labels) or isinstance(labels, Iterable)): + emit_user_level_warning( + "dropping variables using `drop` is deprecated; use drop_vars.", + DeprecationWarning, + ) + # for mypy + if is_scalar(labels): + labels = [labels] + return self.drop_vars(labels, errors=errors) + if dim is not None: + warnings.warn( + "dropping labels using list-like labels is deprecated; using " + "dict-like arguments with `drop_sel`, e.g. `ds.drop_sel(dim=[labels]).", + DeprecationWarning, + stacklevel=2, + ) + return self.drop_sel({dim: labels}, errors=errors, **labels_kwargs) + + emit_user_level_warning( + "dropping labels using `drop` is deprecated; use `drop_sel` instead.", + DeprecationWarning, + ) + return self.drop_sel(labels, errors=errors) + + def drop_sel( + self, labels=None, *, errors: ErrorOptions = "raise", **labels_kwargs + ) -> Self: + """Drop index labels from this dataset. + + Parameters + ---------- + labels : mapping of hashable to Any + Index labels to drop + errors : {"raise", "ignore"}, default: "raise" + If 'raise', raises a ValueError error if + any of the index labels passed are not + in the dataset. If 'ignore', any given labels that are in the + dataset are dropped and no error is raised. + **labels_kwargs : {dim: label, ...}, optional + The keyword arguments form of ``dim`` and ``labels`` + + Returns + ------- + dropped : Dataset + + Examples + -------- + >>> data = np.arange(6).reshape(2, 3) + >>> labels = ["a", "b", "c"] + >>> ds = xr.Dataset({"A": (["x", "y"], data), "y": labels}) + >>> ds + Size: 60B + Dimensions: (x: 2, y: 3) + Coordinates: + * y (y) >> ds.drop_sel(y=["a", "c"]) + Size: 20B + Dimensions: (x: 2, y: 1) + Coordinates: + * y (y) >> ds.drop_sel(y="b") + Size: 40B + Dimensions: (x: 2, y: 2) + Coordinates: + * y (y) Self: + """Drop index positions from this Dataset. + + Parameters + ---------- + indexers : mapping of hashable to Any + Index locations to drop + **indexers_kwargs : {dim: position, ...}, optional + The keyword arguments form of ``dim`` and ``positions`` + + Returns + ------- + dropped : Dataset + + Raises + ------ + IndexError + + Examples + -------- + >>> data = np.arange(6).reshape(2, 3) + >>> labels = ["a", "b", "c"] + >>> ds = xr.Dataset({"A": (["x", "y"], data), "y": labels}) + >>> ds + Size: 60B + Dimensions: (x: 2, y: 3) + Coordinates: + * y (y) >> ds.drop_isel(y=[0, 2]) + Size: 20B + Dimensions: (x: 2, y: 1) + Coordinates: + * y (y) >> ds.drop_isel(y=1) + Size: 40B + Dimensions: (x: 2, y: 2) + Coordinates: + * y (y) Self: + """Drop dimensions and associated variables from this dataset. + + Parameters + ---------- + drop_dims : str or Iterable of Hashable + Dimension or dimensions to drop. + errors : {"raise", "ignore"}, default: "raise" + If 'raise', raises a ValueError error if any of the + dimensions passed are not in the dataset. If 'ignore', any given + dimensions that are in the dataset are dropped and no error is raised. + + Returns + ------- + obj : Dataset + The dataset without the given dimensions (or any variables + containing those dimensions). + """ + if errors not in ["raise", "ignore"]: + raise ValueError('errors must be either "raise" or "ignore"') + + if isinstance(drop_dims, str) or not isinstance(drop_dims, Iterable): + drop_dims = {drop_dims} + else: + drop_dims = set(drop_dims) + + if errors == "raise": + missing_dims = drop_dims - set(self.dims) + if missing_dims: + raise ValueError( + f"Dimensions {tuple(missing_dims)} not found in data dimensions {tuple(self.dims)}" + ) + + drop_vars = {k for k, v in self._variables.items() if set(v.dims) & drop_dims} + return self.drop_vars(drop_vars) + + @deprecate_dims + def transpose( + self, + *dim: Hashable, + missing_dims: ErrorOptionsWithWarn = "raise", + ) -> Self: + """Return a new Dataset object with all array dimensions transposed. + + Although the order of dimensions on each array will change, the dataset + dimensions themselves will remain in fixed (sorted) order. + + Parameters + ---------- + *dim : hashable, optional + By default, reverse the dimensions on each array. Otherwise, + reorder the dimensions to this order. + missing_dims : {"raise", "warn", "ignore"}, default: "raise" + What to do if dimensions that should be selected from are not present in the + Dataset: + - "raise": raise an exception + - "warn": raise a warning, and ignore the missing dimensions + - "ignore": ignore the missing dimensions + + Returns + ------- + transposed : Dataset + Each array in the dataset (including) coordinates will be + transposed to the given order. + + Notes + ----- + This operation returns a view of each array's data. It is + lazy for dask-backed DataArrays but not for numpy-backed DataArrays + -- the data will be fully loaded into memory. + + See Also + -------- + numpy.transpose + DataArray.transpose + """ + # Raise error if list is passed as dim + if (len(dim) > 0) and (isinstance(dim[0], list)): + list_fix = [f"{repr(x)}" if isinstance(x, str) else f"{x}" for x in dim[0]] + raise TypeError( + f'transpose requires dim to be passed as multiple arguments. Expected `{", ".join(list_fix)}`. Received `{dim[0]}` instead' + ) + + # Use infix_dims to check once for missing dimensions + if len(dim) != 0: + _ = list(infix_dims(dim, self.dims, missing_dims)) + + ds = self.copy() + for name, var in self._variables.items(): + var_dims = tuple(d for d in dim if d in (var.dims + (...,))) + ds._variables[name] = var.transpose(*var_dims) + return ds + + @_deprecate_positional_args("v2023.10.0") + def dropna( + self, + dim: Hashable, + *, + how: Literal["any", "all"] = "any", + thresh: int | None = None, + subset: Iterable[Hashable] | None = None, + ) -> Self: + """Returns a new dataset with dropped labels for missing values along + the provided dimension. + + Parameters + ---------- + dim : hashable + Dimension along which to drop missing values. Dropping along + multiple dimensions simultaneously is not yet supported. + how : {"any", "all"}, default: "any" + - any : if any NA values are present, drop that label + - all : if all values are NA, drop that label + + thresh : int or None, optional + If supplied, require this many non-NA values (summed over all the subset variables). + subset : iterable of hashable or None, optional + Which variables to check for missing values. By default, all + variables in the dataset are checked. + + Examples + -------- + >>> dataset = xr.Dataset( + ... { + ... "temperature": ( + ... ["time", "location"], + ... [[23.4, 24.1], [np.nan, 22.1], [21.8, 24.2], [20.5, 25.3]], + ... ) + ... }, + ... coords={"time": [1, 2, 3, 4], "location": ["A", "B"]}, + ... ) + >>> dataset + Size: 104B + Dimensions: (time: 4, location: 2) + Coordinates: + * time (time) int64 32B 1 2 3 4 + * location (location) >> dataset.dropna(dim="time") + Size: 80B + Dimensions: (time: 3, location: 2) + Coordinates: + * time (time) int64 24B 1 3 4 + * location (location) >> dataset.dropna(dim="time", how="any") + Size: 80B + Dimensions: (time: 3, location: 2) + Coordinates: + * time (time) int64 24B 1 3 4 + * location (location) >> dataset.dropna(dim="time", how="all") + Size: 104B + Dimensions: (time: 4, location: 2) + Coordinates: + * time (time) int64 32B 1 2 3 4 + * location (location) >> dataset.dropna(dim="time", thresh=2) + Size: 80B + Dimensions: (time: 3, location: 2) + Coordinates: + * time (time) int64 24B 1 3 4 + * location (location) = thresh + elif how == "any": + mask = count == size + elif how == "all": + mask = count > 0 + elif how is not None: + raise ValueError(f"invalid how option: {how}") + else: + raise TypeError("must specify how or thresh") + + return self.isel({dim: mask}) + + def fillna(self, value: Any) -> Self: + """Fill missing values in this object. + + This operation follows the normal broadcasting and alignment rules that + xarray uses for binary arithmetic, except the result is aligned to this + object (``join='left'``) instead of aligned to the intersection of + index coordinates (``join='inner'``). + + Parameters + ---------- + value : scalar, ndarray, DataArray, dict or Dataset + Used to fill all matching missing values in this dataset's data + variables. Scalars, ndarrays or DataArrays arguments are used to + fill all data with aligned coordinates (for DataArrays). + Dictionaries or datasets match data variables and then align + coordinates if necessary. + + Returns + ------- + Dataset + + Examples + -------- + >>> ds = xr.Dataset( + ... { + ... "A": ("x", [np.nan, 2, np.nan, 0]), + ... "B": ("x", [3, 4, np.nan, 1]), + ... "C": ("x", [np.nan, np.nan, np.nan, 5]), + ... "D": ("x", [np.nan, 3, np.nan, 4]), + ... }, + ... coords={"x": [0, 1, 2, 3]}, + ... ) + >>> ds + Size: 160B + Dimensions: (x: 4) + Coordinates: + * x (x) int64 32B 0 1 2 3 + Data variables: + A (x) float64 32B nan 2.0 nan 0.0 + B (x) float64 32B 3.0 4.0 nan 1.0 + C (x) float64 32B nan nan nan 5.0 + D (x) float64 32B nan 3.0 nan 4.0 + + Replace all `NaN` values with 0s. + + >>> ds.fillna(0) + Size: 160B + Dimensions: (x: 4) + Coordinates: + * x (x) int64 32B 0 1 2 3 + Data variables: + A (x) float64 32B 0.0 2.0 0.0 0.0 + B (x) float64 32B 3.0 4.0 0.0 1.0 + C (x) float64 32B 0.0 0.0 0.0 5.0 + D (x) float64 32B 0.0 3.0 0.0 4.0 + + Replace all `NaN` elements in column ‘A’, ‘B’, ‘C’, and ‘D’, with 0, 1, 2, and 3 respectively. + + >>> values = {"A": 0, "B": 1, "C": 2, "D": 3} + >>> ds.fillna(value=values) + Size: 160B + Dimensions: (x: 4) + Coordinates: + * x (x) int64 32B 0 1 2 3 + Data variables: + A (x) float64 32B 0.0 2.0 0.0 0.0 + B (x) float64 32B 3.0 4.0 1.0 1.0 + C (x) float64 32B 2.0 2.0 2.0 5.0 + D (x) float64 32B 3.0 3.0 3.0 4.0 + """ + if utils.is_dict_like(value): + value_keys = getattr(value, "data_vars", value).keys() + if not set(value_keys) <= set(self.data_vars.keys()): + raise ValueError( + "all variables in the argument to `fillna` " + "must be contained in the original dataset" + ) + out = ops.fillna(self, value) + return out + + def interpolate_na( + self, + dim: Hashable | None = None, + method: InterpOptions = "linear", + limit: int | None = None, + use_coordinate: bool | Hashable = True, + max_gap: ( + int | float | str | pd.Timedelta | np.timedelta64 | datetime.timedelta + ) = None, + **kwargs: Any, + ) -> Self: + """Fill in NaNs by interpolating according to different methods. + + Parameters + ---------- + dim : Hashable or None, optional + Specifies the dimension along which to interpolate. + method : {"linear", "nearest", "zero", "slinear", "quadratic", "cubic", "polynomial", \ + "barycentric", "krogh", "pchip", "spline", "akima"}, default: "linear" + String indicating which method to use for interpolation: + + - 'linear': linear interpolation. Additional keyword + arguments are passed to :py:func:`numpy.interp` + - 'nearest', 'zero', 'slinear', 'quadratic', 'cubic', 'polynomial': + are passed to :py:func:`scipy.interpolate.interp1d`. If + ``method='polynomial'``, the ``order`` keyword argument must also be + provided. + - 'barycentric', 'krogh', 'pchip', 'spline', 'akima': use their + respective :py:class:`scipy.interpolate` classes. + + use_coordinate : bool or Hashable, default: True + Specifies which index to use as the x values in the interpolation + formulated as `y = f(x)`. If False, values are treated as if + equally-spaced along ``dim``. If True, the IndexVariable `dim` is + used. If ``use_coordinate`` is a string, it specifies the name of a + coordinate variable to use as the index. + limit : int, default: None + Maximum number of consecutive NaNs to fill. Must be greater than 0 + or None for no limit. This filling is done regardless of the size of + the gap in the data. To only interpolate over gaps less than a given length, + see ``max_gap``. + max_gap : int, float, str, pandas.Timedelta, numpy.timedelta64, datetime.timedelta, default: None + Maximum size of gap, a continuous sequence of NaNs, that will be filled. + Use None for no limit. When interpolating along a datetime64 dimension + and ``use_coordinate=True``, ``max_gap`` can be one of the following: + + - a string that is valid input for pandas.to_timedelta + - a :py:class:`numpy.timedelta64` object + - a :py:class:`pandas.Timedelta` object + - a :py:class:`datetime.timedelta` object + + Otherwise, ``max_gap`` must be an int or a float. Use of ``max_gap`` with unlabeled + dimensions has not been implemented yet. Gap length is defined as the difference + between coordinate values at the first data point after a gap and the last value + before a gap. For gaps at the beginning (end), gap length is defined as the difference + between coordinate values at the first (last) valid data point and the first (last) NaN. + For example, consider:: + + + array([nan, nan, nan, 1., nan, nan, 4., nan, nan]) + Coordinates: + * x (x) int64 0 1 2 3 4 5 6 7 8 + + The gap lengths are 3-0 = 3; 6-3 = 3; and 8-6 = 2 respectively + **kwargs : dict, optional + parameters passed verbatim to the underlying interpolation function + + Returns + ------- + interpolated: Dataset + Filled in Dataset. + + Warning + -------- + When passing fill_value as a keyword argument with method="linear", it does not use + ``numpy.interp`` but it uses ``scipy.interpolate.interp1d``, which provides the fill_value parameter. + + See Also + -------- + numpy.interp + scipy.interpolate + + Examples + -------- + >>> ds = xr.Dataset( + ... { + ... "A": ("x", [np.nan, 2, 3, np.nan, 0]), + ... "B": ("x", [3, 4, np.nan, 1, 7]), + ... "C": ("x", [np.nan, np.nan, np.nan, 5, 0]), + ... "D": ("x", [np.nan, 3, np.nan, -1, 4]), + ... }, + ... coords={"x": [0, 1, 2, 3, 4]}, + ... ) + >>> ds + Size: 200B + Dimensions: (x: 5) + Coordinates: + * x (x) int64 40B 0 1 2 3 4 + Data variables: + A (x) float64 40B nan 2.0 3.0 nan 0.0 + B (x) float64 40B 3.0 4.0 nan 1.0 7.0 + C (x) float64 40B nan nan nan 5.0 0.0 + D (x) float64 40B nan 3.0 nan -1.0 4.0 + + >>> ds.interpolate_na(dim="x", method="linear") + Size: 200B + Dimensions: (x: 5) + Coordinates: + * x (x) int64 40B 0 1 2 3 4 + Data variables: + A (x) float64 40B nan 2.0 3.0 1.5 0.0 + B (x) float64 40B 3.0 4.0 2.5 1.0 7.0 + C (x) float64 40B nan nan nan 5.0 0.0 + D (x) float64 40B nan 3.0 1.0 -1.0 4.0 + + >>> ds.interpolate_na(dim="x", method="linear", fill_value="extrapolate") + Size: 200B + Dimensions: (x: 5) + Coordinates: + * x (x) int64 40B 0 1 2 3 4 + Data variables: + A (x) float64 40B 1.0 2.0 3.0 1.5 0.0 + B (x) float64 40B 3.0 4.0 2.5 1.0 7.0 + C (x) float64 40B 20.0 15.0 10.0 5.0 0.0 + D (x) float64 40B 5.0 3.0 1.0 -1.0 4.0 + """ + from xarray.core.missing import _apply_over_vars_with_dim, interp_na + + new = _apply_over_vars_with_dim( + interp_na, + self, + dim=dim, + method=method, + limit=limit, + use_coordinate=use_coordinate, + max_gap=max_gap, + **kwargs, + ) + return new + + def ffill(self, dim: Hashable, limit: int | None = None) -> Self: + """Fill NaN values by propagating values forward + + *Requires bottleneck.* + + Parameters + ---------- + dim : Hashable + Specifies the dimension along which to propagate values when filling. + limit : int or None, optional + The maximum number of consecutive NaN values to forward fill. In + other words, if there is a gap with more than this number of + consecutive NaNs, it will only be partially filled. Must be greater + than 0 or None for no limit. Must be None or greater than or equal + to axis length if filling along chunked axes (dimensions). + + Examples + -------- + >>> time = pd.date_range("2023-01-01", periods=10, freq="D") + >>> data = np.array( + ... [1, np.nan, np.nan, np.nan, 5, np.nan, np.nan, 8, np.nan, 10] + ... ) + >>> dataset = xr.Dataset({"data": (("time",), data)}, coords={"time": time}) + >>> dataset + Size: 160B + Dimensions: (time: 10) + Coordinates: + * time (time) datetime64[ns] 80B 2023-01-01 2023-01-02 ... 2023-01-10 + Data variables: + data (time) float64 80B 1.0 nan nan nan 5.0 nan nan 8.0 nan 10.0 + + # Perform forward fill (ffill) on the dataset + + >>> dataset.ffill(dim="time") + Size: 160B + Dimensions: (time: 10) + Coordinates: + * time (time) datetime64[ns] 80B 2023-01-01 2023-01-02 ... 2023-01-10 + Data variables: + data (time) float64 80B 1.0 1.0 1.0 1.0 5.0 5.0 5.0 8.0 8.0 10.0 + + # Limit the forward filling to a maximum of 2 consecutive NaN values + + >>> dataset.ffill(dim="time", limit=2) + Size: 160B + Dimensions: (time: 10) + Coordinates: + * time (time) datetime64[ns] 80B 2023-01-01 2023-01-02 ... 2023-01-10 + Data variables: + data (time) float64 80B 1.0 1.0 1.0 nan 5.0 5.0 5.0 8.0 8.0 10.0 + + Returns + ------- + Dataset + + See Also + -------- + Dataset.bfill + """ + from xarray.core.missing import _apply_over_vars_with_dim, ffill + + new = _apply_over_vars_with_dim(ffill, self, dim=dim, limit=limit) + return new + + def bfill(self, dim: Hashable, limit: int | None = None) -> Self: + """Fill NaN values by propagating values backward + + *Requires bottleneck.* + + Parameters + ---------- + dim : Hashable + Specifies the dimension along which to propagate values when + filling. + limit : int or None, optional + The maximum number of consecutive NaN values to backward fill. In + other words, if there is a gap with more than this number of + consecutive NaNs, it will only be partially filled. Must be greater + than 0 or None for no limit. Must be None or greater than or equal + to axis length if filling along chunked axes (dimensions). + + Examples + -------- + >>> time = pd.date_range("2023-01-01", periods=10, freq="D") + >>> data = np.array( + ... [1, np.nan, np.nan, np.nan, 5, np.nan, np.nan, 8, np.nan, 10] + ... ) + >>> dataset = xr.Dataset({"data": (("time",), data)}, coords={"time": time}) + >>> dataset + Size: 160B + Dimensions: (time: 10) + Coordinates: + * time (time) datetime64[ns] 80B 2023-01-01 2023-01-02 ... 2023-01-10 + Data variables: + data (time) float64 80B 1.0 nan nan nan 5.0 nan nan 8.0 nan 10.0 + + # filled dataset, fills NaN values by propagating values backward + + >>> dataset.bfill(dim="time") + Size: 160B + Dimensions: (time: 10) + Coordinates: + * time (time) datetime64[ns] 80B 2023-01-01 2023-01-02 ... 2023-01-10 + Data variables: + data (time) float64 80B 1.0 5.0 5.0 5.0 5.0 8.0 8.0 8.0 10.0 10.0 + + # Limit the backward filling to a maximum of 2 consecutive NaN values + + >>> dataset.bfill(dim="time", limit=2) + Size: 160B + Dimensions: (time: 10) + Coordinates: + * time (time) datetime64[ns] 80B 2023-01-01 2023-01-02 ... 2023-01-10 + Data variables: + data (time) float64 80B 1.0 nan 5.0 5.0 5.0 8.0 8.0 8.0 10.0 10.0 + + Returns + ------- + Dataset + + See Also + -------- + Dataset.ffill + """ + from xarray.core.missing import _apply_over_vars_with_dim, bfill + + new = _apply_over_vars_with_dim(bfill, self, dim=dim, limit=limit) + return new + + def combine_first(self, other: Self) -> Self: + """Combine two Datasets, default to data_vars of self. + + The new coordinates follow the normal broadcasting and alignment rules + of ``join='outer'``. Vacant cells in the expanded coordinates are + filled with np.nan. + + Parameters + ---------- + other : Dataset + Used to fill all matching missing values in this array. + + Returns + ------- + Dataset + """ + out = ops.fillna(self, other, join="outer", dataset_join="outer") + return out + + def reduce( + self, + func: Callable, + dim: Dims = None, + *, + keep_attrs: bool | None = None, + keepdims: bool = False, + numeric_only: bool = False, + **kwargs: Any, + ) -> Self: + """Reduce this dataset by applying `func` along some dimension(s). + + Parameters + ---------- + func : callable + Function which can be called in the form + `f(x, axis=axis, **kwargs)` to return the result of reducing an + np.ndarray over an integer valued axis. + dim : str, Iterable of Hashable or None, optional + Dimension(s) over which to apply `func`. By default `func` is + applied over all dimensions. + keep_attrs : bool or None, optional + If True, the dataset's attributes (`attrs`) will be copied from + the original object to the new one. If False (default), the new + object will be returned without attributes. + keepdims : bool, default: False + If True, the dimensions which are reduced are left in the result + as dimensions of size one. Coordinates that use these dimensions + are removed. + numeric_only : bool, default: False + If True, only apply ``func`` to variables with a numeric dtype. + **kwargs : Any + Additional keyword arguments passed on to ``func``. + + Returns + ------- + reduced : Dataset + Dataset with this object's DataArrays replaced with new DataArrays + of summarized data and the indicated dimension(s) removed. + + Examples + -------- + + >>> dataset = xr.Dataset( + ... { + ... "math_scores": ( + ... ["student", "test"], + ... [[90, 85, 92], [78, 80, 85], [95, 92, 98]], + ... ), + ... "english_scores": ( + ... ["student", "test"], + ... [[88, 90, 92], [75, 82, 79], [93, 96, 91]], + ... ), + ... }, + ... coords={ + ... "student": ["Alice", "Bob", "Charlie"], + ... "test": ["Test 1", "Test 2", "Test 3"], + ... }, + ... ) + + # Calculate the 75th percentile of math scores for each student using np.percentile + + >>> percentile_scores = dataset.reduce(np.percentile, q=75, dim="test") + >>> percentile_scores + Size: 132B + Dimensions: (student: 3) + Coordinates: + * student (student) Self: + """Apply a function to each data variable in this dataset + + Parameters + ---------- + func : callable + Function which can be called in the form `func(x, *args, **kwargs)` + to transform each DataArray `x` in this dataset into another + DataArray. + keep_attrs : bool or None, optional + If True, both the dataset's and variables' attributes (`attrs`) will be + copied from the original objects to the new ones. If False, the new dataset + and variables will be returned without copying the attributes. + args : iterable, optional + Positional arguments passed on to `func`. + **kwargs : Any + Keyword arguments passed on to `func`. + + Returns + ------- + applied : Dataset + Resulting dataset from applying ``func`` to each data variable. + + Examples + -------- + >>> da = xr.DataArray(np.random.randn(2, 3)) + >>> ds = xr.Dataset({"foo": da, "bar": ("x", [-1, 2])}) + >>> ds + Size: 64B + Dimensions: (dim_0: 2, dim_1: 3, x: 2) + Dimensions without coordinates: dim_0, dim_1, x + Data variables: + foo (dim_0, dim_1) float64 48B 1.764 0.4002 0.9787 2.241 1.868 -0.9773 + bar (x) int64 16B -1 2 + >>> ds.map(np.fabs) + Size: 64B + Dimensions: (dim_0: 2, dim_1: 3, x: 2) + Dimensions without coordinates: dim_0, dim_1, x + Data variables: + foo (dim_0, dim_1) float64 48B 1.764 0.4002 0.9787 2.241 1.868 0.9773 + bar (x) float64 16B 1.0 2.0 + """ + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=False) + variables = { + k: maybe_wrap_array(v, func(v, *args, **kwargs)) + for k, v in self.data_vars.items() + } + if keep_attrs: + for k, v in variables.items(): + v._copy_attrs_from(self.data_vars[k]) + attrs = self.attrs if keep_attrs else None + return type(self)(variables, attrs=attrs) + + def apply( + self, + func: Callable, + keep_attrs: bool | None = None, + args: Iterable[Any] = (), + **kwargs: Any, + ) -> Self: + """ + Backward compatible implementation of ``map`` + + See Also + -------- + Dataset.map + """ + warnings.warn( + "Dataset.apply may be deprecated in the future. Using Dataset.map is encouraged", + PendingDeprecationWarning, + stacklevel=2, + ) + return self.map(func, keep_attrs, args, **kwargs) + + def assign( + self, + variables: Mapping[Any, Any] | None = None, + **variables_kwargs: Any, + ) -> Self: + """Assign new data variables to a Dataset, returning a new object + with all the original variables in addition to the new ones. + + Parameters + ---------- + variables : mapping of hashable to Any + Mapping from variables names to the new values. If the new values + are callable, they are computed on the Dataset and assigned to new + data variables. If the values are not callable, (e.g. a DataArray, + scalar, or array), they are simply assigned. + **variables_kwargs + The keyword arguments form of ``variables``. + One of variables or variables_kwargs must be provided. + + Returns + ------- + ds : Dataset + A new Dataset with the new variables in addition to all the + existing variables. + + Notes + ----- + Since ``kwargs`` is a dictionary, the order of your arguments may not + be preserved, and so the order of the new variables is not well + defined. Assigning multiple variables within the same ``assign`` is + possible, but you cannot reference other variables created within the + same ``assign`` call. + + The new assigned variables that replace existing coordinates in the + original dataset are still listed as coordinates in the returned + Dataset. + + See Also + -------- + pandas.DataFrame.assign + + Examples + -------- + >>> x = xr.Dataset( + ... { + ... "temperature_c": ( + ... ("lat", "lon"), + ... 20 * np.random.rand(4).reshape(2, 2), + ... ), + ... "precipitation": (("lat", "lon"), np.random.rand(4).reshape(2, 2)), + ... }, + ... coords={"lat": [10, 20], "lon": [150, 160]}, + ... ) + >>> x + Size: 96B + Dimensions: (lat: 2, lon: 2) + Coordinates: + * lat (lat) int64 16B 10 20 + * lon (lon) int64 16B 150 160 + Data variables: + temperature_c (lat, lon) float64 32B 10.98 14.3 12.06 10.9 + precipitation (lat, lon) float64 32B 0.4237 0.6459 0.4376 0.8918 + + Where the value is a callable, evaluated on dataset: + + >>> x.assign(temperature_f=lambda x: x.temperature_c * 9 / 5 + 32) + Size: 128B + Dimensions: (lat: 2, lon: 2) + Coordinates: + * lat (lat) int64 16B 10 20 + * lon (lon) int64 16B 150 160 + Data variables: + temperature_c (lat, lon) float64 32B 10.98 14.3 12.06 10.9 + precipitation (lat, lon) float64 32B 0.4237 0.6459 0.4376 0.8918 + temperature_f (lat, lon) float64 32B 51.76 57.75 53.7 51.62 + + Alternatively, the same behavior can be achieved by directly referencing an existing dataarray: + + >>> x.assign(temperature_f=x["temperature_c"] * 9 / 5 + 32) + Size: 128B + Dimensions: (lat: 2, lon: 2) + Coordinates: + * lat (lat) int64 16B 10 20 + * lon (lon) int64 16B 150 160 + Data variables: + temperature_c (lat, lon) float64 32B 10.98 14.3 12.06 10.9 + precipitation (lat, lon) float64 32B 0.4237 0.6459 0.4376 0.8918 + temperature_f (lat, lon) float64 32B 51.76 57.75 53.7 51.62 + + """ + variables = either_dict_or_kwargs(variables, variables_kwargs, "assign") + data = self.copy() + + # do all calculations first... + results: CoercibleMapping = data._calc_assign_results(variables) + + # split data variables to add/replace vs. coordinates to replace + results_data_vars: dict[Hashable, CoercibleValue] = {} + results_coords: dict[Hashable, CoercibleValue] = {} + for k, v in results.items(): + if k in data._coord_names: + results_coords[k] = v + else: + results_data_vars[k] = v + + # ... and then assign + data.coords.update(results_coords) + data.update(results_data_vars) + + return data + + def to_dataarray( + self, dim: Hashable = "variable", name: Hashable | None = None + ) -> DataArray: + """Convert this dataset into an xarray.DataArray + + The data variables of this dataset will be broadcast against each other + and stacked along the first axis of the new array. All coordinates of + this dataset will remain coordinates. + + Parameters + ---------- + dim : Hashable, default: "variable" + Name of the new dimension. + name : Hashable or None, optional + Name of the new data array. + + Returns + ------- + array : xarray.DataArray + """ + from xarray.core.dataarray import DataArray + + data_vars = [self.variables[k] for k in self.data_vars] + broadcast_vars = broadcast_variables(*data_vars) + data = duck_array_ops.stack([b.data for b in broadcast_vars], axis=0) + + dims = (dim,) + broadcast_vars[0].dims + variable = Variable(dims, data, self.attrs, fastpath=True) + + coords = {k: v.variable for k, v in self.coords.items()} + indexes = filter_indexes_from_coords(self._indexes, set(coords)) + new_dim_index = PandasIndex(list(self.data_vars), dim) + indexes[dim] = new_dim_index + coords.update(new_dim_index.create_variables()) + + return DataArray._construct_direct(variable, coords, name, indexes) + + def to_array( + self, dim: Hashable = "variable", name: Hashable | None = None + ) -> DataArray: + """Deprecated version of to_dataarray""" + return self.to_dataarray(dim=dim, name=name) + + def _normalize_dim_order( + self, dim_order: Sequence[Hashable] | None = None + ) -> dict[Hashable, int]: + """ + Check the validity of the provided dimensions if any and return the mapping + between dimension name and their size. + + Parameters + ---------- + dim_order: Sequence of Hashable or None, optional + Dimension order to validate (default to the alphabetical order if None). + + Returns + ------- + result : dict[Hashable, int] + Validated dimensions mapping. + + """ + if dim_order is None: + dim_order = list(self.dims) + elif set(dim_order) != set(self.dims): + raise ValueError( + f"dim_order {dim_order} does not match the set of dimensions of this " + f"Dataset: {list(self.dims)}" + ) + + ordered_dims = {k: self.sizes[k] for k in dim_order} + + return ordered_dims + + def to_pandas(self) -> pd.Series | pd.DataFrame: + """Convert this dataset into a pandas object without changing the number of dimensions. + + The type of the returned object depends on the number of Dataset + dimensions: + + * 0D -> `pandas.Series` + * 1D -> `pandas.DataFrame` + + Only works for Datasets with 1 or fewer dimensions. + """ + if len(self.dims) == 0: + return pd.Series({k: v.item() for k, v in self.items()}) + if len(self.dims) == 1: + return self.to_dataframe() + raise ValueError( + f"cannot convert Datasets with {len(self.dims)} dimensions into " + "pandas objects without changing the number of dimensions. " + "Please use Dataset.to_dataframe() instead." + ) + + def _to_dataframe(self, ordered_dims: Mapping[Any, int]): + columns_in_order = [k for k in self.variables if k not in self.dims] + non_extension_array_columns = [ + k + for k in columns_in_order + if not is_extension_array_dtype(self.variables[k].data) + ] + extension_array_columns = [ + k + for k in columns_in_order + if is_extension_array_dtype(self.variables[k].data) + ] + data = [ + self._variables[k].set_dims(ordered_dims).values.reshape(-1) + for k in non_extension_array_columns + ] + index = self.coords.to_index([*ordered_dims]) + broadcasted_df = pd.DataFrame( + dict(zip(non_extension_array_columns, data)), index=index + ) + for extension_array_column in extension_array_columns: + extension_array = self.variables[extension_array_column].data.array + index = self[self.variables[extension_array_column].dims[0]].data + extension_array_df = pd.DataFrame( + {extension_array_column: extension_array}, + index=self[self.variables[extension_array_column].dims[0]].data, + ) + extension_array_df.index.name = self.variables[extension_array_column].dims[ + 0 + ] + broadcasted_df = broadcasted_df.join(extension_array_df) + return broadcasted_df[columns_in_order] + + def to_dataframe(self, dim_order: Sequence[Hashable] | None = None) -> pd.DataFrame: + """Convert this dataset into a pandas.DataFrame. + + Non-index variables in this dataset form the columns of the + DataFrame. The DataFrame is indexed by the Cartesian product of + this dataset's indices. + + Parameters + ---------- + dim_order: Sequence of Hashable or None, optional + Hierarchical dimension order for the resulting dataframe. All + arrays are transposed to this order and then written out as flat + vectors in contiguous order, so the last dimension in this list + will be contiguous in the resulting DataFrame. This has a major + influence on which operations are efficient on the resulting + dataframe. + + If provided, must include all dimensions of this dataset. By + default, dimensions are sorted alphabetically. + + Returns + ------- + result : DataFrame + Dataset as a pandas DataFrame. + + """ + + ordered_dims = self._normalize_dim_order(dim_order=dim_order) + + return self._to_dataframe(ordered_dims=ordered_dims) + + def _set_sparse_data_from_dataframe( + self, idx: pd.Index, arrays: list[tuple[Hashable, np.ndarray]], dims: tuple + ) -> None: + from sparse import COO + + if isinstance(idx, pd.MultiIndex): + coords = np.stack([np.asarray(code) for code in idx.codes], axis=0) + is_sorted = idx.is_monotonic_increasing + shape = tuple(lev.size for lev in idx.levels) + else: + coords = np.arange(idx.size).reshape(1, -1) + is_sorted = True + shape = (idx.size,) + + for name, values in arrays: + # In virtually all real use cases, the sparse array will now have + # missing values and needs a fill_value. For consistency, don't + # special case the rare exceptions (e.g., dtype=int without a + # MultiIndex). + dtype, fill_value = xrdtypes.maybe_promote(values.dtype) + values = np.asarray(values, dtype=dtype) + + data = COO( + coords, + values, + shape, + has_duplicates=False, + sorted=is_sorted, + fill_value=fill_value, + ) + self[name] = (dims, data) + + def _set_numpy_data_from_dataframe( + self, idx: pd.Index, arrays: list[tuple[Hashable, np.ndarray]], dims: tuple + ) -> None: + if not isinstance(idx, pd.MultiIndex): + for name, values in arrays: + self[name] = (dims, values) + return + + # NB: similar, more general logic, now exists in + # variable.unstack_once; we could consider combining them at some + # point. + + shape = tuple(lev.size for lev in idx.levels) + indexer = tuple(idx.codes) + + # We already verified that the MultiIndex has all unique values, so + # there are missing values if and only if the size of output arrays is + # larger that the index. + missing_values = math.prod(shape) > idx.shape[0] + + for name, values in arrays: + # NumPy indexing is much faster than using DataFrame.reindex() to + # fill in missing values: + # https://stackoverflow.com/a/35049899/809705 + if missing_values: + dtype, fill_value = xrdtypes.maybe_promote(values.dtype) + data = np.full(shape, fill_value, dtype) + else: + # If there are no missing values, keep the existing dtype + # instead of promoting to support NA, e.g., keep integer + # columns as integers. + # TODO: consider removing this special case, which doesn't + # exist for sparse=True. + data = np.zeros(shape, values.dtype) + data[indexer] = values + self[name] = (dims, data) + + @classmethod + def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self: + """Convert a pandas.DataFrame into an xarray.Dataset + + Each column will be converted into an independent variable in the + Dataset. If the dataframe's index is a MultiIndex, it will be expanded + into a tensor product of one-dimensional indices (filling in missing + values with NaN). If you rather preserve the MultiIndex use + `xr.Dataset(df)`. This method will produce a Dataset very similar to + that on which the 'to_dataframe' method was called, except with + possibly redundant dimensions (since all dataset variables will have + the same dimensionality). + + Parameters + ---------- + dataframe : DataFrame + DataFrame from which to copy data and indices. + sparse : bool, default: False + If true, create a sparse arrays instead of dense numpy arrays. This + can potentially save a large amount of memory if the DataFrame has + a MultiIndex. Requires the sparse package (sparse.pydata.org). + + Returns + ------- + New Dataset. + + See Also + -------- + xarray.DataArray.from_series + pandas.DataFrame.to_xarray + """ + # TODO: Add an option to remove dimensions along which the variables + # are constant, to enable consistent serialization to/from a dataframe, + # even if some variables have different dimensionality. + + if not dataframe.columns.is_unique: + raise ValueError("cannot convert DataFrame with non-unique columns") + + idx = remove_unused_levels_categories(dataframe.index) + + if isinstance(idx, pd.MultiIndex) and not idx.is_unique: + raise ValueError( + "cannot convert a DataFrame with a non-unique MultiIndex into xarray" + ) + + arrays = [] + extension_arrays = [] + for k, v in dataframe.items(): + if not is_extension_array_dtype(v) or isinstance( + v.array, (pd.arrays.DatetimeArray, pd.arrays.TimedeltaArray) + ): + arrays.append((k, np.asarray(v))) + else: + extension_arrays.append((k, v)) + + indexes: dict[Hashable, Index] = {} + index_vars: dict[Hashable, Variable] = {} + + if isinstance(idx, pd.MultiIndex): + dims = tuple( + name if name is not None else "level_%i" % n + for n, name in enumerate(idx.names) + ) + for dim, lev in zip(dims, idx.levels): + xr_idx = PandasIndex(lev, dim) + indexes[dim] = xr_idx + index_vars.update(xr_idx.create_variables()) + arrays += [(k, np.asarray(v)) for k, v in extension_arrays] + extension_arrays = [] + else: + index_name = idx.name if idx.name is not None else "index" + dims = (index_name,) + xr_idx = PandasIndex(idx, index_name) + indexes[index_name] = xr_idx + index_vars.update(xr_idx.create_variables()) + + obj = cls._construct_direct(index_vars, set(index_vars), indexes=indexes) + + if sparse: + obj._set_sparse_data_from_dataframe(idx, arrays, dims) + else: + obj._set_numpy_data_from_dataframe(idx, arrays, dims) + for name, extension_array in extension_arrays: + obj[name] = (dims, extension_array) + return obj[dataframe.columns] if len(dataframe.columns) else obj + + def to_dask_dataframe( + self, dim_order: Sequence[Hashable] | None = None, set_index: bool = False + ) -> DaskDataFrame: + """ + Convert this dataset into a dask.dataframe.DataFrame. + + The dimensions, coordinates and data variables in this dataset form + the columns of the DataFrame. + + Parameters + ---------- + dim_order : list, optional + Hierarchical dimension order for the resulting dataframe. All + arrays are transposed to this order and then written out as flat + vectors in contiguous order, so the last dimension in this list + will be contiguous in the resulting DataFrame. This has a major + influence on which operations are efficient on the resulting dask + dataframe. + + If provided, must include all dimensions of this dataset. By + default, dimensions are sorted alphabetically. + set_index : bool, default: False + If set_index=True, the dask DataFrame is indexed by this dataset's + coordinate. Since dask DataFrames do not support multi-indexes, + set_index only works if the dataset only contains one dimension. + + Returns + ------- + dask.dataframe.DataFrame + """ + + import dask.array as da + import dask.dataframe as dd + + ordered_dims = self._normalize_dim_order(dim_order=dim_order) + + columns = list(ordered_dims) + columns.extend(k for k in self.coords if k not in self.dims) + columns.extend(self.data_vars) + + ds_chunks = self.chunks + + series_list = [] + df_meta = pd.DataFrame() + for name in columns: + try: + var = self.variables[name] + except KeyError: + # dimension without a matching coordinate + size = self.sizes[name] + data = da.arange(size, chunks=size, dtype=np.int64) + var = Variable((name,), data) + + # IndexVariable objects have a dummy .chunk() method + if isinstance(var, IndexVariable): + var = var.to_base_variable() + + # Make sure var is a dask array, otherwise the array can become too large + # when it is broadcasted to several dimensions: + if not is_duck_dask_array(var._data): + var = var.chunk() + + # Broadcast then flatten the array: + var_new_dims = var.set_dims(ordered_dims).chunk(ds_chunks) + dask_array = var_new_dims._data.reshape(-1) + + series = dd.from_dask_array(dask_array, columns=name, meta=df_meta) + series_list.append(series) + + df = dd.concat(series_list, axis=1) + + if set_index: + dim_order = [*ordered_dims] + + if len(dim_order) == 1: + (dim,) = dim_order + df = df.set_index(dim) + else: + # triggers an error about multi-indexes, even if only one + # dimension is passed + df = df.set_index(dim_order) + + return df + + def to_dict( + self, data: bool | Literal["list", "array"] = "list", encoding: bool = False + ) -> dict[str, Any]: + """ + Convert this dataset to a dictionary following xarray naming + conventions. + + Converts all variables and attributes to native Python objects + Useful for converting to json. To avoid datetime incompatibility + use decode_times=False kwarg in xarrray.open_dataset. + + Parameters + ---------- + data : bool or {"list", "array"}, default: "list" + Whether to include the actual data in the dictionary. When set to + False, returns just the schema. If set to "array", returns data as + underlying array type. If set to "list" (or True for backwards + compatibility), returns data in lists of Python data types. Note + that for obtaining the "list" output efficiently, use + `ds.compute().to_dict(data="list")`. + + encoding : bool, default: False + Whether to include the Dataset's encoding in the dictionary. + + Returns + ------- + d : dict + Dict with keys: "coords", "attrs", "dims", "data_vars" and optionally + "encoding". + + See Also + -------- + Dataset.from_dict + DataArray.to_dict + """ + d: dict = { + "coords": {}, + "attrs": decode_numpy_dict_values(self.attrs), + "dims": dict(self.sizes), + "data_vars": {}, + } + for k in self.coords: + d["coords"].update( + {k: self[k].variable.to_dict(data=data, encoding=encoding)} + ) + for k in self.data_vars: + d["data_vars"].update( + {k: self[k].variable.to_dict(data=data, encoding=encoding)} + ) + if encoding: + d["encoding"] = dict(self.encoding) + return d + + @classmethod + def from_dict(cls, d: Mapping[Any, Any]) -> Self: + """Convert a dictionary into an xarray.Dataset. + + Parameters + ---------- + d : dict-like + Mapping with a minimum structure of + ``{"var_0": {"dims": [..], "data": [..]}, \ + ...}`` + + Returns + ------- + obj : Dataset + + See also + -------- + Dataset.to_dict + DataArray.from_dict + + Examples + -------- + >>> d = { + ... "t": {"dims": ("t"), "data": [0, 1, 2]}, + ... "a": {"dims": ("t"), "data": ["a", "b", "c"]}, + ... "b": {"dims": ("t"), "data": [10, 20, 30]}, + ... } + >>> ds = xr.Dataset.from_dict(d) + >>> ds + Size: 60B + Dimensions: (t: 3) + Coordinates: + * t (t) int64 24B 0 1 2 + Data variables: + a (t) >> d = { + ... "coords": { + ... "t": {"dims": "t", "data": [0, 1, 2], "attrs": {"units": "s"}} + ... }, + ... "attrs": {"title": "air temperature"}, + ... "dims": "t", + ... "data_vars": { + ... "a": {"dims": "t", "data": [10, 20, 30]}, + ... "b": {"dims": "t", "data": ["a", "b", "c"]}, + ... }, + ... } + >>> ds = xr.Dataset.from_dict(d) + >>> ds + Size: 60B + Dimensions: (t: 3) + Coordinates: + * t (t) int64 24B 0 1 2 + Data variables: + a (t) int64 24B 10 20 30 + b (t) Self: + variables = {} + keep_attrs = kwargs.pop("keep_attrs", None) + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=True) + for k, v in self._variables.items(): + if k in self._coord_names: + variables[k] = v + else: + variables[k] = f(v, *args, **kwargs) + if keep_attrs: + variables[k]._attrs = v._attrs + attrs = self._attrs if keep_attrs else None + return self._replace_with_new_dims(variables, attrs=attrs) + + def _binary_op(self, other, f, reflexive=False, join=None) -> Dataset: + from xarray.core.dataarray import DataArray + from xarray.core.groupby import GroupBy + + if isinstance(other, GroupBy): + return NotImplemented + align_type = OPTIONS["arithmetic_join"] if join is None else join + if isinstance(other, (DataArray, Dataset)): + self, other = align(self, other, join=align_type, copy=False) + g = f if not reflexive else lambda x, y: f(y, x) + ds = self._calculate_binary_op(g, other, join=align_type) + keep_attrs = _get_keep_attrs(default=False) + if keep_attrs: + ds.attrs = self.attrs + return ds + + def _inplace_binary_op(self, other, f) -> Self: + from xarray.core.dataarray import DataArray + from xarray.core.groupby import GroupBy + + if isinstance(other, GroupBy): + raise TypeError( + "in-place operations between a Dataset and " + "a grouped object are not permitted" + ) + # we don't actually modify arrays in-place with in-place Dataset + # arithmetic -- this lets us automatically align things + if isinstance(other, (DataArray, Dataset)): + other = other.reindex_like(self, copy=False) + g = ops.inplace_to_noninplace_op(f) + ds = self._calculate_binary_op(g, other, inplace=True) + self._replace_with_new_dims( + ds._variables, + ds._coord_names, + attrs=ds._attrs, + indexes=ds._indexes, + inplace=True, + ) + return self + + def _calculate_binary_op( + self, f, other, join="inner", inplace: bool = False + ) -> Dataset: + def apply_over_both(lhs_data_vars, rhs_data_vars, lhs_vars, rhs_vars): + if inplace and set(lhs_data_vars) != set(rhs_data_vars): + raise ValueError( + "datasets must have the same data variables " + f"for in-place arithmetic operations: {list(lhs_data_vars)}, {list(rhs_data_vars)}" + ) + + dest_vars = {} + + for k in lhs_data_vars: + if k in rhs_data_vars: + dest_vars[k] = f(lhs_vars[k], rhs_vars[k]) + elif join in ["left", "outer"]: + dest_vars[k] = f(lhs_vars[k], np.nan) + for k in rhs_data_vars: + if k not in dest_vars and join in ["right", "outer"]: + dest_vars[k] = f(rhs_vars[k], np.nan) + return dest_vars + + if utils.is_dict_like(other) and not isinstance(other, Dataset): + # can't use our shortcut of doing the binary operation with + # Variable objects, so apply over our data vars instead. + new_data_vars = apply_over_both( + self.data_vars, other, self.data_vars, other + ) + return type(self)(new_data_vars) + + other_coords: Coordinates | None = getattr(other, "coords", None) + ds = self.coords.merge(other_coords) + + if isinstance(other, Dataset): + new_vars = apply_over_both( + self.data_vars, other.data_vars, self.variables, other.variables + ) + else: + other_variable = getattr(other, "variable", other) + new_vars = {k: f(self.variables[k], other_variable) for k in self.data_vars} + ds._variables.update(new_vars) + ds._dims = calculate_dimensions(ds._variables) + return ds + + def _copy_attrs_from(self, other): + self.attrs = other.attrs + for v in other.variables: + if v in self.variables: + self.variables[v].attrs = other.variables[v].attrs + + @_deprecate_positional_args("v2023.10.0") + def diff( + self, + dim: Hashable, + n: int = 1, + *, + label: Literal["upper", "lower"] = "upper", + ) -> Self: + """Calculate the n-th order discrete difference along given axis. + + Parameters + ---------- + dim : Hashable + Dimension over which to calculate the finite difference. + n : int, default: 1 + The number of times values are differenced. + label : {"upper", "lower"}, default: "upper" + The new coordinate in dimension ``dim`` will have the + values of either the minuend's or subtrahend's coordinate + for values 'upper' and 'lower', respectively. + + Returns + ------- + difference : Dataset + The n-th order finite difference of this object. + + Notes + ----- + `n` matches numpy's behavior and is different from pandas' first argument named + `periods`. + + Examples + -------- + >>> ds = xr.Dataset({"foo": ("x", [5, 5, 6, 6])}) + >>> ds.diff("x") + Size: 24B + Dimensions: (x: 3) + Dimensions without coordinates: x + Data variables: + foo (x) int64 24B 0 1 0 + >>> ds.diff("x", 2) + Size: 16B + Dimensions: (x: 2) + Dimensions without coordinates: x + Data variables: + foo (x) int64 16B 1 -1 + + See Also + -------- + Dataset.differentiate + """ + if n == 0: + return self + if n < 0: + raise ValueError(f"order `n` must be non-negative but got {n}") + + # prepare slices + slice_start = {dim: slice(None, -1)} + slice_end = {dim: slice(1, None)} + + # prepare new coordinate + if label == "upper": + slice_new = slice_end + elif label == "lower": + slice_new = slice_start + else: + raise ValueError("The 'label' argument has to be either 'upper' or 'lower'") + + indexes, index_vars = isel_indexes(self.xindexes, slice_new) + variables = {} + + for name, var in self.variables.items(): + if name in index_vars: + variables[name] = index_vars[name] + elif dim in var.dims: + if name in self.data_vars: + variables[name] = var.isel(slice_end) - var.isel(slice_start) + else: + variables[name] = var.isel(slice_new) + else: + variables[name] = var + + difference = self._replace_with_new_dims(variables, indexes=indexes) + + if n > 1: + return difference.diff(dim, n - 1) + else: + return difference + + def shift( + self, + shifts: Mapping[Any, int] | None = None, + fill_value: Any = xrdtypes.NA, + **shifts_kwargs: int, + ) -> Self: + """Shift this dataset by an offset along one or more dimensions. + + Only data variables are moved; coordinates stay in place. This is + consistent with the behavior of ``shift`` in pandas. + + Values shifted from beyond array bounds will appear at one end of + each dimension, which are filled according to `fill_value`. For periodic + offsets instead see `roll`. + + Parameters + ---------- + shifts : mapping of hashable to int + Integer offset to shift along each of the given dimensions. + Positive offsets shift to the right; negative offsets shift to the + left. + fill_value : scalar or dict-like, optional + Value to use for newly missing values. If a dict-like, maps + variable names (including coordinates) to fill values. + **shifts_kwargs + The keyword arguments form of ``shifts``. + One of shifts or shifts_kwargs must be provided. + + Returns + ------- + shifted : Dataset + Dataset with the same coordinates and attributes but shifted data + variables. + + See Also + -------- + roll + + Examples + -------- + >>> ds = xr.Dataset({"foo": ("x", list("abcde"))}) + >>> ds.shift(x=2) + Size: 40B + Dimensions: (x: 5) + Dimensions without coordinates: x + Data variables: + foo (x) object 40B nan nan 'a' 'b' 'c' + """ + shifts = either_dict_or_kwargs(shifts, shifts_kwargs, "shift") + invalid = tuple(k for k in shifts if k not in self.dims) + if invalid: + raise ValueError( + f"Dimensions {invalid} not found in data dimensions {tuple(self.dims)}" + ) + + variables = {} + for name, var in self.variables.items(): + if name in self.data_vars: + fill_value_ = ( + fill_value.get(name, xrdtypes.NA) + if isinstance(fill_value, dict) + else fill_value + ) + + var_shifts = {k: v for k, v in shifts.items() if k in var.dims} + variables[name] = var.shift(fill_value=fill_value_, shifts=var_shifts) + else: + variables[name] = var + + return self._replace(variables) + + def roll( + self, + shifts: Mapping[Any, int] | None = None, + roll_coords: bool = False, + **shifts_kwargs: int, + ) -> Self: + """Roll this dataset by an offset along one or more dimensions. + + Unlike shift, roll treats the given dimensions as periodic, so will not + create any missing values to be filled. + + Also unlike shift, roll may rotate all variables, including coordinates + if specified. The direction of rotation is consistent with + :py:func:`numpy.roll`. + + Parameters + ---------- + shifts : mapping of hashable to int, optional + A dict with keys matching dimensions and values given + by integers to rotate each of the given dimensions. Positive + offsets roll to the right; negative offsets roll to the left. + roll_coords : bool, default: False + Indicates whether to roll the coordinates by the offset too. + **shifts_kwargs : {dim: offset, ...}, optional + The keyword arguments form of ``shifts``. + One of shifts or shifts_kwargs must be provided. + + Returns + ------- + rolled : Dataset + Dataset with the same attributes but rolled data and coordinates. + + See Also + -------- + shift + + Examples + -------- + >>> ds = xr.Dataset({"foo": ("x", list("abcde"))}, coords={"x": np.arange(5)}) + >>> ds.roll(x=2) + Size: 60B + Dimensions: (x: 5) + Coordinates: + * x (x) int64 40B 0 1 2 3 4 + Data variables: + foo (x) >> ds.roll(x=2, roll_coords=True) + Size: 60B + Dimensions: (x: 5) + Coordinates: + * x (x) int64 40B 3 4 0 1 2 + Data variables: + foo (x) Self: + """ + Sort object by labels or values (along an axis). + + Sorts the dataset, either along specified dimensions, + or according to values of 1-D dataarrays that share dimension + with calling object. + + If the input variables are dataarrays, then the dataarrays are aligned + (via left-join) to the calling object prior to sorting by cell values. + NaNs are sorted to the end, following Numpy convention. + + If multiple sorts along the same dimension is + given, numpy's lexsort is performed along that dimension: + https://numpy.org/doc/stable/reference/generated/numpy.lexsort.html + and the FIRST key in the sequence is used as the primary sort key, + followed by the 2nd key, etc. + + Parameters + ---------- + variables : Hashable, DataArray, sequence of Hashable or DataArray, or Callable + 1D DataArray objects or name(s) of 1D variable(s) in coords whose values are + used to sort this array. If a callable, the callable is passed this object, + and the result is used as the value for cond. + ascending : bool, default: True + Whether to sort by ascending or descending order. + + Returns + ------- + sorted : Dataset + A new dataset where all the specified dims are sorted by dim + labels. + + See Also + -------- + DataArray.sortby + numpy.sort + pandas.sort_values + pandas.sort_index + + Examples + -------- + >>> ds = xr.Dataset( + ... { + ... "A": (("x", "y"), [[1, 2], [3, 4]]), + ... "B": (("x", "y"), [[5, 6], [7, 8]]), + ... }, + ... coords={"x": ["b", "a"], "y": [1, 0]}, + ... ) + >>> ds.sortby("x") + Size: 88B + Dimensions: (x: 2, y: 2) + Coordinates: + * x (x) >> ds.sortby(lambda x: -x["y"]) + Size: 88B + Dimensions: (x: 2, y: 2) + Coordinates: + * x (x) Self: + """Compute the qth quantile of the data along the specified dimension. + + Returns the qth quantiles(s) of the array elements for each variable + in the Dataset. + + Parameters + ---------- + q : float or array-like of float + Quantile to compute, which must be between 0 and 1 inclusive. + dim : str or Iterable of Hashable, optional + Dimension(s) over which to apply quantile. + method : str, default: "linear" + This optional parameter specifies the interpolation method to use when the + desired quantile lies between two data points. The options sorted by their R + type as summarized in the H&F paper [1]_ are: + + 1. "inverted_cdf" + 2. "averaged_inverted_cdf" + 3. "closest_observation" + 4. "interpolated_inverted_cdf" + 5. "hazen" + 6. "weibull" + 7. "linear" (default) + 8. "median_unbiased" + 9. "normal_unbiased" + + The first three methods are discontiuous. The following discontinuous + variations of the default "linear" (7.) option are also available: + + * "lower" + * "higher" + * "midpoint" + * "nearest" + + See :py:func:`numpy.quantile` or [1]_ for details. The "method" argument + was previously called "interpolation", renamed in accordance with numpy + version 1.22.0. + + keep_attrs : bool, optional + If True, the dataset's attributes (`attrs`) will be copied from + the original object to the new one. If False (default), the new + object will be returned without attributes. + numeric_only : bool, optional + If True, only apply ``func`` to variables with a numeric dtype. + skipna : bool, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + + Returns + ------- + quantiles : Dataset + If `q` is a single quantile, then the result is a scalar for each + variable in data_vars. If multiple percentiles are given, first + axis of the result corresponds to the quantile and a quantile + dimension is added to the return Dataset. The other dimensions are + the dimensions that remain after the reduction of the array. + + See Also + -------- + numpy.nanquantile, numpy.quantile, pandas.Series.quantile, DataArray.quantile + + Examples + -------- + >>> ds = xr.Dataset( + ... {"a": (("x", "y"), [[0.7, 4.2, 9.4, 1.5], [6.5, 7.3, 2.6, 1.9]])}, + ... coords={"x": [7, 9], "y": [1, 1.5, 2, 2.5]}, + ... ) + >>> ds.quantile(0) # or ds.quantile(0, dim=...) + Size: 16B + Dimensions: () + Coordinates: + quantile float64 8B 0.0 + Data variables: + a float64 8B 0.7 + >>> ds.quantile(0, dim="x") + Size: 72B + Dimensions: (y: 4) + Coordinates: + * y (y) float64 32B 1.0 1.5 2.0 2.5 + quantile float64 8B 0.0 + Data variables: + a (y) float64 32B 0.7 4.2 2.6 1.5 + >>> ds.quantile([0, 0.5, 1]) + Size: 48B + Dimensions: (quantile: 3) + Coordinates: + * quantile (quantile) float64 24B 0.0 0.5 1.0 + Data variables: + a (quantile) float64 24B 0.7 3.4 9.4 + >>> ds.quantile([0, 0.5, 1], dim="x") + Size: 152B + Dimensions: (quantile: 3, y: 4) + Coordinates: + * y (y) float64 32B 1.0 1.5 2.0 2.5 + * quantile (quantile) float64 24B 0.0 0.5 1.0 + Data variables: + a (quantile, y) float64 96B 0.7 4.2 2.6 1.5 3.6 ... 6.5 7.3 9.4 1.9 + + References + ---------- + .. [1] R. J. Hyndman and Y. Fan, + "Sample quantiles in statistical packages," + The American Statistician, 50(4), pp. 361-365, 1996 + """ + + # interpolation renamed to method in version 0.21.0 + # check here and in variable to avoid repeated warnings + if interpolation is not None: + warnings.warn( + "The `interpolation` argument to quantile was renamed to `method`.", + FutureWarning, + ) + + if method != "linear": + raise TypeError("Cannot pass interpolation and method keywords!") + + method = interpolation + + dims: set[Hashable] + if isinstance(dim, str): + dims = {dim} + elif dim is None or dim is ...: + dims = set(self.dims) + else: + dims = set(dim) + + invalid_dims = set(dims) - set(self.dims) + if invalid_dims: + raise ValueError( + f"Dimensions {tuple(invalid_dims)} not found in data dimensions {tuple(self.dims)}" + ) + + q = np.asarray(q, dtype=np.float64) + + variables = {} + for name, var in self.variables.items(): + reduce_dims = [d for d in var.dims if d in dims] + if reduce_dims or not var.dims: + if name not in self.coords: + if ( + not numeric_only + or np.issubdtype(var.dtype, np.number) + or var.dtype == np.bool_ + ): + variables[name] = var.quantile( + q, + dim=reduce_dims, + method=method, + keep_attrs=keep_attrs, + skipna=skipna, + ) + + else: + variables[name] = var + + # construct the new dataset + coord_names = {k for k in self.coords if k in variables} + indexes = {k: v for k, v in self._indexes.items() if k in variables} + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=False) + attrs = self.attrs if keep_attrs else None + new = self._replace_with_new_dims( + variables, coord_names=coord_names, attrs=attrs, indexes=indexes + ) + return new.assign_coords(quantile=q) + + @_deprecate_positional_args("v2023.10.0") + def rank( + self, + dim: Hashable, + *, + pct: bool = False, + keep_attrs: bool | None = None, + ) -> Self: + """Ranks the data. + + Equal values are assigned a rank that is the average of the ranks that + would have been otherwise assigned to all of the values within + that set. + Ranks begin at 1, not 0. If pct is True, computes percentage ranks. + + NaNs in the input array are returned as NaNs. + + The `bottleneck` library is required. + + Parameters + ---------- + dim : Hashable + Dimension over which to compute rank. + pct : bool, default: False + If True, compute percentage ranks, otherwise compute integer ranks. + keep_attrs : bool or None, optional + If True, the dataset's attributes (`attrs`) will be copied from + the original object to the new one. If False, the new + object will be returned without attributes. + + Returns + ------- + ranked : Dataset + Variables that do not depend on `dim` are dropped. + """ + if not OPTIONS["use_bottleneck"]: + raise RuntimeError( + "rank requires bottleneck to be enabled." + " Call `xr.set_options(use_bottleneck=True)` to enable it." + ) + + if dim not in self.dims: + raise ValueError( + f"Dimension {dim!r} not found in data dimensions {tuple(self.dims)}" + ) + + variables = {} + for name, var in self.variables.items(): + if name in self.data_vars: + if dim in var.dims: + variables[name] = var.rank(dim, pct=pct) + else: + variables[name] = var + + coord_names = set(self.coords) + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=False) + attrs = self.attrs if keep_attrs else None + return self._replace(variables, coord_names, attrs=attrs) + + def differentiate( + self, + coord: Hashable, + edge_order: Literal[1, 2] = 1, + datetime_unit: DatetimeUnitOptions | None = None, + ) -> Self: + """Differentiate with the second order accurate central + differences. + + .. note:: + This feature is limited to simple cartesian geometry, i.e. coord + must be one dimensional. + + Parameters + ---------- + coord : Hashable + The coordinate to be used to compute the gradient. + edge_order : {1, 2}, default: 1 + N-th order accurate differences at the boundaries. + datetime_unit : None or {"Y", "M", "W", "D", "h", "m", "s", "ms", \ + "us", "ns", "ps", "fs", "as", None}, default: None + Unit to compute gradient. Only valid for datetime coordinate. + + Returns + ------- + differentiated: Dataset + + See also + -------- + numpy.gradient: corresponding numpy function + """ + from xarray.core.variable import Variable + + if coord not in self.variables and coord not in self.dims: + variables_and_dims = tuple(set(self.variables.keys()).union(self.dims)) + raise ValueError( + f"Coordinate {coord!r} not found in variables or dimensions {variables_and_dims}." + ) + + coord_var = self[coord].variable + if coord_var.ndim != 1: + raise ValueError( + f"Coordinate {coord} must be 1 dimensional but is {coord_var.ndim}" + " dimensional" + ) + + dim = coord_var.dims[0] + if _contains_datetime_like_objects(coord_var): + if coord_var.dtype.kind in "mM" and datetime_unit is None: + datetime_unit = cast( + "DatetimeUnitOptions", np.datetime_data(coord_var.dtype)[0] + ) + elif datetime_unit is None: + datetime_unit = "s" # Default to seconds for cftime objects + coord_var = coord_var._to_numeric(datetime_unit=datetime_unit) + + variables = {} + for k, v in self.variables.items(): + if k in self.data_vars and dim in v.dims and k not in self.coords: + if _contains_datetime_like_objects(v): + v = v._to_numeric(datetime_unit=datetime_unit) + grad = duck_array_ops.gradient( + v.data, + coord_var.data, + edge_order=edge_order, + axis=v.get_axis_num(dim), + ) + variables[k] = Variable(v.dims, grad) + else: + variables[k] = v + return self._replace(variables) + + def integrate( + self, + coord: Hashable | Sequence[Hashable], + datetime_unit: DatetimeUnitOptions = None, + ) -> Self: + """Integrate along the given coordinate using the trapezoidal rule. + + .. note:: + This feature is limited to simple cartesian geometry, i.e. coord + must be one dimensional. + + Parameters + ---------- + coord : hashable, or sequence of hashable + Coordinate(s) used for the integration. + datetime_unit : {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', 'us', 'ns', \ + 'ps', 'fs', 'as', None}, optional + Specify the unit if datetime coordinate is used. + + Returns + ------- + integrated : Dataset + + See also + -------- + DataArray.integrate + numpy.trapz : corresponding numpy function + + Examples + -------- + >>> ds = xr.Dataset( + ... data_vars={"a": ("x", [5, 5, 6, 6]), "b": ("x", [1, 2, 1, 0])}, + ... coords={"x": [0, 1, 2, 3], "y": ("x", [1, 7, 3, 5])}, + ... ) + >>> ds + Size: 128B + Dimensions: (x: 4) + Coordinates: + * x (x) int64 32B 0 1 2 3 + y (x) int64 32B 1 7 3 5 + Data variables: + a (x) int64 32B 5 5 6 6 + b (x) int64 32B 1 2 1 0 + >>> ds.integrate("x") + Size: 16B + Dimensions: () + Data variables: + a float64 8B 16.5 + b float64 8B 3.5 + >>> ds.integrate("y") + Size: 16B + Dimensions: () + Data variables: + a float64 8B 20.0 + b float64 8B 4.0 + """ + if not isinstance(coord, (list, tuple)): + coord = (coord,) + result = self + for c in coord: + result = result._integrate_one(c, datetime_unit=datetime_unit) + return result + + def _integrate_one(self, coord, datetime_unit=None, cumulative=False): + from xarray.core.variable import Variable + + if coord not in self.variables and coord not in self.dims: + variables_and_dims = tuple(set(self.variables.keys()).union(self.dims)) + raise ValueError( + f"Coordinate {coord!r} not found in variables or dimensions {variables_and_dims}." + ) + + coord_var = self[coord].variable + if coord_var.ndim != 1: + raise ValueError( + f"Coordinate {coord} must be 1 dimensional but is {coord_var.ndim}" + " dimensional" + ) + + dim = coord_var.dims[0] + if _contains_datetime_like_objects(coord_var): + if coord_var.dtype.kind in "mM" and datetime_unit is None: + datetime_unit, _ = np.datetime_data(coord_var.dtype) + elif datetime_unit is None: + datetime_unit = "s" # Default to seconds for cftime objects + coord_var = coord_var._replace( + data=datetime_to_numeric(coord_var.data, datetime_unit=datetime_unit) + ) + + variables = {} + coord_names = set() + for k, v in self.variables.items(): + if k in self.coords: + if dim not in v.dims or cumulative: + variables[k] = v + coord_names.add(k) + else: + if k in self.data_vars and dim in v.dims: + if _contains_datetime_like_objects(v): + v = datetime_to_numeric(v, datetime_unit=datetime_unit) + if cumulative: + integ = duck_array_ops.cumulative_trapezoid( + v.data, coord_var.data, axis=v.get_axis_num(dim) + ) + v_dims = v.dims + else: + integ = duck_array_ops.trapz( + v.data, coord_var.data, axis=v.get_axis_num(dim) + ) + v_dims = list(v.dims) + v_dims.remove(dim) + variables[k] = Variable(v_dims, integ) + else: + variables[k] = v + indexes = {k: v for k, v in self._indexes.items() if k in variables} + return self._replace_with_new_dims( + variables, coord_names=coord_names, indexes=indexes + ) + + def cumulative_integrate( + self, + coord: Hashable | Sequence[Hashable], + datetime_unit: DatetimeUnitOptions = None, + ) -> Self: + """Integrate along the given coordinate using the trapezoidal rule. + + .. note:: + This feature is limited to simple cartesian geometry, i.e. coord + must be one dimensional. + + The first entry of the cumulative integral of each variable is always 0, in + order to keep the length of the dimension unchanged between input and + output. + + Parameters + ---------- + coord : hashable, or sequence of hashable + Coordinate(s) used for the integration. + datetime_unit : {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', 'us', 'ns', \ + 'ps', 'fs', 'as', None}, optional + Specify the unit if datetime coordinate is used. + + Returns + ------- + integrated : Dataset + + See also + -------- + DataArray.cumulative_integrate + scipy.integrate.cumulative_trapezoid : corresponding scipy function + + Examples + -------- + >>> ds = xr.Dataset( + ... data_vars={"a": ("x", [5, 5, 6, 6]), "b": ("x", [1, 2, 1, 0])}, + ... coords={"x": [0, 1, 2, 3], "y": ("x", [1, 7, 3, 5])}, + ... ) + >>> ds + Size: 128B + Dimensions: (x: 4) + Coordinates: + * x (x) int64 32B 0 1 2 3 + y (x) int64 32B 1 7 3 5 + Data variables: + a (x) int64 32B 5 5 6 6 + b (x) int64 32B 1 2 1 0 + >>> ds.cumulative_integrate("x") + Size: 128B + Dimensions: (x: 4) + Coordinates: + * x (x) int64 32B 0 1 2 3 + y (x) int64 32B 1 7 3 5 + Data variables: + a (x) float64 32B 0.0 5.0 10.5 16.5 + b (x) float64 32B 0.0 1.5 3.0 3.5 + >>> ds.cumulative_integrate("y") + Size: 128B + Dimensions: (x: 4) + Coordinates: + * x (x) int64 32B 0 1 2 3 + y (x) int64 32B 1 7 3 5 + Data variables: + a (x) float64 32B 0.0 30.0 8.0 20.0 + b (x) float64 32B 0.0 9.0 3.0 4.0 + """ + if not isinstance(coord, (list, tuple)): + coord = (coord,) + result = self + for c in coord: + result = result._integrate_one( + c, datetime_unit=datetime_unit, cumulative=True + ) + return result + + @property + def real(self) -> Self: + """ + The real part of each data variable. + + See Also + -------- + numpy.ndarray.real + """ + return self.map(lambda x: x.real, keep_attrs=True) + + @property + def imag(self) -> Self: + """ + The imaginary part of each data variable. + + See Also + -------- + numpy.ndarray.imag + """ + return self.map(lambda x: x.imag, keep_attrs=True) + + plot = utils.UncachedAccessor(DatasetPlotAccessor) + + def filter_by_attrs(self, **kwargs) -> Self: + """Returns a ``Dataset`` with variables that match specific conditions. + + Can pass in ``key=value`` or ``key=callable``. A Dataset is returned + containing only the variables for which all the filter tests pass. + These tests are either ``key=value`` for which the attribute ``key`` + has the exact value ``value`` or the callable passed into + ``key=callable`` returns True. The callable will be passed a single + value, either the value of the attribute ``key`` or ``None`` if the + DataArray does not have an attribute with the name ``key``. + + Parameters + ---------- + **kwargs + key : str + Attribute name. + value : callable or obj + If value is a callable, it should return a boolean in the form + of bool = func(attr) where attr is da.attrs[key]. + Otherwise, value will be compared to the each + DataArray's attrs[key]. + + Returns + ------- + new : Dataset + New dataset with variables filtered by attribute. + + Examples + -------- + >>> temp = 15 + 8 * np.random.randn(2, 2, 3) + >>> precip = 10 * np.random.rand(2, 2, 3) + >>> lon = [[-99.83, -99.32], [-99.79, -99.23]] + >>> lat = [[42.25, 42.21], [42.63, 42.59]] + >>> dims = ["x", "y", "time"] + >>> temp_attr = dict(standard_name="air_potential_temperature") + >>> precip_attr = dict(standard_name="convective_precipitation_flux") + + >>> ds = xr.Dataset( + ... dict( + ... temperature=(dims, temp, temp_attr), + ... precipitation=(dims, precip, precip_attr), + ... ), + ... coords=dict( + ... lon=(["x", "y"], lon), + ... lat=(["x", "y"], lat), + ... time=pd.date_range("2014-09-06", periods=3), + ... reference_time=pd.Timestamp("2014-09-05"), + ... ), + ... ) + + Get variables matching a specific standard_name: + + >>> ds.filter_by_attrs(standard_name="convective_precipitation_flux") + Size: 192B + Dimensions: (x: 2, y: 2, time: 3) + Coordinates: + lon (x, y) float64 32B -99.83 -99.32 -99.79 -99.23 + lat (x, y) float64 32B 42.25 42.21 42.63 42.59 + * time (time) datetime64[ns] 24B 2014-09-06 2014-09-07 2014-09-08 + reference_time datetime64[ns] 8B 2014-09-05 + Dimensions without coordinates: x, y + Data variables: + precipitation (x, y, time) float64 96B 5.68 9.256 0.7104 ... 4.615 7.805 + + Get all variables that have a standard_name attribute: + + >>> standard_name = lambda v: v is not None + >>> ds.filter_by_attrs(standard_name=standard_name) + Size: 288B + Dimensions: (x: 2, y: 2, time: 3) + Coordinates: + lon (x, y) float64 32B -99.83 -99.32 -99.79 -99.23 + lat (x, y) float64 32B 42.25 42.21 42.63 42.59 + * time (time) datetime64[ns] 24B 2014-09-06 2014-09-07 2014-09-08 + reference_time datetime64[ns] 8B 2014-09-05 + Dimensions without coordinates: x, y + Data variables: + temperature (x, y, time) float64 96B 29.11 18.2 22.83 ... 16.15 26.63 + precipitation (x, y, time) float64 96B 5.68 9.256 0.7104 ... 4.615 7.805 + + """ + selection = [] + for var_name, variable in self.variables.items(): + has_value_flag = False + for attr_name, pattern in kwargs.items(): + attr_value = variable.attrs.get(attr_name) + if (callable(pattern) and pattern(attr_value)) or attr_value == pattern: + has_value_flag = True + else: + has_value_flag = False + break + if has_value_flag is True: + selection.append(var_name) + return self[selection] + + def unify_chunks(self) -> Self: + """Unify chunk size along all chunked dimensions of this Dataset. + + Returns + ------- + Dataset with consistent chunk sizes for all dask-array variables + + See Also + -------- + dask.array.core.unify_chunks + """ + + return unify_chunks(self)[0] + + def map_blocks( + self, + func: Callable[..., T_Xarray], + args: Sequence[Any] = (), + kwargs: Mapping[str, Any] | None = None, + template: DataArray | Dataset | None = None, + ) -> T_Xarray: + """ + Apply a function to each block of this Dataset. + + .. warning:: + This method is experimental and its signature may change. + + Parameters + ---------- + func : callable + User-provided function that accepts a Dataset as its first + parameter. The function will receive a subset or 'block' of this Dataset (see below), + corresponding to one chunk along each chunked dimension. ``func`` will be + executed as ``func(subset_dataset, *subset_args, **kwargs)``. + + This function must return either a single DataArray or a single Dataset. + + This function cannot add a new chunked dimension. + args : sequence + Passed to func after unpacking and subsetting any xarray objects by blocks. + xarray objects in args must be aligned with obj, otherwise an error is raised. + kwargs : Mapping or None + Passed verbatim to func after unpacking. xarray objects, if any, will not be + subset to blocks. Passing dask collections in kwargs is not allowed. + template : DataArray, Dataset or None, optional + xarray object representing the final result after compute is called. If not provided, + the function will be first run on mocked-up data, that looks like this object but + has sizes 0, to determine properties of the returned object such as dtype, + variable names, attributes, new dimensions and new indexes (if any). + ``template`` must be provided if the function changes the size of existing dimensions. + When provided, ``attrs`` on variables in `template` are copied over to the result. Any + ``attrs`` set by ``func`` will be ignored. + + Returns + ------- + A single DataArray or Dataset with dask backend, reassembled from the outputs of the + function. + + Notes + ----- + This function is designed for when ``func`` needs to manipulate a whole xarray object + subset to each block. Each block is loaded into memory. In the more common case where + ``func`` can work on numpy arrays, it is recommended to use ``apply_ufunc``. + + If none of the variables in this object is backed by dask arrays, calling this function is + equivalent to calling ``func(obj, *args, **kwargs)``. + + See Also + -------- + dask.array.map_blocks, xarray.apply_ufunc, xarray.Dataset.map_blocks + xarray.DataArray.map_blocks + + :doc:`xarray-tutorial:advanced/map_blocks/map_blocks` + Advanced Tutorial on map_blocks with dask + + + Examples + -------- + Calculate an anomaly from climatology using ``.groupby()``. Using + ``xr.map_blocks()`` allows for parallel operations with knowledge of ``xarray``, + its indices, and its methods like ``.groupby()``. + + >>> def calculate_anomaly(da, groupby_type="time.month"): + ... gb = da.groupby(groupby_type) + ... clim = gb.mean(dim="time") + ... return gb - clim + ... + >>> time = xr.cftime_range("1990-01", "1992-01", freq="ME") + >>> month = xr.DataArray(time.month, coords={"time": time}, dims=["time"]) + >>> np.random.seed(123) + >>> array = xr.DataArray( + ... np.random.rand(len(time)), + ... dims=["time"], + ... coords={"time": time, "month": month}, + ... ).chunk() + >>> ds = xr.Dataset({"a": array}) + >>> ds.map_blocks(calculate_anomaly, template=ds).compute() + Size: 576B + Dimensions: (time: 24) + Coordinates: + * time (time) object 192B 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 + month (time) int64 192B 1 2 3 4 5 6 7 8 9 10 ... 3 4 5 6 7 8 9 10 11 12 + Data variables: + a (time) float64 192B 0.1289 0.1132 -0.0856 ... 0.1906 -0.05901 + + Note that one must explicitly use ``args=[]`` and ``kwargs={}`` to pass arguments + to the function being applied in ``xr.map_blocks()``: + + >>> ds.map_blocks( + ... calculate_anomaly, + ... kwargs={"groupby_type": "time.year"}, + ... template=ds, + ... ) + Size: 576B + Dimensions: (time: 24) + Coordinates: + * time (time) object 192B 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 + month (time) int64 192B dask.array + Data variables: + a (time) float64 192B dask.array + """ + from xarray.core.parallel import map_blocks + + return map_blocks(func, self, args, kwargs, template) + + def polyfit( + self, + dim: Hashable, + deg: int, + skipna: bool | None = None, + rcond: float | None = None, + w: Hashable | Any = None, + full: bool = False, + cov: bool | Literal["unscaled"] = False, + ) -> Self: + """ + Least squares polynomial fit. + + This replicates the behaviour of `numpy.polyfit` but differs by skipping + invalid values when `skipna = True`. + + Parameters + ---------- + dim : hashable + Coordinate along which to fit the polynomials. + deg : int + Degree of the fitting polynomial. + skipna : bool or None, optional + If True, removes all invalid values before fitting each 1D slices of the array. + Default is True if data is stored in a dask.array or if there is any + invalid values, False otherwise. + rcond : float or None, optional + Relative condition number to the fit. + w : hashable or Any, optional + Weights to apply to the y-coordinate of the sample points. + Can be an array-like object or the name of a coordinate in the dataset. + full : bool, default: False + Whether to return the residuals, matrix rank and singular values in addition + to the coefficients. + cov : bool or "unscaled", default: False + Whether to return to the covariance matrix in addition to the coefficients. + The matrix is not scaled if `cov='unscaled'`. + + Returns + ------- + polyfit_results : Dataset + A single dataset which contains (for each "var" in the input dataset): + + [var]_polyfit_coefficients + The coefficients of the best fit for each variable in this dataset. + [var]_polyfit_residuals + The residuals of the least-square computation for each variable (only included if `full=True`) + When the matrix rank is deficient, np.nan is returned. + [dim]_matrix_rank + The effective rank of the scaled Vandermonde coefficient matrix (only included if `full=True`) + The rank is computed ignoring the NaN values that might be skipped. + [dim]_singular_values + The singular values of the scaled Vandermonde coefficient matrix (only included if `full=True`) + [var]_polyfit_covariance + The covariance matrix of the polynomial coefficient estimates (only included if `full=False` and `cov=True`) + + Warns + ----- + RankWarning + The rank of the coefficient matrix in the least-squares fit is deficient. + The warning is not raised with in-memory (not dask) data and `full=True`. + + See Also + -------- + numpy.polyfit + numpy.polyval + xarray.polyval + """ + from xarray.core.dataarray import DataArray + + variables = {} + skipna_da = skipna + + x = get_clean_interp_index(self, dim, strict=False) + xname = f"{self[dim].name}_" + order = int(deg) + 1 + lhs = np.vander(x, order) + + if rcond is None: + rcond = x.shape[0] * np.finfo(x.dtype).eps + + # Weights: + if w is not None: + if isinstance(w, Hashable): + w = self.coords[w] + w = np.asarray(w) + if w.ndim != 1: + raise TypeError("Expected a 1-d array for weights.") + if w.shape[0] != lhs.shape[0]: + raise TypeError(f"Expected w and {dim} to have the same length") + lhs *= w[:, np.newaxis] + + # Scaling + scale = np.sqrt((lhs * lhs).sum(axis=0)) + lhs /= scale + + degree_dim = utils.get_temp_dimname(self.dims, "degree") + + rank = np.linalg.matrix_rank(lhs) + + if full: + rank = DataArray(rank, name=xname + "matrix_rank") + variables[rank.name] = rank + _sing = np.linalg.svd(lhs, compute_uv=False) + sing = DataArray( + _sing, + dims=(degree_dim,), + coords={degree_dim: np.arange(rank - 1, -1, -1)}, + name=xname + "singular_values", + ) + variables[sing.name] = sing + + for name, da in self.data_vars.items(): + if dim not in da.dims: + continue + + if is_duck_dask_array(da.data) and ( + rank != order or full or skipna is None + ): + # Current algorithm with dask and skipna=False neither supports + # deficient ranks nor does it output the "full" info (issue dask/dask#6516) + skipna_da = True + elif skipna is None: + skipna_da = bool(np.any(da.isnull())) + + dims_to_stack = [dimname for dimname in da.dims if dimname != dim] + stacked_coords: dict[Hashable, DataArray] = {} + if dims_to_stack: + stacked_dim = utils.get_temp_dimname(dims_to_stack, "stacked") + rhs = da.transpose(dim, *dims_to_stack).stack( + {stacked_dim: dims_to_stack} + ) + stacked_coords = {stacked_dim: rhs[stacked_dim]} + scale_da = scale[:, np.newaxis] + else: + rhs = da + scale_da = scale + + if w is not None: + rhs = rhs * w[:, np.newaxis] + + with warnings.catch_warnings(): + if full: # Copy np.polyfit behavior + warnings.simplefilter("ignore", RankWarning) + else: # Raise only once per variable + warnings.simplefilter("once", RankWarning) + + coeffs, residuals = duck_array_ops.least_squares( + lhs, rhs.data, rcond=rcond, skipna=skipna_da + ) + + if isinstance(name, str): + name = f"{name}_" + else: + # Thus a ReprObject => polyfit was called on a DataArray + name = "" + + coeffs = DataArray( + coeffs / scale_da, + dims=[degree_dim] + list(stacked_coords.keys()), + coords={degree_dim: np.arange(order)[::-1], **stacked_coords}, + name=name + "polyfit_coefficients", + ) + if dims_to_stack: + coeffs = coeffs.unstack(stacked_dim) + variables[coeffs.name] = coeffs + + if full or (cov is True): + residuals = DataArray( + residuals if dims_to_stack else residuals.squeeze(), + dims=list(stacked_coords.keys()), + coords=stacked_coords, + name=name + "polyfit_residuals", + ) + if dims_to_stack: + residuals = residuals.unstack(stacked_dim) + variables[residuals.name] = residuals + + if cov: + Vbase = np.linalg.inv(np.dot(lhs.T, lhs)) + Vbase /= np.outer(scale, scale) + if cov == "unscaled": + fac = 1 + else: + if x.shape[0] <= order: + raise ValueError( + "The number of data points must exceed order to scale the covariance matrix." + ) + fac = residuals / (x.shape[0] - order) + covariance = DataArray(Vbase, dims=("cov_i", "cov_j")) * fac + variables[name + "polyfit_covariance"] = covariance + + return type(self)(data_vars=variables, attrs=self.attrs.copy()) + + def pad( + self, + pad_width: Mapping[Any, int | tuple[int, int]] | None = None, + mode: PadModeOptions = "constant", + stat_length: ( + int | tuple[int, int] | Mapping[Any, tuple[int, int]] | None + ) = None, + constant_values: ( + float | tuple[float, float] | Mapping[Any, tuple[float, float]] | None + ) = None, + end_values: int | tuple[int, int] | Mapping[Any, tuple[int, int]] | None = None, + reflect_type: PadReflectOptions = None, + keep_attrs: bool | None = None, + **pad_width_kwargs: Any, + ) -> Self: + """Pad this dataset along one or more dimensions. + + .. warning:: + This function is experimental and its behaviour is likely to change + especially regarding padding of dimension coordinates (or IndexVariables). + + When using one of the modes ("edge", "reflect", "symmetric", "wrap"), + coordinates will be padded with the same mode, otherwise coordinates + are padded using the "constant" mode with fill_value dtypes.NA. + + Parameters + ---------- + pad_width : mapping of hashable to tuple of int + Mapping with the form of {dim: (pad_before, pad_after)} + describing the number of values padded along each dimension. + {dim: pad} is a shortcut for pad_before = pad_after = pad + mode : {"constant", "edge", "linear_ramp", "maximum", "mean", "median", \ + "minimum", "reflect", "symmetric", "wrap"}, default: "constant" + How to pad the DataArray (taken from numpy docs): + + - "constant": Pads with a constant value. + - "edge": Pads with the edge values of array. + - "linear_ramp": Pads with the linear ramp between end_value and the + array edge value. + - "maximum": Pads with the maximum value of all or part of the + vector along each axis. + - "mean": Pads with the mean value of all or part of the + vector along each axis. + - "median": Pads with the median value of all or part of the + vector along each axis. + - "minimum": Pads with the minimum value of all or part of the + vector along each axis. + - "reflect": Pads with the reflection of the vector mirrored on + the first and last values of the vector along each axis. + - "symmetric": Pads with the reflection of the vector mirrored + along the edge of the array. + - "wrap": Pads with the wrap of the vector along the axis. + The first values are used to pad the end and the + end values are used to pad the beginning. + + stat_length : int, tuple or mapping of hashable to tuple, default: None + Used in 'maximum', 'mean', 'median', and 'minimum'. Number of + values at edge of each axis used to calculate the statistic value. + {dim_1: (before_1, after_1), ... dim_N: (before_N, after_N)} unique + statistic lengths along each dimension. + ((before, after),) yields same before and after statistic lengths + for each dimension. + (stat_length,) or int is a shortcut for before = after = statistic + length for all axes. + Default is ``None``, to use the entire axis. + constant_values : scalar, tuple or mapping of hashable to tuple, default: 0 + Used in 'constant'. The values to set the padded values for each + axis. + ``{dim_1: (before_1, after_1), ... dim_N: (before_N, after_N)}`` unique + pad constants along each dimension. + ``((before, after),)`` yields same before and after constants for each + dimension. + ``(constant,)`` or ``constant`` is a shortcut for ``before = after = constant`` for + all dimensions. + Default is 0. + end_values : scalar, tuple or mapping of hashable to tuple, default: 0 + Used in 'linear_ramp'. The values used for the ending value of the + linear_ramp and that will form the edge of the padded array. + ``{dim_1: (before_1, after_1), ... dim_N: (before_N, after_N)}`` unique + end values along each dimension. + ``((before, after),)`` yields same before and after end values for each + axis. + ``(constant,)`` or ``constant`` is a shortcut for ``before = after = constant`` for + all axes. + Default is 0. + reflect_type : {"even", "odd", None}, optional + Used in "reflect", and "symmetric". The "even" style is the + default with an unaltered reflection around the edge value. For + the "odd" style, the extended part of the array is created by + subtracting the reflected values from two times the edge value. + keep_attrs : bool or None, optional + If True, the attributes (``attrs``) will be copied from the + original object to the new one. If False, the new object + will be returned without attributes. + **pad_width_kwargs + The keyword arguments form of ``pad_width``. + One of ``pad_width`` or ``pad_width_kwargs`` must be provided. + + Returns + ------- + padded : Dataset + Dataset with the padded coordinates and data. + + See Also + -------- + Dataset.shift, Dataset.roll, Dataset.bfill, Dataset.ffill, numpy.pad, dask.array.pad + + Notes + ----- + By default when ``mode="constant"`` and ``constant_values=None``, integer types will be + promoted to ``float`` and padded with ``np.nan``. To avoid type promotion + specify ``constant_values=np.nan`` + + Padding coordinates will drop their corresponding index (if any) and will reset default + indexes for dimension coordinates. + + Examples + -------- + >>> ds = xr.Dataset({"foo": ("x", range(5))}) + >>> ds.pad(x=(1, 2)) + Size: 64B + Dimensions: (x: 8) + Dimensions without coordinates: x + Data variables: + foo (x) float64 64B nan 0.0 1.0 2.0 3.0 4.0 nan nan + """ + pad_width = either_dict_or_kwargs(pad_width, pad_width_kwargs, "pad") + + if mode in ("edge", "reflect", "symmetric", "wrap"): + coord_pad_mode = mode + coord_pad_options = { + "stat_length": stat_length, + "constant_values": constant_values, + "end_values": end_values, + "reflect_type": reflect_type, + } + else: + coord_pad_mode = "constant" + coord_pad_options = {} + + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=True) + + variables = {} + + # keep indexes that won't be affected by pad and drop all other indexes + xindexes = self.xindexes + pad_dims = set(pad_width) + indexes = {} + for k, idx in xindexes.items(): + if not pad_dims.intersection(xindexes.get_all_dims(k)): + indexes[k] = idx + + for name, var in self.variables.items(): + var_pad_width = {k: v for k, v in pad_width.items() if k in var.dims} + if not var_pad_width: + variables[name] = var + elif name in self.data_vars: + variables[name] = var.pad( + pad_width=var_pad_width, + mode=mode, + stat_length=stat_length, + constant_values=constant_values, + end_values=end_values, + reflect_type=reflect_type, + keep_attrs=keep_attrs, + ) + else: + variables[name] = var.pad( + pad_width=var_pad_width, + mode=coord_pad_mode, + keep_attrs=keep_attrs, + **coord_pad_options, # type: ignore[arg-type] + ) + # reset default index of dimension coordinates + if (name,) == var.dims: + dim_var = {name: variables[name]} + index = PandasIndex.from_variables(dim_var, options={}) + index_vars = index.create_variables(dim_var) + indexes[name] = index + variables[name] = index_vars[name] + + attrs = self._attrs if keep_attrs else None + return self._replace_with_new_dims(variables, indexes=indexes, attrs=attrs) + + @_deprecate_positional_args("v2023.10.0") + def idxmin( + self, + dim: Hashable | None = None, + *, + skipna: bool | None = None, + fill_value: Any = xrdtypes.NA, + keep_attrs: bool | None = None, + ) -> Self: + """Return the coordinate label of the minimum value along a dimension. + + Returns a new `Dataset` named after the dimension with the values of + the coordinate labels along that dimension corresponding to minimum + values along that dimension. + + In comparison to :py:meth:`~Dataset.argmin`, this returns the + coordinate label while :py:meth:`~Dataset.argmin` returns the index. + + Parameters + ---------- + dim : Hashable, optional + Dimension over which to apply `idxmin`. This is optional for 1D + variables, but required for variables with 2 or more dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for ``float``, ``complex``, and ``object`` + dtypes; other dtypes either do not have a sentinel missing value + (``int``) or ``skipna=True`` has not been implemented + (``datetime64`` or ``timedelta64``). + fill_value : Any, default: NaN + Value to be filled in case all of the values along a dimension are + null. By default this is NaN. The fill value and result are + automatically converted to a compatible dtype if possible. + Ignored if ``skipna`` is False. + keep_attrs : bool or None, optional + If True, the attributes (``attrs``) will be copied from the + original object to the new one. If False, the new object + will be returned without attributes. + + Returns + ------- + reduced : Dataset + New `Dataset` object with `idxmin` applied to its data and the + indicated dimension removed. + + See Also + -------- + DataArray.idxmin, Dataset.idxmax, Dataset.min, Dataset.argmin + + Examples + -------- + >>> array1 = xr.DataArray( + ... [0, 2, 1, 0, -2], dims="x", coords={"x": ["a", "b", "c", "d", "e"]} + ... ) + >>> array2 = xr.DataArray( + ... [ + ... [2.0, 1.0, 2.0, 0.0, -2.0], + ... [-4.0, np.nan, 2.0, np.nan, -2.0], + ... [np.nan, np.nan, 1.0, np.nan, np.nan], + ... ], + ... dims=["y", "x"], + ... coords={"y": [-1, 0, 1], "x": ["a", "b", "c", "d", "e"]}, + ... ) + >>> ds = xr.Dataset({"int": array1, "float": array2}) + >>> ds.min(dim="x") + Size: 56B + Dimensions: (y: 3) + Coordinates: + * y (y) int64 24B -1 0 1 + Data variables: + int int64 8B -2 + float (y) float64 24B -2.0 -4.0 1.0 + >>> ds.argmin(dim="x") + Size: 56B + Dimensions: (y: 3) + Coordinates: + * y (y) int64 24B -1 0 1 + Data variables: + int int64 8B 4 + float (y) int64 24B 4 0 2 + >>> ds.idxmin(dim="x") + Size: 52B + Dimensions: (y: 3) + Coordinates: + * y (y) int64 24B -1 0 1 + Data variables: + int Self: + """Return the coordinate label of the maximum value along a dimension. + + Returns a new `Dataset` named after the dimension with the values of + the coordinate labels along that dimension corresponding to maximum + values along that dimension. + + In comparison to :py:meth:`~Dataset.argmax`, this returns the + coordinate label while :py:meth:`~Dataset.argmax` returns the index. + + Parameters + ---------- + dim : str, optional + Dimension over which to apply `idxmax`. This is optional for 1D + variables, but required for variables with 2 or more dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for ``float``, ``complex``, and ``object`` + dtypes; other dtypes either do not have a sentinel missing value + (``int``) or ``skipna=True`` has not been implemented + (``datetime64`` or ``timedelta64``). + fill_value : Any, default: NaN + Value to be filled in case all of the values along a dimension are + null. By default this is NaN. The fill value and result are + automatically converted to a compatible dtype if possible. + Ignored if ``skipna`` is False. + keep_attrs : bool or None, optional + If True, the attributes (``attrs``) will be copied from the + original object to the new one. If False, the new object + will be returned without attributes. + + Returns + ------- + reduced : Dataset + New `Dataset` object with `idxmax` applied to its data and the + indicated dimension removed. + + See Also + -------- + DataArray.idxmax, Dataset.idxmin, Dataset.max, Dataset.argmax + + Examples + -------- + >>> array1 = xr.DataArray( + ... [0, 2, 1, 0, -2], dims="x", coords={"x": ["a", "b", "c", "d", "e"]} + ... ) + >>> array2 = xr.DataArray( + ... [ + ... [2.0, 1.0, 2.0, 0.0, -2.0], + ... [-4.0, np.nan, 2.0, np.nan, -2.0], + ... [np.nan, np.nan, 1.0, np.nan, np.nan], + ... ], + ... dims=["y", "x"], + ... coords={"y": [-1, 0, 1], "x": ["a", "b", "c", "d", "e"]}, + ... ) + >>> ds = xr.Dataset({"int": array1, "float": array2}) + >>> ds.max(dim="x") + Size: 56B + Dimensions: (y: 3) + Coordinates: + * y (y) int64 24B -1 0 1 + Data variables: + int int64 8B 2 + float (y) float64 24B 2.0 2.0 1.0 + >>> ds.argmax(dim="x") + Size: 56B + Dimensions: (y: 3) + Coordinates: + * y (y) int64 24B -1 0 1 + Data variables: + int int64 8B 1 + float (y) int64 24B 0 2 2 + >>> ds.idxmax(dim="x") + Size: 52B + Dimensions: (y: 3) + Coordinates: + * y (y) int64 24B -1 0 1 + Data variables: + int Self: + """Indices of the minima of the member variables. + + If there are multiple minima, the indices of the first one found will be + returned. + + Parameters + ---------- + dim : Hashable, optional + The dimension over which to find the minimum. By default, finds minimum over + all dimensions - for now returning an int for backward compatibility, but + this is deprecated, in future will be an error, since DataArray.argmin will + return a dict with indices for all dimensions, which does not make sense for + a Dataset. + keep_attrs : bool, optional + If True, the attributes (`attrs`) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + skipna : bool, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + + Returns + ------- + result : Dataset + + Examples + -------- + >>> dataset = xr.Dataset( + ... { + ... "math_scores": ( + ... ["student", "test"], + ... [[90, 85, 79], [78, 80, 85], [95, 92, 98]], + ... ), + ... "english_scores": ( + ... ["student", "test"], + ... [[88, 90, 92], [75, 82, 79], [39, 96, 78]], + ... ), + ... }, + ... coords={ + ... "student": ["Alice", "Bob", "Charlie"], + ... "test": ["Test 1", "Test 2", "Test 3"], + ... }, + ... ) + + # Indices of the minimum values along the 'student' dimension are calculated + + >>> argmin_indices = dataset.argmin(dim="student") + + >>> min_score_in_math = dataset["student"].isel( + ... student=argmin_indices["math_scores"] + ... ) + >>> min_score_in_math + Size: 84B + array(['Bob', 'Bob', 'Alice'], dtype='>> min_score_in_english = dataset["student"].isel( + ... student=argmin_indices["english_scores"] + ... ) + >>> min_score_in_english + Size: 84B + array(['Charlie', 'Bob', 'Charlie'], dtype=' Self: + """Indices of the maxima of the member variables. + + If there are multiple maxima, the indices of the first one found will be + returned. + + Parameters + ---------- + dim : str, optional + The dimension over which to find the maximum. By default, finds maximum over + all dimensions - for now returning an int for backward compatibility, but + this is deprecated, in future will be an error, since DataArray.argmax will + return a dict with indices for all dimensions, which does not make sense for + a Dataset. + keep_attrs : bool, optional + If True, the attributes (`attrs`) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + skipna : bool, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + + Returns + ------- + result : Dataset + + Examples + -------- + + >>> dataset = xr.Dataset( + ... { + ... "math_scores": ( + ... ["student", "test"], + ... [[90, 85, 92], [78, 80, 85], [95, 92, 98]], + ... ), + ... "english_scores": ( + ... ["student", "test"], + ... [[88, 90, 92], [75, 82, 79], [93, 96, 91]], + ... ), + ... }, + ... coords={ + ... "student": ["Alice", "Bob", "Charlie"], + ... "test": ["Test 1", "Test 2", "Test 3"], + ... }, + ... ) + + # Indices of the maximum values along the 'student' dimension are calculated + + >>> argmax_indices = dataset.argmax(dim="test") + + >>> argmax_indices + Size: 132B + Dimensions: (student: 3) + Coordinates: + * student (student) Self | T_DataArray: + """ + Calculate an expression supplied as a string in the context of the dataset. + + This is currently experimental; the API may change particularly around + assignments, which currently returnn a ``Dataset`` with the additional variable. + Currently only the ``python`` engine is supported, which has the same + performance as executing in python. + + Parameters + ---------- + statement : str + String containing the Python-like expression to evaluate. + + Returns + ------- + result : Dataset or DataArray, depending on whether ``statement`` contains an + assignment. + + Examples + -------- + >>> ds = xr.Dataset( + ... {"a": ("x", np.arange(0, 5, 1)), "b": ("x", np.linspace(0, 1, 5))} + ... ) + >>> ds + Size: 80B + Dimensions: (x: 5) + Dimensions without coordinates: x + Data variables: + a (x) int64 40B 0 1 2 3 4 + b (x) float64 40B 0.0 0.25 0.5 0.75 1.0 + + >>> ds.eval("a + b") + Size: 40B + array([0. , 1.25, 2.5 , 3.75, 5. ]) + Dimensions without coordinates: x + + >>> ds.eval("c = a + b") + Size: 120B + Dimensions: (x: 5) + Dimensions without coordinates: x + Data variables: + a (x) int64 40B 0 1 2 3 4 + b (x) float64 40B 0.0 0.25 0.5 0.75 1.0 + c (x) float64 40B 0.0 1.25 2.5 3.75 5.0 + """ + + return pd.eval( + statement, + resolvers=[self], + target=self, + parser=parser, + # Because numexpr returns a numpy array, using that engine results in + # different behavior. We'd be very open to a contribution handling this. + engine="python", + ) + + def query( + self, + queries: Mapping[Any, Any] | None = None, + parser: QueryParserOptions = "pandas", + engine: QueryEngineOptions = None, + missing_dims: ErrorOptionsWithWarn = "raise", + **queries_kwargs: Any, + ) -> Self: + """Return a new dataset with each array indexed along the specified + dimension(s), where the indexers are given as strings containing + Python expressions to be evaluated against the data variables in the + dataset. + + Parameters + ---------- + queries : dict-like, optional + A dict-like with keys matching dimensions and values given by strings + containing Python expressions to be evaluated against the data variables + in the dataset. The expressions will be evaluated using the pandas + eval() function, and can contain any valid Python expressions but cannot + contain any Python statements. + parser : {"pandas", "python"}, default: "pandas" + The parser to use to construct the syntax tree from the expression. + The default of 'pandas' parses code slightly different than standard + Python. Alternatively, you can parse an expression using the 'python' + parser to retain strict Python semantics. + engine : {"python", "numexpr", None}, default: None + The engine used to evaluate the expression. Supported engines are: + + - None: tries to use numexpr, falls back to python + - "numexpr": evaluates expressions using numexpr + - "python": performs operations as if you had eval’d in top level python + + missing_dims : {"raise", "warn", "ignore"}, default: "raise" + What to do if dimensions that should be selected from are not present in the + Dataset: + + - "raise": raise an exception + - "warn": raise a warning, and ignore the missing dimensions + - "ignore": ignore the missing dimensions + + **queries_kwargs : {dim: query, ...}, optional + The keyword arguments form of ``queries``. + One of queries or queries_kwargs must be provided. + + Returns + ------- + obj : Dataset + A new Dataset with the same contents as this dataset, except each + array and dimension is indexed by the results of the appropriate + queries. + + See Also + -------- + Dataset.isel + pandas.eval + + Examples + -------- + >>> a = np.arange(0, 5, 1) + >>> b = np.linspace(0, 1, 5) + >>> ds = xr.Dataset({"a": ("x", a), "b": ("x", b)}) + >>> ds + Size: 80B + Dimensions: (x: 5) + Dimensions without coordinates: x + Data variables: + a (x) int64 40B 0 1 2 3 4 + b (x) float64 40B 0.0 0.25 0.5 0.75 1.0 + >>> ds.query(x="a > 2") + Size: 32B + Dimensions: (x: 2) + Dimensions without coordinates: x + Data variables: + a (x) int64 16B 3 4 + b (x) float64 16B 0.75 1.0 + """ + + # allow queries to be given either as a dict or as kwargs + queries = either_dict_or_kwargs(queries, queries_kwargs, "query") + + # check queries + for dim, expr in queries.items(): + if not isinstance(expr, str): + msg = f"expr for dim {dim} must be a string to be evaluated, {type(expr)} given" + raise ValueError(msg) + + # evaluate the queries to create the indexers + indexers = { + dim: pd.eval(expr, resolvers=[self], parser=parser, engine=engine) + for dim, expr in queries.items() + } + + # apply the selection + return self.isel(indexers, missing_dims=missing_dims) + + def curvefit( + self, + coords: str | DataArray | Iterable[str | DataArray], + func: Callable[..., Any], + reduce_dims: Dims = None, + skipna: bool = True, + p0: Mapping[str, float | DataArray] | None = None, + bounds: Mapping[str, tuple[float | DataArray, float | DataArray]] | None = None, + param_names: Sequence[str] | None = None, + errors: ErrorOptions = "raise", + kwargs: dict[str, Any] | None = None, + ) -> Self: + """ + Curve fitting optimization for arbitrary functions. + + Wraps `scipy.optimize.curve_fit` with `apply_ufunc`. + + Parameters + ---------- + coords : hashable, DataArray, or sequence of hashable or DataArray + Independent coordinate(s) over which to perform the curve fitting. Must share + at least one dimension with the calling object. When fitting multi-dimensional + functions, supply `coords` as a sequence in the same order as arguments in + `func`. To fit along existing dimensions of the calling object, `coords` can + also be specified as a str or sequence of strs. + func : callable + User specified function in the form `f(x, *params)` which returns a numpy + array of length `len(x)`. `params` are the fittable parameters which are optimized + by scipy curve_fit. `x` can also be specified as a sequence containing multiple + coordinates, e.g. `f((x0, x1), *params)`. + reduce_dims : str, Iterable of Hashable or None, optional + Additional dimension(s) over which to aggregate while fitting. For example, + calling `ds.curvefit(coords='time', reduce_dims=['lat', 'lon'], ...)` will + aggregate all lat and lon points and fit the specified function along the + time dimension. + skipna : bool, default: True + Whether to skip missing values when fitting. Default is True. + p0 : dict-like, optional + Optional dictionary of parameter names to initial guesses passed to the + `curve_fit` `p0` arg. If the values are DataArrays, they will be appropriately + broadcast to the coordinates of the array. If none or only some parameters are + passed, the rest will be assigned initial values following the default scipy + behavior. + bounds : dict-like, optional + Optional dictionary of parameter names to tuples of bounding values passed to the + `curve_fit` `bounds` arg. If any of the bounds are DataArrays, they will be + appropriately broadcast to the coordinates of the array. If none or only some + parameters are passed, the rest will be unbounded following the default scipy + behavior. + param_names : sequence of hashable, optional + Sequence of names for the fittable parameters of `func`. If not supplied, + this will be automatically determined by arguments of `func`. `param_names` + should be manually supplied when fitting a function that takes a variable + number of parameters. + errors : {"raise", "ignore"}, default: "raise" + If 'raise', any errors from the `scipy.optimize_curve_fit` optimization will + raise an exception. If 'ignore', the coefficients and covariances for the + coordinates where the fitting failed will be NaN. + **kwargs : optional + Additional keyword arguments to passed to scipy curve_fit. + + Returns + ------- + curvefit_results : Dataset + A single dataset which contains: + + [var]_curvefit_coefficients + The coefficients of the best fit. + [var]_curvefit_covariance + The covariance matrix of the coefficient estimates. + + See Also + -------- + Dataset.polyfit + scipy.optimize.curve_fit + """ + from scipy.optimize import curve_fit + + from xarray.core.alignment import broadcast + from xarray.core.computation import apply_ufunc + from xarray.core.dataarray import _THIS_ARRAY, DataArray + + if p0 is None: + p0 = {} + if bounds is None: + bounds = {} + if kwargs is None: + kwargs = {} + + reduce_dims_: list[Hashable] + if not reduce_dims: + reduce_dims_ = [] + elif isinstance(reduce_dims, str) or not isinstance(reduce_dims, Iterable): + reduce_dims_ = [reduce_dims] + else: + reduce_dims_ = list(reduce_dims) + + if ( + isinstance(coords, str) + or isinstance(coords, DataArray) + or not isinstance(coords, Iterable) + ): + coords = [coords] + coords_: Sequence[DataArray] = [ + self[coord] if isinstance(coord, str) else coord for coord in coords + ] + + # Determine whether any coords are dims on self + for coord in coords_: + reduce_dims_ += [c for c in self.dims if coord.equals(self[c])] + reduce_dims_ = list(set(reduce_dims_)) + preserved_dims = list(set(self.dims) - set(reduce_dims_)) + if not reduce_dims_: + raise ValueError( + "No arguments to `coords` were identified as a dimension on the calling " + "object, and no dims were supplied to `reduce_dims`. This would result " + "in fitting on scalar data." + ) + + # Check that initial guess and bounds only contain coordinates that are in preserved_dims + for param, guess in p0.items(): + if isinstance(guess, DataArray): + unexpected = set(guess.dims) - set(preserved_dims) + if unexpected: + raise ValueError( + f"Initial guess for '{param}' has unexpected dimensions " + f"{tuple(unexpected)}. It should only have dimensions that are in data " + f"dimensions {preserved_dims}." + ) + for param, (lb, ub) in bounds.items(): + for label, bound in zip(("Lower", "Upper"), (lb, ub)): + if isinstance(bound, DataArray): + unexpected = set(bound.dims) - set(preserved_dims) + if unexpected: + raise ValueError( + f"{label} bound for '{param}' has unexpected dimensions " + f"{tuple(unexpected)}. It should only have dimensions that are in data " + f"dimensions {preserved_dims}." + ) + + if errors not in ["raise", "ignore"]: + raise ValueError('errors must be either "raise" or "ignore"') + + # Broadcast all coords with each other + coords_ = broadcast(*coords_) + coords_ = [ + coord.broadcast_like(self, exclude=preserved_dims) for coord in coords_ + ] + n_coords = len(coords_) + + params, func_args = _get_func_args(func, param_names) + param_defaults, bounds_defaults = _initialize_curvefit_params( + params, p0, bounds, func_args + ) + n_params = len(params) + + def _wrapper(Y, *args, **kwargs): + # Wrap curve_fit with raveled coordinates and pointwise NaN handling + # *args contains: + # - the coordinates + # - initial guess + # - lower bounds + # - upper bounds + coords__ = args[:n_coords] + p0_ = args[n_coords + 0 * n_params : n_coords + 1 * n_params] + lb = args[n_coords + 1 * n_params : n_coords + 2 * n_params] + ub = args[n_coords + 2 * n_params :] + + x = np.vstack([c.ravel() for c in coords__]) + y = Y.ravel() + if skipna: + mask = np.all([np.any(~np.isnan(x), axis=0), ~np.isnan(y)], axis=0) + x = x[:, mask] + y = y[mask] + if not len(y): + popt = np.full([n_params], np.nan) + pcov = np.full([n_params, n_params], np.nan) + return popt, pcov + x = np.squeeze(x) + + try: + popt, pcov = curve_fit(func, x, y, p0=p0_, bounds=(lb, ub), **kwargs) + except RuntimeError: + if errors == "raise": + raise + popt = np.full([n_params], np.nan) + pcov = np.full([n_params, n_params], np.nan) + + return popt, pcov + + result = type(self)() + for name, da in self.data_vars.items(): + if name is _THIS_ARRAY: + name = "" + else: + name = f"{str(name)}_" + + input_core_dims = [reduce_dims_ for _ in range(n_coords + 1)] + input_core_dims.extend( + [[] for _ in range(3 * n_params)] + ) # core_dims for p0 and bounds + + popt, pcov = apply_ufunc( + _wrapper, + da, + *coords_, + *param_defaults.values(), + *[b[0] for b in bounds_defaults.values()], + *[b[1] for b in bounds_defaults.values()], + vectorize=True, + dask="parallelized", + input_core_dims=input_core_dims, + output_core_dims=[["param"], ["cov_i", "cov_j"]], + dask_gufunc_kwargs={ + "output_sizes": { + "param": n_params, + "cov_i": n_params, + "cov_j": n_params, + }, + }, + output_dtypes=(np.float64, np.float64), + exclude_dims=set(reduce_dims_), + kwargs=kwargs, + ) + result[name + "curvefit_coefficients"] = popt + result[name + "curvefit_covariance"] = pcov + + result = result.assign_coords( + {"param": params, "cov_i": params, "cov_j": params} + ) + result.attrs = self.attrs.copy() + + return result + + @_deprecate_positional_args("v2023.10.0") + def drop_duplicates( + self, + dim: Hashable | Iterable[Hashable], + *, + keep: Literal["first", "last", False] = "first", + ) -> Self: + """Returns a new Dataset with duplicate dimension values removed. + + Parameters + ---------- + dim : dimension label or labels + Pass `...` to drop duplicates along all dimensions. + keep : {"first", "last", False}, default: "first" + Determines which duplicates (if any) to keep. + - ``"first"`` : Drop duplicates except for the first occurrence. + - ``"last"`` : Drop duplicates except for the last occurrence. + - False : Drop all duplicates. + + Returns + ------- + Dataset + + See Also + -------- + DataArray.drop_duplicates + """ + if isinstance(dim, str): + dims: Iterable = (dim,) + elif dim is ...: + dims = self.dims + elif not isinstance(dim, Iterable): + dims = [dim] + else: + dims = dim + + missing_dims = set(dims) - set(self.dims) + if missing_dims: + raise ValueError( + f"Dimensions {tuple(missing_dims)} not found in data dimensions {tuple(self.dims)}" + ) + + indexes = {dim: ~self.get_index(dim).duplicated(keep=keep) for dim in dims} + return self.isel(indexes) + + def convert_calendar( + self, + calendar: CFCalendar, + dim: Hashable = "time", + align_on: Literal["date", "year", None] = None, + missing: Any | None = None, + use_cftime: bool | None = None, + ) -> Self: + """Convert the Dataset to another calendar. + + Only converts the individual timestamps, does not modify any data except + in dropping invalid/surplus dates or inserting missing dates. + + If the source and target calendars are either no_leap, all_leap or a + standard type, only the type of the time array is modified. + When converting to a leap year from a non-leap year, the 29th of February + is removed from the array. In the other direction the 29th of February + will be missing in the output, unless `missing` is specified, + in which case that value is inserted. + + For conversions involving `360_day` calendars, see Notes. + + This method is safe to use with sub-daily data as it doesn't touch the + time part of the timestamps. + + Parameters + --------- + calendar : str + The target calendar name. + dim : Hashable, default: "time" + Name of the time coordinate. + align_on : {None, 'date', 'year'}, optional + Must be specified when either source or target is a `360_day` calendar, + ignored otherwise. See Notes. + missing : Any or None, optional + By default, i.e. if the value is None, this method will simply attempt + to convert the dates in the source calendar to the same dates in the + target calendar, and drop any of those that are not possible to + represent. If a value is provided, a new time coordinate will be + created in the target calendar with the same frequency as the original + time coordinate; for any dates that are not present in the source, the + data will be filled with this value. Note that using this mode requires + that the source data have an inferable frequency; for more information + see :py:func:`xarray.infer_freq`. For certain frequency, source, and + target calendar combinations, this could result in many missing values, see notes. + use_cftime : bool or None, optional + Whether to use cftime objects in the output, only used if `calendar` + is one of {"proleptic_gregorian", "gregorian" or "standard"}. + If True, the new time axis uses cftime objects. + If None (default), it uses :py:class:`numpy.datetime64` values if the + date range permits it, and :py:class:`cftime.datetime` objects if not. + If False, it uses :py:class:`numpy.datetime64` or fails. + + Returns + ------- + Dataset + Copy of the dataarray with the time coordinate converted to the + target calendar. If 'missing' was None (default), invalid dates in + the new calendar are dropped, but missing dates are not inserted. + If `missing` was given, the new data is reindexed to have a time axis + with the same frequency as the source, but in the new calendar; any + missing datapoints are filled with `missing`. + + Notes + ----- + Passing a value to `missing` is only usable if the source's time coordinate as an + inferable frequencies (see :py:func:`~xarray.infer_freq`) and is only appropriate + if the target coordinate, generated from this frequency, has dates equivalent to the + source. It is usually **not** appropriate to use this mode with: + + - Period-end frequencies : 'A', 'Y', 'Q' or 'M', in opposition to 'AS' 'YS', 'QS' and 'MS' + - Sub-monthly frequencies that do not divide a day evenly : 'W', 'nD' where `N != 1` + or 'mH' where 24 % m != 0). + + If one of the source or target calendars is `"360_day"`, `align_on` must + be specified and two options are offered. + + - "year" + The dates are translated according to their relative position in the year, + ignoring their original month and day information, meaning that the + missing/surplus days are added/removed at regular intervals. + + From a `360_day` to a standard calendar, the output will be missing the + following dates (day of year in parentheses): + + To a leap year: + January 31st (31), March 31st (91), June 1st (153), July 31st (213), + September 31st (275) and November 30th (335). + To a non-leap year: + February 6th (36), April 19th (109), July 2nd (183), + September 12th (255), November 25th (329). + + From a standard calendar to a `"360_day"`, the following dates in the + source array will be dropped: + + From a leap year: + January 31st (31), April 1st (92), June 1st (153), August 1st (214), + September 31st (275), December 1st (336) + From a non-leap year: + February 6th (37), April 20th (110), July 2nd (183), + September 13th (256), November 25th (329) + + This option is best used on daily and subdaily data. + + - "date" + The month/day information is conserved and invalid dates are dropped + from the output. This means that when converting from a `"360_day"` to a + standard calendar, all 31st (Jan, March, May, July, August, October and + December) will be missing as there is no equivalent dates in the + `"360_day"` calendar and the 29th (on non-leap years) and 30th of February + will be dropped as there are no equivalent dates in a standard calendar. + + This option is best used with data on a frequency coarser than daily. + """ + return convert_calendar( + self, + calendar, + dim=dim, + align_on=align_on, + missing=missing, + use_cftime=use_cftime, + ) + + def interp_calendar( + self, + target: pd.DatetimeIndex | CFTimeIndex | DataArray, + dim: Hashable = "time", + ) -> Self: + """Interpolates the Dataset to another calendar based on decimal year measure. + + Each timestamp in `source` and `target` are first converted to their decimal + year equivalent then `source` is interpolated on the target coordinate. + The decimal year of a timestamp is its year plus its sub-year component + converted to the fraction of its year. For example "2000-03-01 12:00" is + 2000.1653 in a standard calendar or 2000.16301 in a `"noleap"` calendar. + + This method should only be used when the time (HH:MM:SS) information of + time coordinate is not important. + + Parameters + ---------- + target: DataArray or DatetimeIndex or CFTimeIndex + The target time coordinate of a valid dtype + (np.datetime64 or cftime objects) + dim : Hashable, default: "time" + The time coordinate name. + + Return + ------ + DataArray + The source interpolated on the decimal years of target, + """ + return interp_calendar(self, target, dim=dim) + + def groupby( + self, + group: Hashable | DataArray | IndexVariable, + squeeze: bool | None = None, + restore_coord_dims: bool = False, + ) -> DatasetGroupBy: + """Returns a DatasetGroupBy object for performing grouped operations. + + Parameters + ---------- + group : Hashable, DataArray or IndexVariable + Array whose unique values should be used to group this array. If a + string, must be the name of a variable contained in this dataset. + squeeze : bool, default: True + If "group" is a dimension of any arrays in this dataset, `squeeze` + controls whether the subarrays have a dimension of length 1 along + that dimension or if the dimension is squeezed out. + restore_coord_dims : bool, default: False + If True, also restore the dimension order of multi-dimensional + coordinates. + + Returns + ------- + grouped : DatasetGroupBy + A `DatasetGroupBy` object patterned after `pandas.GroupBy` that can be + iterated over in the form of `(unique_value, grouped_array)` pairs. + + See Also + -------- + :ref:`groupby` + Users guide explanation of how to group and bin data. + + :doc:`xarray-tutorial:intermediate/01-high-level-computation-patterns` + Tutorial on :py:func:`~xarray.Dataset.Groupby` for windowed computation. + + :doc:`xarray-tutorial:fundamentals/03.2_groupby_with_xarray` + Tutorial on :py:func:`~xarray.Dataset.Groupby` demonstrating reductions, transformation and comparison with :py:func:`~xarray.Dataset.resample`. + + Dataset.groupby_bins + DataArray.groupby + core.groupby.DatasetGroupBy + pandas.DataFrame.groupby + Dataset.coarsen + Dataset.resample + DataArray.resample + """ + from xarray.core.groupby import ( + DatasetGroupBy, + ResolvedGrouper, + UniqueGrouper, + _validate_groupby_squeeze, + ) + + _validate_groupby_squeeze(squeeze) + rgrouper = ResolvedGrouper(UniqueGrouper(), group, self) + + return DatasetGroupBy( + self, + (rgrouper,), + squeeze=squeeze, + restore_coord_dims=restore_coord_dims, + ) + + def groupby_bins( + self, + group: Hashable | DataArray | IndexVariable, + bins: ArrayLike, + right: bool = True, + labels: ArrayLike | None = None, + precision: int = 3, + include_lowest: bool = False, + squeeze: bool | None = None, + restore_coord_dims: bool = False, + ) -> DatasetGroupBy: + """Returns a DatasetGroupBy object for performing grouped operations. + + Rather than using all unique values of `group`, the values are discretized + first by applying `pandas.cut` [1]_ to `group`. + + Parameters + ---------- + group : Hashable, DataArray or IndexVariable + Array whose binned values should be used to group this array. If a + string, must be the name of a variable contained in this dataset. + bins : int or array-like + If bins is an int, it defines the number of equal-width bins in the + range of x. However, in this case, the range of x is extended by .1% + on each side to include the min or max values of x. If bins is a + sequence it defines the bin edges allowing for non-uniform bin + width. No extension of the range of x is done in this case. + right : bool, default: True + Indicates whether the bins include the rightmost edge or not. If + right == True (the default), then the bins [1,2,3,4] indicate + (1,2], (2,3], (3,4]. + labels : array-like or bool, default: None + Used as labels for the resulting bins. Must be of the same length as + the resulting bins. If False, string bin labels are assigned by + `pandas.cut`. + precision : int, default: 3 + The precision at which to store and display the bins labels. + include_lowest : bool, default: False + Whether the first interval should be left-inclusive or not. + squeeze : bool, default: True + If "group" is a dimension of any arrays in this dataset, `squeeze` + controls whether the subarrays have a dimension of length 1 along + that dimension or if the dimension is squeezed out. + restore_coord_dims : bool, default: False + If True, also restore the dimension order of multi-dimensional + coordinates. + + Returns + ------- + grouped : DatasetGroupBy + A `DatasetGroupBy` object patterned after `pandas.GroupBy` that can be + iterated over in the form of `(unique_value, grouped_array)` pairs. + The name of the group has the added suffix `_bins` in order to + distinguish it from the original variable. + + See Also + -------- + :ref:`groupby` + Users guide explanation of how to group and bin data. + Dataset.groupby + DataArray.groupby_bins + core.groupby.DatasetGroupBy + pandas.DataFrame.groupby + + References + ---------- + .. [1] http://pandas.pydata.org/pandas-docs/stable/generated/pandas.cut.html + """ + from xarray.core.groupby import ( + BinGrouper, + DatasetGroupBy, + ResolvedGrouper, + _validate_groupby_squeeze, + ) + + _validate_groupby_squeeze(squeeze) + grouper = BinGrouper( + bins=bins, + cut_kwargs={ + "right": right, + "labels": labels, + "precision": precision, + "include_lowest": include_lowest, + }, + ) + rgrouper = ResolvedGrouper(grouper, group, self) + + return DatasetGroupBy( + self, + (rgrouper,), + squeeze=squeeze, + restore_coord_dims=restore_coord_dims, + ) + + def weighted(self, weights: DataArray) -> DatasetWeighted: + """ + Weighted Dataset operations. + + Parameters + ---------- + weights : DataArray + An array of weights associated with the values in this Dataset. + Each value in the data contributes to the reduction operation + according to its associated weight. + + Notes + ----- + ``weights`` must be a DataArray and cannot contain missing values. + Missing values can be replaced by ``weights.fillna(0)``. + + Returns + ------- + core.weighted.DatasetWeighted + + See Also + -------- + DataArray.weighted + + :ref:`comput.weighted` + User guide on weighted array reduction using :py:func:`~xarray.Dataset.weighted` + + :doc:`xarray-tutorial:fundamentals/03.4_weighted` + Tutorial on Weighted Reduction using :py:func:`~xarray.Dataset.weighted` + + """ + from xarray.core.weighted import DatasetWeighted + + return DatasetWeighted(self, weights) + + def rolling( + self, + dim: Mapping[Any, int] | None = None, + min_periods: int | None = None, + center: bool | Mapping[Any, bool] = False, + **window_kwargs: int, + ) -> DatasetRolling: + """ + Rolling window object for Datasets. + + Parameters + ---------- + dim : dict, optional + Mapping from the dimension name to create the rolling iterator + along (e.g. `time`) to its moving window size. + min_periods : int or None, default: None + Minimum number of observations in window required to have a value + (otherwise result is NA). The default, None, is equivalent to + setting min_periods equal to the size of the window. + center : bool or Mapping to int, default: False + Set the labels at the center of the window. + **window_kwargs : optional + The keyword arguments form of ``dim``. + One of dim or window_kwargs must be provided. + + Returns + ------- + core.rolling.DatasetRolling + + See Also + -------- + Dataset.cumulative + DataArray.rolling + core.rolling.DatasetRolling + """ + from xarray.core.rolling import DatasetRolling + + dim = either_dict_or_kwargs(dim, window_kwargs, "rolling") + return DatasetRolling(self, dim, min_periods=min_periods, center=center) + + def cumulative( + self, + dim: str | Iterable[Hashable], + min_periods: int = 1, + ) -> DatasetRolling: + """ + Accumulating object for Datasets + + Parameters + ---------- + dims : iterable of hashable + The name(s) of the dimensions to create the cumulative window along + min_periods : int, default: 1 + Minimum number of observations in window required to have a value + (otherwise result is NA). The default is 1 (note this is different + from ``Rolling``, whose default is the size of the window). + + Returns + ------- + core.rolling.DatasetRolling + + See Also + -------- + Dataset.rolling + DataArray.cumulative + core.rolling.DatasetRolling + """ + from xarray.core.rolling import DatasetRolling + + if isinstance(dim, str): + if dim not in self.dims: + raise ValueError( + f"Dimension {dim} not found in data dimensions: {self.dims}" + ) + dim = {dim: self.sizes[dim]} + else: + missing_dims = set(dim) - set(self.dims) + if missing_dims: + raise ValueError( + f"Dimensions {missing_dims} not found in data dimensions: {self.dims}" + ) + dim = {d: self.sizes[d] for d in dim} + + return DatasetRolling(self, dim, min_periods=min_periods, center=False) + + def coarsen( + self, + dim: Mapping[Any, int] | None = None, + boundary: CoarsenBoundaryOptions = "exact", + side: SideOptions | Mapping[Any, SideOptions] = "left", + coord_func: str | Callable | Mapping[Any, str | Callable] = "mean", + **window_kwargs: int, + ) -> DatasetCoarsen: + """ + Coarsen object for Datasets. + + Parameters + ---------- + dim : mapping of hashable to int, optional + Mapping from the dimension name to the window size. + boundary : {"exact", "trim", "pad"}, default: "exact" + If 'exact', a ValueError will be raised if dimension size is not a + multiple of the window size. If 'trim', the excess entries are + dropped. If 'pad', NA will be padded. + side : {"left", "right"} or mapping of str to {"left", "right"}, default: "left" + coord_func : str or mapping of hashable to str, default: "mean" + function (name) that is applied to the coordinates, + or a mapping from coordinate name to function (name). + + Returns + ------- + core.rolling.DatasetCoarsen + + See Also + -------- + core.rolling.DatasetCoarsen + DataArray.coarsen + + :ref:`reshape.coarsen` + User guide describing :py:func:`~xarray.Dataset.coarsen` + + :ref:`compute.coarsen` + User guide on block arrgragation :py:func:`~xarray.Dataset.coarsen` + + :doc:`xarray-tutorial:fundamentals/03.3_windowed` + Tutorial on windowed computation using :py:func:`~xarray.Dataset.coarsen` + + """ + from xarray.core.rolling import DatasetCoarsen + + dim = either_dict_or_kwargs(dim, window_kwargs, "coarsen") + return DatasetCoarsen( + self, + dim, + boundary=boundary, + side=side, + coord_func=coord_func, + ) + + def resample( + self, + indexer: Mapping[Any, str] | None = None, + skipna: bool | None = None, + closed: SideOptions | None = None, + label: SideOptions | None = None, + base: int | None = None, + offset: pd.Timedelta | datetime.timedelta | str | None = None, + origin: str | DatetimeLike = "start_day", + loffset: datetime.timedelta | str | None = None, + restore_coord_dims: bool | None = None, + **indexer_kwargs: str, + ) -> DatasetResample: + """Returns a Resample object for performing resampling operations. + + Handles both downsampling and upsampling. The resampled + dimension must be a datetime-like coordinate. If any intervals + contain no values from the original object, they will be given + the value ``NaN``. + + Parameters + ---------- + indexer : Mapping of Hashable to str, optional + Mapping from the dimension name to resample frequency [1]_. The + dimension must be datetime-like. + skipna : bool, optional + Whether to skip missing values when aggregating in downsampling. + closed : {"left", "right"}, optional + Side of each interval to treat as closed. + label : {"left", "right"}, optional + Side of each interval to use for labeling. + base : int, optional + For frequencies that evenly subdivide 1 day, the "origin" of the + aggregated intervals. For example, for "24H" frequency, base could + range from 0 through 23. + origin : {'epoch', 'start', 'start_day', 'end', 'end_day'}, pd.Timestamp, datetime.datetime, np.datetime64, or cftime.datetime, default 'start_day' + The datetime on which to adjust the grouping. The timezone of origin + must match the timezone of the index. + + If a datetime is not used, these values are also supported: + - 'epoch': `origin` is 1970-01-01 + - 'start': `origin` is the first value of the timeseries + - 'start_day': `origin` is the first day at midnight of the timeseries + - 'end': `origin` is the last value of the timeseries + - 'end_day': `origin` is the ceiling midnight of the last day + offset : pd.Timedelta, datetime.timedelta, or str, default is None + An offset timedelta added to the origin. + loffset : timedelta or str, optional + Offset used to adjust the resampled time labels. Some pandas date + offset strings are supported. + + .. deprecated:: 2023.03.0 + Following pandas, the ``loffset`` parameter is deprecated in favor + of using time offset arithmetic, and will be removed in a future + version of xarray. + + restore_coord_dims : bool, optional + If True, also restore the dimension order of multi-dimensional + coordinates. + **indexer_kwargs : str + The keyword arguments form of ``indexer``. + One of indexer or indexer_kwargs must be provided. + + Returns + ------- + resampled : core.resample.DataArrayResample + This object resampled. + + See Also + -------- + DataArray.resample + pandas.Series.resample + pandas.DataFrame.resample + Dataset.groupby + DataArray.groupby + + References + ---------- + .. [1] http://pandas.pydata.org/pandas-docs/stable/timeseries.html#offset-aliases + """ + from xarray.core.resample import DatasetResample + + return self._resample( + resample_cls=DatasetResample, + indexer=indexer, + skipna=skipna, + closed=closed, + label=label, + base=base, + offset=offset, + origin=origin, + loffset=loffset, + restore_coord_dims=restore_coord_dims, + **indexer_kwargs, + ) diff --git a/test/fixtures/whole_applications/xarray/xarray/core/datatree.py b/test/fixtures/whole_applications/xarray/xarray/core/datatree.py new file mode 100644 index 0000000..4e4d308 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/core/datatree.py @@ -0,0 +1,1600 @@ +from __future__ import annotations + +import copy +import itertools +from collections.abc import Hashable, Iterable, Iterator, Mapping, MutableMapping +from html import escape +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generic, + Literal, + NoReturn, + Union, + overload, +) + +from xarray.core import utils +from xarray.core.common import TreeAttrAccessMixin +from xarray.core.coordinates import DatasetCoordinates +from xarray.core.dataarray import DataArray +from xarray.core.dataset import Dataset, DataVariables +from xarray.core.datatree_mapping import ( + TreeIsomorphismError, + check_isomorphic, + map_over_subtree, +) +from xarray.core.datatree_ops import ( + DataTreeArithmeticMixin, + MappedDatasetMethodsMixin, + MappedDataWithCoords, +) +from xarray.core.datatree_render import RenderDataTree +from xarray.core.formatting import datatree_repr +from xarray.core.formatting_html import ( + datatree_repr as datatree_repr_html, +) +from xarray.core.indexes import Index, Indexes +from xarray.core.merge import dataset_update_method +from xarray.core.options import OPTIONS as XR_OPTS +from xarray.core.treenode import NamedNode, NodePath, Tree +from xarray.core.utils import ( + Default, + Frozen, + HybridMappingProxy, + _default, + either_dict_or_kwargs, + maybe_wrap_array, +) +from xarray.core.variable import Variable + +try: + from xarray.core.variable import calculate_dimensions +except ImportError: + # for xarray versions 2022.03.0 and earlier + from xarray.core.dataset import calculate_dimensions + +if TYPE_CHECKING: + import pandas as pd + + from xarray.core.datatree_io import T_DataTreeNetcdfEngine, T_DataTreeNetcdfTypes + from xarray.core.merge import CoercibleValue + from xarray.core.types import ErrorOptions, NetcdfWriteModes, ZarrWriteModes + +# """ +# DEVELOPERS' NOTE +# ---------------- +# The idea of this module is to create a `DataTree` class which inherits the tree structure from TreeNode, and also copies +# the entire API of `xarray.Dataset`, but with certain methods decorated to instead map the dataset function over every +# node in the tree. As this API is copied without directly subclassing `xarray.Dataset` we instead create various Mixin +# classes (in ops.py) which each define part of `xarray.Dataset`'s extensive API. +# +# Some of these methods must be wrapped to map over all nodes in the subtree. Others are fine to inherit unaltered +# (normally because they (a) only call dataset properties and (b) don't return a dataset that should be nested into a new +# tree) and some will get overridden by the class definition of DataTree. +# """ + + +T_Path = Union[str, NodePath] + + +def _coerce_to_dataset(data: Dataset | DataArray | None) -> Dataset: + if isinstance(data, DataArray): + ds = data.to_dataset() + elif isinstance(data, Dataset): + ds = data + elif data is None: + ds = Dataset() + else: + raise TypeError( + f"data object is not an xarray Dataset, DataArray, or None, it is of type {type(data)}" + ) + return ds + + +def _check_for_name_collisions( + children: Iterable[str], variables: Iterable[Hashable] +) -> None: + colliding_names = set(children).intersection(set(variables)) + if colliding_names: + raise KeyError( + f"Some names would collide between variables and children: {list(colliding_names)}" + ) + + +class DatasetView(Dataset): + """ + An immutable Dataset-like view onto the data in a single DataTree node. + + In-place operations modifying this object should raise an AttributeError. + This requires overriding all inherited constructors. + + Operations returning a new result will return a new xarray.Dataset object. + This includes all API on Dataset, which will be inherited. + """ + + # TODO what happens if user alters (in-place) a DataArray they extracted from this object? + + __slots__ = ( + "_attrs", + "_cache", + "_coord_names", + "_dims", + "_encoding", + "_close", + "_indexes", + "_variables", + ) + + def __init__( + self, + data_vars: Mapping[Any, Any] | None = None, + coords: Mapping[Any, Any] | None = None, + attrs: Mapping[Any, Any] | None = None, + ): + raise AttributeError("DatasetView objects are not to be initialized directly") + + @classmethod + def _from_node( + cls, + wrapping_node: DataTree, + ) -> DatasetView: + """Constructor, using dataset attributes from wrapping node""" + + obj: DatasetView = object.__new__(cls) + obj._variables = wrapping_node._variables + obj._coord_names = wrapping_node._coord_names + obj._dims = wrapping_node._dims + obj._indexes = wrapping_node._indexes + obj._attrs = wrapping_node._attrs + obj._close = wrapping_node._close + obj._encoding = wrapping_node._encoding + + return obj + + def __setitem__(self, key, val) -> None: + raise AttributeError( + "Mutation of the DatasetView is not allowed, please use `.__setitem__` on the wrapping DataTree node, " + "or use `dt.to_dataset()` if you want a mutable dataset. If calling this from within `map_over_subtree`," + "use `.copy()` first to get a mutable version of the input dataset." + ) + + def update(self, other) -> NoReturn: + raise AttributeError( + "Mutation of the DatasetView is not allowed, please use `.update` on the wrapping DataTree node, " + "or use `dt.to_dataset()` if you want a mutable dataset. If calling this from within `map_over_subtree`," + "use `.copy()` first to get a mutable version of the input dataset." + ) + + # FIXME https://github.com/python/mypy/issues/7328 + @overload # type: ignore[override] + def __getitem__(self, key: Mapping) -> Dataset: # type: ignore[overload-overlap] + ... + + @overload + def __getitem__(self, key: Hashable) -> DataArray: # type: ignore[overload-overlap] + ... + + # See: https://github.com/pydata/xarray/issues/8855 + @overload + def __getitem__(self, key: Any) -> Dataset: ... + + def __getitem__(self, key) -> DataArray | Dataset: + # TODO call the `_get_item` method of DataTree to allow path-like access to contents of other nodes + # For now just call Dataset.__getitem__ + return Dataset.__getitem__(self, key) + + @classmethod + def _construct_direct( # type: ignore[override] + cls, + variables: dict[Any, Variable], + coord_names: set[Hashable], + dims: dict[Any, int] | None = None, + attrs: dict | None = None, + indexes: dict[Any, Index] | None = None, + encoding: dict | None = None, + close: Callable[[], None] | None = None, + ) -> Dataset: + """ + Overriding this method (along with ._replace) and modifying it to return a Dataset object + should hopefully ensure that the return type of any method on this object is a Dataset. + """ + if dims is None: + dims = calculate_dimensions(variables) + if indexes is None: + indexes = {} + obj = object.__new__(Dataset) + obj._variables = variables + obj._coord_names = coord_names + obj._dims = dims + obj._indexes = indexes + obj._attrs = attrs + obj._close = close + obj._encoding = encoding + return obj + + def _replace( # type: ignore[override] + self, + variables: dict[Hashable, Variable] | None = None, + coord_names: set[Hashable] | None = None, + dims: dict[Any, int] | None = None, + attrs: dict[Hashable, Any] | None | Default = _default, + indexes: dict[Hashable, Index] | None = None, + encoding: dict | None | Default = _default, + inplace: bool = False, + ) -> Dataset: + """ + Overriding this method (along with ._construct_direct) and modifying it to return a Dataset object + should hopefully ensure that the return type of any method on this object is a Dataset. + """ + + if inplace: + raise AttributeError("In-place mutation of the DatasetView is not allowed") + + return Dataset._replace( + self, + variables=variables, + coord_names=coord_names, + dims=dims, + attrs=attrs, + indexes=indexes, + encoding=encoding, + inplace=inplace, + ) + + def map( # type: ignore[override] + self, + func: Callable, + keep_attrs: bool | None = None, + args: Iterable[Any] = (), + **kwargs: Any, + ) -> Dataset: + """Apply a function to each data variable in this dataset + + Parameters + ---------- + func : callable + Function which can be called in the form `func(x, *args, **kwargs)` + to transform each DataArray `x` in this dataset into another + DataArray. + keep_attrs : bool | None, optional + If True, both the dataset's and variables' attributes (`attrs`) will be + copied from the original objects to the new ones. If False, the new dataset + and variables will be returned without copying the attributes. + args : iterable, optional + Positional arguments passed on to `func`. + **kwargs : Any + Keyword arguments passed on to `func`. + + Returns + ------- + applied : Dataset + Resulting dataset from applying ``func`` to each data variable. + + Examples + -------- + >>> da = xr.DataArray(np.random.randn(2, 3)) + >>> ds = xr.Dataset({"foo": da, "bar": ("x", [-1, 2])}) + >>> ds + Size: 64B + Dimensions: (dim_0: 2, dim_1: 3, x: 2) + Dimensions without coordinates: dim_0, dim_1, x + Data variables: + foo (dim_0, dim_1) float64 48B 1.764 0.4002 0.9787 2.241 1.868 -0.9773 + bar (x) int64 16B -1 2 + >>> ds.map(np.fabs) + Size: 64B + Dimensions: (dim_0: 2, dim_1: 3, x: 2) + Dimensions without coordinates: dim_0, dim_1, x + Data variables: + foo (dim_0, dim_1) float64 48B 1.764 0.4002 0.9787 2.241 1.868 0.9773 + bar (x) float64 16B 1.0 2.0 + """ + + # Copied from xarray.Dataset so as not to call type(self), which causes problems (see https://github.com/xarray-contrib/datatree/issues/188). + # TODO Refactor xarray upstream to avoid needing to overwrite this. + # TODO This copied version will drop all attrs - the keep_attrs stuff should be re-instated + variables = { + k: maybe_wrap_array(v, func(v, *args, **kwargs)) + for k, v in self.data_vars.items() + } + # return type(self)(variables, attrs=attrs) + return Dataset(variables) + + +class DataTree( + NamedNode, + MappedDatasetMethodsMixin, + MappedDataWithCoords, + DataTreeArithmeticMixin, + TreeAttrAccessMixin, + Generic[Tree], + Mapping, +): + """ + A tree-like hierarchical collection of xarray objects. + + Attempts to present an API like that of xarray.Dataset, but methods are wrapped to also update all the tree's child nodes. + """ + + # TODO Some way of sorting children by depth + + # TODO do we need a watch out for if methods intended only for root nodes are called on non-root nodes? + + # TODO dataset methods which should not or cannot act over the whole tree, such as .to_array + + # TODO .loc method + + # TODO a lot of properties like .variables could be defined in a DataMapping class which both Dataset and DataTree inherit from + + # TODO all groupby classes + + # TODO a lot of properties like .variables could be defined in a DataMapping class which both Dataset and DataTree inherit from + + # TODO all groupby classes + + _name: str | None + _parent: DataTree | None + _children: dict[str, DataTree] + _attrs: dict[Hashable, Any] | None + _cache: dict[str, Any] + _coord_names: set[Hashable] + _dims: dict[Hashable, int] + _encoding: dict[Hashable, Any] | None + _close: Callable[[], None] | None + _indexes: dict[Hashable, Index] + _variables: dict[Hashable, Variable] + + __slots__ = ( + "_name", + "_parent", + "_children", + "_attrs", + "_cache", + "_coord_names", + "_dims", + "_encoding", + "_close", + "_indexes", + "_variables", + ) + + def __init__( + self, + data: Dataset | DataArray | None = None, + parent: DataTree | None = None, + children: Mapping[str, DataTree] | None = None, + name: str | None = None, + ): + """ + Create a single node of a DataTree. + + The node may optionally contain data in the form of data and coordinate variables, stored in the same way as + data is stored in an xarray.Dataset. + + Parameters + ---------- + data : Dataset, DataArray, or None, optional + Data to store under the .ds attribute of this node. DataArrays will be promoted to Datasets. + Default is None. + parent : DataTree, optional + Parent node to this node. Default is None. + children : Mapping[str, DataTree], optional + Any child nodes of this node. Default is None. + name : str, optional + Name for this node of the tree. Default is None. + + Returns + ------- + DataTree + + See Also + -------- + DataTree.from_dict + """ + + # validate input + if children is None: + children = {} + ds = _coerce_to_dataset(data) + _check_for_name_collisions(children, ds.variables) + + super().__init__(name=name) + + # set data attributes + self._replace( + inplace=True, + variables=ds._variables, + coord_names=ds._coord_names, + dims=ds._dims, + indexes=ds._indexes, + attrs=ds._attrs, + encoding=ds._encoding, + ) + self._close = ds._close + + # set tree attributes (must happen after variables set to avoid initialization errors) + self.children = children + self.parent = parent + + @property + def parent(self: DataTree) -> DataTree | None: + """Parent of this node.""" + return self._parent + + @parent.setter + def parent(self: DataTree, new_parent: DataTree) -> None: + if new_parent and self.name is None: + raise ValueError("Cannot set an unnamed node as a child of another node") + self._set_parent(new_parent, self.name) + + @property + def ds(self) -> DatasetView: + """ + An immutable Dataset-like view onto the data in this node. + + For a mutable Dataset containing the same data as in this node, use `.to_dataset()` instead. + + See Also + -------- + DataTree.to_dataset + """ + return DatasetView._from_node(self) + + @ds.setter + def ds(self, data: Dataset | DataArray | None = None) -> None: + # Known mypy issue for setters with different type to property: + # https://github.com/python/mypy/issues/3004 + ds = _coerce_to_dataset(data) + + _check_for_name_collisions(self.children, ds.variables) + + self._replace( + inplace=True, + variables=ds._variables, + coord_names=ds._coord_names, + dims=ds._dims, + indexes=ds._indexes, + attrs=ds._attrs, + encoding=ds._encoding, + ) + self._close = ds._close + + def _pre_attach(self: DataTree, parent: DataTree) -> None: + """ + Method which superclass calls before setting parent, here used to prevent having two + children with duplicate names (or a data variable with the same name as a child). + """ + super()._pre_attach(parent) + if self.name in list(parent.ds.variables): + raise KeyError( + f"parent {parent.name} already contains a data variable named {self.name}" + ) + + def to_dataset(self) -> Dataset: + """ + Return the data in this node as a new xarray.Dataset object. + + See Also + -------- + DataTree.ds + """ + return Dataset._construct_direct( + self._variables, + self._coord_names, + self._dims, + self._attrs, + self._indexes, + self._encoding, + self._close, + ) + + @property + def has_data(self): + """Whether or not there are any data variables in this node.""" + return len(self._variables) > 0 + + @property + def has_attrs(self) -> bool: + """Whether or not there are any metadata attributes in this node.""" + return len(self.attrs.keys()) > 0 + + @property + def is_empty(self) -> bool: + """False if node contains any data or attrs. Does not look at children.""" + return not (self.has_data or self.has_attrs) + + @property + def is_hollow(self) -> bool: + """True if only leaf nodes contain data.""" + return not any(node.has_data for node in self.subtree if not node.is_leaf) + + @property + def variables(self) -> Mapping[Hashable, Variable]: + """Low level interface to node contents as dict of Variable objects. + + This dictionary is frozen to prevent mutation that could violate + Dataset invariants. It contains all variable objects constituting this + DataTree node, including both data variables and coordinates. + """ + return Frozen(self._variables) + + @property + def attrs(self) -> dict[Hashable, Any]: + """Dictionary of global attributes on this node object.""" + if self._attrs is None: + self._attrs = {} + return self._attrs + + @attrs.setter + def attrs(self, value: Mapping[Any, Any]) -> None: + self._attrs = dict(value) + + @property + def encoding(self) -> dict: + """Dictionary of global encoding attributes on this node object.""" + if self._encoding is None: + self._encoding = {} + return self._encoding + + @encoding.setter + def encoding(self, value: Mapping) -> None: + self._encoding = dict(value) + + @property + def dims(self) -> Mapping[Hashable, int]: + """Mapping from dimension names to lengths. + + Cannot be modified directly, but is updated when adding new variables. + + Note that type of this object differs from `DataArray.dims`. + See `DataTree.sizes`, `Dataset.sizes`, and `DataArray.sizes` for consistently named + properties. + """ + return Frozen(self._dims) + + @property + def sizes(self) -> Mapping[Hashable, int]: + """Mapping from dimension names to lengths. + + Cannot be modified directly, but is updated when adding new variables. + + This is an alias for `DataTree.dims` provided for the benefit of + consistency with `DataArray.sizes`. + + See Also + -------- + DataArray.sizes + """ + return self.dims + + @property + def _attr_sources(self) -> Iterable[Mapping[Hashable, Any]]: + """Places to look-up items for attribute-style access""" + yield from self._item_sources + yield self.attrs + + @property + def _item_sources(self) -> Iterable[Mapping[Any, Any]]: + """Places to look-up items for key-completion""" + yield self.data_vars + yield HybridMappingProxy(keys=self._coord_names, mapping=self.coords) + + # virtual coordinates + yield HybridMappingProxy(keys=self.dims, mapping=self) + + # immediate child nodes + yield self.children + + def _ipython_key_completions_(self) -> list[str]: + """Provide method for the key-autocompletions in IPython. + See http://ipython.readthedocs.io/en/stable/config/integrating.html#tab-completion + For the details. + """ + + # TODO allow auto-completing relative string paths, e.g. `dt['path/to/../ node'` + # Would require changes to ipython's autocompleter, see https://github.com/ipython/ipython/issues/12420 + # Instead for now we only list direct paths to all node in subtree explicitly + + items_on_this_node = self._item_sources + full_file_like_paths_to_all_nodes_in_subtree = { + node.path[1:]: node for node in self.subtree + } + + all_item_sources = itertools.chain( + items_on_this_node, [full_file_like_paths_to_all_nodes_in_subtree] + ) + + items = { + item + for source in all_item_sources + for item in source + if isinstance(item, str) + } + return list(items) + + def __contains__(self, key: object) -> bool: + """The 'in' operator will return true or false depending on whether + 'key' is either an array stored in the datatree or a child node, or neither. + """ + return key in self.variables or key in self.children + + def __bool__(self) -> bool: + return bool(self.ds.data_vars) or bool(self.children) + + def __iter__(self) -> Iterator[Hashable]: + return itertools.chain(self.ds.data_vars, self.children) + + def __array__(self, dtype=None, copy=None): + raise TypeError( + "cannot directly convert a DataTree into a " + "numpy array. Instead, create an xarray.DataArray " + "first, either with indexing on the DataTree or by " + "invoking the `to_array()` method." + ) + + def __repr__(self) -> str: # type: ignore[override] + return datatree_repr(self) + + def __str__(self) -> str: + return datatree_repr(self) + + def _repr_html_(self): + """Make html representation of datatree object""" + if XR_OPTS["display_style"] == "text": + return f"
{escape(repr(self))}
" + return datatree_repr_html(self) + + @classmethod + def _construct_direct( + cls, + variables: dict[Any, Variable], + coord_names: set[Hashable], + dims: dict[Any, int] | None = None, + attrs: dict | None = None, + indexes: dict[Any, Index] | None = None, + encoding: dict | None = None, + name: str | None = None, + parent: DataTree | None = None, + children: dict[str, DataTree] | None = None, + close: Callable[[], None] | None = None, + ) -> DataTree: + """Shortcut around __init__ for internal use when we want to skip costly validation.""" + + # data attributes + if dims is None: + dims = calculate_dimensions(variables) + if indexes is None: + indexes = {} + if children is None: + children = dict() + + obj: DataTree = object.__new__(cls) + obj._variables = variables + obj._coord_names = coord_names + obj._dims = dims + obj._indexes = indexes + obj._attrs = attrs + obj._close = close + obj._encoding = encoding + + # tree attributes + obj._name = name + obj._children = children + obj._parent = parent + + return obj + + def _replace( + self: DataTree, + variables: dict[Hashable, Variable] | None = None, + coord_names: set[Hashable] | None = None, + dims: dict[Any, int] | None = None, + attrs: dict[Hashable, Any] | None | Default = _default, + indexes: dict[Hashable, Index] | None = None, + encoding: dict | None | Default = _default, + name: str | None | Default = _default, + parent: DataTree | None | Default = _default, + children: dict[str, DataTree] | None = None, + inplace: bool = False, + ) -> DataTree: + """ + Fastpath constructor for internal use. + + Returns an object with optionally replaced attributes. + + Explicitly passed arguments are *not* copied when placed on the new + datatree. It is up to the caller to ensure that they have the right type + and are not used elsewhere. + """ + # TODO Adding new children inplace using this method will cause bugs. + # You will end up with an inconsistency between the name of the child node and the key the child is stored under. + # Use ._set() instead for now + if inplace: + if variables is not None: + self._variables = variables + if coord_names is not None: + self._coord_names = coord_names + if dims is not None: + self._dims = dims + if attrs is not _default: + self._attrs = attrs + if indexes is not None: + self._indexes = indexes + if encoding is not _default: + self._encoding = encoding + if name is not _default: + self._name = name + if parent is not _default: + self._parent = parent + if children is not None: + self._children = children + obj = self + else: + if variables is None: + variables = self._variables.copy() + if coord_names is None: + coord_names = self._coord_names.copy() + if dims is None: + dims = self._dims.copy() + if attrs is _default: + attrs = copy.copy(self._attrs) + if indexes is None: + indexes = self._indexes.copy() + if encoding is _default: + encoding = copy.copy(self._encoding) + if name is _default: + name = self._name # no need to copy str objects or None + if parent is _default: + parent = copy.copy(self._parent) + if children is _default: + children = copy.copy(self._children) + obj = self._construct_direct( + variables, + coord_names, + dims, + attrs, + indexes, + encoding, + name, + parent, + children, + ) + return obj + + def copy( + self: DataTree, + deep: bool = False, + ) -> DataTree: + """ + Returns a copy of this subtree. + + Copies this node and all child nodes. + + If `deep=True`, a deep copy is made of each of the component variables. + Otherwise, a shallow copy of each of the component variable is made, so + that the underlying memory region of the new datatree is the same as in + the original datatree. + + Parameters + ---------- + deep : bool, default: False + Whether each component variable is loaded into memory and copied onto + the new object. Default is False. + + Returns + ------- + object : DataTree + New object with dimensions, attributes, coordinates, name, encoding, + and data of this node and all child nodes copied from original. + + See Also + -------- + xarray.Dataset.copy + pandas.DataFrame.copy + """ + return self._copy_subtree(deep=deep) + + def _copy_subtree( + self: DataTree, + deep: bool = False, + memo: dict[int, Any] | None = None, + ) -> DataTree: + """Copy entire subtree""" + new_tree = self._copy_node(deep=deep) + for node in self.descendants: + path = node.relative_to(self) + new_tree[path] = node._copy_node(deep=deep) + return new_tree + + def _copy_node( + self: DataTree, + deep: bool = False, + ) -> DataTree: + """Copy just one node of a tree""" + new_node: DataTree = DataTree() + new_node.name = self.name + new_node.ds = self.to_dataset().copy(deep=deep) # type: ignore[assignment] + return new_node + + def __copy__(self: DataTree) -> DataTree: + return self._copy_subtree(deep=False) + + def __deepcopy__(self: DataTree, memo: dict[int, Any] | None = None) -> DataTree: + return self._copy_subtree(deep=True, memo=memo) + + def get( # type: ignore[override] + self: DataTree, key: str, default: DataTree | DataArray | None = None + ) -> DataTree | DataArray | None: + """ + Access child nodes, variables, or coordinates stored in this node. + + Returned object will be either a DataTree or DataArray object depending on whether the key given points to a + child or variable. + + Parameters + ---------- + key : str + Name of variable / child within this node. Must lie in this immediate node (not elsewhere in the tree). + default : DataTree | DataArray | None, optional + A value to return if the specified key does not exist. Default return value is None. + """ + if key in self.children: + return self.children[key] + elif key in self.ds: + return self.ds[key] + else: + return default + + def __getitem__(self: DataTree, key: str) -> DataTree | DataArray: + """ + Access child nodes, variables, or coordinates stored anywhere in this tree. + + Returned object will be either a DataTree or DataArray object depending on whether the key given points to a + child or variable. + + Parameters + ---------- + key : str + Name of variable / child within this node, or unix-like path to variable / child within another node. + + Returns + ------- + DataTree | DataArray + """ + + # Either: + if utils.is_dict_like(key): + # dict-like indexing + raise NotImplementedError("Should this index over whole tree?") + elif isinstance(key, str): + # TODO should possibly deal with hashables in general? + # path-like: a name of a node/variable, or path to a node/variable + path = NodePath(key) + return self._get_item(path) + elif utils.is_list_like(key): + # iterable of variable names + raise NotImplementedError( + "Selecting via tags is deprecated, and selecting multiple items should be " + "implemented via .subset" + ) + else: + raise ValueError(f"Invalid format for key: {key}") + + def _set(self, key: str, val: DataTree | CoercibleValue) -> None: + """ + Set the child node or variable with the specified key to value. + + Counterpart to the public .get method, and also only works on the immediate node, not other nodes in the tree. + """ + if isinstance(val, DataTree): + # create and assign a shallow copy here so as not to alter original name of node in grafted tree + new_node = val.copy(deep=False) + new_node.name = key + new_node.parent = self + else: + if not isinstance(val, (DataArray, Variable)): + # accommodate other types that can be coerced into Variables + val = DataArray(val) + + self.update({key: val}) + + def __setitem__( + self, + key: str, + value: Any, + ) -> None: + """ + Add either a child node or an array to the tree, at any position. + + Data can be added anywhere, and new nodes will be created to cross the path to the new location if necessary. + + If there is already a node at the given location, then if value is a Node class or Dataset it will overwrite the + data already present at that node, and if value is a single array, it will be merged with it. + """ + # TODO xarray.Dataset accepts other possibilities, how do we exactly replicate all the behaviour? + if utils.is_dict_like(key): + raise NotImplementedError + elif isinstance(key, str): + # TODO should possibly deal with hashables in general? + # path-like: a name of a node/variable, or path to a node/variable + path = NodePath(key) + return self._set_item(path, value, new_nodes_along_path=True) + else: + raise ValueError("Invalid format for key") + + @overload + def update(self, other: Dataset) -> None: ... + + @overload + def update(self, other: Mapping[Hashable, DataArray | Variable]) -> None: ... + + @overload + def update(self, other: Mapping[str, DataTree | DataArray | Variable]) -> None: ... + + def update( + self, + other: ( + Dataset + | Mapping[Hashable, DataArray | Variable] + | Mapping[str, DataTree | DataArray | Variable] + ), + ) -> None: + """ + Update this node's children and / or variables. + + Just like `dict.update` this is an in-place operation. + """ + # TODO separate by type + new_children: dict[str, DataTree] = {} + new_variables = {} + for k, v in other.items(): + if isinstance(v, DataTree): + # avoid named node being stored under inconsistent key + new_child: DataTree = v.copy() + # Datatree's name is always a string until we fix that (#8836) + new_child.name = str(k) + new_children[str(k)] = new_child + elif isinstance(v, (DataArray, Variable)): + # TODO this should also accommodate other types that can be coerced into Variables + new_variables[k] = v + else: + raise TypeError(f"Type {type(v)} cannot be assigned to a DataTree") + + vars_merge_result = dataset_update_method(self.to_dataset(), new_variables) + # TODO are there any subtleties with preserving order of children like this? + merged_children = {**self.children, **new_children} + self._replace( + inplace=True, children=merged_children, **vars_merge_result._asdict() + ) + + def assign( + self, items: Mapping[Any, Any] | None = None, **items_kwargs: Any + ) -> DataTree: + """ + Assign new data variables or child nodes to a DataTree, returning a new object + with all the original items in addition to the new ones. + + Parameters + ---------- + items : mapping of hashable to Any + Mapping from variable or child node names to the new values. If the new values + are callable, they are computed on the Dataset and assigned to new + data variables. If the values are not callable, (e.g. a DataTree, DataArray, + scalar, or array), they are simply assigned. + **items_kwargs + The keyword arguments form of ``variables``. + One of variables or variables_kwargs must be provided. + + Returns + ------- + dt : DataTree + A new DataTree with the new variables or children in addition to all the + existing items. + + Notes + ----- + Since ``kwargs`` is a dictionary, the order of your arguments may not + be preserved, and so the order of the new variables is not well-defined. + Assigning multiple items within the same ``assign`` is + possible, but you cannot reference other variables created within the + same ``assign`` call. + + See Also + -------- + xarray.Dataset.assign + pandas.DataFrame.assign + """ + items = either_dict_or_kwargs(items, items_kwargs, "assign") + dt = self.copy() + dt.update(items) + return dt + + def drop_nodes( + self: DataTree, names: str | Iterable[str], *, errors: ErrorOptions = "raise" + ) -> DataTree: + """ + Drop child nodes from this node. + + Parameters + ---------- + names : str or iterable of str + Name(s) of nodes to drop. + errors : {"raise", "ignore"}, default: "raise" + If 'raise', raises a KeyError if any of the node names + passed are not present as children of this node. If 'ignore', + any given names that are present are dropped and no error is raised. + + Returns + ------- + dropped : DataTree + A copy of the node with the specified children dropped. + """ + # the Iterable check is required for mypy + if isinstance(names, str) or not isinstance(names, Iterable): + names = {names} + else: + names = set(names) + + if errors == "raise": + extra = names - set(self.children) + if extra: + raise KeyError(f"Cannot drop all nodes - nodes {extra} not present") + + children_to_keep = { + name: child for name, child in self.children.items() if name not in names + } + return self._replace(children=children_to_keep) + + @classmethod + def from_dict( + cls, + d: MutableMapping[str, Dataset | DataArray | DataTree | None], + name: str | None = None, + ) -> DataTree: + """ + Create a datatree from a dictionary of data objects, organised by paths into the tree. + + Parameters + ---------- + d : dict-like + A mapping from path names to xarray.Dataset, xarray.DataArray, or DataTree objects. + + Path names are to be given as unix-like path. If path names containing more than one part are given, new + tree nodes will be constructed as necessary. + + To assign data to the root node of the tree use "/" as the path. + name : Hashable | None, optional + Name for the root node of the tree. Default is None. + + Returns + ------- + DataTree + + Notes + ----- + If your dictionary is nested you will need to flatten it before using this method. + """ + + # First create the root node + root_data = d.pop("/", None) + if isinstance(root_data, DataTree): + obj = root_data.copy() + obj.orphan() + else: + obj = cls(name=name, data=root_data, parent=None, children=None) + + if d: + # Populate tree with children determined from data_objects mapping + for path, data in d.items(): + # Create and set new node + node_name = NodePath(path).name + if isinstance(data, DataTree): + new_node = data.copy() + new_node.orphan() + else: + new_node = cls(name=node_name, data=data) + obj._set_item( + path, + new_node, + allow_overwrite=False, + new_nodes_along_path=True, + ) + + return obj + + def to_dict(self) -> dict[str, Dataset]: + """ + Create a dictionary mapping of absolute node paths to the data contained in those nodes. + + Returns + ------- + dict[str, Dataset] + """ + return {node.path: node.to_dataset() for node in self.subtree} + + @property + def nbytes(self) -> int: + return sum(node.to_dataset().nbytes for node in self.subtree) + + def __len__(self) -> int: + return len(self.children) + len(self.data_vars) + + @property + def indexes(self) -> Indexes[pd.Index]: + """Mapping of pandas.Index objects used for label based indexing. + + Raises an error if this DataTree node has indexes that cannot be coerced + to pandas.Index objects. + + See Also + -------- + DataTree.xindexes + """ + return self.xindexes.to_pandas_indexes() + + @property + def xindexes(self) -> Indexes[Index]: + """Mapping of xarray Index objects used for label based indexing.""" + return Indexes(self._indexes, {k: self._variables[k] for k in self._indexes}) + + @property + def coords(self) -> DatasetCoordinates: + """Dictionary of xarray.DataArray objects corresponding to coordinate + variables + """ + return DatasetCoordinates(self.to_dataset()) + + @property + def data_vars(self) -> DataVariables: + """Dictionary of DataArray objects corresponding to data variables""" + return DataVariables(self.to_dataset()) + + def isomorphic( + self, + other: DataTree, + from_root: bool = False, + strict_names: bool = False, + ) -> bool: + """ + Two DataTrees are considered isomorphic if every node has the same number of children. + + Nothing about the data in each node is checked. + + Isomorphism is a necessary condition for two trees to be used in a nodewise binary operation, + such as ``tree1 + tree2``. + + By default this method does not check any part of the tree above the given node. + Therefore this method can be used as default to check that two subtrees are isomorphic. + + Parameters + ---------- + other : DataTree + The other tree object to compare to. + from_root : bool, optional, default is False + Whether or not to first traverse to the root of the two trees before checking for isomorphism. + If neither tree has a parent then this has no effect. + strict_names : bool, optional, default is False + Whether or not to also check that every node in the tree has the same name as its counterpart in the other + tree. + + See Also + -------- + DataTree.equals + DataTree.identical + """ + try: + check_isomorphic( + self, + other, + require_names_equal=strict_names, + check_from_root=from_root, + ) + return True + except (TypeError, TreeIsomorphismError): + return False + + def equals(self, other: DataTree, from_root: bool = True) -> bool: + """ + Two DataTrees are equal if they have isomorphic node structures, with matching node names, + and if they have matching variables and coordinates, all of which are equal. + + By default this method will check the whole tree above the given node. + + Parameters + ---------- + other : DataTree + The other tree object to compare to. + from_root : bool, optional, default is True + Whether or not to first traverse to the root of the two trees before checking for isomorphism. + If neither tree has a parent then this has no effect. + + See Also + -------- + Dataset.equals + DataTree.isomorphic + DataTree.identical + """ + if not self.isomorphic(other, from_root=from_root, strict_names=True): + return False + + return all( + [ + node.ds.equals(other_node.ds) + for node, other_node in zip(self.subtree, other.subtree) + ] + ) + + def identical(self, other: DataTree, from_root=True) -> bool: + """ + Like equals, but will also check all dataset attributes and the attributes on + all variables and coordinates. + + By default this method will check the whole tree above the given node. + + Parameters + ---------- + other : DataTree + The other tree object to compare to. + from_root : bool, optional, default is True + Whether or not to first traverse to the root of the two trees before checking for isomorphism. + If neither tree has a parent then this has no effect. + + See Also + -------- + Dataset.identical + DataTree.isomorphic + DataTree.equals + """ + if not self.isomorphic(other, from_root=from_root, strict_names=True): + return False + + return all( + node.ds.identical(other_node.ds) + for node, other_node in zip(self.subtree, other.subtree) + ) + + def filter(self: DataTree, filterfunc: Callable[[DataTree], bool]) -> DataTree: + """ + Filter nodes according to a specified condition. + + Returns a new tree containing only the nodes in the original tree for which `fitlerfunc(node)` is True. + Will also contain empty nodes at intermediate positions if required to support leaves. + + Parameters + ---------- + filterfunc: function + A function which accepts only one DataTree - the node on which filterfunc will be called. + + Returns + ------- + DataTree + + See Also + -------- + match + pipe + map_over_subtree + """ + filtered_nodes = { + node.path: node.ds for node in self.subtree if filterfunc(node) + } + return DataTree.from_dict(filtered_nodes, name=self.root.name) + + def match(self, pattern: str) -> DataTree: + """ + Return nodes with paths matching pattern. + + Uses unix glob-like syntax for pattern-matching. + + Parameters + ---------- + pattern: str + A pattern to match each node path against. + + Returns + ------- + DataTree + + See Also + -------- + filter + pipe + map_over_subtree + + Examples + -------- + >>> dt = DataTree.from_dict( + ... { + ... "/a/A": None, + ... "/a/B": None, + ... "/b/A": None, + ... "/b/B": None, + ... } + ... ) + >>> dt.match("*/B") + DataTree('None', parent=None) + ├── DataTree('a') + │ └── DataTree('B') + └── DataTree('b') + └── DataTree('B') + """ + matching_nodes = { + node.path: node.ds + for node in self.subtree + if NodePath(node.path).match(pattern) + } + return DataTree.from_dict(matching_nodes, name=self.root.name) + + def map_over_subtree( + self, + func: Callable, + *args: Iterable[Any], + **kwargs: Any, + ) -> DataTree | tuple[DataTree]: + """ + Apply a function to every dataset in this subtree, returning a new tree which stores the results. + + The function will be applied to any dataset stored in this node, as well as any dataset stored in any of the + descendant nodes. The returned tree will have the same structure as the original subtree. + + func needs to return a Dataset in order to rebuild the subtree. + + Parameters + ---------- + func : callable + Function to apply to datasets with signature: + `func(node.ds, *args, **kwargs) -> Dataset`. + + Function will not be applied to any nodes without datasets. + *args : tuple, optional + Positional arguments passed on to `func`. + **kwargs : Any + Keyword arguments passed on to `func`. + + Returns + ------- + subtrees : DataTree, tuple of DataTrees + One or more subtrees containing results from applying ``func`` to the data at each node. + """ + # TODO this signature means that func has no way to know which node it is being called upon - change? + + # TODO fix this typing error + return map_over_subtree(func)(self, *args, **kwargs) + + def map_over_subtree_inplace( + self, + func: Callable, + *args: Iterable[Any], + **kwargs: Any, + ) -> None: + """ + Apply a function to every dataset in this subtree, updating data in place. + + Parameters + ---------- + func : callable + Function to apply to datasets with signature: + `func(node.ds, *args, **kwargs) -> Dataset`. + + Function will not be applied to any nodes without datasets, + *args : tuple, optional + Positional arguments passed on to `func`. + **kwargs : Any + Keyword arguments passed on to `func`. + """ + + # TODO if func fails on some node then the previous nodes will still have been updated... + + for node in self.subtree: + if node.has_data: + node.ds = func(node.ds, *args, **kwargs) + + def pipe( + self, func: Callable | tuple[Callable, str], *args: Any, **kwargs: Any + ) -> Any: + """Apply ``func(self, *args, **kwargs)`` + + This method replicates the pandas method of the same name. + + Parameters + ---------- + func : callable + function to apply to this xarray object (Dataset/DataArray). + ``args``, and ``kwargs`` are passed into ``func``. + Alternatively a ``(callable, data_keyword)`` tuple where + ``data_keyword`` is a string indicating the keyword of + ``callable`` that expects the xarray object. + *args + positional arguments passed into ``func``. + **kwargs + a dictionary of keyword arguments passed into ``func``. + + Returns + ------- + object : Any + the return type of ``func``. + + Notes + ----- + Use ``.pipe`` when chaining together functions that expect + xarray or pandas objects, e.g., instead of writing + + .. code:: python + + f(g(h(dt), arg1=a), arg2=b, arg3=c) + + You can write + + .. code:: python + + (dt.pipe(h).pipe(g, arg1=a).pipe(f, arg2=b, arg3=c)) + + If you have a function that takes the data as (say) the second + argument, pass a tuple indicating which keyword expects the + data. For example, suppose ``f`` takes its data as ``arg2``: + + .. code:: python + + (dt.pipe(h).pipe(g, arg1=a).pipe((f, "arg2"), arg1=a, arg3=c)) + + """ + if isinstance(func, tuple): + func, target = func + if target in kwargs: + raise ValueError( + f"{target} is both the pipe target and a keyword argument" + ) + kwargs[target] = self + else: + args = (self,) + args + return func(*args, **kwargs) + + def render(self): + """Print tree structure, including any data stored at each node.""" + for pre, fill, node in RenderDataTree(self): + print(f"{pre}DataTree('{self.name}')") + for ds_line in repr(node.ds)[1:]: + print(f"{fill}{ds_line}") + + def merge(self, datatree: DataTree) -> DataTree: + """Merge all the leaves of a second DataTree into this one.""" + raise NotImplementedError + + def merge_child_nodes(self, *paths, new_path: T_Path) -> DataTree: + """Merge a set of child nodes into a single new node.""" + raise NotImplementedError + + # TODO some kind of .collapse() or .flatten() method to merge a subtree + + def to_dataarray(self) -> DataArray: + return self.ds.to_dataarray() + + @property + def groups(self): + """Return all netCDF4 groups in the tree, given as a tuple of path-like strings.""" + return tuple(node.path for node in self.subtree) + + def to_netcdf( + self, + filepath, + mode: NetcdfWriteModes = "w", + encoding=None, + unlimited_dims=None, + format: T_DataTreeNetcdfTypes | None = None, + engine: T_DataTreeNetcdfEngine | None = None, + group: str | None = None, + compute: bool = True, + **kwargs, + ): + """ + Write datatree contents to a netCDF file. + + Parameters + ---------- + filepath : str or Path + Path to which to save this datatree. + mode : {"w", "a"}, default: "w" + Write ('w') or append ('a') mode. If mode='w', any existing file at + this location will be overwritten. If mode='a', existing variables + will be overwritten. Only appies to the root group. + encoding : dict, optional + Nested dictionary with variable names as keys and dictionaries of + variable specific encodings as values, e.g., + ``{"root/set1": {"my_variable": {"dtype": "int16", "scale_factor": 0.1, + "zlib": True}, ...}, ...}``. See ``xarray.Dataset.to_netcdf`` for available + options. + unlimited_dims : dict, optional + Mapping of unlimited dimensions per group that that should be serialized as unlimited dimensions. + By default, no dimensions are treated as unlimited dimensions. + Note that unlimited_dims may also be set via + ``dataset.encoding["unlimited_dims"]``. + format : {"NETCDF4", }, optional + File format for the resulting netCDF file: + + * NETCDF4: Data is stored in an HDF5 file, using netCDF4 API features. + engine : {"netcdf4", "h5netcdf"}, optional + Engine to use when writing netCDF files. If not provided, the + default engine is chosen based on available dependencies, with a + preference for "netcdf4" if writing to a file on disk. + group : str, optional + Path to the netCDF4 group in the given file to open as the root group + of the ``DataTree``. Currently, specifying a group is not supported. + compute : bool, default: True + If true compute immediately, otherwise return a + ``dask.delayed.Delayed`` object that can be computed later. + Currently, ``compute=False`` is not supported. + kwargs : + Addional keyword arguments to be passed to ``xarray.Dataset.to_netcdf`` + """ + from xarray.core.datatree_io import _datatree_to_netcdf + + _datatree_to_netcdf( + self, + filepath, + mode=mode, + encoding=encoding, + unlimited_dims=unlimited_dims, + format=format, + engine=engine, + group=group, + compute=compute, + **kwargs, + ) + + def to_zarr( + self, + store, + mode: ZarrWriteModes = "w-", + encoding=None, + consolidated: bool = True, + group: str | None = None, + compute: Literal[True] = True, + **kwargs, + ): + """ + Write datatree contents to a Zarr store. + + Parameters + ---------- + store : MutableMapping, str or Path, optional + Store or path to directory in file system + mode : {{"w", "w-", "a", "r+", None}, default: "w-" + Persistence mode: “w” means create (overwrite if exists); “w-” means create (fail if exists); + “a” means override existing variables (create if does not exist); “r+” means modify existing + array values only (raise an error if any metadata or shapes would change). The default mode + is “w-”. + encoding : dict, optional + Nested dictionary with variable names as keys and dictionaries of + variable specific encodings as values, e.g., + ``{"root/set1": {"my_variable": {"dtype": "int16", "scale_factor": 0.1}, ...}, ...}``. + See ``xarray.Dataset.to_zarr`` for available options. + consolidated : bool + If True, apply zarr's `consolidate_metadata` function to the store + after writing metadata for all groups. + group : str, optional + Group path. (a.k.a. `path` in zarr terminology.) + compute : bool, default: True + If true compute immediately, otherwise return a + ``dask.delayed.Delayed`` object that can be computed later. Metadata + is always updated eagerly. Currently, ``compute=False`` is not + supported. + kwargs : + Additional keyword arguments to be passed to ``xarray.Dataset.to_zarr`` + """ + from xarray.core.datatree_io import _datatree_to_zarr + + _datatree_to_zarr( + self, + store, + mode=mode, + encoding=encoding, + consolidated=consolidated, + group=group, + compute=compute, + **kwargs, + ) + + def plot(self): + raise NotImplementedError diff --git a/test/fixtures/whole_applications/xarray/xarray/core/datatree_io.py b/test/fixtures/whole_applications/xarray/xarray/core/datatree_io.py new file mode 100644 index 0000000..1473e62 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/core/datatree_io.py @@ -0,0 +1,171 @@ +from __future__ import annotations + +from collections.abc import Mapping, MutableMapping +from os import PathLike +from typing import Any, Literal, get_args + +from xarray.core.datatree import DataTree +from xarray.core.types import NetcdfWriteModes, ZarrWriteModes + +T_DataTreeNetcdfEngine = Literal["netcdf4", "h5netcdf"] +T_DataTreeNetcdfTypes = Literal["NETCDF4"] + + +def _get_nc_dataset_class(engine: T_DataTreeNetcdfEngine | None): + if engine == "netcdf4": + from netCDF4 import Dataset + elif engine == "h5netcdf": + from h5netcdf.legacyapi import Dataset + elif engine is None: + try: + from netCDF4 import Dataset + except ImportError: + from h5netcdf.legacyapi import Dataset + else: + raise ValueError(f"unsupported engine: {engine}") + return Dataset + + +def _create_empty_netcdf_group( + filename: str | PathLike, + group: str, + mode: NetcdfWriteModes, + engine: T_DataTreeNetcdfEngine | None, +): + ncDataset = _get_nc_dataset_class(engine) + + with ncDataset(filename, mode=mode) as rootgrp: + rootgrp.createGroup(group) + + +def _datatree_to_netcdf( + dt: DataTree, + filepath: str | PathLike, + mode: NetcdfWriteModes = "w", + encoding: Mapping[str, Any] | None = None, + unlimited_dims: Mapping | None = None, + format: T_DataTreeNetcdfTypes | None = None, + engine: T_DataTreeNetcdfEngine | None = None, + group: str | None = None, + compute: bool = True, + **kwargs, +): + """This function creates an appropriate datastore for writing a datatree to + disk as a netCDF file. + + See `DataTree.to_netcdf` for full API docs. + """ + + if format not in [None, *get_args(T_DataTreeNetcdfTypes)]: + raise ValueError("to_netcdf only supports the NETCDF4 format") + + if engine not in [None, *get_args(T_DataTreeNetcdfEngine)]: + raise ValueError("to_netcdf only supports the netcdf4 and h5netcdf engines") + + if group is not None: + raise NotImplementedError( + "specifying a root group for the tree has not been implemented" + ) + + if not compute: + raise NotImplementedError("compute=False has not been implemented yet") + + if encoding is None: + encoding = {} + + # In the future, we may want to expand this check to insure all the provided encoding + # options are valid. For now, this simply checks that all provided encoding keys are + # groups in the datatree. + if set(encoding) - set(dt.groups): + raise ValueError( + f"unexpected encoding group name(s) provided: {set(encoding) - set(dt.groups)}" + ) + + if unlimited_dims is None: + unlimited_dims = {} + + for node in dt.subtree: + ds = node.ds + group_path = node.path + if ds is None: + _create_empty_netcdf_group(filepath, group_path, mode, engine) + else: + ds.to_netcdf( + filepath, + group=group_path, + mode=mode, + encoding=encoding.get(node.path), + unlimited_dims=unlimited_dims.get(node.path), + engine=engine, + format=format, + compute=compute, + **kwargs, + ) + mode = "a" + + +def _create_empty_zarr_group( + store: MutableMapping | str | PathLike[str], group: str, mode: ZarrWriteModes +): + import zarr + + root = zarr.open_group(store, mode=mode) + root.create_group(group, overwrite=True) + + +def _datatree_to_zarr( + dt: DataTree, + store: MutableMapping | str | PathLike[str], + mode: ZarrWriteModes = "w-", + encoding: Mapping[str, Any] | None = None, + consolidated: bool = True, + group: str | None = None, + compute: Literal[True] = True, + **kwargs, +): + """This function creates an appropriate datastore for writing a datatree + to a zarr store. + + See `DataTree.to_zarr` for full API docs. + """ + + from zarr.convenience import consolidate_metadata + + if group is not None: + raise NotImplementedError( + "specifying a root group for the tree has not been implemented" + ) + + if not compute: + raise NotImplementedError("compute=False has not been implemented yet") + + if encoding is None: + encoding = {} + + # In the future, we may want to expand this check to insure all the provided encoding + # options are valid. For now, this simply checks that all provided encoding keys are + # groups in the datatree. + if set(encoding) - set(dt.groups): + raise ValueError( + f"unexpected encoding group name(s) provided: {set(encoding) - set(dt.groups)}" + ) + + for node in dt.subtree: + ds = node.ds + group_path = node.path + if ds is None: + _create_empty_zarr_group(store, group_path, mode) + else: + ds.to_zarr( + store, + group=group_path, + mode=mode, + encoding=encoding.get(node.path), + consolidated=False, + **kwargs, + ) + if "w" in mode: + mode = "a" + + if consolidated: + consolidate_metadata(store) diff --git a/test/fixtures/whole_applications/xarray/xarray/core/datatree_mapping.py b/test/fixtures/whole_applications/xarray/xarray/core/datatree_mapping.py new file mode 100644 index 0000000..6e5aae1 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/core/datatree_mapping.py @@ -0,0 +1,314 @@ +from __future__ import annotations + +import functools +import sys +from itertools import repeat +from typing import TYPE_CHECKING, Callable + +from xarray.core.dataarray import DataArray +from xarray.core.dataset import Dataset +from xarray.core.formatting import diff_treestructure +from xarray.core.treenode import NodePath, TreeNode + +if TYPE_CHECKING: + from xarray.core.datatree import DataTree + + +class TreeIsomorphismError(ValueError): + """Error raised if two tree objects do not share the same node structure.""" + + pass + + +def check_isomorphic( + a: DataTree, + b: DataTree, + require_names_equal: bool = False, + check_from_root: bool = True, +): + """ + Check that two trees have the same structure, raising an error if not. + + Does not compare the actual data in the nodes. + + By default this function only checks that subtrees are isomorphic, not the entire tree above (if it exists). + Can instead optionally check the entire trees starting from the root, which will ensure all + + Can optionally check if corresponding nodes should have the same name. + + Parameters + ---------- + a : DataTree + b : DataTree + require_names_equal : Bool + Whether or not to also check that each node has the same name as its counterpart. + check_from_root : Bool + Whether or not to first traverse to the root of the trees before checking for isomorphism. + If a & b have no parents then this has no effect. + + Raises + ------ + TypeError + If either a or b are not tree objects. + TreeIsomorphismError + If a and b are tree objects, but are not isomorphic to one another. + Also optionally raised if their structure is isomorphic, but the names of any two + respective nodes are not equal. + """ + + if not isinstance(a, TreeNode): + raise TypeError(f"Argument `a` is not a tree, it is of type {type(a)}") + if not isinstance(b, TreeNode): + raise TypeError(f"Argument `b` is not a tree, it is of type {type(b)}") + + if check_from_root: + a = a.root + b = b.root + + diff = diff_treestructure(a, b, require_names_equal=require_names_equal) + + if diff: + raise TreeIsomorphismError("DataTree objects are not isomorphic:\n" + diff) + + +def map_over_subtree(func: Callable) -> Callable: + """ + Decorator which turns a function which acts on (and returns) Datasets into one which acts on and returns DataTrees. + + Applies a function to every dataset in one or more subtrees, returning new trees which store the results. + + The function will be applied to any data-containing dataset stored in any of the nodes in the trees. The returned + trees will have the same structure as the supplied trees. + + `func` needs to return one Datasets, DataArrays, or None in order to be able to rebuild the subtrees after + mapping, as each result will be assigned to its respective node of a new tree via `DataTree.__setitem__`. Any + returned value that is one of these types will be stacked into a separate tree before returning all of them. + + The trees passed to the resulting function must all be isomorphic to one another. Their nodes need not be named + similarly, but all the output trees will have nodes named in the same way as the first tree passed. + + Parameters + ---------- + func : callable + Function to apply to datasets with signature: + + `func(*args, **kwargs) -> Union[DataTree, Iterable[DataTree]]`. + + (i.e. func must accept at least one Dataset and return at least one Dataset.) + Function will not be applied to any nodes without datasets. + *args : tuple, optional + Positional arguments passed on to `func`. If DataTrees any data-containing nodes will be converted to Datasets + via `.ds`. + **kwargs : Any + Keyword arguments passed on to `func`. If DataTrees any data-containing nodes will be converted to Datasets + via `.ds`. + + Returns + ------- + mapped : callable + Wrapped function which returns one or more tree(s) created from results of applying ``func`` to the dataset at + each node. + + See also + -------- + DataTree.map_over_subtree + DataTree.map_over_subtree_inplace + DataTree.subtree + """ + + # TODO examples in the docstring + + # TODO inspect function to work out immediately if the wrong number of arguments were passed for it? + + @functools.wraps(func) + def _map_over_subtree(*args, **kwargs) -> DataTree | tuple[DataTree, ...]: + """Internal function which maps func over every node in tree, returning a tree of the results.""" + from xarray.core.datatree import DataTree + + all_tree_inputs = [a for a in args if isinstance(a, DataTree)] + [ + a for a in kwargs.values() if isinstance(a, DataTree) + ] + + if len(all_tree_inputs) > 0: + first_tree, *other_trees = all_tree_inputs + else: + raise TypeError("Must pass at least one tree object") + + for other_tree in other_trees: + # isomorphism is transitive so this is enough to guarantee all trees are mutually isomorphic + check_isomorphic( + first_tree, other_tree, require_names_equal=False, check_from_root=False + ) + + # Walk all trees simultaneously, applying func to all nodes that lie in same position in different trees + # We don't know which arguments are DataTrees so we zip all arguments together as iterables + # Store tuples of results in a dict because we don't yet know how many trees we need to rebuild to return + out_data_objects = {} + args_as_tree_length_iterables = [ + a.subtree if isinstance(a, DataTree) else repeat(a) for a in args + ] + n_args = len(args_as_tree_length_iterables) + kwargs_as_tree_length_iterables = { + k: v.subtree if isinstance(v, DataTree) else repeat(v) + for k, v in kwargs.items() + } + for node_of_first_tree, *all_node_args in zip( + first_tree.subtree, + *args_as_tree_length_iterables, + *list(kwargs_as_tree_length_iterables.values()), + ): + node_args_as_datasetviews = [ + a.ds if isinstance(a, DataTree) else a for a in all_node_args[:n_args] + ] + node_kwargs_as_datasetviews = dict( + zip( + [k for k in kwargs_as_tree_length_iterables.keys()], + [ + v.ds if isinstance(v, DataTree) else v + for v in all_node_args[n_args:] + ], + ) + ) + func_with_error_context = _handle_errors_with_path_context( + node_of_first_tree.path + )(func) + + if node_of_first_tree.has_data: + # call func on the data in this particular set of corresponding nodes + results = func_with_error_context( + *node_args_as_datasetviews, **node_kwargs_as_datasetviews + ) + elif node_of_first_tree.has_attrs: + # propagate attrs + results = node_of_first_tree.ds + else: + # nothing to propagate so use fastpath to create empty node in new tree + results = None + + # TODO implement mapping over multiple trees in-place using if conditions from here on? + out_data_objects[node_of_first_tree.path] = results + + # Find out how many return values we received + num_return_values = _check_all_return_values(out_data_objects) + + # Reconstruct 1+ subtrees from the dict of results, by filling in all nodes of all result trees + original_root_path = first_tree.path + result_trees = [] + for i in range(num_return_values): + out_tree_contents = {} + for n in first_tree.subtree: + p = n.path + if p in out_data_objects.keys(): + if isinstance(out_data_objects[p], tuple): + output_node_data = out_data_objects[p][i] + else: + output_node_data = out_data_objects[p] + else: + output_node_data = None + + # Discard parentage so that new trees don't include parents of input nodes + relative_path = str(NodePath(p).relative_to(original_root_path)) + relative_path = "/" if relative_path == "." else relative_path + out_tree_contents[relative_path] = output_node_data + + new_tree = DataTree.from_dict( + out_tree_contents, + name=first_tree.name, + ) + result_trees.append(new_tree) + + # If only one result then don't wrap it in a tuple + if len(result_trees) == 1: + return result_trees[0] + else: + return tuple(result_trees) + + return _map_over_subtree + + +def _handle_errors_with_path_context(path: str): + """Wraps given function so that if it fails it also raises path to node on which it failed.""" + + def decorator(func): + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + # Add the context information to the error message + add_note( + e, f"Raised whilst mapping function over node with path {path}" + ) + raise + + return wrapper + + return decorator + + +def add_note(err: BaseException, msg: str) -> None: + # TODO: remove once python 3.10 can be dropped + if sys.version_info < (3, 11): + err.__notes__ = getattr(err, "__notes__", []) + [msg] # type: ignore[attr-defined] + else: + err.add_note(msg) + + +def _check_single_set_return_values( + path_to_node: str, obj: Dataset | DataArray | tuple[Dataset | DataArray] +): + """Check types returned from single evaluation of func, and return number of return values received from func.""" + if isinstance(obj, (Dataset, DataArray)): + return 1 + elif isinstance(obj, tuple): + for r in obj: + if not isinstance(r, (Dataset, DataArray)): + raise TypeError( + f"One of the results of calling func on datasets on the nodes at position {path_to_node} is " + f"of type {type(r)}, not Dataset or DataArray." + ) + return len(obj) + else: + raise TypeError( + f"The result of calling func on the node at position {path_to_node} is of type {type(obj)}, not " + f"Dataset or DataArray, nor a tuple of such types." + ) + + +def _check_all_return_values(returned_objects): + """Walk through all values returned by mapping func over subtrees, raising on any invalid or inconsistent types.""" + + if all(r is None for r in returned_objects.values()): + raise TypeError( + "Called supplied function on all nodes but found a return value of None for" + "all of them." + ) + + result_data_objects = [ + (path_to_node, r) + for path_to_node, r in returned_objects.items() + if r is not None + ] + + if len(result_data_objects) == 1: + # Only one node in the tree: no need to check consistency of results between nodes + path_to_node, result = result_data_objects[0] + num_return_values = _check_single_set_return_values(path_to_node, result) + else: + prev_path, _ = result_data_objects[0] + prev_num_return_values, num_return_values = None, None + for path_to_node, obj in result_data_objects[1:]: + num_return_values = _check_single_set_return_values(path_to_node, obj) + + if ( + num_return_values != prev_num_return_values + and prev_num_return_values is not None + ): + raise TypeError( + f"Calling func on the nodes at position {path_to_node} returns {num_return_values} separate return " + f"values, whereas calling func on the nodes at position {prev_path} instead returns " + f"{prev_num_return_values} separate return values." + ) + + prev_path, prev_num_return_values = path_to_node, num_return_values + + return num_return_values diff --git a/test/fixtures/whole_applications/xarray/xarray/core/datatree_ops.py b/test/fixtures/whole_applications/xarray/xarray/core/datatree_ops.py new file mode 100644 index 0000000..bc64b44 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/core/datatree_ops.py @@ -0,0 +1,309 @@ +from __future__ import annotations + +import re +import textwrap + +from xarray.core.dataset import Dataset +from xarray.core.datatree_mapping import map_over_subtree + +""" +Module which specifies the subset of xarray.Dataset's API which we wish to copy onto DataTree. + +Structured to mirror the way xarray defines Dataset's various operations internally, but does not actually import from +xarray's internals directly, only the public-facing xarray.Dataset class. +""" + + +_MAPPED_DOCSTRING_ADDENDUM = ( + "This method was copied from xarray.Dataset, but has been altered to " + "call the method on the Datasets stored in every node of the subtree. " + "See the `map_over_subtree` function for more details." +) + +# TODO equals, broadcast_equals etc. +# TODO do dask-related private methods need to be exposed? +_DATASET_DASK_METHODS_TO_MAP = [ + "load", + "compute", + "persist", + "unify_chunks", + "chunk", + "map_blocks", +] +_DATASET_METHODS_TO_MAP = [ + "as_numpy", + "set_coords", + "reset_coords", + "info", + "isel", + "sel", + "head", + "tail", + "thin", + "broadcast_like", + "reindex_like", + "reindex", + "interp", + "interp_like", + "rename", + "rename_dims", + "rename_vars", + "swap_dims", + "expand_dims", + "set_index", + "reset_index", + "reorder_levels", + "stack", + "unstack", + "merge", + "drop_vars", + "drop_sel", + "drop_isel", + "drop_dims", + "transpose", + "dropna", + "fillna", + "interpolate_na", + "ffill", + "bfill", + "combine_first", + "reduce", + "map", + "diff", + "shift", + "roll", + "sortby", + "quantile", + "rank", + "differentiate", + "integrate", + "cumulative_integrate", + "filter_by_attrs", + "polyfit", + "pad", + "idxmin", + "idxmax", + "argmin", + "argmax", + "query", + "curvefit", +] +_ALL_DATASET_METHODS_TO_MAP = _DATASET_DASK_METHODS_TO_MAP + _DATASET_METHODS_TO_MAP + +_DATA_WITH_COORDS_METHODS_TO_MAP = [ + "squeeze", + "clip", + "assign_coords", + "where", + "close", + "isnull", + "notnull", + "isin", + "astype", +] + +REDUCE_METHODS = ["all", "any"] +NAN_REDUCE_METHODS = [ + "max", + "min", + "mean", + "prod", + "sum", + "std", + "var", + "median", +] +NAN_CUM_METHODS = ["cumsum", "cumprod"] +_TYPED_DATASET_OPS_TO_MAP = [ + "__add__", + "__sub__", + "__mul__", + "__pow__", + "__truediv__", + "__floordiv__", + "__mod__", + "__and__", + "__xor__", + "__or__", + "__lt__", + "__le__", + "__gt__", + "__ge__", + "__eq__", + "__ne__", + "__radd__", + "__rsub__", + "__rmul__", + "__rpow__", + "__rtruediv__", + "__rfloordiv__", + "__rmod__", + "__rand__", + "__rxor__", + "__ror__", + "__iadd__", + "__isub__", + "__imul__", + "__ipow__", + "__itruediv__", + "__ifloordiv__", + "__imod__", + "__iand__", + "__ixor__", + "__ior__", + "__neg__", + "__pos__", + "__abs__", + "__invert__", + "round", + "argsort", + "conj", + "conjugate", +] +# TODO NUM_BINARY_OPS apparently aren't defined on DatasetArithmetic, and don't appear to be injected anywhere... +_ARITHMETIC_METHODS_TO_MAP = ( + REDUCE_METHODS + + NAN_REDUCE_METHODS + + NAN_CUM_METHODS + + _TYPED_DATASET_OPS_TO_MAP + + ["__array_ufunc__"] +) + + +def _wrap_then_attach_to_cls( + target_cls_dict, source_cls, methods_to_set, wrap_func=None +): + """ + Attach given methods on a class, and optionally wrap each method first. (i.e. with map_over_subtree). + + Result is like having written this in the classes' definition: + ``` + @wrap_func + def method_name(self, *args, **kwargs): + return self.method(*args, **kwargs) + ``` + + Every method attached here needs to have a return value of Dataset or DataArray in order to construct a new tree. + + Parameters + ---------- + target_cls_dict : MappingProxy + The __dict__ attribute of the class which we want the methods to be added to. (The __dict__ attribute can also + be accessed by calling vars() from within that classes' definition.) This will be updated by this function. + source_cls : class + Class object from which we want to copy methods (and optionally wrap them). Should be the actual class object + (or instance), not just the __dict__. + methods_to_set : Iterable[Tuple[str, callable]] + The method names and definitions supplied as a list of (method_name_string, method) pairs. + This format matches the output of inspect.getmembers(). + wrap_func : callable, optional + Function to decorate each method with. Must have the same return type as the method. + """ + for method_name in methods_to_set: + orig_method = getattr(source_cls, method_name) + wrapped_method = ( + wrap_func(orig_method) if wrap_func is not None else orig_method + ) + target_cls_dict[method_name] = wrapped_method + + if wrap_func is map_over_subtree: + # Add a paragraph to the method's docstring explaining how it's been mapped + orig_method_docstring = orig_method.__doc__ + + if orig_method_docstring is not None: + new_method_docstring = insert_doc_addendum( + orig_method_docstring, _MAPPED_DOCSTRING_ADDENDUM + ) + setattr(target_cls_dict[method_name], "__doc__", new_method_docstring) + + +def insert_doc_addendum(docstring: str | None, addendum: str) -> str | None: + """Insert addendum after first paragraph or at the end of the docstring. + + There are a number of Dataset's functions that are wrapped. These come from + Dataset directly as well as the mixins: DataWithCoords, DatasetAggregations, and DatasetOpsMixin. + + The majority of the docstrings fall into a parseable pattern. Those that + don't, just have the addendum appeneded after. None values are returned. + + """ + if docstring is None: + return None + + pattern = re.compile( + r"^(?P(\S+)?(.*?))(?P\n\s*\n)(?P[ ]*)(?P.*)", + re.DOTALL, + ) + capture = re.match(pattern, docstring) + if capture is None: + ### single line docstring. + return ( + docstring + + "\n\n" + + textwrap.fill( + addendum, + subsequent_indent=" ", + width=79, + ) + ) + + if len(capture.groups()) == 6: + return ( + capture["start"] + + capture["paragraph_break"] + + capture["whitespace"] + + ".. note::\n" + + textwrap.fill( + addendum, + initial_indent=capture["whitespace"] + " ", + subsequent_indent=capture["whitespace"] + " ", + width=79, + ) + + capture["paragraph_break"] + + capture["whitespace"] + + capture["rest"] + ) + else: + return docstring + + +class MappedDatasetMethodsMixin: + """ + Mixin to add methods defined specifically on the Dataset class such as .query(), but wrapped to map over all nodes + in the subtree. + """ + + _wrap_then_attach_to_cls( + target_cls_dict=vars(), + source_cls=Dataset, + methods_to_set=_ALL_DATASET_METHODS_TO_MAP, + wrap_func=map_over_subtree, + ) + + +class MappedDataWithCoords: + """ + Mixin to add coordinate-aware Dataset methods such as .where(), but wrapped to map over all nodes in the subtree. + """ + + # TODO add mapped versions of groupby, weighted, rolling, rolling_exp, coarsen, resample + _wrap_then_attach_to_cls( + target_cls_dict=vars(), + source_cls=Dataset, + methods_to_set=_DATA_WITH_COORDS_METHODS_TO_MAP, + wrap_func=map_over_subtree, + ) + + +class DataTreeArithmeticMixin: + """ + Mixin to add Dataset arithmetic operations such as __add__, reduction methods such as .mean(), and enable numpy + ufuncs such as np.sin(), but wrapped to map over all nodes in the subtree. + """ + + _wrap_then_attach_to_cls( + target_cls_dict=vars(), + source_cls=Dataset, + methods_to_set=_ARITHMETIC_METHODS_TO_MAP, + wrap_func=map_over_subtree, + ) diff --git a/test/fixtures/whole_applications/xarray/xarray/core/datatree_render.py b/test/fixtures/whole_applications/xarray/xarray/core/datatree_render.py new file mode 100644 index 0000000..d069071 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/core/datatree_render.py @@ -0,0 +1,266 @@ +""" +String Tree Rendering. Copied from anytree. + +Minor changes to `RenderDataTree` include accessing `children.values()`, and +type hints. + +""" + +from __future__ import annotations + +from collections import namedtuple +from collections.abc import Iterable, Iterator +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from xarray.core.datatree import DataTree + +Row = namedtuple("Row", ("pre", "fill", "node")) + + +class AbstractStyle: + def __init__(self, vertical: str, cont: str, end: str): + """ + Tree Render Style. + Args: + vertical: Sign for vertical line. + cont: Chars for a continued branch. + end: Chars for the last branch. + """ + super().__init__() + self.vertical = vertical + self.cont = cont + self.end = end + assert ( + len(cont) == len(vertical) == len(end) + ), f"'{vertical}', '{cont}' and '{end}' need to have equal length" + + @property + def empty(self) -> str: + """Empty string as placeholder.""" + return " " * len(self.end) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" + + +class ContStyle(AbstractStyle): + def __init__(self): + """ + Continued style, without gaps. + + >>> from xarray.core.datatree import DataTree + >>> from xarray.core.datatree_render import RenderDataTree + >>> root = DataTree(name="root") + >>> s0 = DataTree(name="sub0", parent=root) + >>> s0b = DataTree(name="sub0B", parent=s0) + >>> s0a = DataTree(name="sub0A", parent=s0) + >>> s1 = DataTree(name="sub1", parent=root) + >>> print(RenderDataTree(root)) + DataTree('root', parent=None) + ├── DataTree('sub0') + │ ├── DataTree('sub0B') + │ └── DataTree('sub0A') + └── DataTree('sub1') + """ + super().__init__("\u2502 ", "\u251c\u2500\u2500 ", "\u2514\u2500\u2500 ") + + +class RenderDataTree: + def __init__( + self, + node: DataTree, + style=ContStyle(), + childiter: type = list, + maxlevel: int | None = None, + ): + """ + Render tree starting at `node`. + Keyword Args: + style (AbstractStyle): Render Style. + childiter: Child iterator. Note, due to the use of node.children.values(), + Iterables that change the order of children cannot be used + (e.g., `reversed`). + maxlevel: Limit rendering to this depth. + :any:`RenderDataTree` is an iterator, returning a tuple with 3 items: + `pre` + tree prefix. + `fill` + filling for multiline entries. + `node` + :any:`NodeMixin` object. + It is up to the user to assemble these parts to a whole. + + Examples + -------- + + >>> from xarray import Dataset + >>> from xarray.core.datatree import DataTree + >>> from xarray.core.datatree_render import RenderDataTree + >>> root = DataTree(name="root", data=Dataset({"a": 0, "b": 1})) + >>> s0 = DataTree(name="sub0", parent=root, data=Dataset({"c": 2, "d": 3})) + >>> s0b = DataTree(name="sub0B", parent=s0, data=Dataset({"e": 4})) + >>> s0a = DataTree(name="sub0A", parent=s0, data=Dataset({"f": 5, "g": 6})) + >>> s1 = DataTree(name="sub1", parent=root, data=Dataset({"h": 7})) + + # Simple one line: + + >>> for pre, _, node in RenderDataTree(root): + ... print(f"{pre}{node.name}") + ... + root + ├── sub0 + │ ├── sub0B + │ └── sub0A + └── sub1 + + # Multiline: + + >>> for pre, fill, node in RenderDataTree(root): + ... print(f"{pre}{node.name}") + ... for variable in node.variables: + ... print(f"{fill}{variable}") + ... + root + a + b + ├── sub0 + │ c + │ d + │ ├── sub0B + │ │ e + │ └── sub0A + │ f + │ g + └── sub1 + h + + :any:`by_attr` simplifies attribute rendering and supports multiline: + >>> print(RenderDataTree(root).by_attr()) + root + ├── sub0 + │ ├── sub0B + │ └── sub0A + └── sub1 + + # `maxlevel` limits the depth of the tree: + + >>> print(RenderDataTree(root, maxlevel=2).by_attr("name")) + root + ├── sub0 + └── sub1 + """ + if not isinstance(style, AbstractStyle): + style = style() + self.node = node + self.style = style + self.childiter = childiter + self.maxlevel = maxlevel + + def __iter__(self) -> Iterator[Row]: + return self.__next(self.node, tuple()) + + def __next( + self, node: DataTree, continues: tuple[bool, ...], level: int = 0 + ) -> Iterator[Row]: + yield RenderDataTree.__item(node, continues, self.style) + children = node.children.values() + level += 1 + if children and (self.maxlevel is None or level < self.maxlevel): + children = self.childiter(children) + for child, is_last in _is_last(children): + yield from self.__next(child, continues + (not is_last,), level=level) + + @staticmethod + def __item( + node: DataTree, continues: tuple[bool, ...], style: AbstractStyle + ) -> Row: + if not continues: + return Row("", "", node) + else: + items = [style.vertical if cont else style.empty for cont in continues] + indent = "".join(items[:-1]) + branch = style.cont if continues[-1] else style.end + pre = indent + branch + fill = "".join(items) + return Row(pre, fill, node) + + def __str__(self) -> str: + return str(self.node) + + def __repr__(self) -> str: + classname = self.__class__.__name__ + args = [ + repr(self.node), + f"style={repr(self.style)}", + f"childiter={repr(self.childiter)}", + ] + return f"{classname}({', '.join(args)})" + + def by_attr(self, attrname: str = "name") -> str: + """ + Return rendered tree with node attribute `attrname`. + + Examples + -------- + + >>> from xarray import Dataset + >>> from xarray.core.datatree import DataTree + >>> from xarray.core.datatree_render import RenderDataTree + >>> root = DataTree(name="root") + >>> s0 = DataTree(name="sub0", parent=root) + >>> s0b = DataTree( + ... name="sub0B", parent=s0, data=Dataset({"foo": 4, "bar": 109}) + ... ) + >>> s0a = DataTree(name="sub0A", parent=s0) + >>> s1 = DataTree(name="sub1", parent=root) + >>> s1a = DataTree(name="sub1A", parent=s1) + >>> s1b = DataTree(name="sub1B", parent=s1, data=Dataset({"bar": 8})) + >>> s1c = DataTree(name="sub1C", parent=s1) + >>> s1ca = DataTree(name="sub1Ca", parent=s1c) + >>> print(RenderDataTree(root).by_attr("name")) + root + ├── sub0 + │ ├── sub0B + │ └── sub0A + └── sub1 + ├── sub1A + ├── sub1B + └── sub1C + └── sub1Ca + """ + + def get() -> Iterator[str]: + for pre, fill, node in self: + attr = ( + attrname(node) + if callable(attrname) + else getattr(node, attrname, "") + ) + if isinstance(attr, (list, tuple)): + lines = attr + else: + lines = str(attr).split("\n") + yield f"{pre}{lines[0]}" + for line in lines[1:]: + yield f"{fill}{line}" + + return "\n".join(get()) + + +def _is_last(iterable: Iterable) -> Iterator[tuple[DataTree, bool]]: + iter_ = iter(iterable) + try: + nextitem = next(iter_) + except StopIteration: + pass + else: + item = nextitem + while True: + try: + nextitem = next(iter_) + yield item, False + except StopIteration: + yield nextitem, True + break + item = nextitem diff --git a/test/fixtures/whole_applications/xarray/xarray/core/dtypes.py b/test/fixtures/whole_applications/xarray/xarray/core/dtypes.py new file mode 100644 index 0000000..2c3a43e --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/core/dtypes.py @@ -0,0 +1,266 @@ +from __future__ import annotations + +import functools +from typing import Any + +import numpy as np +from pandas.api.types import is_extension_array_dtype + +from xarray.core import array_api_compat, npcompat, utils + +# Use as a sentinel value to indicate a dtype appropriate NA value. +NA = utils.ReprObject("") + + +@functools.total_ordering +class AlwaysGreaterThan: + def __gt__(self, other): + return True + + def __eq__(self, other): + return isinstance(other, type(self)) + + +@functools.total_ordering +class AlwaysLessThan: + def __lt__(self, other): + return True + + def __eq__(self, other): + return isinstance(other, type(self)) + + +# Equivalence to np.inf (-np.inf) for object-type +INF = AlwaysGreaterThan() +NINF = AlwaysLessThan() + + +# Pairs of types that, if both found, should be promoted to object dtype +# instead of following NumPy's own type-promotion rules. These type promotion +# rules match pandas instead. For reference, see the NumPy type hierarchy: +# https://numpy.org/doc/stable/reference/arrays.scalars.html +PROMOTE_TO_OBJECT: tuple[tuple[type[np.generic], type[np.generic]], ...] = ( + (np.number, np.character), # numpy promotes to character + (np.bool_, np.character), # numpy promotes to character + (np.bytes_, np.str_), # numpy promotes to unicode +) + + +def maybe_promote(dtype: np.dtype) -> tuple[np.dtype, Any]: + """Simpler equivalent of pandas.core.common._maybe_promote + + Parameters + ---------- + dtype : np.dtype + + Returns + ------- + dtype : Promoted dtype that can hold missing values. + fill_value : Valid missing value for the promoted dtype. + """ + # N.B. these casting rules should match pandas + dtype_: np.typing.DTypeLike + fill_value: Any + if isdtype(dtype, "real floating"): + dtype_ = dtype + fill_value = np.nan + elif isinstance(dtype, np.dtype) and np.issubdtype(dtype, np.timedelta64): + # See https://github.com/numpy/numpy/issues/10685 + # np.timedelta64 is a subclass of np.integer + # Check np.timedelta64 before np.integer + fill_value = np.timedelta64("NaT") + dtype_ = dtype + elif isdtype(dtype, "integral"): + dtype_ = np.float32 if dtype.itemsize <= 2 else np.float64 + fill_value = np.nan + elif isdtype(dtype, "complex floating"): + dtype_ = dtype + fill_value = np.nan + np.nan * 1j + elif isinstance(dtype, np.dtype) and np.issubdtype(dtype, np.datetime64): + dtype_ = dtype + fill_value = np.datetime64("NaT") + else: + dtype_ = object + fill_value = np.nan + + dtype_out = np.dtype(dtype_) + fill_value = dtype_out.type(fill_value) + return dtype_out, fill_value + + +NAT_TYPES = {np.datetime64("NaT").dtype, np.timedelta64("NaT").dtype} + + +def get_fill_value(dtype): + """Return an appropriate fill value for this dtype. + + Parameters + ---------- + dtype : np.dtype + + Returns + ------- + fill_value : Missing value corresponding to this dtype. + """ + _, fill_value = maybe_promote(dtype) + return fill_value + + +def get_pos_infinity(dtype, max_for_int=False): + """Return an appropriate positive infinity for this dtype. + + Parameters + ---------- + dtype : np.dtype + max_for_int : bool + Return np.iinfo(dtype).max instead of np.inf + + Returns + ------- + fill_value : positive infinity value corresponding to this dtype. + """ + if isdtype(dtype, "real floating"): + return np.inf + + if isdtype(dtype, "integral"): + if max_for_int: + return np.iinfo(dtype).max + else: + return np.inf + + if isdtype(dtype, "complex floating"): + return np.inf + 1j * np.inf + + if isdtype(dtype, "bool"): + return True + + return np.array(INF, dtype=object) + + +def get_neg_infinity(dtype, min_for_int=False): + """Return an appropriate positive infinity for this dtype. + + Parameters + ---------- + dtype : np.dtype + min_for_int : bool + Return np.iinfo(dtype).min instead of -np.inf + + Returns + ------- + fill_value : positive infinity value corresponding to this dtype. + """ + if isdtype(dtype, "real floating"): + return -np.inf + + if isdtype(dtype, "integral"): + if min_for_int: + return np.iinfo(dtype).min + else: + return -np.inf + + if isdtype(dtype, "complex floating"): + return -np.inf - 1j * np.inf + + if isdtype(dtype, "bool"): + return False + + return np.array(NINF, dtype=object) + + +def is_datetime_like(dtype) -> bool: + """Check if a dtype is a subclass of the numpy datetime types""" + return _is_numpy_subdtype(dtype, (np.datetime64, np.timedelta64)) + + +def is_object(dtype) -> bool: + """Check if a dtype is object""" + return _is_numpy_subdtype(dtype, object) + + +def is_string(dtype) -> bool: + """Check if a dtype is a string dtype""" + return _is_numpy_subdtype(dtype, (np.str_, np.character)) + + +def _is_numpy_subdtype(dtype, kind) -> bool: + if not isinstance(dtype, np.dtype): + return False + + kinds = kind if isinstance(kind, tuple) else (kind,) + return any(np.issubdtype(dtype, kind) for kind in kinds) + + +def isdtype(dtype, kind: str | tuple[str, ...], xp=None) -> bool: + """Compatibility wrapper for isdtype() from the array API standard. + + Unlike xp.isdtype(), kind must be a string. + """ + # TODO(shoyer): remove this wrapper when Xarray requires + # numpy>=2 and pandas extensions arrays are implemented in + # Xarray via the array API + if not isinstance(kind, str) and not ( + isinstance(kind, tuple) and all(isinstance(k, str) for k in kind) + ): + raise TypeError(f"kind must be a string or a tuple of strings: {repr(kind)}") + + if isinstance(dtype, np.dtype): + return npcompat.isdtype(dtype, kind) + elif is_extension_array_dtype(dtype): + # we never want to match pandas extension array dtypes + return False + else: + if xp is None: + xp = np + return xp.isdtype(dtype, kind) + + +def preprocess_scalar_types(t): + if isinstance(t, (str, bytes)): + return type(t) + else: + return t + + +def result_type( + *arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike, + xp=None, +) -> np.dtype: + """Like np.result_type, but with type promotion rules matching pandas. + + Examples of changed behavior: + number + string -> object (not string) + bytes + unicode -> object (not unicode) + + Parameters + ---------- + *arrays_and_dtypes : list of arrays and dtypes + The dtype is extracted from both numpy and dask arrays. + + Returns + ------- + numpy.dtype for the result. + """ + # TODO (keewis): replace `array_api_compat.result_type` with `xp.result_type` once we + # can require a version of the Array API that supports passing scalars to it. + from xarray.core.duck_array_ops import get_array_namespace + + if xp is None: + xp = get_array_namespace(arrays_and_dtypes) + + types = { + array_api_compat.result_type(preprocess_scalar_types(t), xp=xp) + for t in arrays_and_dtypes + } + if any(isinstance(t, np.dtype) for t in types): + # only check if there's numpy dtypes – the array API does not + # define the types we're checking for + for left, right in PROMOTE_TO_OBJECT: + if any(np.issubdtype(t, left) for t in types) and any( + np.issubdtype(t, right) for t in types + ): + return np.dtype(object) + + return array_api_compat.result_type( + *map(preprocess_scalar_types, arrays_and_dtypes), xp=xp + ) diff --git a/test/fixtures/whole_applications/xarray/xarray/core/duck_array_ops.py b/test/fixtures/whole_applications/xarray/xarray/core/duck_array_ops.py new file mode 100644 index 0000000..8993c13 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/core/duck_array_ops.py @@ -0,0 +1,833 @@ +"""Compatibility module defining operations on duck numpy-arrays. + +Currently, this means Dask or NumPy arrays. None of these functions should +accept or return xarray objects. +""" + +from __future__ import annotations + +import contextlib +import datetime +import inspect +import warnings +from functools import partial +from importlib import import_module + +import numpy as np +import pandas as pd +from numpy import all as array_all # noqa +from numpy import any as array_any # noqa +from numpy import ( # noqa + around, # noqa + full_like, + gradient, + isclose, + isin, + isnat, + take, + tensordot, + transpose, + unravel_index, +) +from numpy import concatenate as _concatenate +from numpy.lib.stride_tricks import sliding_window_view # noqa +from packaging.version import Version +from pandas.api.types import is_extension_array_dtype + +from xarray.core import dask_array_ops, dtypes, nputils +from xarray.core.options import OPTIONS +from xarray.core.utils import is_duck_array, is_duck_dask_array, module_available +from xarray.namedarray import pycompat +from xarray.namedarray.parallelcompat import get_chunked_array_type +from xarray.namedarray.pycompat import array_type, is_chunked_array + +# remove once numpy 2.0 is the oldest supported version +if module_available("numpy", minversion="2.0.0.dev0"): + from numpy.lib.array_utils import ( # type: ignore[import-not-found,unused-ignore] + normalize_axis_index, + ) +else: + from numpy.core.multiarray import ( # type: ignore[attr-defined,no-redef,unused-ignore] + normalize_axis_index, + ) + + +dask_available = module_available("dask") + + +def get_array_namespace(*values): + def _get_array_namespace(x): + if hasattr(x, "__array_namespace__"): + return x.__array_namespace__() + else: + return np + + namespaces = {_get_array_namespace(t) for t in values} + non_numpy = namespaces - {np} + + if len(non_numpy) > 1: + raise TypeError( + "cannot deal with more than one type supporting the array API at the same time" + ) + elif non_numpy: + [xp] = non_numpy + else: + xp = np + + return xp + + +def einsum(*args, **kwargs): + from xarray.core.options import OPTIONS + + if OPTIONS["use_opt_einsum"] and module_available("opt_einsum"): + import opt_einsum + + return opt_einsum.contract(*args, **kwargs) + else: + return np.einsum(*args, **kwargs) + + +def _dask_or_eager_func( + name, + eager_module=np, + dask_module="dask.array", +): + """Create a function that dispatches to dask for dask array inputs.""" + + def f(*args, **kwargs): + if any(is_duck_dask_array(a) for a in args): + mod = ( + import_module(dask_module) + if isinstance(dask_module, str) + else dask_module + ) + wrapped = getattr(mod, name) + else: + wrapped = getattr(eager_module, name) + return wrapped(*args, **kwargs) + + return f + + +def fail_on_dask_array_input(values, msg=None, func_name=None): + if is_duck_dask_array(values): + if msg is None: + msg = "%r is not yet a valid method on dask arrays" + if func_name is None: + func_name = inspect.stack()[1][3] + raise NotImplementedError(msg % func_name) + + +# Requires special-casing because pandas won't automatically dispatch to dask.isnull via NEP-18 +pandas_isnull = _dask_or_eager_func("isnull", eager_module=pd, dask_module="dask.array") + +# np.around has failing doctests, overwrite it so they pass: +# https://github.com/numpy/numpy/issues/19759 +around.__doc__ = str.replace( + around.__doc__ or "", + "array([0., 2.])", + "array([0., 2.])", +) +around.__doc__ = str.replace( + around.__doc__ or "", + "array([0., 2.])", + "array([0., 2.])", +) +around.__doc__ = str.replace( + around.__doc__ or "", + "array([0.4, 1.6])", + "array([0.4, 1.6])", +) +around.__doc__ = str.replace( + around.__doc__ or "", + "array([0., 2., 2., 4., 4.])", + "array([0., 2., 2., 4., 4.])", +) +around.__doc__ = str.replace( + around.__doc__ or "", + ( + ' .. [2] "How Futile are Mindless Assessments of\n' + ' Roundoff in Floating-Point Computation?", William Kahan,\n' + " https://people.eecs.berkeley.edu/~wkahan/Mindless.pdf\n" + ), + "", +) + + +def isnull(data): + data = asarray(data) + + xp = get_array_namespace(data) + scalar_type = data.dtype + if dtypes.is_datetime_like(scalar_type): + # datetime types use NaT for null + # note: must check timedelta64 before integers, because currently + # timedelta64 inherits from np.integer + return isnat(data) + elif dtypes.isdtype(scalar_type, ("real floating", "complex floating"), xp=xp): + # float types use NaN for null + xp = get_array_namespace(data) + return xp.isnan(data) + elif dtypes.isdtype(scalar_type, ("bool", "integral"), xp=xp) or ( + isinstance(scalar_type, np.dtype) + and ( + np.issubdtype(scalar_type, np.character) + or np.issubdtype(scalar_type, np.void) + ) + ): + # these types cannot represent missing values + return full_like(data, dtype=bool, fill_value=False) + else: + # at this point, array should have dtype=object + if isinstance(data, np.ndarray) or is_extension_array_dtype(data): + return pandas_isnull(data) + else: + # Not reachable yet, but intended for use with other duck array + # types. For full consistency with pandas, we should accept None as + # a null value as well as NaN, but it isn't clear how to do this + # with duck typing. + return data != data + + +def notnull(data): + return ~isnull(data) + + +# TODO replace with simply np.ma.masked_invalid once numpy/numpy#16022 is fixed +masked_invalid = _dask_or_eager_func( + "masked_invalid", eager_module=np.ma, dask_module="dask.array.ma" +) + + +def trapz(y, x, axis): + if axis < 0: + axis = y.ndim + axis + x_sl1 = (slice(1, None),) + (None,) * (y.ndim - axis - 1) + x_sl2 = (slice(None, -1),) + (None,) * (y.ndim - axis - 1) + slice1 = (slice(None),) * axis + (slice(1, None),) + slice2 = (slice(None),) * axis + (slice(None, -1),) + dx = x[x_sl1] - x[x_sl2] + integrand = dx * 0.5 * (y[tuple(slice1)] + y[tuple(slice2)]) + return sum(integrand, axis=axis, skipna=False) + + +def cumulative_trapezoid(y, x, axis): + if axis < 0: + axis = y.ndim + axis + x_sl1 = (slice(1, None),) + (None,) * (y.ndim - axis - 1) + x_sl2 = (slice(None, -1),) + (None,) * (y.ndim - axis - 1) + slice1 = (slice(None),) * axis + (slice(1, None),) + slice2 = (slice(None),) * axis + (slice(None, -1),) + dx = x[x_sl1] - x[x_sl2] + integrand = dx * 0.5 * (y[tuple(slice1)] + y[tuple(slice2)]) + + # Pad so that 'axis' has same length in result as it did in y + pads = [(1, 0) if i == axis else (0, 0) for i in range(y.ndim)] + integrand = np.pad(integrand, pads, mode="constant", constant_values=0.0) + + return cumsum(integrand, axis=axis, skipna=False) + + +def astype(data, dtype, **kwargs): + if hasattr(data, "__array_namespace__"): + xp = get_array_namespace(data) + if xp == np: + # numpy currently doesn't have a astype: + return data.astype(dtype, **kwargs) + return xp.astype(data, dtype, **kwargs) + return data.astype(dtype, **kwargs) + + +def asarray(data, xp=np, dtype=None): + converted = data if is_duck_array(data) else xp.asarray(data) + + if dtype is None or converted.dtype == dtype: + return converted + + if xp is np or not hasattr(xp, "astype"): + return converted.astype(dtype) + else: + return xp.astype(converted, dtype) + + +def as_shared_dtype(scalars_or_arrays, xp=None): + """Cast a arrays to a shared dtype using xarray's type promotion rules.""" + if any(is_extension_array_dtype(x) for x in scalars_or_arrays): + extension_array_types = [ + x.dtype for x in scalars_or_arrays if is_extension_array_dtype(x) + ] + if len(extension_array_types) == len(scalars_or_arrays) and all( + isinstance(x, type(extension_array_types[0])) for x in extension_array_types + ): + return scalars_or_arrays + raise ValueError( + "Cannot cast arrays to shared type, found" + f" array types {[x.dtype for x in scalars_or_arrays]}" + ) + + # Avoid calling array_type("cupy") repeatidely in the any check + array_type_cupy = array_type("cupy") + if any(isinstance(x, array_type_cupy) for x in scalars_or_arrays): + import cupy as cp + + xp = cp + elif xp is None: + xp = get_array_namespace(scalars_or_arrays) + + # Pass arrays directly instead of dtypes to result_type so scalars + # get handled properly. + # Note that result_type() safely gets the dtype from dask arrays without + # evaluating them. + dtype = dtypes.result_type(*scalars_or_arrays, xp=xp) + + return [asarray(x, dtype=dtype, xp=xp) for x in scalars_or_arrays] + + +def broadcast_to(array, shape): + xp = get_array_namespace(array) + return xp.broadcast_to(array, shape) + + +def lazy_array_equiv(arr1, arr2): + """Like array_equal, but doesn't actually compare values. + Returns True when arr1, arr2 identical or their dask tokens are equal. + Returns False when shapes are not equal. + Returns None when equality cannot determined: one or both of arr1, arr2 are numpy arrays; + or their dask tokens are not equal + """ + if arr1 is arr2: + return True + arr1 = asarray(arr1) + arr2 = asarray(arr2) + if arr1.shape != arr2.shape: + return False + if dask_available and is_duck_dask_array(arr1) and is_duck_dask_array(arr2): + from dask.base import tokenize + + # GH3068, GH4221 + if tokenize(arr1) == tokenize(arr2): + return True + else: + return None + return None + + +def allclose_or_equiv(arr1, arr2, rtol=1e-5, atol=1e-8): + """Like np.allclose, but also allows values to be NaN in both arrays""" + arr1 = asarray(arr1) + arr2 = asarray(arr2) + + lazy_equiv = lazy_array_equiv(arr1, arr2) + if lazy_equiv is None: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered") + return bool(isclose(arr1, arr2, rtol=rtol, atol=atol, equal_nan=True).all()) + else: + return lazy_equiv + + +def array_equiv(arr1, arr2): + """Like np.array_equal, but also allows values to be NaN in both arrays""" + arr1 = asarray(arr1) + arr2 = asarray(arr2) + lazy_equiv = lazy_array_equiv(arr1, arr2) + if lazy_equiv is None: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "In the future, 'NAT == x'") + flag_array = (arr1 == arr2) | (isnull(arr1) & isnull(arr2)) + return bool(flag_array.all()) + else: + return lazy_equiv + + +def array_notnull_equiv(arr1, arr2): + """Like np.array_equal, but also allows values to be NaN in either or both + arrays + """ + arr1 = asarray(arr1) + arr2 = asarray(arr2) + lazy_equiv = lazy_array_equiv(arr1, arr2) + if lazy_equiv is None: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "In the future, 'NAT == x'") + flag_array = (arr1 == arr2) | isnull(arr1) | isnull(arr2) + return bool(flag_array.all()) + else: + return lazy_equiv + + +def count(data, axis=None): + """Count the number of non-NA in this array along the given axis or axes""" + return np.sum(np.logical_not(isnull(data)), axis=axis) + + +def sum_where(data, axis=None, dtype=None, where=None): + xp = get_array_namespace(data) + if where is not None: + a = where_method(xp.zeros_like(data), where, data) + else: + a = data + result = xp.sum(a, axis=axis, dtype=dtype) + return result + + +def where(condition, x, y): + """Three argument where() with better dtype promotion rules.""" + xp = get_array_namespace(condition) + return xp.where(condition, *as_shared_dtype([x, y], xp=xp)) + + +def where_method(data, cond, other=dtypes.NA): + if other is dtypes.NA: + other = dtypes.get_fill_value(data.dtype) + return where(cond, data, other) + + +def fillna(data, other): + # we need to pass data first so pint has a chance of returning the + # correct unit + # TODO: revert after https://github.com/hgrecco/pint/issues/1019 is fixed + return where(notnull(data), data, other) + + +def concatenate(arrays, axis=0): + """concatenate() with better dtype promotion rules.""" + # TODO: remove the additional check once `numpy` adds `concat` to its array namespace + if hasattr(arrays[0], "__array_namespace__") and not isinstance( + arrays[0], np.ndarray + ): + xp = get_array_namespace(arrays[0]) + return xp.concat(as_shared_dtype(arrays, xp=xp), axis=axis) + return _concatenate(as_shared_dtype(arrays), axis=axis) + + +def stack(arrays, axis=0): + """stack() with better dtype promotion rules.""" + xp = get_array_namespace(arrays[0]) + return xp.stack(as_shared_dtype(arrays, xp=xp), axis=axis) + + +def reshape(array, shape): + xp = get_array_namespace(array) + return xp.reshape(array, shape) + + +def ravel(array): + return reshape(array, (-1,)) + + +@contextlib.contextmanager +def _ignore_warnings_if(condition): + if condition: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + yield + else: + yield + + +def _create_nan_agg_method(name, coerce_strings=False, invariant_0d=False): + from xarray.core import nanops + + def f(values, axis=None, skipna=None, **kwargs): + if kwargs.pop("out", None) is not None: + raise TypeError(f"`out` is not valid for {name}") + + # The data is invariant in the case of 0d data, so do not + # change the data (and dtype) + # See https://github.com/pydata/xarray/issues/4885 + if invariant_0d and axis == (): + return values + + xp = get_array_namespace(values) + values = asarray(values, xp=xp) + + if coerce_strings and dtypes.is_string(values.dtype): + values = astype(values, object) + + func = None + if skipna or ( + skipna is None + and ( + dtypes.isdtype( + values.dtype, ("complex floating", "real floating"), xp=xp + ) + or dtypes.is_object(values.dtype) + ) + ): + nanname = "nan" + name + func = getattr(nanops, nanname) + else: + if name in ["sum", "prod"]: + kwargs.pop("min_count", None) + + xp = get_array_namespace(values) + func = getattr(xp, name) + + try: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "All-NaN slice encountered") + return func(values, axis=axis, **kwargs) + except AttributeError: + if not is_duck_dask_array(values): + raise + try: # dask/dask#3133 dask sometimes needs dtype argument + # if func does not accept dtype, then raises TypeError + return func(values, axis=axis, dtype=values.dtype, **kwargs) + except (AttributeError, TypeError): + raise NotImplementedError( + f"{name} is not yet implemented on dask arrays" + ) + + f.__name__ = name + return f + + +# Attributes `numeric_only`, `available_min_count` is used for docs. +# See ops.inject_reduce_methods +argmax = _create_nan_agg_method("argmax", coerce_strings=True) +argmin = _create_nan_agg_method("argmin", coerce_strings=True) +max = _create_nan_agg_method("max", coerce_strings=True, invariant_0d=True) +min = _create_nan_agg_method("min", coerce_strings=True, invariant_0d=True) +sum = _create_nan_agg_method("sum", invariant_0d=True) +sum.numeric_only = True +sum.available_min_count = True +std = _create_nan_agg_method("std") +std.numeric_only = True +var = _create_nan_agg_method("var") +var.numeric_only = True +median = _create_nan_agg_method("median", invariant_0d=True) +median.numeric_only = True +prod = _create_nan_agg_method("prod", invariant_0d=True) +prod.numeric_only = True +prod.available_min_count = True +cumprod_1d = _create_nan_agg_method("cumprod", invariant_0d=True) +cumprod_1d.numeric_only = True +cumsum_1d = _create_nan_agg_method("cumsum", invariant_0d=True) +cumsum_1d.numeric_only = True + + +_mean = _create_nan_agg_method("mean", invariant_0d=True) + + +def _datetime_nanmin(array): + """nanmin() function for datetime64. + + Caveats that this function deals with: + + - In numpy < 1.18, min() on datetime64 incorrectly ignores NaT + - numpy nanmin() don't work on datetime64 (all versions at the moment of writing) + - dask min() does not work on datetime64 (all versions at the moment of writing) + """ + dtype = array.dtype + assert dtypes.is_datetime_like(dtype) + # (NaT).astype(float) does not produce NaN... + array = where(pandas_isnull(array), np.nan, array.astype(float)) + array = min(array, skipna=True) + if isinstance(array, float): + array = np.array(array) + # ...but (NaN).astype("M8") does produce NaT + return array.astype(dtype) + + +def datetime_to_numeric(array, offset=None, datetime_unit=None, dtype=float): + """Convert an array containing datetime-like data to numerical values. + Convert the datetime array to a timedelta relative to an offset. + Parameters + ---------- + array : array-like + Input data + offset : None, datetime or cftime.datetime + Datetime offset. If None, this is set by default to the array's minimum + value to reduce round off errors. + datetime_unit : {None, Y, M, W, D, h, m, s, ms, us, ns, ps, fs, as} + If not None, convert output to a given datetime unit. Note that some + conversions are not allowed due to non-linear relationships between units. + dtype : dtype + Output dtype. + Returns + ------- + array + Numerical representation of datetime object relative to an offset. + Notes + ----- + Some datetime unit conversions won't work, for example from days to years, even + though some calendars would allow for them (e.g. no_leap). This is because there + is no `cftime.timedelta` object. + """ + # Set offset to minimum if not given + if offset is None: + if dtypes.is_datetime_like(array.dtype): + offset = _datetime_nanmin(array) + else: + offset = min(array) + + # Compute timedelta object. + # For np.datetime64, this can silently yield garbage due to overflow. + # One option is to enforce 1970-01-01 as the universal offset. + + # This map_blocks call is for backwards compatibility. + # dask == 2021.04.1 does not support subtracting object arrays + # which is required for cftime + if is_duck_dask_array(array) and dtypes.is_object(array.dtype): + array = array.map_blocks(lambda a, b: a - b, offset, meta=array._meta) + else: + array = array - offset + + # Scalar is converted to 0d-array + if not hasattr(array, "dtype"): + array = np.array(array) + + # Convert timedelta objects to float by first converting to microseconds. + if dtypes.is_object(array.dtype): + return py_timedelta_to_float(array, datetime_unit or "ns").astype(dtype) + + # Convert np.NaT to np.nan + elif dtypes.is_datetime_like(array.dtype): + # Convert to specified timedelta units. + if datetime_unit: + array = array / np.timedelta64(1, datetime_unit) + return np.where(isnull(array), np.nan, array.astype(dtype)) + + +def timedelta_to_numeric(value, datetime_unit="ns", dtype=float): + """Convert a timedelta-like object to numerical values. + + Parameters + ---------- + value : datetime.timedelta, numpy.timedelta64, pandas.Timedelta, str + Time delta representation. + datetime_unit : {Y, M, W, D, h, m, s, ms, us, ns, ps, fs, as} + The time units of the output values. Note that some conversions are not allowed due to + non-linear relationships between units. + dtype : type + The output data type. + + """ + import datetime as dt + + if isinstance(value, dt.timedelta): + out = py_timedelta_to_float(value, datetime_unit) + elif isinstance(value, np.timedelta64): + out = np_timedelta64_to_float(value, datetime_unit) + elif isinstance(value, pd.Timedelta): + out = pd_timedelta_to_float(value, datetime_unit) + elif isinstance(value, str): + try: + a = pd.to_timedelta(value) + except ValueError: + raise ValueError( + f"Could not convert {value!r} to timedelta64 using pandas.to_timedelta" + ) + return py_timedelta_to_float(a, datetime_unit) + else: + raise TypeError( + f"Expected value of type str, pandas.Timedelta, datetime.timedelta " + f"or numpy.timedelta64, but received {type(value).__name__}" + ) + return out.astype(dtype) + + +def _to_pytimedelta(array, unit="us"): + return array.astype(f"timedelta64[{unit}]").astype(datetime.timedelta) + + +def np_timedelta64_to_float(array, datetime_unit): + """Convert numpy.timedelta64 to float. + + Notes + ----- + The array is first converted to microseconds, which is less likely to + cause overflow errors. + """ + array = array.astype("timedelta64[ns]").astype(np.float64) + conversion_factor = np.timedelta64(1, "ns") / np.timedelta64(1, datetime_unit) + return conversion_factor * array + + +def pd_timedelta_to_float(value, datetime_unit): + """Convert pandas.Timedelta to float. + + Notes + ----- + Built on the assumption that pandas timedelta values are in nanoseconds, + which is also the numpy default resolution. + """ + value = value.to_timedelta64() + return np_timedelta64_to_float(value, datetime_unit) + + +def _timedelta_to_seconds(array): + if isinstance(array, datetime.timedelta): + return array.total_seconds() * 1e6 + else: + return np.reshape([a.total_seconds() for a in array.ravel()], array.shape) * 1e6 + + +def py_timedelta_to_float(array, datetime_unit): + """Convert a timedelta object to a float, possibly at a loss of resolution.""" + array = asarray(array) + if is_duck_dask_array(array): + array = array.map_blocks( + _timedelta_to_seconds, meta=np.array([], dtype=np.float64) + ) + else: + array = _timedelta_to_seconds(array) + conversion_factor = np.timedelta64(1, "us") / np.timedelta64(1, datetime_unit) + return conversion_factor * array + + +def mean(array, axis=None, skipna=None, **kwargs): + """inhouse mean that can handle np.datetime64 or cftime.datetime + dtypes""" + from xarray.core.common import _contains_cftime_datetimes + + array = asarray(array) + if dtypes.is_datetime_like(array.dtype): + offset = _datetime_nanmin(array) + + # xarray always uses np.datetime64[ns] for np.datetime64 data + dtype = "timedelta64[ns]" + return ( + _mean( + datetime_to_numeric(array, offset), axis=axis, skipna=skipna, **kwargs + ).astype(dtype) + + offset + ) + elif _contains_cftime_datetimes(array): + offset = min(array) + timedeltas = datetime_to_numeric(array, offset, datetime_unit="us") + mean_timedeltas = _mean(timedeltas, axis=axis, skipna=skipna, **kwargs) + return _to_pytimedelta(mean_timedeltas, unit="us") + offset + else: + return _mean(array, axis=axis, skipna=skipna, **kwargs) + + +mean.numeric_only = True # type: ignore[attr-defined] + + +def _nd_cum_func(cum_func, array, axis, **kwargs): + array = asarray(array) + if axis is None: + axis = tuple(range(array.ndim)) + if isinstance(axis, int): + axis = (axis,) + + out = array + for ax in axis: + out = cum_func(out, axis=ax, **kwargs) + return out + + +def cumprod(array, axis=None, **kwargs): + """N-dimensional version of cumprod.""" + return _nd_cum_func(cumprod_1d, array, axis, **kwargs) + + +def cumsum(array, axis=None, **kwargs): + """N-dimensional version of cumsum.""" + return _nd_cum_func(cumsum_1d, array, axis, **kwargs) + + +def first(values, axis, skipna=None): + """Return the first non-NA elements in this array along the given axis""" + if (skipna or skipna is None) and not ( + dtypes.isdtype(values.dtype, "signed integer") or dtypes.is_string(values.dtype) + ): + # only bother for dtypes that can hold NaN + if is_chunked_array(values): + return chunked_nanfirst(values, axis) + else: + return nputils.nanfirst(values, axis) + return take(values, 0, axis=axis) + + +def last(values, axis, skipna=None): + """Return the last non-NA elements in this array along the given axis""" + if (skipna or skipna is None) and not ( + dtypes.isdtype(values.dtype, "signed integer") or dtypes.is_string(values.dtype) + ): + # only bother for dtypes that can hold NaN + if is_chunked_array(values): + return chunked_nanlast(values, axis) + else: + return nputils.nanlast(values, axis) + return take(values, -1, axis=axis) + + +def least_squares(lhs, rhs, rcond=None, skipna=False): + """Return the coefficients and residuals of a least-squares fit.""" + if is_duck_dask_array(rhs): + return dask_array_ops.least_squares(lhs, rhs, rcond=rcond, skipna=skipna) + else: + return nputils.least_squares(lhs, rhs, rcond=rcond, skipna=skipna) + + +def _push(array, n: int | None = None, axis: int = -1): + """ + Use either bottleneck or numbagg depending on options & what's available + """ + + if not OPTIONS["use_bottleneck"] and not OPTIONS["use_numbagg"]: + raise RuntimeError( + "ffill & bfill requires bottleneck or numbagg to be enabled." + " Call `xr.set_options(use_bottleneck=True)` or `xr.set_options(use_numbagg=True)` to enable one." + ) + if OPTIONS["use_numbagg"] and module_available("numbagg"): + import numbagg + + if pycompat.mod_version("numbagg") < Version("0.6.2"): + warnings.warn( + f"numbagg >= 0.6.2 is required for bfill & ffill; {pycompat.mod_version('numbagg')} is installed. We'll attempt with bottleneck instead." + ) + else: + return numbagg.ffill(array, limit=n, axis=axis) + + # work around for bottleneck 178 + limit = n if n is not None else array.shape[axis] + + import bottleneck as bn + + return bn.push(array, limit, axis) + + +def push(array, n, axis): + if not OPTIONS["use_bottleneck"] and not OPTIONS["use_numbagg"]: + raise RuntimeError( + "ffill & bfill requires bottleneck or numbagg to be enabled." + " Call `xr.set_options(use_bottleneck=True)` or `xr.set_options(use_numbagg=True)` to enable one." + ) + if is_duck_dask_array(array): + return dask_array_ops.push(array, n, axis) + else: + return _push(array, n, axis) + + +def _first_last_wrapper(array, *, axis, op, keepdims): + return op(array, axis, keepdims=keepdims) + + +def _chunked_first_or_last(darray, axis, op): + chunkmanager = get_chunked_array_type(darray) + + # This will raise the same error message seen for numpy + axis = normalize_axis_index(axis, darray.ndim) + + wrapped_op = partial(_first_last_wrapper, op=op) + return chunkmanager.reduction( + darray, + func=wrapped_op, + aggregate_func=wrapped_op, + axis=axis, + dtype=darray.dtype, + keepdims=False, # match numpy version + ) + + +def chunked_nanfirst(darray, axis): + return _chunked_first_or_last(darray, axis, op=nputils.nanfirst) + + +def chunked_nanlast(darray, axis): + return _chunked_first_or_last(darray, axis, op=nputils.nanlast) diff --git a/test/fixtures/whole_applications/xarray/xarray/core/extension_array.py b/test/fixtures/whole_applications/xarray/xarray/core/extension_array.py new file mode 100644 index 0000000..c8b4fa8 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/core/extension_array.py @@ -0,0 +1,134 @@ +from __future__ import annotations + +from collections.abc import Sequence +from typing import Callable, Generic + +import numpy as np +import pandas as pd +from pandas.api.types import is_extension_array_dtype + +from xarray.core.types import DTypeLikeSave, T_ExtensionArray + +HANDLED_EXTENSION_ARRAY_FUNCTIONS: dict[Callable, Callable] = {} + + +def implements(numpy_function): + """Register an __array_function__ implementation for MyArray objects.""" + + def decorator(func): + HANDLED_EXTENSION_ARRAY_FUNCTIONS[numpy_function] = func + return func + + return decorator + + +@implements(np.issubdtype) +def __extension_duck_array__issubdtype( + extension_array_dtype: T_ExtensionArray, other_dtype: DTypeLikeSave +) -> bool: + return False # never want a function to think a pandas extension dtype is a subtype of numpy + + +@implements(np.broadcast_to) +def __extension_duck_array__broadcast(arr: T_ExtensionArray, shape: tuple): + if shape[0] == len(arr) and len(shape) == 1: + return arr + raise NotImplementedError("Cannot broadcast 1d-only pandas categorical array.") + + +@implements(np.stack) +def __extension_duck_array__stack(arr: T_ExtensionArray, axis: int): + raise NotImplementedError("Cannot stack 1d-only pandas categorical array.") + + +@implements(np.concatenate) +def __extension_duck_array__concatenate( + arrays: Sequence[T_ExtensionArray], axis: int = 0, out=None +) -> T_ExtensionArray: + return type(arrays[0])._concat_same_type(arrays) + + +@implements(np.where) +def __extension_duck_array__where( + condition: np.ndarray, x: T_ExtensionArray, y: T_ExtensionArray +) -> T_ExtensionArray: + if ( + isinstance(x, pd.Categorical) + and isinstance(y, pd.Categorical) + and x.dtype != y.dtype + ): + x = x.add_categories(set(y.categories).difference(set(x.categories))) + y = y.add_categories(set(x.categories).difference(set(y.categories))) + return pd.Series(x).where(condition, pd.Series(y)).array + + +class PandasExtensionArray(Generic[T_ExtensionArray]): + array: T_ExtensionArray + + def __init__(self, array: T_ExtensionArray): + """NEP-18 compliant wrapper for pandas extension arrays. + + Parameters + ---------- + array : T_ExtensionArray + The array to be wrapped upon e.g,. :py:class:`xarray.Variable` creation. + ``` + """ + if not isinstance(array, pd.api.extensions.ExtensionArray): + raise TypeError(f"{array} is not an pandas ExtensionArray.") + self.array = array + + def __array_function__(self, func, types, args, kwargs): + def replace_duck_with_extension_array(args) -> list: + args_as_list = list(args) + for index, value in enumerate(args_as_list): + if isinstance(value, PandasExtensionArray): + args_as_list[index] = value.array + elif isinstance( + value, tuple + ): # should handle more than just tuple? iterable? + args_as_list[index] = tuple( + replace_duck_with_extension_array(value) + ) + elif isinstance(value, list): + args_as_list[index] = replace_duck_with_extension_array(value) + return args_as_list + + args = tuple(replace_duck_with_extension_array(args)) + if func not in HANDLED_EXTENSION_ARRAY_FUNCTIONS: + return func(*args, **kwargs) + res = HANDLED_EXTENSION_ARRAY_FUNCTIONS[func](*args, **kwargs) + if is_extension_array_dtype(res): + return type(self)[type(res)](res) + return res + + def __array_ufunc__(ufunc, method, *inputs, **kwargs): + return ufunc(*inputs, **kwargs) + + def __repr__(self): + return f"{type(self)}(array={repr(self.array)})" + + def __getattr__(self, attr: str) -> object: + return getattr(self.array, attr) + + def __getitem__(self, key) -> PandasExtensionArray[T_ExtensionArray]: + item = self.array[key] + if is_extension_array_dtype(item): + return type(self)(item) + if np.isscalar(item): + return type(self)(type(self.array)([item])) + return item + + def __setitem__(self, key, val): + self.array[key] = val + + def __eq__(self, other): + if isinstance(other, PandasExtensionArray): + return self.array == other.array + return self.array == other + + def __ne__(self, other): + return ~(self == other) + + def __len__(self): + return len(self.array) diff --git a/test/fixtures/whole_applications/xarray/xarray/core/extensions.py b/test/fixtures/whole_applications/xarray/xarray/core/extensions.py new file mode 100644 index 0000000..9ebbd56 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/core/extensions.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +import warnings + +from xarray.core.dataarray import DataArray +from xarray.core.dataset import Dataset +from xarray.core.datatree import DataTree + + +class AccessorRegistrationWarning(Warning): + """Warning for conflicts in accessor registration.""" + + +class _CachedAccessor: + """Custom property-like object (descriptor) for caching accessors.""" + + def __init__(self, name, accessor): + self._name = name + self._accessor = accessor + + def __get__(self, obj, cls): + if obj is None: + # we're accessing the attribute of the class, i.e., Dataset.geo + return self._accessor + + # Use the same dict as @pandas.util.cache_readonly. + # It must be explicitly declared in obj.__slots__. + try: + cache = obj._cache + except AttributeError: + cache = obj._cache = {} + + try: + return cache[self._name] + except KeyError: + pass + + try: + accessor_obj = self._accessor(obj) + except AttributeError: + # __getattr__ on data object will swallow any AttributeErrors + # raised when initializing the accessor, so we need to raise as + # something else (GH933): + raise RuntimeError(f"error initializing {self._name!r} accessor.") + + cache[self._name] = accessor_obj + return accessor_obj + + +def _register_accessor(name, cls): + def decorator(accessor): + if hasattr(cls, name): + warnings.warn( + f"registration of accessor {accessor!r} under name {name!r} for type {cls!r} is " + "overriding a preexisting attribute with the same name.", + AccessorRegistrationWarning, + stacklevel=2, + ) + setattr(cls, name, _CachedAccessor(name, accessor)) + return accessor + + return decorator + + +def register_dataarray_accessor(name): + """Register a custom accessor on xarray.DataArray objects. + + Parameters + ---------- + name : str + Name under which the accessor should be registered. A warning is issued + if this name conflicts with a preexisting attribute. + + See Also + -------- + register_dataset_accessor + """ + return _register_accessor(name, DataArray) + + +def register_dataset_accessor(name): + """Register a custom property on xarray.Dataset objects. + + Parameters + ---------- + name : str + Name under which the accessor should be registered. A warning is issued + if this name conflicts with a preexisting attribute. + + Examples + -------- + In your library code: + + >>> @xr.register_dataset_accessor("geo") + ... class GeoAccessor: + ... def __init__(self, xarray_obj): + ... self._obj = xarray_obj + ... + ... @property + ... def center(self): + ... # return the geographic center point of this dataset + ... lon = self._obj.latitude + ... lat = self._obj.longitude + ... return (float(lon.mean()), float(lat.mean())) + ... + ... def plot(self): + ... # plot this array's data on a map, e.g., using Cartopy + ... pass + ... + + Back in an interactive IPython session: + + >>> ds = xr.Dataset( + ... {"longitude": np.linspace(0, 10), "latitude": np.linspace(0, 20)} + ... ) + >>> ds.geo.center + (10.0, 5.0) + >>> ds.geo.plot() # plots data on a map + + See Also + -------- + register_dataarray_accessor + """ + return _register_accessor(name, Dataset) + + +def register_datatree_accessor(name): + """Register a custom accessor on DataTree objects. + + Parameters + ---------- + name : str + Name under which the accessor should be registered. A warning is issued + if this name conflicts with a preexisting attribute. + + See Also + -------- + xarray.register_dataarray_accessor + xarray.register_dataset_accessor + """ + return _register_accessor(name, DataTree) diff --git a/test/fixtures/whole_applications/xarray/xarray/core/formatting.py b/test/fixtures/whole_applications/xarray/xarray/core/formatting.py new file mode 100644 index 0000000..ad65a44 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/core/formatting.py @@ -0,0 +1,1116 @@ +"""String formatting routines for __repr__. +""" + +from __future__ import annotations + +import contextlib +import functools +import math +from collections import defaultdict +from collections.abc import Collection, Hashable, Sequence +from datetime import datetime, timedelta +from itertools import chain, zip_longest +from reprlib import recursive_repr +from textwrap import dedent +from typing import TYPE_CHECKING + +import numpy as np +import pandas as pd +from pandas.errors import OutOfBoundsDatetime + +from xarray.core.datatree_render import RenderDataTree +from xarray.core.duck_array_ops import array_equiv, astype +from xarray.core.indexing import MemoryCachedArray +from xarray.core.iterators import LevelOrderIter +from xarray.core.options import OPTIONS, _get_boolean_with_default +from xarray.core.utils import is_duck_array +from xarray.namedarray.pycompat import array_type, to_duck_array, to_numpy + +if TYPE_CHECKING: + from xarray.core.coordinates import AbstractCoordinates + from xarray.core.datatree import DataTree + +UNITS = ("B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB") + + +def pretty_print(x, numchars: int): + """Given an object `x`, call `str(x)` and format the returned string so + that it is numchars long, padding with trailing spaces or truncating with + ellipses as necessary + """ + s = maybe_truncate(x, numchars) + return s + " " * max(numchars - len(s), 0) + + +def maybe_truncate(obj, maxlen=500): + s = str(obj) + if len(s) > maxlen: + s = s[: (maxlen - 3)] + "..." + return s + + +def wrap_indent(text, start="", length=None): + if length is None: + length = len(start) + indent = "\n" + " " * length + return start + indent.join(x for x in text.splitlines()) + + +def _get_indexer_at_least_n_items(shape, n_desired, from_end): + assert 0 < n_desired <= math.prod(shape) + cum_items = np.cumprod(shape[::-1]) + n_steps = np.argmax(cum_items >= n_desired) + stop = math.ceil(float(n_desired) / np.r_[1, cum_items][n_steps]) + indexer = ( + ((-1 if from_end else 0),) * (len(shape) - 1 - n_steps) + + ((slice(-stop, None) if from_end else slice(stop)),) + + (slice(None),) * n_steps + ) + return indexer + + +def first_n_items(array, n_desired): + """Returns the first n_desired items of an array""" + # Unfortunately, we can't just do array.flat[:n_desired] here because it + # might not be a numpy.ndarray. Moreover, access to elements of the array + # could be very expensive (e.g. if it's only available over DAP), so go out + # of our way to get them in a single call to __getitem__ using only slices. + from xarray.core.variable import Variable + + if n_desired < 1: + raise ValueError("must request at least one item") + + if array.size == 0: + # work around for https://github.com/numpy/numpy/issues/5195 + return [] + + if n_desired < array.size: + indexer = _get_indexer_at_least_n_items(array.shape, n_desired, from_end=False) + array = array[indexer] + + # We pass variable objects in to handle indexing + # with indexer above. It would not work with our + # lazy indexing classes at the moment, so we cannot + # pass Variable._data + if isinstance(array, Variable): + array = array._data + return np.ravel(to_duck_array(array))[:n_desired] + + +def last_n_items(array, n_desired): + """Returns the last n_desired items of an array""" + # Unfortunately, we can't just do array.flat[-n_desired:] here because it + # might not be a numpy.ndarray. Moreover, access to elements of the array + # could be very expensive (e.g. if it's only available over DAP), so go out + # of our way to get them in a single call to __getitem__ using only slices. + from xarray.core.variable import Variable + + if (n_desired == 0) or (array.size == 0): + return [] + + if n_desired < array.size: + indexer = _get_indexer_at_least_n_items(array.shape, n_desired, from_end=True) + array = array[indexer] + + # We pass variable objects in to handle indexing + # with indexer above. It would not work with our + # lazy indexing classes at the moment, so we cannot + # pass Variable._data + if isinstance(array, Variable): + array = array._data + return np.ravel(to_duck_array(array))[-n_desired:] + + +def last_item(array): + """Returns the last item of an array in a list or an empty list.""" + if array.size == 0: + # work around for https://github.com/numpy/numpy/issues/5195 + return [] + + indexer = (slice(-1, None),) * array.ndim + # to_numpy since dask doesn't support tolist + return np.ravel(to_numpy(array[indexer])).tolist() + + +def calc_max_rows_first(max_rows: int) -> int: + """Calculate the first rows to maintain the max number of rows.""" + return max_rows // 2 + max_rows % 2 + + +def calc_max_rows_last(max_rows: int) -> int: + """Calculate the last rows to maintain the max number of rows.""" + return max_rows // 2 + + +def format_timestamp(t): + """Cast given object to a Timestamp and return a nicely formatted string""" + try: + timestamp = pd.Timestamp(t) + datetime_str = timestamp.isoformat(sep=" ") + except OutOfBoundsDatetime: + datetime_str = str(t) + + try: + date_str, time_str = datetime_str.split() + except ValueError: + # catch NaT and others that don't split nicely + return datetime_str + else: + if time_str == "00:00:00": + return date_str + else: + return f"{date_str}T{time_str}" + + +def format_timedelta(t, timedelta_format=None): + """Cast given object to a Timestamp and return a nicely formatted string""" + timedelta_str = str(pd.Timedelta(t)) + try: + days_str, time_str = timedelta_str.split(" days ") + except ValueError: + # catch NaT and others that don't split nicely + return timedelta_str + else: + if timedelta_format == "date": + return days_str + " days" + elif timedelta_format == "time": + return time_str + else: + return timedelta_str + + +def format_item(x, timedelta_format=None, quote_strings=True): + """Returns a succinct summary of an object as a string""" + if isinstance(x, (np.datetime64, datetime)): + return format_timestamp(x) + if isinstance(x, (np.timedelta64, timedelta)): + return format_timedelta(x, timedelta_format=timedelta_format) + elif isinstance(x, (str, bytes)): + if hasattr(x, "dtype"): + x = x.item() + return repr(x) if quote_strings else x + elif hasattr(x, "dtype") and np.issubdtype(x.dtype, np.floating): + return f"{x.item():.4}" + else: + return str(x) + + +def format_items(x): + """Returns a succinct summaries of all items in a sequence as strings""" + x = to_duck_array(x) + timedelta_format = "datetime" + if np.issubdtype(x.dtype, np.timedelta64): + x = astype(x, dtype="timedelta64[ns]") + day_part = x[~pd.isnull(x)].astype("timedelta64[D]").astype("timedelta64[ns]") + time_needed = x[~pd.isnull(x)] != day_part + day_needed = day_part != np.timedelta64(0, "ns") + if np.logical_not(day_needed).all(): + timedelta_format = "time" + elif np.logical_not(time_needed).all(): + timedelta_format = "date" + + formatted = [format_item(xi, timedelta_format) for xi in x] + return formatted + + +def format_array_flat(array, max_width: int): + """Return a formatted string for as many items in the flattened version of + array that will fit within max_width characters. + """ + # every item will take up at least two characters, but we always want to + # print at least first and last items + max_possibly_relevant = min(max(array.size, 1), max(math.ceil(max_width / 2.0), 2)) + relevant_front_items = format_items( + first_n_items(array, (max_possibly_relevant + 1) // 2) + ) + relevant_back_items = format_items(last_n_items(array, max_possibly_relevant // 2)) + # interleave relevant front and back items: + # [a, b, c] and [y, z] -> [a, z, b, y, c] + relevant_items = sum( + zip_longest(relevant_front_items, reversed(relevant_back_items)), () + )[:max_possibly_relevant] + + cum_len = np.cumsum([len(s) + 1 for s in relevant_items]) - 1 + if (array.size > 2) and ( + (max_possibly_relevant < array.size) or (cum_len > max_width).any() + ): + padding = " ... " + max_len = max(int(np.argmax(cum_len + len(padding) - 1 > max_width)), 2) + count = min(array.size, max_len) + else: + count = array.size + padding = "" if (count <= 1) else " " + + num_front = (count + 1) // 2 + num_back = count - num_front + # note that num_back is 0 <--> array.size is 0 or 1 + # <--> relevant_back_items is [] + pprint_str = "".join( + [ + " ".join(relevant_front_items[:num_front]), + padding, + " ".join(relevant_back_items[-num_back:]), + ] + ) + + # As a final check, if it's still too long even with the limit in values, + # replace the end with an ellipsis + # NB: this will still returns a full 3-character ellipsis when max_width < 3 + if len(pprint_str) > max_width: + pprint_str = pprint_str[: max(max_width - 3, 0)] + "..." + + return pprint_str + + +# mapping of tuple[modulename, classname] to repr +_KNOWN_TYPE_REPRS = { + ("numpy", "ndarray"): "np.ndarray", + ("sparse._coo.core", "COO"): "sparse.COO", +} + + +def inline_dask_repr(array): + """Similar to dask.array.DataArray.__repr__, but without + redundant information that's already printed by the repr + function of the xarray wrapper. + """ + assert isinstance(array, array_type("dask")), array + + chunksize = tuple(c[0] for c in array.chunks) + + if hasattr(array, "_meta"): + meta = array._meta + identifier = (type(meta).__module__, type(meta).__name__) + meta_repr = _KNOWN_TYPE_REPRS.get(identifier, ".".join(identifier)) + meta_string = f", meta={meta_repr}" + else: + meta_string = "" + + return f"dask.array" + + +def inline_sparse_repr(array): + """Similar to sparse.COO.__repr__, but without the redundant shape/dtype.""" + sparse_array_type = array_type("sparse") + assert isinstance(array, sparse_array_type), array + return ( + f"<{type(array).__name__}: nnz={array.nnz:d}, fill_value={array.fill_value!s}>" + ) + + +def inline_variable_array_repr(var, max_width): + """Build a one-line summary of a variable's data.""" + if hasattr(var._data, "_repr_inline_"): + return var._data._repr_inline_(max_width) + if var._in_memory: + return format_array_flat(var, max_width) + dask_array_type = array_type("dask") + if isinstance(var._data, dask_array_type): + return inline_dask_repr(var.data) + sparse_array_type = array_type("sparse") + if isinstance(var._data, sparse_array_type): + return inline_sparse_repr(var.data) + if hasattr(var._data, "__array_function__"): + return maybe_truncate(repr(var._data).replace("\n", " "), max_width) + # internal xarray array type + return "..." + + +def summarize_variable( + name: Hashable, + var, + col_width: int, + max_width: int | None = None, + is_index: bool = False, +): + """Summarize a variable in one line, e.g., for the Dataset.__repr__.""" + variable = getattr(var, "variable", var) + + if max_width is None: + max_width_options = OPTIONS["display_width"] + if not isinstance(max_width_options, int): + raise TypeError(f"`max_width` value of `{max_width}` is not a valid int") + else: + max_width = max_width_options + + marker = "*" if is_index else " " + first_col = pretty_print(f" {marker} {name} ", col_width) + + if variable.dims: + dims_str = "({}) ".format(", ".join(map(str, variable.dims))) + else: + dims_str = "" + + nbytes_str = f" {render_human_readable_nbytes(variable.nbytes)}" + front_str = f"{first_col}{dims_str}{variable.dtype}{nbytes_str} " + + values_width = max_width - len(front_str) + values_str = inline_variable_array_repr(variable, values_width) + + return front_str + values_str + + +def summarize_attr(key, value, col_width=None): + """Summary for __repr__ - use ``X.attrs[key]`` for full value.""" + # Indent key and add ':', then right-pad if col_width is not None + k_str = f" {key}:" + if col_width is not None: + k_str = pretty_print(k_str, col_width) + # Replace tabs and newlines, so we print on one line in known width + v_str = str(value).replace("\t", "\\t").replace("\n", "\\n") + # Finally, truncate to the desired display width + return maybe_truncate(f"{k_str} {v_str}", OPTIONS["display_width"]) + + +EMPTY_REPR = " *empty*" + + +def _calculate_col_width(col_items): + max_name_length = max((len(str(s)) for s in col_items), default=0) + col_width = max(max_name_length, 7) + 6 + return col_width + + +def _mapping_repr( + mapping, + title, + summarizer, + expand_option_name, + col_width=None, + max_rows=None, + indexes=None, +): + if col_width is None: + col_width = _calculate_col_width(mapping) + + summarizer_kwargs = defaultdict(dict) + if indexes is not None: + summarizer_kwargs = {k: {"is_index": k in indexes} for k in mapping} + + summary = [f"{title}:"] + if mapping: + len_mapping = len(mapping) + if not _get_boolean_with_default(expand_option_name, default=True): + summary = [f"{summary[0]} ({len_mapping})"] + elif max_rows is not None and len_mapping > max_rows: + summary = [f"{summary[0]} ({max_rows}/{len_mapping})"] + first_rows = calc_max_rows_first(max_rows) + keys = list(mapping.keys()) + summary += [ + summarizer(k, mapping[k], col_width, **summarizer_kwargs[k]) + for k in keys[:first_rows] + ] + if max_rows > 1: + last_rows = calc_max_rows_last(max_rows) + summary += [pretty_print(" ...", col_width) + " ..."] + summary += [ + summarizer(k, mapping[k], col_width, **summarizer_kwargs[k]) + for k in keys[-last_rows:] + ] + else: + summary += [ + summarizer(k, v, col_width, **summarizer_kwargs[k]) + for k, v in mapping.items() + ] + else: + summary += [EMPTY_REPR] + return "\n".join(summary) + + +data_vars_repr = functools.partial( + _mapping_repr, + title="Data variables", + summarizer=summarize_variable, + expand_option_name="display_expand_data_vars", +) + +attrs_repr = functools.partial( + _mapping_repr, + title="Attributes", + summarizer=summarize_attr, + expand_option_name="display_expand_attrs", +) + + +def coords_repr(coords: AbstractCoordinates, col_width=None, max_rows=None): + if col_width is None: + col_width = _calculate_col_width(coords) + return _mapping_repr( + coords, + title="Coordinates", + summarizer=summarize_variable, + expand_option_name="display_expand_coords", + col_width=col_width, + indexes=coords.xindexes, + max_rows=max_rows, + ) + + +def inline_index_repr(index: pd.Index, max_width=None): + if hasattr(index, "_repr_inline_"): + repr_ = index._repr_inline_(max_width=max_width) + else: + # fallback for the `pandas.Index` subclasses from + # `Indexes.get_pandas_indexes` / `xr_obj.indexes` + repr_ = repr(index) + + return repr_ + + +def summarize_index( + names: tuple[Hashable, ...], + index, + col_width: int, + max_width: int | None = None, +) -> str: + if max_width is None: + max_width = OPTIONS["display_width"] + + def prefixes(length: int) -> list[str]: + if length in (0, 1): + return [" "] + + return ["┌"] + ["│"] * max(length - 2, 0) + ["└"] + + preformatted = [ + pretty_print(f" {prefix} {name}", col_width) + for prefix, name in zip(prefixes(len(names)), names) + ] + + head, *tail = preformatted + index_width = max_width - len(head) + repr_ = inline_index_repr(index, max_width=index_width) + return "\n".join([head + repr_] + [line.rstrip() for line in tail]) + + +def filter_nondefault_indexes(indexes, filter_indexes: bool): + from xarray.core.indexes import PandasIndex, PandasMultiIndex + + if not filter_indexes: + return indexes + + default_indexes = (PandasIndex, PandasMultiIndex) + + return { + key: index + for key, index in indexes.items() + if not isinstance(index, default_indexes) + } + + +def indexes_repr(indexes, max_rows: int | None = None) -> str: + col_width = _calculate_col_width(chain.from_iterable(indexes)) + + return _mapping_repr( + indexes, + "Indexes", + summarize_index, + "display_expand_indexes", + col_width=col_width, + max_rows=max_rows, + ) + + +def dim_summary(obj): + elements = [f"{k}: {v}" for k, v in obj.sizes.items()] + return ", ".join(elements) + + +def _element_formatter( + elements: Collection[Hashable], + col_width: int, + max_rows: int | None = None, + delimiter: str = ", ", +) -> str: + """ + Formats elements for better readability. + + Once it becomes wider than the display width it will create a newline and + continue indented to col_width. + Once there are more rows than the maximum displayed rows it will start + removing rows. + + Parameters + ---------- + elements : Collection of hashable + Elements to join together. + col_width : int + The width to indent to if a newline has been made. + max_rows : int, optional + The maximum number of allowed rows. The default is None. + delimiter : str, optional + Delimiter to use between each element. The default is ", ". + """ + elements_len = len(elements) + out = [""] + length_row = 0 + for i, v in enumerate(elements): + delim = delimiter if i < elements_len - 1 else "" + v_delim = f"{v}{delim}" + length_element = len(v_delim) + length_row += length_element + + # Create a new row if the next elements makes the print wider than + # the maximum display width: + if col_width + length_row > OPTIONS["display_width"]: + out[-1] = out[-1].rstrip() # Remove trailing whitespace. + out.append("\n" + pretty_print("", col_width) + v_delim) + length_row = length_element + else: + out[-1] += v_delim + + # If there are too many rows of dimensions trim some away: + if max_rows and (len(out) > max_rows): + first_rows = calc_max_rows_first(max_rows) + last_rows = calc_max_rows_last(max_rows) + out = ( + out[:first_rows] + + ["\n" + pretty_print("", col_width) + "..."] + + (out[-last_rows:] if max_rows > 1 else []) + ) + return "".join(out) + + +def dim_summary_limited(obj, col_width: int, max_rows: int | None = None) -> str: + elements = [f"{k}: {v}" for k, v in obj.sizes.items()] + return _element_formatter(elements, col_width, max_rows) + + +def unindexed_dims_repr(dims, coords, max_rows: int | None = None): + unindexed_dims = [d for d in dims if d not in coords] + if unindexed_dims: + dims_start = "Dimensions without coordinates: " + dims_str = _element_formatter( + unindexed_dims, col_width=len(dims_start), max_rows=max_rows + ) + return dims_start + dims_str + else: + return None + + +@contextlib.contextmanager +def set_numpy_options(*args, **kwargs): + original = np.get_printoptions() + np.set_printoptions(*args, **kwargs) + try: + yield + finally: + np.set_printoptions(**original) + + +def limit_lines(string: str, *, limit: int): + """ + If the string is more lines than the limit, + this returns the middle lines replaced by an ellipsis + """ + lines = string.splitlines() + if len(lines) > limit: + string = "\n".join(chain(lines[: limit // 2], ["..."], lines[-limit // 2 :])) + return string + + +def short_array_repr(array): + from xarray.core.common import AbstractArray + + if isinstance(array, AbstractArray): + array = array.data + array = to_duck_array(array) + + # default to lower precision so a full (abbreviated) line can fit on + # one line with the default display_width + options = { + "precision": 6, + "linewidth": OPTIONS["display_width"], + "threshold": OPTIONS["display_values_threshold"], + } + if array.ndim < 3: + edgeitems = 3 + elif array.ndim == 3: + edgeitems = 2 + else: + edgeitems = 1 + options["edgeitems"] = edgeitems + with set_numpy_options(**options): + return repr(array) + + +def short_data_repr(array): + """Format "data" for DataArray and Variable.""" + internal_data = getattr(array, "variable", array)._data + if isinstance(array, np.ndarray): + return short_array_repr(array) + elif is_duck_array(internal_data): + return limit_lines(repr(array.data), limit=40) + elif getattr(array, "_in_memory", None): + return short_array_repr(array) + else: + # internal xarray array type + return f"[{array.size} values with dtype={array.dtype}]" + + +def _get_indexes_dict(indexes): + return { + tuple(index_vars.keys()): idx for idx, index_vars in indexes.group_by_index() + } + + +@recursive_repr("") +def array_repr(arr): + from xarray.core.variable import Variable + + max_rows = OPTIONS["display_max_rows"] + + # used for DataArray, Variable and IndexVariable + if hasattr(arr, "name") and arr.name is not None: + name_str = f"{arr.name!r} " + else: + name_str = "" + + if ( + isinstance(arr, Variable) + or _get_boolean_with_default("display_expand_data", default=True) + or isinstance(arr.variable._data, MemoryCachedArray) + ): + data_repr = short_data_repr(arr) + else: + data_repr = inline_variable_array_repr(arr.variable, OPTIONS["display_width"]) + + start = f" Size: {nbytes_str}", + data_repr, + ] + if hasattr(arr, "coords"): + if arr.coords: + col_width = _calculate_col_width(arr.coords) + summary.append( + coords_repr(arr.coords, col_width=col_width, max_rows=max_rows) + ) + + unindexed_dims_str = unindexed_dims_repr( + arr.dims, arr.coords, max_rows=max_rows + ) + if unindexed_dims_str: + summary.append(unindexed_dims_str) + + display_default_indexes = _get_boolean_with_default( + "display_default_indexes", False + ) + + xindexes = filter_nondefault_indexes( + _get_indexes_dict(arr.xindexes), not display_default_indexes + ) + + if xindexes: + summary.append(indexes_repr(xindexes, max_rows=max_rows)) + + if arr.attrs: + summary.append(attrs_repr(arr.attrs, max_rows=max_rows)) + + return "\n".join(summary) + + +@recursive_repr("") +def dataset_repr(ds): + nbytes_str = render_human_readable_nbytes(ds.nbytes) + summary = [f" Size: {nbytes_str}"] + + col_width = _calculate_col_width(ds.variables) + max_rows = OPTIONS["display_max_rows"] + + dims_start = pretty_print("Dimensions:", col_width) + dims_values = dim_summary_limited(ds, col_width=col_width + 1, max_rows=max_rows) + summary.append(f"{dims_start}({dims_values})") + + if ds.coords: + summary.append(coords_repr(ds.coords, col_width=col_width, max_rows=max_rows)) + + unindexed_dims_str = unindexed_dims_repr(ds.dims, ds.coords, max_rows=max_rows) + if unindexed_dims_str: + summary.append(unindexed_dims_str) + + summary.append(data_vars_repr(ds.data_vars, col_width=col_width, max_rows=max_rows)) + + display_default_indexes = _get_boolean_with_default( + "display_default_indexes", False + ) + xindexes = filter_nondefault_indexes( + _get_indexes_dict(ds.xindexes), not display_default_indexes + ) + if xindexes: + summary.append(indexes_repr(xindexes, max_rows=max_rows)) + + if ds.attrs: + summary.append(attrs_repr(ds.attrs, max_rows=max_rows)) + + return "\n".join(summary) + + +def diff_dim_summary(a, b): + if a.sizes != b.sizes: + return f"Differing dimensions:\n ({dim_summary(a)}) != ({dim_summary(b)})" + else: + return "" + + +def _diff_mapping_repr( + a_mapping, + b_mapping, + compat, + title, + summarizer, + col_width=None, + a_indexes=None, + b_indexes=None, +): + def extra_items_repr(extra_keys, mapping, ab_side, kwargs): + extra_repr = [ + summarizer(k, mapping[k], col_width, **kwargs[k]) for k in extra_keys + ] + if extra_repr: + header = f"{title} only on the {ab_side} object:" + return [header] + extra_repr + else: + return [] + + a_keys = set(a_mapping) + b_keys = set(b_mapping) + + summary = [] + + diff_items = [] + + a_summarizer_kwargs = defaultdict(dict) + if a_indexes is not None: + a_summarizer_kwargs = {k: {"is_index": k in a_indexes} for k in a_mapping} + b_summarizer_kwargs = defaultdict(dict) + if b_indexes is not None: + b_summarizer_kwargs = {k: {"is_index": k in b_indexes} for k in b_mapping} + + for k in a_keys & b_keys: + try: + # compare xarray variable + if not callable(compat): + compatible = getattr(a_mapping[k].variable, compat)( + b_mapping[k].variable + ) + else: + compatible = compat(a_mapping[k].variable, b_mapping[k].variable) + is_variable = True + except AttributeError: + # compare attribute value + if is_duck_array(a_mapping[k]) or is_duck_array(b_mapping[k]): + compatible = array_equiv(a_mapping[k], b_mapping[k]) + else: + compatible = a_mapping[k] == b_mapping[k] + + is_variable = False + + if not compatible: + temp = [ + summarizer(k, a_mapping[k], col_width, **a_summarizer_kwargs[k]), + summarizer(k, b_mapping[k], col_width, **b_summarizer_kwargs[k]), + ] + + if compat == "identical" and is_variable: + attrs_summary = [] + a_attrs = a_mapping[k].attrs + b_attrs = b_mapping[k].attrs + + attrs_to_print = set(a_attrs) ^ set(b_attrs) + attrs_to_print.update( + {k for k in set(a_attrs) & set(b_attrs) if a_attrs[k] != b_attrs[k]} + ) + for m in (a_mapping, b_mapping): + attr_s = "\n".join( + " " + summarize_attr(ak, av) + for ak, av in m[k].attrs.items() + if ak in attrs_to_print + ) + if attr_s: + attr_s = " Differing variable attributes:\n" + attr_s + attrs_summary.append(attr_s) + + temp = [ + "\n".join([var_s, attr_s]) if attr_s else var_s + for var_s, attr_s in zip(temp, attrs_summary) + ] + + # TODO: It should be possible recursively use _diff_mapping_repr + # instead of explicitly handling variable attrs specially. + # That would require some refactoring. + # newdiff = _diff_mapping_repr( + # {k: v for k,v in a_attrs.items() if k in attrs_to_print}, + # {k: v for k,v in b_attrs.items() if k in attrs_to_print}, + # compat=compat, + # summarizer=summarize_attr, + # title="Variable Attributes" + # ) + # temp += [newdiff] + + diff_items += [ab_side + s[1:] for ab_side, s in zip(("L", "R"), temp)] + + if diff_items: + summary += [f"Differing {title.lower()}:"] + diff_items + + summary += extra_items_repr(a_keys - b_keys, a_mapping, "left", a_summarizer_kwargs) + summary += extra_items_repr( + b_keys - a_keys, b_mapping, "right", b_summarizer_kwargs + ) + + return "\n".join(summary) + + +def diff_coords_repr(a, b, compat, col_width=None): + return _diff_mapping_repr( + a, + b, + compat, + "Coordinates", + summarize_variable, + col_width=col_width, + a_indexes=a.indexes, + b_indexes=b.indexes, + ) + + +diff_data_vars_repr = functools.partial( + _diff_mapping_repr, title="Data variables", summarizer=summarize_variable +) + + +diff_attrs_repr = functools.partial( + _diff_mapping_repr, title="Attributes", summarizer=summarize_attr +) + + +def _compat_to_str(compat): + if callable(compat): + compat = compat.__name__ + + if compat == "equals": + return "equal" + elif compat == "allclose": + return "close" + else: + return compat + + +def diff_array_repr(a, b, compat): + # used for DataArray, Variable and IndexVariable + summary = [ + f"Left and right {type(a).__name__} objects are not {_compat_to_str(compat)}" + ] + + summary.append(diff_dim_summary(a, b)) + if callable(compat): + equiv = compat + else: + equiv = array_equiv + + if not equiv(a.data, b.data): + temp = [wrap_indent(short_array_repr(obj), start=" ") for obj in (a, b)] + diff_data_repr = [ + ab_side + "\n" + ab_data_repr + for ab_side, ab_data_repr in zip(("L", "R"), temp) + ] + summary += ["Differing values:"] + diff_data_repr + + if hasattr(a, "coords"): + col_width = _calculate_col_width(set(a.coords) | set(b.coords)) + summary.append( + diff_coords_repr(a.coords, b.coords, compat, col_width=col_width) + ) + + if compat == "identical": + summary.append(diff_attrs_repr(a.attrs, b.attrs, compat)) + + return "\n".join(summary) + + +def diff_treestructure(a: DataTree, b: DataTree, require_names_equal: bool) -> str: + """ + Return a summary of why two trees are not isomorphic. + If they are isomorphic return an empty string. + """ + + # Walking nodes in "level-order" fashion means walking down from the root breadth-first. + # Checking for isomorphism by walking in this way implicitly assumes that the tree is an ordered tree + # (which it is so long as children are stored in a tuple or list rather than in a set). + for node_a, node_b in zip(LevelOrderIter(a), LevelOrderIter(b)): + path_a, path_b = node_a.path, node_b.path + + if require_names_equal and node_a.name != node_b.name: + diff = dedent( + f"""\ + Node '{path_a}' in the left object has name '{node_a.name}' + Node '{path_b}' in the right object has name '{node_b.name}'""" + ) + return diff + + if len(node_a.children) != len(node_b.children): + diff = dedent( + f"""\ + Number of children on node '{path_a}' of the left object: {len(node_a.children)} + Number of children on node '{path_b}' of the right object: {len(node_b.children)}""" + ) + return diff + + return "" + + +def diff_dataset_repr(a, b, compat): + summary = [ + f"Left and right {type(a).__name__} objects are not {_compat_to_str(compat)}" + ] + + col_width = _calculate_col_width(set(list(a.variables) + list(b.variables))) + + summary.append(diff_dim_summary(a, b)) + summary.append(diff_coords_repr(a.coords, b.coords, compat, col_width=col_width)) + summary.append( + diff_data_vars_repr(a.data_vars, b.data_vars, compat, col_width=col_width) + ) + + if compat == "identical": + summary.append(diff_attrs_repr(a.attrs, b.attrs, compat)) + + return "\n".join(summary) + + +def diff_nodewise_summary(a: DataTree, b: DataTree, compat): + """Iterates over all corresponding nodes, recording differences between data at each location.""" + + compat_str = _compat_to_str(compat) + + summary = [] + for node_a, node_b in zip(a.subtree, b.subtree): + a_ds, b_ds = node_a.ds, node_b.ds + + if not a_ds._all_compat(b_ds, compat): + dataset_diff = diff_dataset_repr(a_ds, b_ds, compat_str) + data_diff = "\n".join(dataset_diff.split("\n", 1)[1:]) + + nodediff = ( + f"\nData in nodes at position '{node_a.path}' do not match:" + f"{data_diff}" + ) + summary.append(nodediff) + + return "\n".join(summary) + + +def diff_datatree_repr(a: DataTree, b: DataTree, compat): + summary = [ + f"Left and right {type(a).__name__} objects are not {_compat_to_str(compat)}" + ] + + strict_names = True if compat in ["equals", "identical"] else False + treestructure_diff = diff_treestructure(a, b, strict_names) + + # If the trees structures are different there is no point comparing each node + # TODO we could show any differences in nodes up to the first place that structure differs? + if treestructure_diff or compat == "isomorphic": + summary.append("\n" + treestructure_diff) + else: + nodewise_diff = diff_nodewise_summary(a, b, compat) + summary.append("\n" + nodewise_diff) + + return "\n".join(summary) + + +def _single_node_repr(node: DataTree) -> str: + """Information about this node, not including its relationships to other nodes.""" + node_info = f"DataTree('{node.name}')" + + if node.has_data or node.has_attrs: + ds_info = "\n" + repr(node.ds) + else: + ds_info = "" + return node_info + ds_info + + +def datatree_repr(dt: DataTree): + """A printable representation of the structure of this entire tree.""" + renderer = RenderDataTree(dt) + + lines = [] + for pre, fill, node in renderer: + node_repr = _single_node_repr(node) + + node_line = f"{pre}{node_repr.splitlines()[0]}" + lines.append(node_line) + + if node.has_data or node.has_attrs: + ds_repr = node_repr.splitlines()[2:] + for line in ds_repr: + if len(node.children) > 0: + lines.append(f"{fill}{renderer.style.vertical}{line}") + else: + lines.append(f"{fill}{' ' * len(renderer.style.vertical)}{line}") + + # Tack on info about whether or not root node has a parent at the start + first_line = lines[0] + parent = f'"{dt.parent.name}"' if dt.parent is not None else "None" + first_line_with_parent = first_line[:-1] + f", parent={parent})" + lines[0] = first_line_with_parent + + return "\n".join(lines) + + +def shorten_list_repr(items: Sequence, max_items: int) -> str: + if len(items) <= max_items: + return repr(items) + else: + first_half = repr(items[: max_items // 2])[ + 1:-1 + ] # Convert to string and remove brackets + second_half = repr(items[-max_items // 2 :])[ + 1:-1 + ] # Convert to string and remove brackets + return f"[{first_half}, ..., {second_half}]" + + +def render_human_readable_nbytes( + nbytes: int, + /, + *, + attempt_constant_width: bool = False, +) -> str: + """Renders simple human-readable byte count representation + + This is only a quick representation that should not be relied upon for precise needs. + + To get the exact byte count, please use the ``nbytes`` attribute directly. + + Parameters + ---------- + nbytes + Byte count + attempt_constant_width + For reasonable nbytes sizes, tries to render a fixed-width representation. + + Returns + ------- + Human-readable representation of the byte count + """ + dividend = float(nbytes) + divisor = 1000.0 + last_unit_available = UNITS[-1] + + for unit in UNITS: + if dividend < divisor or unit == last_unit_available: + break + dividend /= divisor + + dividend_str = f"{dividend:.0f}" + unit_str = f"{unit}" + + if attempt_constant_width: + dividend_str = dividend_str.rjust(3) + unit_str = unit_str.ljust(2) + + string = f"{dividend_str}{unit_str}" + return string diff --git a/test/fixtures/whole_applications/xarray/xarray/core/formatting_html.py b/test/fixtures/whole_applications/xarray/xarray/core/formatting_html.py new file mode 100644 index 0000000..9bf5bef --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/core/formatting_html.py @@ -0,0 +1,474 @@ +from __future__ import annotations + +import uuid +from collections import OrderedDict +from collections.abc import Mapping +from functools import lru_cache, partial +from html import escape +from importlib.resources import files +from typing import TYPE_CHECKING + +from xarray.core.formatting import ( + inline_index_repr, + inline_variable_array_repr, + short_data_repr, +) +from xarray.core.options import _get_boolean_with_default + +STATIC_FILES = ( + ("xarray.static.html", "icons-svg-inline.html"), + ("xarray.static.css", "style.css"), +) + +if TYPE_CHECKING: + from xarray.core.datatree import DataTree + + +@lru_cache(None) +def _load_static_files(): + """Lazily load the resource files into memory the first time they are needed""" + return [ + files(package).joinpath(resource).read_text(encoding="utf-8") + for package, resource in STATIC_FILES + ] + + +def short_data_repr_html(array) -> str: + """Format "data" for DataArray and Variable.""" + internal_data = getattr(array, "variable", array)._data + if hasattr(internal_data, "_repr_html_"): + return internal_data._repr_html_() + text = escape(short_data_repr(array)) + return f"
{text}
" + + +def format_dims(dim_sizes, dims_with_index) -> str: + if not dim_sizes: + return "" + + dim_css_map = { + dim: " class='xr-has-index'" if dim in dims_with_index else "" + for dim in dim_sizes + } + + dims_li = "".join( + f"
  • {escape(str(dim))}: {size}
  • " + for dim, size in dim_sizes.items() + ) + + return f"
      {dims_li}
    " + + +def summarize_attrs(attrs) -> str: + attrs_dl = "".join( + f"
    {escape(str(k))} :
    {escape(str(v))}
    " + for k, v in attrs.items() + ) + + return f"
    {attrs_dl}
    " + + +def _icon(icon_name) -> str: + # icon_name should be defined in xarray/static/html/icon-svg-inline.html + return ( + f"" + f"" + "" + "" + ) + + +def summarize_variable(name, var, is_index=False, dtype=None) -> str: + variable = var.variable if hasattr(var, "variable") else var + + cssclass_idx = " class='xr-has-index'" if is_index else "" + dims_str = f"({', '.join(escape(dim) for dim in var.dims)})" + name = escape(str(name)) + dtype = dtype or escape(str(var.dtype)) + + # "unique" ids required to expand/collapse subsections + attrs_id = "attrs-" + str(uuid.uuid4()) + data_id = "data-" + str(uuid.uuid4()) + disabled = "" if len(var.attrs) else "disabled" + + preview = escape(inline_variable_array_repr(variable, 35)) + attrs_ul = summarize_attrs(var.attrs) + data_repr = short_data_repr_html(variable) + + attrs_icon = _icon("icon-file-text2") + data_icon = _icon("icon-database") + + return ( + f"
    {name}
    " + f"
    {dims_str}
    " + f"
    {dtype}
    " + f"
    {preview}
    " + f"" + f"" + f"" + f"" + f"
    {attrs_ul}
    " + f"
    {data_repr}
    " + ) + + +def summarize_coords(variables) -> str: + li_items = [] + for k, v in variables.items(): + li_content = summarize_variable(k, v, is_index=k in variables.xindexes) + li_items.append(f"
  • {li_content}
  • ") + + vars_li = "".join(li_items) + + return f"
      {vars_li}
    " + + +def summarize_vars(variables) -> str: + vars_li = "".join( + f"
  • {summarize_variable(k, v)}
  • " + for k, v in variables.items() + ) + + return f"
      {vars_li}
    " + + +def short_index_repr_html(index) -> str: + if hasattr(index, "_repr_html_"): + return index._repr_html_() + + return f"
    {escape(repr(index))}
    " + + +def summarize_index(coord_names, index) -> str: + name = "
    ".join([escape(str(n)) for n in coord_names]) + + index_id = f"index-{uuid.uuid4()}" + preview = escape(inline_index_repr(index)) + details = short_index_repr_html(index) + + data_icon = _icon("icon-database") + + return ( + f"
    {name}
    " + f"
    {preview}
    " + f"
    " + f"" + f"" + f"
    {details}
    " + ) + + +def summarize_indexes(indexes) -> str: + indexes_li = "".join( + f"
  • {summarize_index(v, i)}
  • " + for v, i in indexes.items() + ) + return f"
      {indexes_li}
    " + + +def collapsible_section( + name, inline_details="", details="", n_items=None, enabled=True, collapsed=False +) -> str: + # "unique" id to expand/collapse the section + data_id = "section-" + str(uuid.uuid4()) + + has_items = n_items is not None and n_items + n_items_span = "" if n_items is None else f" ({n_items})" + enabled = "" if enabled and has_items else "disabled" + collapsed = "" if collapsed or not has_items else "checked" + tip = " title='Expand/collapse section'" if enabled else "" + + return ( + f"" + f"" + f"
    {inline_details}
    " + f"
    {details}
    " + ) + + +def _mapping_section( + mapping, name, details_func, max_items_collapse, expand_option_name, enabled=True +) -> str: + n_items = len(mapping) + expanded = _get_boolean_with_default( + expand_option_name, n_items < max_items_collapse + ) + collapsed = not expanded + + return collapsible_section( + name, + details=details_func(mapping), + n_items=n_items, + enabled=enabled, + collapsed=collapsed, + ) + + +def dim_section(obj) -> str: + dim_list = format_dims(obj.sizes, obj.xindexes.dims) + + return collapsible_section( + "Dimensions", inline_details=dim_list, enabled=False, collapsed=True + ) + + +def array_section(obj) -> str: + # "unique" id to expand/collapse the section + data_id = "section-" + str(uuid.uuid4()) + collapsed = ( + "checked" + if _get_boolean_with_default("display_expand_data", default=True) + else "" + ) + variable = getattr(obj, "variable", obj) + preview = escape(inline_variable_array_repr(variable, max_width=70)) + data_repr = short_data_repr_html(obj) + data_icon = _icon("icon-database") + + return ( + "
    " + f"" + f"" + f"
    {preview}
    " + f"
    {data_repr}
    " + "
    " + ) + + +coord_section = partial( + _mapping_section, + name="Coordinates", + details_func=summarize_coords, + max_items_collapse=25, + expand_option_name="display_expand_coords", +) + + +datavar_section = partial( + _mapping_section, + name="Data variables", + details_func=summarize_vars, + max_items_collapse=15, + expand_option_name="display_expand_data_vars", +) + +index_section = partial( + _mapping_section, + name="Indexes", + details_func=summarize_indexes, + max_items_collapse=0, + expand_option_name="display_expand_indexes", +) + +attr_section = partial( + _mapping_section, + name="Attributes", + details_func=summarize_attrs, + max_items_collapse=10, + expand_option_name="display_expand_attrs", +) + + +def _get_indexes_dict(indexes): + return { + tuple(index_vars.keys()): idx for idx, index_vars in indexes.group_by_index() + } + + +def _obj_repr(obj, header_components, sections): + """Return HTML repr of an xarray object. + + If CSS is not injected (untrusted notebook), fallback to the plain text repr. + + """ + header = f"
    {''.join(h for h in header_components)}
    " + sections = "".join(f"
  • {s}
  • " for s in sections) + + icons_svg, css_style = _load_static_files() + return ( + "
    " + f"{icons_svg}" + f"
    {escape(repr(obj))}
    " + "" + "
    " + ) + + +def array_repr(arr) -> str: + dims = OrderedDict((k, v) for k, v in zip(arr.dims, arr.shape)) + if hasattr(arr, "xindexes"): + indexed_dims = arr.xindexes.dims + else: + indexed_dims = {} + + obj_type = f"xarray.{type(arr).__name__}" + arr_name = f"'{arr.name}'" if getattr(arr, "name", None) else "" + + header_components = [ + f"
    {obj_type}
    ", + f"
    {arr_name}
    ", + format_dims(dims, indexed_dims), + ] + + sections = [array_section(arr)] + + if hasattr(arr, "coords"): + sections.append(coord_section(arr.coords)) + + if hasattr(arr, "xindexes"): + indexes = _get_indexes_dict(arr.xindexes) + sections.append(index_section(indexes)) + + sections.append(attr_section(arr.attrs)) + + return _obj_repr(arr, header_components, sections) + + +def dataset_repr(ds) -> str: + obj_type = f"xarray.{type(ds).__name__}" + + header_components = [f"
    {escape(obj_type)}
    "] + + sections = [ + dim_section(ds), + coord_section(ds.coords), + datavar_section(ds.data_vars), + index_section(_get_indexes_dict(ds.xindexes)), + attr_section(ds.attrs), + ] + + return _obj_repr(ds, header_components, sections) + + +def summarize_datatree_children(children: Mapping[str, DataTree]) -> str: + N_CHILDREN = len(children) - 1 + + # Get result from datatree_node_repr and wrap it + lines_callback = lambda n, c, end: _wrap_datatree_repr( + datatree_node_repr(n, c), end=end + ) + + children_html = "".join( + ( + lines_callback(n, c, end=False) # Long lines + if i < N_CHILDREN + else lines_callback(n, c, end=True) + ) # Short lines + for i, (n, c) in enumerate(children.items()) + ) + + return "".join( + [ + "
    ", + children_html, + "
    ", + ] + ) + + +children_section = partial( + _mapping_section, + name="Groups", + details_func=summarize_datatree_children, + max_items_collapse=1, + expand_option_name="display_expand_groups", +) + + +def datatree_node_repr(group_title: str, dt: DataTree) -> str: + header_components = [f"
    {escape(group_title)}
    "] + + ds = dt.ds + + sections = [ + children_section(dt.children), + dim_section(ds), + coord_section(ds.coords), + datavar_section(ds.data_vars), + attr_section(ds.attrs), + ] + + return _obj_repr(ds, header_components, sections) + + +def _wrap_datatree_repr(r: str, end: bool = False) -> str: + """ + Wrap HTML representation with a tee to the left of it. + + Enclosing HTML tag is a
    with :code:`display: inline-grid` style. + + Turns: + [ title ] + | details | + |_____________| + + into (A): + |─ [ title ] + | | details | + | |_____________| + + or (B): + └─ [ title ] + | details | + |_____________| + + Parameters + ---------- + r: str + HTML representation to wrap. + end: bool + Specify if the line on the left should continue or end. + + Default is True. + + Returns + ------- + str + Wrapped HTML representation. + + Tee color is set to the variable :code:`--xr-border-color`. + """ + # height of line + end = bool(end) + height = "100%" if end is False else "1.2em" + return "".join( + [ + "
    ", + "
    ", + "
    ", + "
    ", + "
    ", + "
    ", + r, + "
    ", + "
    ", + ] + ) + + +def datatree_repr(dt: DataTree) -> str: + obj_type = f"datatree.{type(dt).__name__}" + return datatree_node_repr(obj_type, dt) diff --git a/test/fixtures/whole_applications/xarray/xarray/core/groupby.py b/test/fixtures/whole_applications/xarray/xarray/core/groupby.py new file mode 100644 index 0000000..5966c32 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/core/groupby.py @@ -0,0 +1,2011 @@ +from __future__ import annotations + +import copy +import datetime +import warnings +from abc import ABC, abstractmethod +from collections.abc import Hashable, Iterator, Mapping, Sequence +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union + +import numpy as np +import pandas as pd +from packaging.version import Version + +from xarray.coding.cftime_offsets import _new_to_legacy_freq +from xarray.core import dtypes, duck_array_ops, nputils, ops +from xarray.core._aggregations import ( + DataArrayGroupByAggregations, + DatasetGroupByAggregations, +) +from xarray.core.alignment import align +from xarray.core.arithmetic import DataArrayGroupbyArithmetic, DatasetGroupbyArithmetic +from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce +from xarray.core.concat import concat +from xarray.core.formatting import format_array_flat +from xarray.core.indexes import ( + create_default_index_implicit, + filter_indexes_from_coords, + safe_cast_to_index, +) +from xarray.core.options import OPTIONS, _get_keep_attrs +from xarray.core.types import ( + Dims, + QuantileMethods, + T_DataArray, + T_DataWithCoords, + T_Xarray, +) +from xarray.core.utils import ( + FrozenMappingWarningOnValuesAccess, + contains_only_chunked_or_numpy, + either_dict_or_kwargs, + emit_user_level_warning, + hashable, + is_scalar, + maybe_wrap_array, + module_available, + peek_at, +) +from xarray.core.variable import IndexVariable, Variable +from xarray.util.deprecation_helpers import _deprecate_positional_args + +if TYPE_CHECKING: + from numpy.typing import ArrayLike + + from xarray.core.dataarray import DataArray + from xarray.core.dataset import Dataset + from xarray.core.resample_cftime import CFTimeGrouper + from xarray.core.types import DatetimeLike, SideOptions + from xarray.core.utils import Frozen + + GroupKey = Any + GroupIndex = Union[int, slice, list[int]] + T_GroupIndices = list[GroupIndex] + + +def check_reduce_dims(reduce_dims, dimensions): + if reduce_dims is not ...: + if is_scalar(reduce_dims): + reduce_dims = [reduce_dims] + if any(dim not in dimensions for dim in reduce_dims): + raise ValueError( + f"cannot reduce over dimensions {reduce_dims!r}. expected either '...' " + f"to reduce over all dimensions or one or more of {dimensions!r}." + f" Try passing .groupby(..., squeeze=False)" + ) + + +def _maybe_squeeze_indices( + indices, squeeze: bool | None, grouper: ResolvedGrouper, warn: bool +): + is_unique_grouper = isinstance(grouper.grouper, UniqueGrouper) + can_squeeze = is_unique_grouper and grouper.grouper.can_squeeze + if squeeze in [None, True] and can_squeeze: + if isinstance(indices, slice): + if indices.stop - indices.start == 1: + if (squeeze is None and warn) or squeeze is True: + emit_user_level_warning( + "The `squeeze` kwarg to GroupBy is being removed." + "Pass .groupby(..., squeeze=False) to disable squeezing," + " which is the new default, and to silence this warning." + ) + + indices = indices.start + return indices + + +def unique_value_groups( + ar, sort: bool = True +) -> tuple[np.ndarray | pd.Index, np.ndarray]: + """Group an array by its unique values. + + Parameters + ---------- + ar : array-like + Input array. This will be flattened if it is not already 1-D. + sort : bool, default: True + Whether or not to sort unique values. + + Returns + ------- + values : np.ndarray + Sorted, unique values as returned by `np.unique`. + indices : list of lists of int + Each element provides the integer indices in `ar` with values given by + the corresponding value in `unique_values`. + """ + inverse, values = pd.factorize(ar, sort=sort) + if isinstance(values, pd.MultiIndex): + values.names = ar.names + return values, inverse + + +def _codes_to_group_indices(inverse: np.ndarray, N: int) -> T_GroupIndices: + assert inverse.ndim == 1 + groups: T_GroupIndices = [[] for _ in range(N)] + for n, g in enumerate(inverse): + if g >= 0: + groups[g].append(n) + return groups + + +def _dummy_copy(xarray_obj): + from xarray.core.dataarray import DataArray + from xarray.core.dataset import Dataset + + if isinstance(xarray_obj, Dataset): + res = Dataset( + { + k: dtypes.get_fill_value(v.dtype) + for k, v in xarray_obj.data_vars.items() + }, + { + k: dtypes.get_fill_value(v.dtype) + for k, v in xarray_obj.coords.items() + if k not in xarray_obj.dims + }, + xarray_obj.attrs, + ) + elif isinstance(xarray_obj, DataArray): + res = DataArray( + dtypes.get_fill_value(xarray_obj.dtype), + { + k: dtypes.get_fill_value(v.dtype) + for k, v in xarray_obj.coords.items() + if k not in xarray_obj.dims + }, + dims=[], + name=xarray_obj.name, + attrs=xarray_obj.attrs, + ) + else: # pragma: no cover + raise AssertionError + return res + + +def _is_one_or_none(obj) -> bool: + return obj == 1 or obj is None + + +def _consolidate_slices(slices: list[slice]) -> list[slice]: + """Consolidate adjacent slices in a list of slices.""" + result: list[slice] = [] + last_slice = slice(None) + for slice_ in slices: + if not isinstance(slice_, slice): + raise ValueError(f"list element is not a slice: {slice_!r}") + if ( + result + and last_slice.stop == slice_.start + and _is_one_or_none(last_slice.step) + and _is_one_or_none(slice_.step) + ): + last_slice = slice(last_slice.start, slice_.stop, slice_.step) + result[-1] = last_slice + else: + result.append(slice_) + last_slice = slice_ + return result + + +def _inverse_permutation_indices(positions, N: int | None = None) -> np.ndarray | None: + """Like inverse_permutation, but also handles slices. + + Parameters + ---------- + positions : list of ndarray or slice + If slice objects, all are assumed to be slices. + + Returns + ------- + np.ndarray of indices or None, if no permutation is necessary. + """ + if not positions: + return None + + if isinstance(positions[0], slice): + positions = _consolidate_slices(positions) + if positions == slice(None): + return None + positions = [np.arange(sl.start, sl.stop, sl.step) for sl in positions] + + newpositions = nputils.inverse_permutation(np.concatenate(positions), N) + return newpositions[newpositions != -1] + + +class _DummyGroup(Generic[T_Xarray]): + """Class for keeping track of grouped dimensions without coordinates. + + Should not be user visible. + """ + + __slots__ = ("name", "coords", "size", "dataarray") + + def __init__(self, obj: T_Xarray, name: Hashable, coords) -> None: + self.name = name + self.coords = coords + self.size = obj.sizes[name] + + @property + def dims(self) -> tuple[Hashable]: + return (self.name,) + + @property + def ndim(self) -> Literal[1]: + return 1 + + @property + def values(self) -> range: + return range(self.size) + + @property + def data(self) -> range: + return range(self.size) + + def __array__(self) -> np.ndarray: + return np.arange(self.size) + + @property + def shape(self) -> tuple[int]: + return (self.size,) + + @property + def attrs(self) -> dict: + return {} + + def __getitem__(self, key): + if isinstance(key, tuple): + key = key[0] + return self.values[key] + + def to_index(self) -> pd.Index: + # could be pd.RangeIndex? + return pd.Index(np.arange(self.size)) + + def copy(self, deep: bool = True, data: Any = None): + raise NotImplementedError + + def to_dataarray(self) -> DataArray: + from xarray.core.dataarray import DataArray + + return DataArray( + data=self.data, dims=(self.name,), coords=self.coords, name=self.name + ) + + def to_array(self) -> DataArray: + """Deprecated version of to_dataarray.""" + return self.to_dataarray() + + +T_Group = Union["T_DataArray", "IndexVariable", _DummyGroup] + + +def _ensure_1d(group: T_Group, obj: T_DataWithCoords) -> tuple[ + T_Group, + T_DataWithCoords, + Hashable | None, + list[Hashable], +]: + # 1D cases: do nothing + if isinstance(group, (IndexVariable, _DummyGroup)) or group.ndim == 1: + return group, obj, None, [] + + from xarray.core.dataarray import DataArray + + if isinstance(group, DataArray): + # try to stack the dims of the group into a single dim + orig_dims = group.dims + stacked_dim = "stacked_" + "_".join(map(str, orig_dims)) + # these dimensions get created by the stack operation + inserted_dims = [dim for dim in group.dims if dim not in group.coords] + newgroup = group.stack({stacked_dim: orig_dims}) + newobj = obj.stack({stacked_dim: orig_dims}) + return newgroup, newobj, stacked_dim, inserted_dims + + raise TypeError( + f"group must be DataArray, IndexVariable or _DummyGroup, got {type(group)!r}." + ) + + +def _apply_loffset( + loffset: str | pd.DateOffset | datetime.timedelta | pd.Timedelta, + result: pd.Series | pd.DataFrame, +): + """ + (copied from pandas) + if loffset is set, offset the result index + + This is NOT an idempotent routine, it will be applied + exactly once to the result. + + Parameters + ---------- + result : Series or DataFrame + the result of resample + """ + # pd.Timedelta is a subclass of datetime.timedelta so we do not need to + # include it in instance checks. + if not isinstance(loffset, (str, pd.DateOffset, datetime.timedelta)): + raise ValueError( + f"`loffset` must be a str, pd.DateOffset, datetime.timedelta, or pandas.Timedelta object. " + f"Got {loffset}." + ) + + if isinstance(loffset, str): + loffset = pd.tseries.frequencies.to_offset(loffset) + + needs_offset = ( + isinstance(loffset, (pd.DateOffset, datetime.timedelta)) + and isinstance(result.index, pd.DatetimeIndex) + and len(result.index) > 0 + ) + + if needs_offset: + result.index = result.index + loffset + + +class Grouper(ABC): + """Base class for Grouper objects that allow specializing GroupBy instructions.""" + + @property + def can_squeeze(self) -> bool: + """TODO: delete this when the `squeeze` kwarg is deprecated. Only `UniqueGrouper` + should override it.""" + return False + + @abstractmethod + def factorize(self, group) -> EncodedGroups: + """ + Takes the group, and creates intermediates necessary for GroupBy. + These intermediates are + 1. codes - Same shape as `group` containing a unique integer code for each group. + 2. group_indices - Indexes that let us index out the members of each group. + 3. unique_coord - Unique groups present in the dataset. + 4. full_index - Unique groups in the output. This differs from `unique_coord` in the + case of resampling and binning, where certain groups in the output are not present in + the input. + """ + pass + + +class Resampler(Grouper): + """Base class for Grouper objects that allow specializing resampling-type GroupBy instructions. + Currently only used for TimeResampler, but could be used for SpaceResampler in the future. + """ + + pass + + +@dataclass +class EncodedGroups: + """ + Dataclass for storing intermediate values for GroupBy operation. + Returned by factorize method on Grouper objects. + + Parameters + ---------- + codes: integer codes for each group + full_index: pandas Index for the group coordinate + group_indices: optional, List of indices of array elements belonging + to each group. Inferred if not provided. + unique_coord: Unique group values present in dataset. Inferred if not provided + """ + + codes: DataArray + full_index: pd.Index + group_indices: T_GroupIndices | None = field(default=None) + unique_coord: IndexVariable | _DummyGroup | None = field(default=None) + + +@dataclass +class ResolvedGrouper(Generic[T_DataWithCoords]): + """ + Wrapper around a Grouper object. + + The Grouper object represents an abstract instruction to group an object. + The ResolvedGrouper object is a concrete version that contains all the common + logic necessary for a GroupBy problem including the intermediates necessary for + executing a GroupBy calculation. Specialization to the grouping problem at hand, + is accomplished by calling the `factorize` method on the encapsulated Grouper + object. + + This class is private API, while Groupers are public. + """ + + grouper: Grouper + group: T_Group + obj: T_DataWithCoords + + # returned by factorize: + codes: DataArray = field(init=False) + full_index: pd.Index = field(init=False) + group_indices: T_GroupIndices = field(init=False) + unique_coord: IndexVariable | _DummyGroup = field(init=False) + + # _ensure_1d: + group1d: T_Group = field(init=False) + stacked_obj: T_DataWithCoords = field(init=False) + stacked_dim: Hashable | None = field(init=False) + inserted_dims: list[Hashable] = field(init=False) + + def __post_init__(self) -> None: + # This copy allows the BinGrouper.factorize() method + # to update BinGrouper.bins when provided as int, using the output + # of pd.cut + # We do not want to modify the original object, since the same grouper + # might be used multiple times. + self.grouper = copy.deepcopy(self.grouper) + + self.group: T_Group = _resolve_group(self.obj, self.group) + + ( + self.group1d, + self.stacked_obj, + self.stacked_dim, + self.inserted_dims, + ) = _ensure_1d(group=self.group, obj=self.obj) + + self.factorize() + + @property + def name(self) -> Hashable: + # the name has to come from unique_coord because we need `_bins` suffix for BinGrouper + return self.unique_coord.name + + @property + def size(self) -> int: + return len(self) + + def __len__(self) -> int: + return len(self.full_index) + + @property + def dims(self): + return self.group1d.dims + + def factorize(self) -> None: + encoded = self.grouper.factorize(self.group1d) + + self.codes = encoded.codes + self.full_index = encoded.full_index + + if encoded.group_indices is not None: + self.group_indices = encoded.group_indices + else: + self.group_indices = [ + g + for g in _codes_to_group_indices(self.codes.data, len(self.full_index)) + if g + ] + if encoded.unique_coord is None: + unique_values = self.full_index[np.unique(encoded.codes)] + self.unique_coord = IndexVariable( + self.group.name, unique_values, attrs=self.group.attrs + ) + else: + self.unique_coord = encoded.unique_coord + + +@dataclass +class UniqueGrouper(Grouper): + """Grouper object for grouping by a categorical variable.""" + + _group_as_index: pd.Index | None = None + + @property + def is_unique_and_monotonic(self) -> bool: + if isinstance(self.group, _DummyGroup): + return True + index = self.group_as_index + return index.is_unique and index.is_monotonic_increasing + + @property + def group_as_index(self) -> pd.Index: + if self._group_as_index is None: + self._group_as_index = self.group.to_index() + return self._group_as_index + + @property + def can_squeeze(self) -> bool: + is_dimension = self.group.dims == (self.group.name,) + return is_dimension and self.is_unique_and_monotonic + + def factorize(self, group1d) -> EncodedGroups: + self.group = group1d + + if self.can_squeeze: + return self._factorize_dummy() + else: + return self._factorize_unique() + + def _factorize_unique(self) -> EncodedGroups: + # look through group to find the unique values + sort = not isinstance(self.group_as_index, pd.MultiIndex) + unique_values, codes_ = unique_value_groups(self.group_as_index, sort=sort) + if (codes_ == -1).all(): + raise ValueError( + "Failed to group data. Are you grouping by a variable that is all NaN?" + ) + codes = self.group.copy(data=codes_) + unique_coord = IndexVariable( + self.group.name, unique_values, attrs=self.group.attrs + ) + full_index = unique_coord + + return EncodedGroups( + codes=codes, full_index=full_index, unique_coord=unique_coord + ) + + def _factorize_dummy(self) -> EncodedGroups: + size = self.group.size + # no need to factorize + # use slices to do views instead of fancy indexing + # equivalent to: group_indices = group_indices.reshape(-1, 1) + group_indices: T_GroupIndices = [slice(i, i + 1) for i in range(size)] + size_range = np.arange(size) + if isinstance(self.group, _DummyGroup): + codes = self.group.to_dataarray().copy(data=size_range) + else: + codes = self.group.copy(data=size_range) + unique_coord = self.group + full_index = IndexVariable( + self.group.name, unique_coord.values, self.group.attrs + ) + return EncodedGroups( + codes=codes, + group_indices=group_indices, + full_index=full_index, + unique_coord=unique_coord, + ) + + +@dataclass +class BinGrouper(Grouper): + """Grouper object for binning numeric data.""" + + bins: Any # TODO: What is the typing? + cut_kwargs: Mapping = field(default_factory=dict) + binned: Any = None + name: Any = None + + def __post_init__(self) -> None: + if duck_array_ops.isnull(self.bins).all(): + raise ValueError("All bin edges are NaN.") + + def factorize(self, group) -> EncodedGroups: + from xarray.core.dataarray import DataArray + + data = group.data + + binned, self.bins = pd.cut(data, self.bins, **self.cut_kwargs, retbins=True) + + binned_codes = binned.codes + if (binned_codes == -1).all(): + raise ValueError( + f"None of the data falls within bins with edges {self.bins!r}" + ) + + new_dim_name = f"{group.name}_bins" + + full_index = binned.categories + uniques = np.sort(pd.unique(binned_codes)) + unique_values = full_index[uniques[uniques != -1]] + + codes = DataArray( + binned_codes, getattr(group, "coords", None), name=new_dim_name + ) + unique_coord = IndexVariable(new_dim_name, pd.Index(unique_values), group.attrs) + return EncodedGroups( + codes=codes, full_index=full_index, unique_coord=unique_coord + ) + + +@dataclass +class TimeResampler(Resampler): + """Grouper object specialized to resampling the time coordinate.""" + + freq: str + closed: SideOptions | None = field(default=None) + label: SideOptions | None = field(default=None) + origin: str | DatetimeLike = field(default="start_day") + offset: pd.Timedelta | datetime.timedelta | str | None = field(default=None) + loffset: datetime.timedelta | str | None = field(default=None) + base: int | None = field(default=None) + + index_grouper: CFTimeGrouper | pd.Grouper = field(init=False) + group_as_index: pd.Index = field(init=False) + + def __post_init__(self): + if self.loffset is not None: + emit_user_level_warning( + "Following pandas, the `loffset` parameter to resample is deprecated. " + "Switch to updating the resampled dataset time coordinate using " + "time offset arithmetic. For example:\n" + " >>> offset = pd.tseries.frequencies.to_offset(freq) / 2\n" + ' >>> resampled_ds["time"] = resampled_ds.get_index("time") + offset', + FutureWarning, + ) + + if self.base is not None: + emit_user_level_warning( + "Following pandas, the `base` parameter to resample will be deprecated in " + "a future version of xarray. Switch to using `origin` or `offset` instead.", + FutureWarning, + ) + + if self.base is not None and self.offset is not None: + raise ValueError("base and offset cannot be present at the same time") + + def _init_properties(self, group: T_Group) -> None: + from xarray import CFTimeIndex + from xarray.core.pdcompat import _convert_base_to_offset + + group_as_index = safe_cast_to_index(group) + + if self.base is not None: + # grouper constructor verifies that grouper.offset is None at this point + offset = _convert_base_to_offset(self.base, self.freq, group_as_index) + else: + offset = self.offset + + if not group_as_index.is_monotonic_increasing: + # TODO: sort instead of raising an error + raise ValueError("index must be monotonic for resampling") + + if isinstance(group_as_index, CFTimeIndex): + from xarray.core.resample_cftime import CFTimeGrouper + + index_grouper = CFTimeGrouper( + freq=self.freq, + closed=self.closed, + label=self.label, + origin=self.origin, + offset=offset, + loffset=self.loffset, + ) + else: + index_grouper = pd.Grouper( + # TODO remove once requiring pandas >= 2.2 + freq=_new_to_legacy_freq(self.freq), + closed=self.closed, + label=self.label, + origin=self.origin, + offset=offset, + ) + self.index_grouper = index_grouper + self.group_as_index = group_as_index + + def _get_index_and_items(self) -> tuple[pd.Index, pd.Series, np.ndarray]: + first_items, codes = self.first_items() + full_index = first_items.index + if first_items.isnull().any(): + first_items = first_items.dropna() + + full_index = full_index.rename("__resample_dim__") + return full_index, first_items, codes + + def first_items(self) -> tuple[pd.Series, np.ndarray]: + from xarray import CFTimeIndex + + if isinstance(self.group_as_index, CFTimeIndex): + return self.index_grouper.first_items(self.group_as_index) + else: + s = pd.Series(np.arange(self.group_as_index.size), self.group_as_index) + grouped = s.groupby(self.index_grouper) + first_items = grouped.first() + counts = grouped.count() + # This way we generate codes for the final output index: full_index. + # So for _flox_reduce we avoid one reindex and copy by avoiding + # _maybe_restore_empty_groups + codes = np.repeat(np.arange(len(first_items)), counts) + if self.loffset is not None: + _apply_loffset(self.loffset, first_items) + return first_items, codes + + def factorize(self, group) -> EncodedGroups: + self._init_properties(group) + full_index, first_items, codes_ = self._get_index_and_items() + sbins = first_items.values.astype(np.int64) + group_indices: T_GroupIndices = [ + slice(i, j) for i, j in zip(sbins[:-1], sbins[1:]) + ] + group_indices += [slice(sbins[-1], None)] + + unique_coord = IndexVariable(group.name, first_items.index, group.attrs) + codes = group.copy(data=codes_) + + return EncodedGroups( + codes=codes, + group_indices=group_indices, + full_index=full_index, + unique_coord=unique_coord, + ) + + +def _validate_groupby_squeeze(squeeze: bool | None) -> None: + # While we don't generally check the type of every arg, passing + # multiple dimensions as multiple arguments is common enough, and the + # consequences hidden enough (strings evaluate as true) to warrant + # checking here. + # A future version could make squeeze kwarg only, but would face + # backward-compat issues. + if squeeze is not None and not isinstance(squeeze, bool): + raise TypeError( + f"`squeeze` must be None, True or False, but {squeeze} was supplied" + ) + + +def _resolve_group(obj: T_DataWithCoords, group: T_Group | Hashable) -> T_Group: + from xarray.core.dataarray import DataArray + + error_msg = ( + "the group variable's length does not " + "match the length of this variable along its " + "dimensions" + ) + + newgroup: T_Group + if isinstance(group, DataArray): + try: + align(obj, group, join="exact", copy=False) + except ValueError: + raise ValueError(error_msg) + + newgroup = group.copy(deep=False) + newgroup.name = group.name or "group" + + elif isinstance(group, IndexVariable): + # This assumption is built in to _ensure_1d. + if group.ndim != 1: + raise ValueError( + "Grouping by multi-dimensional IndexVariables is not allowed." + "Convert to and pass a DataArray instead." + ) + (group_dim,) = group.dims + if len(group) != obj.sizes[group_dim]: + raise ValueError(error_msg) + newgroup = DataArray(group) + + else: + if not hashable(group): + raise TypeError( + "`group` must be an xarray.DataArray or the " + "name of an xarray variable or dimension. " + f"Received {group!r} instead." + ) + group_da: DataArray = obj[group] + if group_da.name not in obj._indexes and group_da.name in obj.dims: + # DummyGroups should not appear on groupby results + newgroup = _DummyGroup(obj, group_da.name, group_da.coords) + else: + newgroup = group_da + + if newgroup.size == 0: + raise ValueError(f"{newgroup.name} must not be empty") + + return newgroup + + +class GroupBy(Generic[T_Xarray]): + """A object that implements the split-apply-combine pattern. + + Modeled after `pandas.GroupBy`. The `GroupBy` object can be iterated over + (unique_value, grouped_array) pairs, but the main way to interact with a + groupby object are with the `apply` or `reduce` methods. You can also + directly call numpy methods like `mean` or `std`. + + You should create a GroupBy object by using the `DataArray.groupby` or + `Dataset.groupby` methods. + + See Also + -------- + Dataset.groupby + DataArray.groupby + """ + + __slots__ = ( + "_full_index", + "_inserted_dims", + "_group", + "_group_dim", + "_group_indices", + "_groups", + "groupers", + "_obj", + "_restore_coord_dims", + "_stacked_dim", + "_unique_coord", + "_dims", + "_sizes", + "_squeeze", + # Save unstacked object for flox + "_original_obj", + "_original_group", + "_bins", + "_codes", + ) + _obj: T_Xarray + groupers: tuple[ResolvedGrouper] + _squeeze: bool | None + _restore_coord_dims: bool + + _original_obj: T_Xarray + _original_group: T_Group + _group_indices: T_GroupIndices + _codes: DataArray + _group_dim: Hashable + + _groups: dict[GroupKey, GroupIndex] | None + _dims: tuple[Hashable, ...] | Frozen[Hashable, int] | None + _sizes: Mapping[Hashable, int] | None + + def __init__( + self, + obj: T_Xarray, + groupers: tuple[ResolvedGrouper], + squeeze: bool | None = False, + restore_coord_dims: bool = True, + ) -> None: + """Create a GroupBy object + + Parameters + ---------- + obj : Dataset or DataArray + Object to group. + grouper : Grouper + Grouper object + restore_coord_dims : bool, default: True + If True, also restore the dimension order of multi-dimensional + coordinates. + """ + self.groupers = groupers + + self._original_obj = obj + + (grouper,) = self.groupers + self._original_group = grouper.group + + # specification for the groupby operation + self._obj = grouper.stacked_obj + self._restore_coord_dims = restore_coord_dims + self._squeeze = squeeze + + # These should generalize to multiple groupers + self._group_indices = grouper.group_indices + self._codes = self._maybe_unstack(grouper.codes) + + (self._group_dim,) = grouper.group1d.dims + # cached attributes + self._groups = None + self._dims = None + self._sizes = None + + @property + def sizes(self) -> Mapping[Hashable, int]: + """Ordered mapping from dimension names to lengths. + + Immutable. + + See Also + -------- + DataArray.sizes + Dataset.sizes + """ + if self._sizes is None: + (grouper,) = self.groupers + index = _maybe_squeeze_indices( + self._group_indices[0], + self._squeeze, + grouper, + warn=True, + ) + self._sizes = self._obj.isel({self._group_dim: index}).sizes + + return self._sizes + + def map( + self, + func: Callable, + args: tuple[Any, ...] = (), + shortcut: bool | None = None, + **kwargs: Any, + ) -> T_Xarray: + raise NotImplementedError() + + def reduce( + self, + func: Callable[..., Any], + dim: Dims = None, + *, + axis: int | Sequence[int] | None = None, + keep_attrs: bool | None = None, + keepdims: bool = False, + shortcut: bool = True, + **kwargs: Any, + ) -> T_Xarray: + raise NotImplementedError() + + @property + def groups(self) -> dict[GroupKey, GroupIndex]: + """ + Mapping from group labels to indices. The indices can be used to index the underlying object. + """ + # provided to mimic pandas.groupby + if self._groups is None: + (grouper,) = self.groupers + squeezed_indices = ( + _maybe_squeeze_indices(ind, self._squeeze, grouper, warn=idx > 0) + for idx, ind in enumerate(self._group_indices) + ) + self._groups = dict(zip(grouper.unique_coord.values, squeezed_indices)) + return self._groups + + def __getitem__(self, key: GroupKey) -> T_Xarray: + """ + Get DataArray or Dataset corresponding to a particular group label. + """ + (grouper,) = self.groupers + index = _maybe_squeeze_indices( + self.groups[key], self._squeeze, grouper, warn=True + ) + return self._obj.isel({self._group_dim: index}) + + def __len__(self) -> int: + (grouper,) = self.groupers + return grouper.size + + def __iter__(self) -> Iterator[tuple[GroupKey, T_Xarray]]: + (grouper,) = self.groupers + return zip(grouper.unique_coord.data, self._iter_grouped()) + + def __repr__(self) -> str: + (grouper,) = self.groupers + return "{}, grouped over {!r}\n{!r} groups with labels {}.".format( + self.__class__.__name__, + grouper.name, + grouper.full_index.size, + ", ".join(format_array_flat(grouper.full_index, 30).split()), + ) + + def _iter_grouped(self, warn_squeeze=True) -> Iterator[T_Xarray]: + """Iterate over each element in this group""" + (grouper,) = self.groupers + for idx, indices in enumerate(self._group_indices): + indices = _maybe_squeeze_indices( + indices, self._squeeze, grouper, warn=warn_squeeze and idx == 0 + ) + yield self._obj.isel({self._group_dim: indices}) + + def _infer_concat_args(self, applied_example): + (grouper,) = self.groupers + if self._group_dim in applied_example.dims: + coord = grouper.group1d + positions = self._group_indices + else: + coord = grouper.unique_coord + positions = None + (dim,) = coord.dims + if isinstance(coord, _DummyGroup): + coord = None + coord = getattr(coord, "variable", coord) + return coord, dim, positions + + def _binary_op(self, other, f, reflexive=False): + from xarray.core.dataarray import DataArray + from xarray.core.dataset import Dataset + + g = f if not reflexive else lambda x, y: f(y, x) + + (grouper,) = self.groupers + obj = self._original_obj + group = grouper.group + codes = self._codes + dims = group.dims + + if isinstance(group, _DummyGroup): + group = coord = group.to_dataarray() + else: + coord = grouper.unique_coord + if not isinstance(coord, DataArray): + coord = DataArray(grouper.unique_coord) + name = grouper.name + + if not isinstance(other, (Dataset, DataArray)): + raise TypeError( + "GroupBy objects only support binary ops " + "when the other argument is a Dataset or " + "DataArray" + ) + + if name not in other.dims: + raise ValueError( + "incompatible dimensions for a grouped " + f"binary operation: the group variable {name!r} " + "is not a dimension on the other argument " + f"with dimensions {other.dims!r}" + ) + + # Broadcast out scalars for backwards compatibility + # TODO: get rid of this when fixing GH2145 + for var in other.coords: + if other[var].ndim == 0: + other[var] = ( + other[var].drop_vars(var).expand_dims({name: other.sizes[name]}) + ) + + # need to handle NaNs in group or elements that don't belong to any bins + mask = codes == -1 + if mask.any(): + obj = obj.where(~mask, drop=True) + group = group.where(~mask, drop=True) + codes = codes.where(~mask, drop=True).astype(int) + + # if other is dask-backed, that's a hint that the + # "expanded" dataset is too big to hold in memory. + # this can be the case when `other` was read from disk + # and contains our lazy indexing classes + # We need to check for dask-backed Datasets + # so utils.is_duck_dask_array does not work for this check + if obj.chunks and not other.chunks: + # TODO: What about datasets with some dask vars, and others not? + # This handles dims other than `name`` + chunks = {k: v for k, v in obj.chunksizes.items() if k in other.dims} + # a chunk size of 1 seems reasonable since we expect individual elements of + # other to be repeated multiple times across the reduced dimension(s) + chunks[name] = 1 + other = other.chunk(chunks) + + # codes are defined for coord, so we align `other` with `coord` + # before indexing + other, _ = align(other, coord, join="right", copy=False) + expanded = other.isel({name: codes}) + + result = g(obj, expanded) + + if group.ndim > 1: + # backcompat: + # TODO: get rid of this when fixing GH2145 + for var in set(obj.coords) - set(obj.xindexes): + if set(obj[var].dims) < set(group.dims): + result[var] = obj[var].reset_coords(drop=True).broadcast_like(group) + + if isinstance(result, Dataset) and isinstance(obj, Dataset): + for var in set(result): + for d in dims: + if d not in obj[var].dims: + result[var] = result[var].transpose(d, ...) + return result + + def _restore_dim_order(self, stacked): + raise NotImplementedError + + def _maybe_restore_empty_groups(self, combined): + """Our index contained empty groups (e.g., from a resampling or binning). If we + reduced on that dimension, we want to restore the full index. + """ + (grouper,) = self.groupers + if ( + isinstance(grouper.grouper, (BinGrouper, TimeResampler)) + and grouper.name in combined.dims + ): + indexers = {grouper.name: grouper.full_index} + combined = combined.reindex(**indexers) + return combined + + def _maybe_unstack(self, obj): + """This gets called if we are applying on an array with a + multidimensional group.""" + (grouper,) = self.groupers + stacked_dim = grouper.stacked_dim + inserted_dims = grouper.inserted_dims + if stacked_dim is not None and stacked_dim in obj.dims: + obj = obj.unstack(stacked_dim) + for dim in inserted_dims: + if dim in obj.coords: + del obj.coords[dim] + obj._indexes = filter_indexes_from_coords(obj._indexes, set(obj.coords)) + return obj + + def _flox_reduce( + self, + dim: Dims, + keep_attrs: bool | None = None, + **kwargs: Any, + ): + """Adaptor function that translates our groupby API to that of flox.""" + import flox + from flox.xarray import xarray_reduce + + from xarray.core.dataset import Dataset + + obj = self._original_obj + (grouper,) = self.groupers + isbin = isinstance(grouper.grouper, BinGrouper) + + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=True) + + if Version(flox.__version__) < Version("0.9"): + # preserve current strategy (approximately) for dask groupby + # on older flox versions to prevent surprises. + # flox >=0.9 will choose this on its own. + kwargs.setdefault("method", "cohorts") + + numeric_only = kwargs.pop("numeric_only", None) + if numeric_only: + non_numeric = { + name: var + for name, var in obj.data_vars.items() + if not (np.issubdtype(var.dtype, np.number) or (var.dtype == np.bool_)) + } + else: + non_numeric = {} + + if "min_count" in kwargs: + if kwargs["func"] not in ["sum", "prod"]: + raise TypeError("Received an unexpected keyword argument 'min_count'") + elif kwargs["min_count"] is None: + # set explicitly to avoid unnecessarily accumulating count + kwargs["min_count"] = 0 + + # weird backcompat + # reducing along a unique indexed dimension with squeeze=True + # should raise an error + if (dim is None or dim == grouper.name) and grouper.name in obj.xindexes: + index = obj.indexes[grouper.name] + if index.is_unique and self._squeeze: + raise ValueError(f"cannot reduce over dimensions {grouper.name!r}") + + unindexed_dims: tuple[Hashable, ...] = tuple() + if isinstance(grouper.group, _DummyGroup) and not isbin: + unindexed_dims = (grouper.name,) + + parsed_dim: tuple[Hashable, ...] + if isinstance(dim, str): + parsed_dim = (dim,) + elif dim is None: + parsed_dim = grouper.group.dims + elif dim is ...: + parsed_dim = tuple(obj.dims) + else: + parsed_dim = tuple(dim) + + # Do this so we raise the same error message whether flox is present or not. + # Better to control it here than in flox. + if any(d not in grouper.group.dims and d not in obj.dims for d in parsed_dim): + raise ValueError(f"cannot reduce over dimensions {dim}.") + + if kwargs["func"] not in ["all", "any", "count"]: + kwargs.setdefault("fill_value", np.nan) + if isbin and kwargs["func"] == "count": + # This is an annoying hack. Xarray returns np.nan + # when there are no observations in a bin, instead of 0. + # We can fake that here by forcing min_count=1. + # note min_count makes no sense in the xarray world + # as a kwarg for count, so this should be OK + kwargs.setdefault("fill_value", np.nan) + kwargs.setdefault("min_count", 1) + + output_index = grouper.full_index + result = xarray_reduce( + obj.drop_vars(non_numeric.keys()), + self._codes, + dim=parsed_dim, + # pass RangeIndex as a hint to flox that `by` is already factorized + expected_groups=(pd.RangeIndex(len(output_index)),), + isbin=False, + keep_attrs=keep_attrs, + **kwargs, + ) + + # we did end up reducing over dimension(s) that are + # in the grouped variable + group_dims = grouper.group.dims + if set(group_dims).issubset(set(parsed_dim)): + result[grouper.name] = output_index + result = result.drop_vars(unindexed_dims) + + # broadcast and restore non-numeric data variables (backcompat) + for name, var in non_numeric.items(): + if all(d not in var.dims for d in parsed_dim): + result[name] = var.variable.set_dims( + (grouper.name,) + var.dims, + (result.sizes[grouper.name],) + var.shape, + ) + + if not isinstance(result, Dataset): + # only restore dimension order for arrays + result = self._restore_dim_order(result) + + return result + + def fillna(self, value: Any) -> T_Xarray: + """Fill missing values in this object by group. + + This operation follows the normal broadcasting and alignment rules that + xarray uses for binary arithmetic, except the result is aligned to this + object (``join='left'``) instead of aligned to the intersection of + index coordinates (``join='inner'``). + + Parameters + ---------- + value + Used to fill all matching missing values by group. Needs + to be of a valid type for the wrapped object's fillna + method. + + Returns + ------- + same type as the grouped object + + See Also + -------- + Dataset.fillna + DataArray.fillna + """ + return ops.fillna(self, value) + + @_deprecate_positional_args("v2023.10.0") + def quantile( + self, + q: ArrayLike, + dim: Dims = None, + *, + method: QuantileMethods = "linear", + keep_attrs: bool | None = None, + skipna: bool | None = None, + interpolation: QuantileMethods | None = None, + ) -> T_Xarray: + """Compute the qth quantile over each array in the groups and + concatenate them together into a new array. + + Parameters + ---------- + q : float or sequence of float + Quantile to compute, which must be between 0 and 1 + inclusive. + dim : str or Iterable of Hashable, optional + Dimension(s) over which to apply quantile. + Defaults to the grouped dimension. + method : str, default: "linear" + This optional parameter specifies the interpolation method to use when the + desired quantile lies between two data points. The options sorted by their R + type as summarized in the H&F paper [1]_ are: + + 1. "inverted_cdf" + 2. "averaged_inverted_cdf" + 3. "closest_observation" + 4. "interpolated_inverted_cdf" + 5. "hazen" + 6. "weibull" + 7. "linear" (default) + 8. "median_unbiased" + 9. "normal_unbiased" + + The first three methods are discontiuous. The following discontinuous + variations of the default "linear" (7.) option are also available: + + * "lower" + * "higher" + * "midpoint" + * "nearest" + + See :py:func:`numpy.quantile` or [1]_ for details. The "method" argument + was previously called "interpolation", renamed in accordance with numpy + version 1.22.0. + keep_attrs : bool or None, default: None + If True, the dataarray's attributes (`attrs`) will be copied from + the original object to the new one. If False, the new + object will be returned without attributes. + skipna : bool or None, default: None + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + + Returns + ------- + quantiles : Variable + If `q` is a single quantile, then the result is a + scalar. If multiple percentiles are given, first axis of + the result corresponds to the quantile. In either case a + quantile dimension is added to the return array. The other + dimensions are the dimensions that remain after the + reduction of the array. + + See Also + -------- + numpy.nanquantile, numpy.quantile, pandas.Series.quantile, Dataset.quantile + DataArray.quantile + + Examples + -------- + >>> da = xr.DataArray( + ... [[1.3, 8.4, 0.7, 6.9], [0.7, 4.2, 9.4, 1.5], [6.5, 7.3, 2.6, 1.9]], + ... coords={"x": [0, 0, 1], "y": [1, 1, 2, 2]}, + ... dims=("x", "y"), + ... ) + >>> ds = xr.Dataset({"a": da}) + >>> da.groupby("x").quantile(0) + Size: 64B + array([[0.7, 4.2, 0.7, 1.5], + [6.5, 7.3, 2.6, 1.9]]) + Coordinates: + * y (y) int64 32B 1 1 2 2 + quantile float64 8B 0.0 + * x (x) int64 16B 0 1 + >>> ds.groupby("y").quantile(0, dim=...) + Size: 40B + Dimensions: (y: 2) + Coordinates: + quantile float64 8B 0.0 + * y (y) int64 16B 1 2 + Data variables: + a (y) float64 16B 0.7 0.7 + >>> da.groupby("x").quantile([0, 0.5, 1]) + Size: 192B + array([[[0.7 , 1. , 1.3 ], + [4.2 , 6.3 , 8.4 ], + [0.7 , 5.05, 9.4 ], + [1.5 , 4.2 , 6.9 ]], + + [[6.5 , 6.5 , 6.5 ], + [7.3 , 7.3 , 7.3 ], + [2.6 , 2.6 , 2.6 ], + [1.9 , 1.9 , 1.9 ]]]) + Coordinates: + * y (y) int64 32B 1 1 2 2 + * quantile (quantile) float64 24B 0.0 0.5 1.0 + * x (x) int64 16B 0 1 + >>> ds.groupby("y").quantile([0, 0.5, 1], dim=...) + Size: 88B + Dimensions: (y: 2, quantile: 3) + Coordinates: + * quantile (quantile) float64 24B 0.0 0.5 1.0 + * y (y) int64 16B 1 2 + Data variables: + a (y, quantile) float64 48B 0.7 5.35 8.4 0.7 2.25 9.4 + + References + ---------- + .. [1] R. J. Hyndman and Y. Fan, + "Sample quantiles in statistical packages," + The American Statistician, 50(4), pp. 361-365, 1996 + """ + if dim is None: + (grouper,) = self.groupers + dim = grouper.group1d.dims + + # Dataset.quantile does this, do it for flox to ensure same output. + q = np.asarray(q, dtype=np.float64) + + if ( + method == "linear" + and OPTIONS["use_flox"] + and contains_only_chunked_or_numpy(self._obj) + and module_available("flox", minversion="0.9.4") + ): + result = self._flox_reduce( + func="quantile", q=q, dim=dim, keep_attrs=keep_attrs, skipna=skipna + ) + return result + else: + return self.map( + self._obj.__class__.quantile, + shortcut=False, + q=q, + dim=dim, + method=method, + keep_attrs=keep_attrs, + skipna=skipna, + interpolation=interpolation, + ) + + def where(self, cond, other=dtypes.NA) -> T_Xarray: + """Return elements from `self` or `other` depending on `cond`. + + Parameters + ---------- + cond : DataArray or Dataset + Locations at which to preserve this objects values. dtypes have to be `bool` + other : scalar, DataArray or Dataset, optional + Value to use for locations in this object where ``cond`` is False. + By default, inserts missing values. + + Returns + ------- + same type as the grouped object + + See Also + -------- + Dataset.where + """ + return ops.where_method(self, cond, other) + + def _first_or_last(self, op, skipna, keep_attrs): + if all( + isinstance(maybe_slice, slice) + and (maybe_slice.stop == maybe_slice.start + 1) + for maybe_slice in self._group_indices + ): + # NB. this is currently only used for reductions along an existing + # dimension + return self._obj + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=True) + return self.reduce( + op, dim=[self._group_dim], skipna=skipna, keep_attrs=keep_attrs + ) + + def first(self, skipna: bool | None = None, keep_attrs: bool | None = None): + """Return the first element of each group along the group dimension""" + return self._first_or_last(duck_array_ops.first, skipna, keep_attrs) + + def last(self, skipna: bool | None = None, keep_attrs: bool | None = None): + """Return the last element of each group along the group dimension""" + return self._first_or_last(duck_array_ops.last, skipna, keep_attrs) + + def assign_coords(self, coords=None, **coords_kwargs): + """Assign coordinates by group. + + See Also + -------- + Dataset.assign_coords + Dataset.swap_dims + """ + coords_kwargs = either_dict_or_kwargs(coords, coords_kwargs, "assign_coords") + return self.map(lambda ds: ds.assign_coords(**coords_kwargs)) + + +def _maybe_reorder(xarray_obj, dim, positions, N: int | None): + order = _inverse_permutation_indices(positions, N) + + if order is None or len(order) != xarray_obj.sizes[dim]: + return xarray_obj + else: + return xarray_obj[{dim: order}] + + +class DataArrayGroupByBase(GroupBy["DataArray"], DataArrayGroupbyArithmetic): + """GroupBy object specialized to grouping DataArray objects""" + + __slots__ = () + _dims: tuple[Hashable, ...] | None + + @property + def dims(self) -> tuple[Hashable, ...]: + if self._dims is None: + (grouper,) = self.groupers + index = _maybe_squeeze_indices( + self._group_indices[0], self._squeeze, grouper, warn=True + ) + self._dims = self._obj.isel({self._group_dim: index}).dims + + return self._dims + + def _iter_grouped_shortcut(self, warn_squeeze=True): + """Fast version of `_iter_grouped` that yields Variables without + metadata + """ + var = self._obj.variable + (grouper,) = self.groupers + for idx, indices in enumerate(self._group_indices): + indices = _maybe_squeeze_indices( + indices, self._squeeze, grouper, warn=warn_squeeze and idx == 0 + ) + yield var[{self._group_dim: indices}] + + def _concat_shortcut(self, applied, dim, positions=None): + # nb. don't worry too much about maintaining this method -- it does + # speed things up, but it's not very interpretable and there are much + # faster alternatives (e.g., doing the grouped aggregation in a + # compiled language) + # TODO: benbovy - explicit indexes: this fast implementation doesn't + # create an explicit index for the stacked dim coordinate + stacked = Variable.concat(applied, dim, shortcut=True) + + (grouper,) = self.groupers + reordered = _maybe_reorder(stacked, dim, positions, N=grouper.group.size) + return self._obj._replace_maybe_drop_dims(reordered) + + def _restore_dim_order(self, stacked: DataArray) -> DataArray: + (grouper,) = self.groupers + group = grouper.group1d + + groupby_coord = ( + f"{group.name}_bins" + if isinstance(grouper.grouper, BinGrouper) + else group.name + ) + + def lookup_order(dimension): + if dimension == groupby_coord: + (dimension,) = group.dims + if dimension in self._obj.dims: + axis = self._obj.get_axis_num(dimension) + else: + axis = 1e6 # some arbitrarily high value + return axis + + new_order = sorted(stacked.dims, key=lookup_order) + return stacked.transpose(*new_order, transpose_coords=self._restore_coord_dims) + + def map( + self, + func: Callable[..., DataArray], + args: tuple[Any, ...] = (), + shortcut: bool | None = None, + **kwargs: Any, + ) -> DataArray: + """Apply a function to each array in the group and concatenate them + together into a new array. + + `func` is called like `func(ar, *args, **kwargs)` for each array `ar` + in this group. + + Apply uses heuristics (like `pandas.GroupBy.apply`) to figure out how + to stack together the array. The rule is: + + 1. If the dimension along which the group coordinate is defined is + still in the first grouped array after applying `func`, then stack + over this dimension. + 2. Otherwise, stack over the new dimension given by name of this + grouping (the argument to the `groupby` function). + + Parameters + ---------- + func : callable + Callable to apply to each array. + shortcut : bool, optional + Whether or not to shortcut evaluation under the assumptions that: + + (1) The action of `func` does not depend on any of the array + metadata (attributes or coordinates) but only on the data and + dimensions. + (2) The action of `func` creates arrays with homogeneous metadata, + that is, with the same dimensions and attributes. + + If these conditions are satisfied `shortcut` provides significant + speedup. This should be the case for many common groupby operations + (e.g., applying numpy ufuncs). + *args : tuple, optional + Positional arguments passed to `func`. + **kwargs + Used to call `func(ar, **kwargs)` for each array `ar`. + + Returns + ------- + applied : DataArray + The result of splitting, applying and combining this array. + """ + return self._map_maybe_warn( + func, args, warn_squeeze=True, shortcut=shortcut, **kwargs + ) + + def _map_maybe_warn( + self, + func: Callable[..., DataArray], + args: tuple[Any, ...] = (), + *, + warn_squeeze: bool = True, + shortcut: bool | None = None, + **kwargs: Any, + ) -> DataArray: + grouped = ( + self._iter_grouped_shortcut(warn_squeeze) + if shortcut + else self._iter_grouped(warn_squeeze) + ) + applied = (maybe_wrap_array(arr, func(arr, *args, **kwargs)) for arr in grouped) + return self._combine(applied, shortcut=shortcut) + + def apply(self, func, shortcut=False, args=(), **kwargs): + """ + Backward compatible implementation of ``map`` + + See Also + -------- + DataArrayGroupBy.map + """ + warnings.warn( + "GroupBy.apply may be deprecated in the future. Using GroupBy.map is encouraged", + PendingDeprecationWarning, + stacklevel=2, + ) + return self.map(func, shortcut=shortcut, args=args, **kwargs) + + def _combine(self, applied, shortcut=False): + """Recombine the applied objects like the original.""" + applied_example, applied = peek_at(applied) + coord, dim, positions = self._infer_concat_args(applied_example) + if shortcut: + combined = self._concat_shortcut(applied, dim, positions) + else: + combined = concat(applied, dim) + (grouper,) = self.groupers + combined = _maybe_reorder(combined, dim, positions, N=grouper.group.size) + + if isinstance(combined, type(self._obj)): + # only restore dimension order for arrays + combined = self._restore_dim_order(combined) + # assign coord and index when the applied function does not return that coord + if coord is not None and dim not in applied_example.dims: + index, index_vars = create_default_index_implicit(coord) + indexes = {k: index for k in index_vars} + combined = combined._overwrite_indexes(indexes, index_vars) + combined = self._maybe_restore_empty_groups(combined) + combined = self._maybe_unstack(combined) + return combined + + def reduce( + self, + func: Callable[..., Any], + dim: Dims = None, + *, + axis: int | Sequence[int] | None = None, + keep_attrs: bool | None = None, + keepdims: bool = False, + shortcut: bool = True, + **kwargs: Any, + ) -> DataArray: + """Reduce the items in this group by applying `func` along some + dimension(s). + + Parameters + ---------- + func : callable + Function which can be called in the form + `func(x, axis=axis, **kwargs)` to return the result of collapsing + an np.ndarray over an integer valued axis. + dim : "...", str, Iterable of Hashable or None, optional + Dimension(s) over which to apply `func`. If None, apply over the + groupby dimension, if "..." apply over all dimensions. + axis : int or sequence of int, optional + Axis(es) over which to apply `func`. Only one of the 'dimension' + and 'axis' arguments can be supplied. If neither are supplied, then + `func` is calculated over all dimension for each group item. + keep_attrs : bool, optional + If True, the datasets's attributes (`attrs`) will be copied from + the original object to the new one. If False (default), the new + object will be returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to `func`. + + Returns + ------- + reduced : Array + Array with summarized data and the indicated dimension(s) + removed. + """ + if dim is None: + dim = [self._group_dim] + + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=True) + + def reduce_array(ar: DataArray) -> DataArray: + return ar.reduce( + func=func, + dim=dim, + axis=axis, + keep_attrs=keep_attrs, + keepdims=keepdims, + **kwargs, + ) + + check_reduce_dims(dim, self.dims) + + return self.map(reduce_array, shortcut=shortcut) + + def _reduce_without_squeeze_warn( + self, + func: Callable[..., Any], + dim: Dims = None, + *, + axis: int | Sequence[int] | None = None, + keep_attrs: bool | None = None, + keepdims: bool = False, + shortcut: bool = True, + **kwargs: Any, + ) -> DataArray: + """Reduce the items in this group by applying `func` along some + dimension(s). + + Parameters + ---------- + func : callable + Function which can be called in the form + `func(x, axis=axis, **kwargs)` to return the result of collapsing + an np.ndarray over an integer valued axis. + dim : "...", str, Iterable of Hashable or None, optional + Dimension(s) over which to apply `func`. If None, apply over the + groupby dimension, if "..." apply over all dimensions. + axis : int or sequence of int, optional + Axis(es) over which to apply `func`. Only one of the 'dimension' + and 'axis' arguments can be supplied. If neither are supplied, then + `func` is calculated over all dimension for each group item. + keep_attrs : bool, optional + If True, the datasets's attributes (`attrs`) will be copied from + the original object to the new one. If False (default), the new + object will be returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to `func`. + + Returns + ------- + reduced : Array + Array with summarized data and the indicated dimension(s) + removed. + """ + if dim is None: + dim = [self._group_dim] + + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=True) + + def reduce_array(ar: DataArray) -> DataArray: + return ar.reduce( + func=func, + dim=dim, + axis=axis, + keep_attrs=keep_attrs, + keepdims=keepdims, + **kwargs, + ) + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="The `squeeze` kwarg") + check_reduce_dims(dim, self.dims) + + return self._map_maybe_warn(reduce_array, shortcut=shortcut, warn_squeeze=False) + + +# https://github.com/python/mypy/issues/9031 +class DataArrayGroupBy( # type: ignore[misc] + DataArrayGroupByBase, + DataArrayGroupByAggregations, + ImplementsArrayReduce, +): + __slots__ = () + + +class DatasetGroupByBase(GroupBy["Dataset"], DatasetGroupbyArithmetic): + __slots__ = () + _dims: Frozen[Hashable, int] | None + + @property + def dims(self) -> Frozen[Hashable, int]: + if self._dims is None: + (grouper,) = self.groupers + index = _maybe_squeeze_indices( + self._group_indices[0], + self._squeeze, + grouper, + warn=True, + ) + self._dims = self._obj.isel({self._group_dim: index}).dims + + return FrozenMappingWarningOnValuesAccess(self._dims) + + def map( + self, + func: Callable[..., Dataset], + args: tuple[Any, ...] = (), + shortcut: bool | None = None, + **kwargs: Any, + ) -> Dataset: + """Apply a function to each Dataset in the group and concatenate them + together into a new Dataset. + + `func` is called like `func(ds, *args, **kwargs)` for each dataset `ds` + in this group. + + Apply uses heuristics (like `pandas.GroupBy.apply`) to figure out how + to stack together the datasets. The rule is: + + 1. If the dimension along which the group coordinate is defined is + still in the first grouped item after applying `func`, then stack + over this dimension. + 2. Otherwise, stack over the new dimension given by name of this + grouping (the argument to the `groupby` function). + + Parameters + ---------- + func : callable + Callable to apply to each sub-dataset. + args : tuple, optional + Positional arguments to pass to `func`. + **kwargs + Used to call `func(ds, **kwargs)` for each sub-dataset `ar`. + + Returns + ------- + applied : Dataset + The result of splitting, applying and combining this dataset. + """ + return self._map_maybe_warn(func, args, shortcut, warn_squeeze=True, **kwargs) + + def _map_maybe_warn( + self, + func: Callable[..., Dataset], + args: tuple[Any, ...] = (), + shortcut: bool | None = None, + warn_squeeze: bool = False, + **kwargs: Any, + ) -> Dataset: + # ignore shortcut if set (for now) + applied = (func(ds, *args, **kwargs) for ds in self._iter_grouped(warn_squeeze)) + return self._combine(applied) + + def apply(self, func, args=(), shortcut=None, **kwargs): + """ + Backward compatible implementation of ``map`` + + See Also + -------- + DatasetGroupBy.map + """ + + warnings.warn( + "GroupBy.apply may be deprecated in the future. Using GroupBy.map is encouraged", + PendingDeprecationWarning, + stacklevel=2, + ) + return self.map(func, shortcut=shortcut, args=args, **kwargs) + + def _combine(self, applied): + """Recombine the applied objects like the original.""" + applied_example, applied = peek_at(applied) + coord, dim, positions = self._infer_concat_args(applied_example) + combined = concat(applied, dim) + (grouper,) = self.groupers + combined = _maybe_reorder(combined, dim, positions, N=grouper.group.size) + # assign coord when the applied function does not return that coord + if coord is not None and dim not in applied_example.dims: + index, index_vars = create_default_index_implicit(coord) + indexes = {k: index for k in index_vars} + combined = combined._overwrite_indexes(indexes, index_vars) + combined = self._maybe_restore_empty_groups(combined) + combined = self._maybe_unstack(combined) + return combined + + def reduce( + self, + func: Callable[..., Any], + dim: Dims = None, + *, + axis: int | Sequence[int] | None = None, + keep_attrs: bool | None = None, + keepdims: bool = False, + shortcut: bool = True, + **kwargs: Any, + ) -> Dataset: + """Reduce the items in this group by applying `func` along some + dimension(s). + + Parameters + ---------- + func : callable + Function which can be called in the form + `func(x, axis=axis, **kwargs)` to return the result of collapsing + an np.ndarray over an integer valued axis. + dim : ..., str, Iterable of Hashable or None, optional + Dimension(s) over which to apply `func`. By default apply over the + groupby dimension, with "..." apply over all dimensions. + axis : int or sequence of int, optional + Axis(es) over which to apply `func`. Only one of the 'dimension' + and 'axis' arguments can be supplied. If neither are supplied, then + `func` is calculated over all dimension for each group item. + keep_attrs : bool, optional + If True, the datasets's attributes (`attrs`) will be copied from + the original object to the new one. If False (default), the new + object will be returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to `func`. + + Returns + ------- + reduced : Dataset + Array with summarized data and the indicated dimension(s) + removed. + """ + if dim is None: + dim = [self._group_dim] + + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=True) + + def reduce_dataset(ds: Dataset) -> Dataset: + return ds.reduce( + func=func, + dim=dim, + axis=axis, + keep_attrs=keep_attrs, + keepdims=keepdims, + **kwargs, + ) + + check_reduce_dims(dim, self.dims) + + return self.map(reduce_dataset) + + def _reduce_without_squeeze_warn( + self, + func: Callable[..., Any], + dim: Dims = None, + *, + axis: int | Sequence[int] | None = None, + keep_attrs: bool | None = None, + keepdims: bool = False, + shortcut: bool = True, + **kwargs: Any, + ) -> Dataset: + """Reduce the items in this group by applying `func` along some + dimension(s). + + Parameters + ---------- + func : callable + Function which can be called in the form + `func(x, axis=axis, **kwargs)` to return the result of collapsing + an np.ndarray over an integer valued axis. + dim : ..., str, Iterable of Hashable or None, optional + Dimension(s) over which to apply `func`. By default apply over the + groupby dimension, with "..." apply over all dimensions. + axis : int or sequence of int, optional + Axis(es) over which to apply `func`. Only one of the 'dimension' + and 'axis' arguments can be supplied. If neither are supplied, then + `func` is calculated over all dimension for each group item. + keep_attrs : bool, optional + If True, the datasets's attributes (`attrs`) will be copied from + the original object to the new one. If False (default), the new + object will be returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to `func`. + + Returns + ------- + reduced : Dataset + Array with summarized data and the indicated dimension(s) + removed. + """ + if dim is None: + dim = [self._group_dim] + + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=True) + + def reduce_dataset(ds: Dataset) -> Dataset: + return ds.reduce( + func=func, + dim=dim, + axis=axis, + keep_attrs=keep_attrs, + keepdims=keepdims, + **kwargs, + ) + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="The `squeeze` kwarg") + check_reduce_dims(dim, self.dims) + + return self._map_maybe_warn(reduce_dataset, warn_squeeze=False) + + def assign(self, **kwargs: Any) -> Dataset: + """Assign data variables by group. + + See Also + -------- + Dataset.assign + """ + return self.map(lambda ds: ds.assign(**kwargs)) + + +# https://github.com/python/mypy/issues/9031 +class DatasetGroupBy( # type: ignore[misc] + DatasetGroupByBase, + DatasetGroupByAggregations, + ImplementsDatasetReduce, +): + __slots__ = () diff --git a/test/fixtures/whole_applications/xarray/xarray/core/indexes.py b/test/fixtures/whole_applications/xarray/xarray/core/indexes.py new file mode 100644 index 0000000..f25c0ec --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/core/indexes.py @@ -0,0 +1,1908 @@ +from __future__ import annotations + +import collections.abc +import copy +from collections import defaultdict +from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast + +import numpy as np +import pandas as pd + +from xarray.core import formatting, nputils, utils +from xarray.core.indexing import ( + IndexSelResult, + PandasIndexingAdapter, + PandasMultiIndexingAdapter, +) +from xarray.core.utils import ( + Frozen, + emit_user_level_warning, + get_valid_numpy_dtype, + is_dict_like, + is_scalar, +) + +if TYPE_CHECKING: + from xarray.core.types import ErrorOptions, JoinOptions, Self + from xarray.core.variable import Variable + + +IndexVars = dict[Any, "Variable"] + + +class Index: + """ + Base class inherited by all xarray-compatible indexes. + + Do not use this class directly for creating index objects. Xarray indexes + are created exclusively from subclasses of ``Index``, mostly via Xarray's + public API like ``Dataset.set_xindex``. + + Every subclass must at least implement :py:meth:`Index.from_variables`. The + (re)implementation of the other methods of this base class is optional but + mostly required in order to support operations relying on indexes such as + label-based selection or alignment. + + The ``Index`` API closely follows the :py:meth:`Dataset` and + :py:meth:`DataArray` API, e.g., for an index to support ``.sel()`` it needs + to implement :py:meth:`Index.sel`, to support ``.stack()`` and + ``.unstack()`` it needs to implement :py:meth:`Index.stack` and + :py:meth:`Index.unstack`, etc. + + When a method is not (re)implemented, depending on the case the + corresponding operation on a :py:meth:`Dataset` or :py:meth:`DataArray` + either will raise a ``NotImplementedError`` or will simply drop/pass/copy + the index from/to the result. + + Do not use this class directly for creating index objects. + """ + + @classmethod + def from_variables( + cls, + variables: Mapping[Any, Variable], + *, + options: Mapping[str, Any], + ) -> Self: + """Create a new index object from one or more coordinate variables. + + This factory method must be implemented in all subclasses of Index. + + The coordinate variables may be passed here in an arbitrary number and + order and each with arbitrary dimensions. It is the responsibility of + the index to check the consistency and validity of these coordinates. + + Parameters + ---------- + variables : dict-like + Mapping of :py:class:`Variable` objects holding the coordinate labels + to index. + + Returns + ------- + index : Index + A new Index object. + """ + raise NotImplementedError() + + @classmethod + def concat( + cls, + indexes: Sequence[Self], + dim: Hashable, + positions: Iterable[Iterable[int]] | None = None, + ) -> Self: + """Create a new index by concatenating one or more indexes of the same + type. + + Implementation is optional but required in order to support + ``concat``. Otherwise it will raise an error if the index needs to be + updated during the operation. + + Parameters + ---------- + indexes : sequence of Index objects + Indexes objects to concatenate together. All objects must be of the + same type. + dim : Hashable + Name of the dimension to concatenate along. + positions : None or list of integer arrays, optional + List of integer arrays which specifies the integer positions to which + to assign each dataset along the concatenated dimension. If not + supplied, objects are concatenated in the provided order. + + Returns + ------- + index : Index + A new Index object. + """ + raise NotImplementedError() + + @classmethod + def stack(cls, variables: Mapping[Any, Variable], dim: Hashable) -> Self: + """Create a new index by stacking coordinate variables into a single new + dimension. + + Implementation is optional but required in order to support ``stack``. + Otherwise it will raise an error when trying to pass the Index subclass + as argument to :py:meth:`Dataset.stack`. + + Parameters + ---------- + variables : dict-like + Mapping of :py:class:`Variable` objects to stack together. + dim : Hashable + Name of the new, stacked dimension. + + Returns + ------- + index + A new Index object. + """ + raise NotImplementedError( + f"{cls!r} cannot be used for creating an index of stacked coordinates" + ) + + def unstack(self) -> tuple[dict[Hashable, Index], pd.MultiIndex]: + """Unstack a (multi-)index into multiple (single) indexes. + + Implementation is optional but required in order to support unstacking + the coordinates from which this index has been built. + + Returns + ------- + indexes : tuple + A 2-length tuple where the 1st item is a dictionary of unstacked + Index objects and the 2nd item is a :py:class:`pandas.MultiIndex` + object used to unstack unindexed coordinate variables or data + variables. + """ + raise NotImplementedError() + + def create_variables( + self, variables: Mapping[Any, Variable] | None = None + ) -> IndexVars: + """Maybe create new coordinate variables from this index. + + This method is useful if the index data can be reused as coordinate + variable data. It is often the case when the underlying index structure + has an array-like interface, like :py:class:`pandas.Index` objects. + + The variables given as argument (if any) are either returned as-is + (default behavior) or can be used to copy their metadata (attributes and + encoding) into the new returned coordinate variables. + + Note: the input variables may or may not have been filtered for this + index. + + Parameters + ---------- + variables : dict-like, optional + Mapping of :py:class:`Variable` objects. + + Returns + ------- + index_variables : dict-like + Dictionary of :py:class:`Variable` or :py:class:`IndexVariable` + objects. + """ + if variables is not None: + # pass through + return dict(**variables) + else: + return {} + + def to_pandas_index(self) -> pd.Index: + """Cast this xarray index to a pandas.Index object or raise a + ``TypeError`` if this is not supported. + + This method is used by all xarray operations that still rely on + pandas.Index objects. + + By default it raises a ``TypeError``, unless it is re-implemented in + subclasses of Index. + """ + raise TypeError(f"{self!r} cannot be cast to a pandas.Index object") + + def isel( + self, indexers: Mapping[Any, int | slice | np.ndarray | Variable] + ) -> Self | None: + """Maybe returns a new index from the current index itself indexed by + positional indexers. + + This method should be re-implemented in subclasses of Index if the + wrapped index structure supports indexing operations. For example, + indexing a ``pandas.Index`` is pretty straightforward as it behaves very + much like an array. By contrast, it may be harder doing so for a + structure like a kd-tree that differs much from a simple array. + + If not re-implemented in subclasses of Index, this method returns + ``None``, i.e., calling :py:meth:`Dataset.isel` will either drop the + index in the resulting dataset or pass it unchanged if its corresponding + coordinate(s) are not indexed. + + Parameters + ---------- + indexers : dict + A dictionary of positional indexers as passed from + :py:meth:`Dataset.isel` and where the entries have been filtered + for the current index. + + Returns + ------- + maybe_index : Index + A new Index object or ``None``. + """ + return None + + def sel(self, labels: dict[Any, Any]) -> IndexSelResult: + """Query the index with arbitrary coordinate label indexers. + + Implementation is optional but required in order to support label-based + selection. Otherwise it will raise an error when trying to call + :py:meth:`Dataset.sel` with labels for this index coordinates. + + Coordinate label indexers can be of many kinds, e.g., scalar, list, + tuple, array-like, slice, :py:class:`Variable`, :py:class:`DataArray`, etc. + It is the responsibility of the index to handle those indexers properly. + + Parameters + ---------- + labels : dict + A dictionary of coordinate label indexers passed from + :py:meth:`Dataset.sel` and where the entries have been filtered + for the current index. + + Returns + ------- + sel_results : :py:class:`IndexSelResult` + An index query result object that contains dimension positional indexers. + It may also contain new indexes, coordinate variables, etc. + """ + raise NotImplementedError(f"{self!r} doesn't support label-based selection") + + def join(self, other: Self, how: JoinOptions = "inner") -> Self: + """Return a new index from the combination of this index with another + index of the same type. + + Implementation is optional but required in order to support alignment. + + Parameters + ---------- + other : Index + The other Index object to combine with this index. + join : str, optional + Method for joining the two indexes (see :py:func:`~xarray.align`). + + Returns + ------- + joined : Index + A new Index object. + """ + raise NotImplementedError( + f"{self!r} doesn't support alignment with inner/outer join method" + ) + + def reindex_like(self, other: Self) -> dict[Hashable, Any]: + """Query the index with another index of the same type. + + Implementation is optional but required in order to support alignment. + + Parameters + ---------- + other : Index + The other Index object used to query this index. + + Returns + ------- + dim_positional_indexers : dict + A dictionary where keys are dimension names and values are positional + indexers. + """ + raise NotImplementedError(f"{self!r} doesn't support re-indexing labels") + + def equals(self, other: Self) -> bool: + """Compare this index with another index of the same type. + + Implementation is optional but required in order to support alignment. + + Parameters + ---------- + other : Index + The other Index object to compare with this object. + + Returns + ------- + is_equal : bool + ``True`` if the indexes are equal, ``False`` otherwise. + """ + raise NotImplementedError() + + def roll(self, shifts: Mapping[Any, int]) -> Self | None: + """Roll this index by an offset along one or more dimensions. + + This method can be re-implemented in subclasses of Index, e.g., when the + index can be itself indexed. + + If not re-implemented, this method returns ``None``, i.e., calling + :py:meth:`Dataset.roll` will either drop the index in the resulting + dataset or pass it unchanged if its corresponding coordinate(s) are not + rolled. + + Parameters + ---------- + shifts : mapping of hashable to int, optional + A dict with keys matching dimensions and values given + by integers to rotate each of the given dimensions, as passed + :py:meth:`Dataset.roll`. + + Returns + ------- + rolled : Index + A new index with rolled data. + """ + return None + + def rename( + self, + name_dict: Mapping[Any, Hashable], + dims_dict: Mapping[Any, Hashable], + ) -> Self: + """Maybe update the index with new coordinate and dimension names. + + This method should be re-implemented in subclasses of Index if it has + attributes that depend on coordinate or dimension names. + + By default (if not re-implemented), it returns the index itself. + + Warning: the input names are not filtered for this method, they may + correspond to any variable or dimension of a Dataset or a DataArray. + + Parameters + ---------- + name_dict : dict-like + Mapping of current variable or coordinate names to the desired names, + as passed from :py:meth:`Dataset.rename_vars`. + dims_dict : dict-like + Mapping of current dimension names to the desired names, as passed + from :py:meth:`Dataset.rename_dims`. + + Returns + ------- + renamed : Index + Index with renamed attributes. + """ + return self + + def copy(self, deep: bool = True) -> Self: + """Return a (deep) copy of this index. + + Implementation in subclasses of Index is optional. The base class + implements the default (deep) copy semantics. + + Parameters + ---------- + deep : bool, optional + If true (default), a copy of the internal structures + (e.g., wrapped index) is returned with the new object. + + Returns + ------- + index : Index + A new Index object. + """ + return self._copy(deep=deep) + + def __copy__(self) -> Self: + return self.copy(deep=False) + + def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Index: + return self._copy(deep=True, memo=memo) + + def _copy(self, deep: bool = True, memo: dict[int, Any] | None = None) -> Self: + cls = self.__class__ + copied = cls.__new__(cls) + if deep: + for k, v in self.__dict__.items(): + setattr(copied, k, copy.deepcopy(v, memo)) + else: + copied.__dict__.update(self.__dict__) + return copied + + def __getitem__(self, indexer: Any) -> Self: + raise NotImplementedError() + + def _repr_inline_(self, max_width): + return self.__class__.__name__ + + +def _maybe_cast_to_cftimeindex(index: pd.Index) -> pd.Index: + from xarray.coding.cftimeindex import CFTimeIndex + + if len(index) > 0 and index.dtype == "O" and not isinstance(index, CFTimeIndex): + try: + return CFTimeIndex(index) + except (ImportError, TypeError): + return index + else: + return index + + +def safe_cast_to_index(array: Any) -> pd.Index: + """Given an array, safely cast it to a pandas.Index. + + If it is already a pandas.Index, return it unchanged. + + Unlike pandas.Index, if the array has dtype=object or dtype=timedelta64, + this function will not attempt to do automatic type conversion but will + always return an index with dtype=object. + """ + from xarray.core.dataarray import DataArray + from xarray.core.variable import Variable + + if isinstance(array, pd.Index): + index = array + elif isinstance(array, (DataArray, Variable)): + # returns the original multi-index for pandas.MultiIndex level coordinates + index = array._to_index() + elif isinstance(array, Index): + index = array.to_pandas_index() + elif isinstance(array, PandasIndexingAdapter): + index = array.array + else: + kwargs: dict[str, str] = {} + if hasattr(array, "dtype"): + if array.dtype.kind == "O": + kwargs["dtype"] = "object" + elif array.dtype == "float16": + emit_user_level_warning( + ( + "`pandas.Index` does not support the `float16` dtype." + " Casting to `float64` for you, but in the future please" + " manually cast to either `float32` and `float64`." + ), + category=DeprecationWarning, + ) + kwargs["dtype"] = "float64" + + index = pd.Index(np.asarray(array), **kwargs) + + return _maybe_cast_to_cftimeindex(index) + + +def _sanitize_slice_element(x): + from xarray.core.dataarray import DataArray + from xarray.core.variable import Variable + + if not isinstance(x, tuple) and len(np.shape(x)) != 0: + raise ValueError( + f"cannot use non-scalar arrays in a slice for xarray indexing: {x}" + ) + + if isinstance(x, (Variable, DataArray)): + x = x.values + + if isinstance(x, np.ndarray): + x = x[()] + + return x + + +def _query_slice(index, label, coord_name="", method=None, tolerance=None): + if method is not None or tolerance is not None: + raise NotImplementedError( + "cannot use ``method`` argument if any indexers are slice objects" + ) + indexer = index.slice_indexer( + _sanitize_slice_element(label.start), + _sanitize_slice_element(label.stop), + _sanitize_slice_element(label.step), + ) + if not isinstance(indexer, slice): + # unlike pandas, in xarray we never want to silently convert a + # slice indexer into an array indexer + raise KeyError( + "cannot represent labeled-based slice indexer for coordinate " + f"{coord_name!r} with a slice over integer positions; the index is " + "unsorted or non-unique" + ) + return indexer + + +def _asarray_tuplesafe(values): + """ + Convert values into a numpy array of at most 1-dimension, while preserving + tuples. + + Adapted from pandas.core.common._asarray_tuplesafe + """ + if isinstance(values, tuple): + result = utils.to_0d_object_array(values) + else: + result = np.asarray(values) + if result.ndim == 2: + result = np.empty(len(values), dtype=object) + result[:] = values + + return result + + +def _is_nested_tuple(possible_tuple): + return isinstance(possible_tuple, tuple) and any( + isinstance(value, (tuple, list, slice)) for value in possible_tuple + ) + + +def normalize_label(value, dtype=None) -> np.ndarray: + if getattr(value, "ndim", 1) <= 1: + value = _asarray_tuplesafe(value) + if dtype is not None and dtype.kind == "f" and value.dtype.kind != "b": + # pd.Index built from coordinate with float precision != 64 + # see https://github.com/pydata/xarray/pull/3153 for details + # bypass coercing dtype for boolean indexers (ignore index) + # see https://github.com/pydata/xarray/issues/5727 + value = np.asarray(value, dtype=dtype) + return value + + +def as_scalar(value: np.ndarray): + # see https://github.com/pydata/xarray/pull/4292 for details + return value[()] if value.dtype.kind in "mM" else value.item() + + +def get_indexer_nd(index, labels, method=None, tolerance=None): + """Wrapper around :meth:`pandas.Index.get_indexer` supporting n-dimensional + labels + """ + flat_labels = np.ravel(labels) + if flat_labels.dtype == "float16": + flat_labels = flat_labels.astype("float64") + flat_indexer = index.get_indexer(flat_labels, method=method, tolerance=tolerance) + indexer = flat_indexer.reshape(labels.shape) + return indexer + + +T_PandasIndex = TypeVar("T_PandasIndex", bound="PandasIndex") + + +class PandasIndex(Index): + """Wrap a pandas.Index as an xarray compatible index.""" + + index: pd.Index + dim: Hashable + coord_dtype: Any + + __slots__ = ("index", "dim", "coord_dtype") + + def __init__( + self, + array: Any, + dim: Hashable, + coord_dtype: Any = None, + *, + fastpath: bool = False, + ): + if fastpath: + index = array + else: + index = safe_cast_to_index(array) + + if index.name is None: + # make a shallow copy: cheap and because the index name may be updated + # here or in other constructors (cannot use pd.Index.rename as this + # constructor is also called from PandasMultiIndex) + index = index.copy() + index.name = dim + + self.index = index + self.dim = dim + + if coord_dtype is None: + coord_dtype = get_valid_numpy_dtype(index) + self.coord_dtype = coord_dtype + + def _replace(self, index, dim=None, coord_dtype=None): + if dim is None: + dim = self.dim + if coord_dtype is None: + coord_dtype = self.coord_dtype + return type(self)(index, dim, coord_dtype, fastpath=True) + + @classmethod + def from_variables( + cls, + variables: Mapping[Any, Variable], + *, + options: Mapping[str, Any], + ) -> PandasIndex: + if len(variables) != 1: + raise ValueError( + f"PandasIndex only accepts one variable, found {len(variables)} variables" + ) + + name, var = next(iter(variables.items())) + + if var.ndim == 0: + raise ValueError( + f"cannot set a PandasIndex from the scalar variable {name!r}, " + "only 1-dimensional variables are supported. " + f"Note: you might want to use `obj.expand_dims({name!r})` to create a " + f"new dimension and turn {name!r} as an indexed dimension coordinate." + ) + elif var.ndim != 1: + raise ValueError( + "PandasIndex only accepts a 1-dimensional variable, " + f"variable {name!r} has {var.ndim} dimensions" + ) + + dim = var.dims[0] + + # TODO: (benbovy - explicit indexes): add __index__ to ExplicitlyIndexesNDArrayMixin? + # this could be eventually used by Variable.to_index() and would remove the need to perform + # the checks below. + + # preserve wrapped pd.Index (if any) + # accessing `.data` can load data from disk, so we only access if needed + data = getattr(var._data, "array") if hasattr(var._data, "array") else var.data + # multi-index level variable: get level index + if isinstance(var._data, PandasMultiIndexingAdapter): + level = var._data.level + if level is not None: + data = var._data.array.get_level_values(level) + + obj = cls(data, dim, coord_dtype=var.dtype) + assert not isinstance(obj.index, pd.MultiIndex) + # Rename safely + # make a shallow copy: cheap and because the index name may be updated + # here or in other constructors (cannot use pd.Index.rename as this + # constructor is also called from PandasMultiIndex) + obj.index = obj.index.copy() + obj.index.name = name + + return obj + + @staticmethod + def _concat_indexes(indexes, dim, positions=None) -> pd.Index: + new_pd_index: pd.Index + + if not indexes: + new_pd_index = pd.Index([]) + else: + if not all(idx.dim == dim for idx in indexes): + dims = ",".join({f"{idx.dim!r}" for idx in indexes}) + raise ValueError( + f"Cannot concatenate along dimension {dim!r} indexes with " + f"dimensions: {dims}" + ) + pd_indexes = [idx.index for idx in indexes] + new_pd_index = pd_indexes[0].append(pd_indexes[1:]) + + if positions is not None: + indices = nputils.inverse_permutation(np.concatenate(positions)) + new_pd_index = new_pd_index.take(indices) + + return new_pd_index + + @classmethod + def concat( + cls, + indexes: Sequence[Self], + dim: Hashable, + positions: Iterable[Iterable[int]] | None = None, + ) -> Self: + new_pd_index = cls._concat_indexes(indexes, dim, positions) + + if not indexes: + coord_dtype = None + else: + coord_dtype = np.result_type(*[idx.coord_dtype for idx in indexes]) + + return cls(new_pd_index, dim=dim, coord_dtype=coord_dtype) + + def create_variables( + self, variables: Mapping[Any, Variable] | None = None + ) -> IndexVars: + from xarray.core.variable import IndexVariable + + name = self.index.name + attrs: Mapping[Hashable, Any] | None + encoding: Mapping[Hashable, Any] | None + + if variables is not None and name in variables: + var = variables[name] + attrs = var.attrs + encoding = var.encoding + else: + attrs = None + encoding = None + + data = PandasIndexingAdapter(self.index, dtype=self.coord_dtype) + var = IndexVariable(self.dim, data, attrs=attrs, encoding=encoding) + return {name: var} + + def to_pandas_index(self) -> pd.Index: + return self.index + + def isel( + self, indexers: Mapping[Any, int | slice | np.ndarray | Variable] + ) -> PandasIndex | None: + from xarray.core.variable import Variable + + indxr = indexers[self.dim] + if isinstance(indxr, Variable): + if indxr.dims != (self.dim,): + # can't preserve a index if result has new dimensions + return None + else: + indxr = indxr.data + if not isinstance(indxr, slice) and is_scalar(indxr): + # scalar indexer: drop index + return None + + return self._replace(self.index[indxr]) + + def sel( + self, labels: dict[Any, Any], method=None, tolerance=None + ) -> IndexSelResult: + from xarray.core.dataarray import DataArray + from xarray.core.variable import Variable + + if method is not None and not isinstance(method, str): + raise TypeError("``method`` must be a string") + + assert len(labels) == 1 + coord_name, label = next(iter(labels.items())) + + if isinstance(label, slice): + indexer = _query_slice(self.index, label, coord_name, method, tolerance) + elif is_dict_like(label): + raise ValueError( + "cannot use a dict-like object for selection on " + "a dimension that does not have a MultiIndex" + ) + else: + label_array = normalize_label(label, dtype=self.coord_dtype) + if label_array.ndim == 0: + label_value = as_scalar(label_array) + if isinstance(self.index, pd.CategoricalIndex): + if method is not None: + raise ValueError( + "'method' is not supported when indexing using a CategoricalIndex." + ) + if tolerance is not None: + raise ValueError( + "'tolerance' is not supported when indexing using a CategoricalIndex." + ) + indexer = self.index.get_loc(label_value) + else: + if method is not None: + indexer = get_indexer_nd( + self.index, label_array, method, tolerance + ) + if np.any(indexer < 0): + raise KeyError( + f"not all values found in index {coord_name!r}" + ) + else: + try: + indexer = self.index.get_loc(label_value) + except KeyError as e: + raise KeyError( + f"not all values found in index {coord_name!r}. " + "Try setting the `method` keyword argument (example: method='nearest')." + ) from e + + elif label_array.dtype.kind == "b": + indexer = label_array + else: + indexer = get_indexer_nd(self.index, label_array, method, tolerance) + if np.any(indexer < 0): + raise KeyError(f"not all values found in index {coord_name!r}") + + # attach dimension names and/or coordinates to positional indexer + if isinstance(label, Variable): + indexer = Variable(label.dims, indexer) + elif isinstance(label, DataArray): + indexer = DataArray(indexer, coords=label._coords, dims=label.dims) + + return IndexSelResult({self.dim: indexer}) + + def equals(self, other: Index): + if not isinstance(other, PandasIndex): + return False + return self.index.equals(other.index) and self.dim == other.dim + + def join( + self, + other: Self, + how: str = "inner", + ) -> Self: + if how == "outer": + index = self.index.union(other.index) + else: + # how = "inner" + index = self.index.intersection(other.index) + + coord_dtype = np.result_type(self.coord_dtype, other.coord_dtype) + return type(self)(index, self.dim, coord_dtype=coord_dtype) + + def reindex_like( + self, other: Self, method=None, tolerance=None + ) -> dict[Hashable, Any]: + if not self.index.is_unique: + raise ValueError( + f"cannot reindex or align along dimension {self.dim!r} because the " + "(pandas) index has duplicate values" + ) + + return {self.dim: get_indexer_nd(self.index, other.index, method, tolerance)} + + def roll(self, shifts: Mapping[Any, int]) -> PandasIndex: + shift = shifts[self.dim] % self.index.shape[0] + + if shift != 0: + new_pd_idx = self.index[-shift:].append(self.index[:-shift]) + else: + new_pd_idx = self.index[:] + + return self._replace(new_pd_idx) + + def rename(self, name_dict, dims_dict): + if self.index.name not in name_dict and self.dim not in dims_dict: + return self + + new_name = name_dict.get(self.index.name, self.index.name) + index = self.index.rename(new_name) + new_dim = dims_dict.get(self.dim, self.dim) + return self._replace(index, dim=new_dim) + + def _copy( + self: T_PandasIndex, deep: bool = True, memo: dict[int, Any] | None = None + ) -> T_PandasIndex: + if deep: + # pandas is not using the memo + index = self.index.copy(deep=True) + else: + # index will be copied in constructor + index = self.index + return self._replace(index) + + def __getitem__(self, indexer: Any): + return self._replace(self.index[indexer]) + + def __repr__(self): + return f"PandasIndex({repr(self.index)})" + + +def _check_dim_compat(variables: Mapping[Any, Variable], all_dims: str = "equal"): + """Check that all multi-index variable candidates are 1-dimensional and + either share the same (single) dimension or each have a different dimension. + + """ + if any([var.ndim != 1 for var in variables.values()]): + raise ValueError("PandasMultiIndex only accepts 1-dimensional variables") + + dims = {var.dims for var in variables.values()} + + if all_dims == "equal" and len(dims) > 1: + raise ValueError( + "unmatched dimensions for multi-index variables " + + ", ".join([f"{k!r} {v.dims}" for k, v in variables.items()]) + ) + + if all_dims == "different" and len(dims) < len(variables): + raise ValueError( + "conflicting dimensions for multi-index product variables " + + ", ".join([f"{k!r} {v.dims}" for k, v in variables.items()]) + ) + + +def remove_unused_levels_categories(index: pd.Index) -> pd.Index: + """ + Remove unused levels from MultiIndex and unused categories from CategoricalIndex + """ + if isinstance(index, pd.MultiIndex): + index = index.remove_unused_levels() + # if it contains CategoricalIndex, we need to remove unused categories + # manually. See https://github.com/pandas-dev/pandas/issues/30846 + if any(isinstance(lev, pd.CategoricalIndex) for lev in index.levels): + levels = [] + for i, level in enumerate(index.levels): + if isinstance(level, pd.CategoricalIndex): + level = level[index.codes[i]].remove_unused_categories() + else: + level = level[index.codes[i]] + levels.append(level) + # TODO: calling from_array() reorders MultiIndex levels. It would + # be best to avoid this, if possible, e.g., by using + # MultiIndex.remove_unused_levels() (which does not reorder) on the + # part of the MultiIndex that is not categorical, or by fixing this + # upstream in pandas. + index = pd.MultiIndex.from_arrays(levels, names=index.names) + elif isinstance(index, pd.CategoricalIndex): + index = index.remove_unused_categories() + return index + + +class PandasMultiIndex(PandasIndex): + """Wrap a pandas.MultiIndex as an xarray compatible index.""" + + level_coords_dtype: dict[str, Any] + + __slots__ = ("index", "dim", "coord_dtype", "level_coords_dtype") + + def __init__(self, array: Any, dim: Hashable, level_coords_dtype: Any = None): + super().__init__(array, dim) + + # default index level names + names = [] + for i, idx in enumerate(self.index.levels): + name = idx.name or f"{dim}_level_{i}" + if name == dim: + raise ValueError( + f"conflicting multi-index level name {name!r} with dimension {dim!r}" + ) + names.append(name) + self.index.names = names + + if level_coords_dtype is None: + level_coords_dtype = { + idx.name: get_valid_numpy_dtype(idx) for idx in self.index.levels + } + self.level_coords_dtype = level_coords_dtype + + def _replace(self, index, dim=None, level_coords_dtype=None) -> PandasMultiIndex: + if dim is None: + dim = self.dim + index.name = dim + if level_coords_dtype is None: + level_coords_dtype = self.level_coords_dtype + return type(self)(index, dim, level_coords_dtype) + + @classmethod + def from_variables( + cls, + variables: Mapping[Any, Variable], + *, + options: Mapping[str, Any], + ) -> PandasMultiIndex: + _check_dim_compat(variables) + dim = next(iter(variables.values())).dims[0] + + index = pd.MultiIndex.from_arrays( + [var.values for var in variables.values()], names=variables.keys() + ) + index.name = dim + level_coords_dtype = {name: var.dtype for name, var in variables.items()} + obj = cls(index, dim, level_coords_dtype=level_coords_dtype) + + return obj + + @classmethod + def concat( + cls, + indexes: Sequence[Self], + dim: Hashable, + positions: Iterable[Iterable[int]] | None = None, + ) -> Self: + new_pd_index = cls._concat_indexes(indexes, dim, positions) + + if not indexes: + level_coords_dtype = None + else: + level_coords_dtype = {} + for name in indexes[0].level_coords_dtype: + level_coords_dtype[name] = np.result_type( + *[idx.level_coords_dtype[name] for idx in indexes] + ) + + return cls(new_pd_index, dim=dim, level_coords_dtype=level_coords_dtype) + + @classmethod + def stack( + cls, variables: Mapping[Any, Variable], dim: Hashable + ) -> PandasMultiIndex: + """Create a new Pandas MultiIndex from the product of 1-d variables (levels) along a + new dimension. + + Level variables must have a dimension distinct from each other. + + Keeps levels the same (doesn't refactorize them) so that it gives back the original + labels after a stack/unstack roundtrip. + + """ + _check_dim_compat(variables, all_dims="different") + + level_indexes = [safe_cast_to_index(var) for var in variables.values()] + for name, idx in zip(variables, level_indexes): + if isinstance(idx, pd.MultiIndex): + raise ValueError( + f"cannot create a multi-index along stacked dimension {dim!r} " + f"from variable {name!r} that wraps a multi-index" + ) + + split_labels, levels = zip(*[lev.factorize() for lev in level_indexes]) + labels_mesh = np.meshgrid(*split_labels, indexing="ij") + labels = [x.ravel() for x in labels_mesh] + + index = pd.MultiIndex(levels, labels, sortorder=0, names=variables.keys()) + level_coords_dtype = {k: var.dtype for k, var in variables.items()} + + return cls(index, dim, level_coords_dtype=level_coords_dtype) + + def unstack(self) -> tuple[dict[Hashable, Index], pd.MultiIndex]: + clean_index = remove_unused_levels_categories(self.index) + + if not clean_index.is_unique: + raise ValueError( + "Cannot unstack MultiIndex containing duplicates. Make sure entries " + f"are unique, e.g., by calling ``.drop_duplicates('{self.dim}')``, " + "before unstacking." + ) + + new_indexes: dict[Hashable, Index] = {} + for name, lev in zip(clean_index.names, clean_index.levels): + idx = PandasIndex( + lev.copy(), name, coord_dtype=self.level_coords_dtype[name] + ) + new_indexes[name] = idx + + return new_indexes, clean_index + + @classmethod + def from_variables_maybe_expand( + cls, + dim: Hashable, + current_variables: Mapping[Any, Variable], + variables: Mapping[Any, Variable], + ) -> tuple[PandasMultiIndex, IndexVars]: + """Create a new multi-index maybe by expanding an existing one with + new variables as index levels. + + The index and its corresponding coordinates may be created along a new dimension. + """ + names: list[Hashable] = [] + codes: list[list[int]] = [] + levels: list[list[int]] = [] + level_variables: dict[Any, Variable] = {} + + _check_dim_compat({**current_variables, **variables}) + + if len(current_variables) > 1: + # expand from an existing multi-index + data = cast( + PandasMultiIndexingAdapter, next(iter(current_variables.values()))._data + ) + current_index = data.array + names.extend(current_index.names) + codes.extend(current_index.codes) + levels.extend(current_index.levels) + for name in current_index.names: + level_variables[name] = current_variables[name] + + elif len(current_variables) == 1: + # expand from one 1D variable (no multi-index): convert it to an index level + var = next(iter(current_variables.values())) + new_var_name = f"{dim}_level_0" + names.append(new_var_name) + cat = pd.Categorical(var.values, ordered=True) + codes.append(cat.codes) + levels.append(cat.categories) + level_variables[new_var_name] = var + + for name, var in variables.items(): + names.append(name) + cat = pd.Categorical(var.values, ordered=True) + codes.append(cat.codes) + levels.append(cat.categories) + level_variables[name] = var + + index = pd.MultiIndex(levels, codes, names=names) + level_coords_dtype = {k: var.dtype for k, var in level_variables.items()} + obj = cls(index, dim, level_coords_dtype=level_coords_dtype) + index_vars = obj.create_variables(level_variables) + + return obj, index_vars + + def keep_levels( + self, level_variables: Mapping[Any, Variable] + ) -> PandasMultiIndex | PandasIndex: + """Keep only the provided levels and return a new multi-index with its + corresponding coordinates. + + """ + index = self.index.droplevel( + [k for k in self.index.names if k not in level_variables] + ) + + if isinstance(index, pd.MultiIndex): + level_coords_dtype = {k: self.level_coords_dtype[k] for k in index.names} + return self._replace(index, level_coords_dtype=level_coords_dtype) + else: + # backward compatibility: rename the level coordinate to the dimension name + return PandasIndex( + index.rename(self.dim), + self.dim, + coord_dtype=self.level_coords_dtype[index.name], + ) + + def reorder_levels( + self, level_variables: Mapping[Any, Variable] + ) -> PandasMultiIndex: + """Re-arrange index levels using input order and return a new multi-index with + its corresponding coordinates. + + """ + index = self.index.reorder_levels(level_variables.keys()) + level_coords_dtype = {k: self.level_coords_dtype[k] for k in index.names} + return self._replace(index, level_coords_dtype=level_coords_dtype) + + def create_variables( + self, variables: Mapping[Any, Variable] | None = None + ) -> IndexVars: + from xarray.core.variable import IndexVariable + + if variables is None: + variables = {} + + index_vars: IndexVars = {} + for name in (self.dim,) + self.index.names: + if name == self.dim: + level = None + dtype = None + else: + level = name + dtype = self.level_coords_dtype[name] + + var = variables.get(name, None) + if var is not None: + attrs = var.attrs + encoding = var.encoding + else: + attrs = {} + encoding = {} + + data = PandasMultiIndexingAdapter(self.index, dtype=dtype, level=level) + index_vars[name] = IndexVariable( + self.dim, + data, + attrs=attrs, + encoding=encoding, + fastpath=True, + ) + + return index_vars + + def sel(self, labels, method=None, tolerance=None) -> IndexSelResult: + from xarray.core.dataarray import DataArray + from xarray.core.variable import Variable + + if method is not None or tolerance is not None: + raise ValueError( + "multi-index does not support ``method`` and ``tolerance``" + ) + + new_index = None + scalar_coord_values = {} + + # label(s) given for multi-index level(s) + if all([lbl in self.index.names for lbl in labels]): + label_values = {} + for k, v in labels.items(): + label_array = normalize_label(v, dtype=self.level_coords_dtype[k]) + try: + label_values[k] = as_scalar(label_array) + except ValueError: + # label should be an item not an array-like + raise ValueError( + "Vectorized selection is not " + f"available along coordinate {k!r} (multi-index level)" + ) + + has_slice = any([isinstance(v, slice) for v in label_values.values()]) + + if len(label_values) == self.index.nlevels and not has_slice: + indexer = self.index.get_loc( + tuple(label_values[k] for k in self.index.names) + ) + else: + indexer, new_index = self.index.get_loc_level( + tuple(label_values.values()), level=tuple(label_values.keys()) + ) + scalar_coord_values.update(label_values) + # GH2619. Raise a KeyError if nothing is chosen + if indexer.dtype.kind == "b" and indexer.sum() == 0: + raise KeyError(f"{labels} not found") + + # assume one label value given for the multi-index "array" (dimension) + else: + if len(labels) > 1: + coord_name = next(iter(set(labels) - set(self.index.names))) + raise ValueError( + f"cannot provide labels for both coordinate {coord_name!r} (multi-index array) " + f"and one or more coordinates among {self.index.names!r} (multi-index levels)" + ) + + coord_name, label = next(iter(labels.items())) + + if is_dict_like(label): + invalid_levels = tuple( + name for name in label if name not in self.index.names + ) + if invalid_levels: + raise ValueError( + f"multi-index level names {invalid_levels} not found in indexes {tuple(self.index.names)}" + ) + return self.sel(label) + + elif isinstance(label, slice): + indexer = _query_slice(self.index, label, coord_name) + + elif isinstance(label, tuple): + if _is_nested_tuple(label): + indexer = self.index.get_locs(label) + elif len(label) == self.index.nlevels: + indexer = self.index.get_loc(label) + else: + levels = [self.index.names[i] for i in range(len(label))] + indexer, new_index = self.index.get_loc_level(label, level=levels) + scalar_coord_values.update({k: v for k, v in zip(levels, label)}) + + else: + label_array = normalize_label(label) + if label_array.ndim == 0: + label_value = as_scalar(label_array) + indexer, new_index = self.index.get_loc_level(label_value, level=0) + scalar_coord_values[self.index.names[0]] = label_value + elif label_array.dtype.kind == "b": + indexer = label_array + else: + if label_array.ndim > 1: + raise ValueError( + "Vectorized selection is not available along " + f"coordinate {coord_name!r} with a multi-index" + ) + indexer = get_indexer_nd(self.index, label_array) + if np.any(indexer < 0): + raise KeyError(f"not all values found in index {coord_name!r}") + + # attach dimension names and/or coordinates to positional indexer + if isinstance(label, Variable): + indexer = Variable(label.dims, indexer) + elif isinstance(label, DataArray): + # do not include label-indexer DataArray coordinates that conflict + # with the level names of this index + coords = { + k: v + for k, v in label._coords.items() + if k not in self.index.names + } + indexer = DataArray(indexer, coords=coords, dims=label.dims) + + if new_index is not None: + if isinstance(new_index, pd.MultiIndex): + level_coords_dtype = { + k: self.level_coords_dtype[k] for k in new_index.names + } + new_index = self._replace( + new_index, level_coords_dtype=level_coords_dtype + ) + dims_dict = {} + drop_coords = [] + else: + new_index = PandasIndex( + new_index, + new_index.name, + coord_dtype=self.level_coords_dtype[new_index.name], + ) + dims_dict = {self.dim: new_index.index.name} + drop_coords = [self.dim] + + # variable(s) attrs and encoding metadata are propagated + # when replacing the indexes in the resulting xarray object + new_vars = new_index.create_variables() + indexes = cast(dict[Any, Index], {k: new_index for k in new_vars}) + + # add scalar variable for each dropped level + variables = new_vars + for name, val in scalar_coord_values.items(): + variables[name] = Variable([], val) + + return IndexSelResult( + {self.dim: indexer}, + indexes=indexes, + variables=variables, + drop_indexes=list(scalar_coord_values), + drop_coords=drop_coords, + rename_dims=dims_dict, + ) + + else: + return IndexSelResult({self.dim: indexer}) + + def join(self, other, how: str = "inner"): + if how == "outer": + # bug in pandas? need to reset index.name + other_index = other.index.copy() + other_index.name = None + index = self.index.union(other_index) + index.name = self.dim + else: + # how = "inner" + index = self.index.intersection(other.index) + + level_coords_dtype = { + k: np.result_type(lvl_dtype, other.level_coords_dtype[k]) + for k, lvl_dtype in self.level_coords_dtype.items() + } + + return type(self)(index, self.dim, level_coords_dtype=level_coords_dtype) + + def rename(self, name_dict, dims_dict): + if not set(self.index.names) & set(name_dict) and self.dim not in dims_dict: + return self + + # pandas 1.3.0: could simply do `self.index.rename(names_dict)` + new_names = [name_dict.get(k, k) for k in self.index.names] + index = self.index.rename(new_names) + + new_dim = dims_dict.get(self.dim, self.dim) + new_level_coords_dtype = { + k: v for k, v in zip(new_names, self.level_coords_dtype.values()) + } + return self._replace( + index, dim=new_dim, level_coords_dtype=new_level_coords_dtype + ) + + +def create_default_index_implicit( + dim_variable: Variable, + all_variables: Mapping | Iterable[Hashable] | None = None, +) -> tuple[PandasIndex, IndexVars]: + """Create a default index from a dimension variable. + + Create a PandasMultiIndex if the given variable wraps a pandas.MultiIndex, + otherwise create a PandasIndex (note that this will become obsolete once we + depreciate implicitly passing a pandas.MultiIndex as a coordinate). + + """ + if all_variables is None: + all_variables = {} + if not isinstance(all_variables, Mapping): + all_variables = {k: None for k in all_variables} + + name = dim_variable.dims[0] + array = getattr(dim_variable._data, "array", None) + index: PandasIndex + + if isinstance(array, pd.MultiIndex): + index = PandasMultiIndex(array, name) + index_vars = index.create_variables() + # check for conflict between level names and variable names + duplicate_names = [k for k in index_vars if k in all_variables and k != name] + if duplicate_names: + # dirty workaround for an edge case where both the dimension + # coordinate and the level coordinates are given for the same + # multi-index object => do not raise an error + # TODO: remove this check when removing the multi-index dimension coordinate + if len(duplicate_names) < len(index.index.names): + conflict = True + else: + duplicate_vars = [all_variables[k] for k in duplicate_names] + conflict = any( + v is None or not dim_variable.equals(v) for v in duplicate_vars + ) + + if conflict: + conflict_str = "\n".join(duplicate_names) + raise ValueError( + f"conflicting MultiIndex level / variable name(s):\n{conflict_str}" + ) + else: + dim_var = {name: dim_variable} + index = PandasIndex.from_variables(dim_var, options={}) + index_vars = index.create_variables(dim_var) + + return index, index_vars + + +# generic type that represents either a pandas or an xarray index +T_PandasOrXarrayIndex = TypeVar("T_PandasOrXarrayIndex", Index, pd.Index) + + +class Indexes(collections.abc.Mapping, Generic[T_PandasOrXarrayIndex]): + """Immutable proxy for Dataset or DataArray indexes. + + It is a mapping where keys are coordinate names and values are either pandas + or xarray indexes. + + It also contains the indexed coordinate variables and provides some utility + methods. + + """ + + _index_type: type[Index] | type[pd.Index] + _indexes: dict[Any, T_PandasOrXarrayIndex] + _variables: dict[Any, Variable] + + __slots__ = ( + "_index_type", + "_indexes", + "_variables", + "_dims", + "__coord_name_id", + "__id_index", + "__id_coord_names", + ) + + def __init__( + self, + indexes: Mapping[Any, T_PandasOrXarrayIndex] | None = None, + variables: Mapping[Any, Variable] | None = None, + index_type: type[Index] | type[pd.Index] = Index, + ): + """Constructor not for public consumption. + + Parameters + ---------- + indexes : dict + Indexes held by this object. + variables : dict + Indexed coordinate variables in this object. Entries must + match those of `indexes`. + index_type : type + The type of all indexes, i.e., either :py:class:`xarray.indexes.Index` + or :py:class:`pandas.Index`. + + """ + if indexes is None: + indexes = {} + if variables is None: + variables = {} + + unmatched_keys = set(indexes) ^ set(variables) + if unmatched_keys: + raise ValueError( + f"unmatched keys found in indexes and variables: {unmatched_keys}" + ) + + if any(not isinstance(idx, index_type) for idx in indexes.values()): + index_type_str = f"{index_type.__module__}.{index_type.__name__}" + raise TypeError( + f"values of indexes must all be instances of {index_type_str}" + ) + + self._index_type = index_type + self._indexes = dict(**indexes) + self._variables = dict(**variables) + + self._dims: Mapping[Hashable, int] | None = None + self.__coord_name_id: dict[Any, int] | None = None + self.__id_index: dict[int, T_PandasOrXarrayIndex] | None = None + self.__id_coord_names: dict[int, tuple[Hashable, ...]] | None = None + + @property + def _coord_name_id(self) -> dict[Any, int]: + if self.__coord_name_id is None: + self.__coord_name_id = {k: id(idx) for k, idx in self._indexes.items()} + return self.__coord_name_id + + @property + def _id_index(self) -> dict[int, T_PandasOrXarrayIndex]: + if self.__id_index is None: + self.__id_index = {id(idx): idx for idx in self.get_unique()} + return self.__id_index + + @property + def _id_coord_names(self) -> dict[int, tuple[Hashable, ...]]: + if self.__id_coord_names is None: + id_coord_names: Mapping[int, list[Hashable]] = defaultdict(list) + for k, v in self._coord_name_id.items(): + id_coord_names[v].append(k) + self.__id_coord_names = {k: tuple(v) for k, v in id_coord_names.items()} + + return self.__id_coord_names + + @property + def variables(self) -> Mapping[Hashable, Variable]: + return Frozen(self._variables) + + @property + def dims(self) -> Mapping[Hashable, int]: + from xarray.core.variable import calculate_dimensions + + if self._dims is None: + self._dims = calculate_dimensions(self._variables) + + return Frozen(self._dims) + + def copy(self) -> Indexes: + return type(self)(dict(self._indexes), dict(self._variables)) + + def get_unique(self) -> list[T_PandasOrXarrayIndex]: + """Return a list of unique indexes, preserving order.""" + + unique_indexes: list[T_PandasOrXarrayIndex] = [] + seen: set[int] = set() + + for index in self._indexes.values(): + index_id = id(index) + if index_id not in seen: + unique_indexes.append(index) + seen.add(index_id) + + return unique_indexes + + def is_multi(self, key: Hashable) -> bool: + """Return True if ``key`` maps to a multi-coordinate index, + False otherwise. + """ + return len(self._id_coord_names[self._coord_name_id[key]]) > 1 + + def get_all_coords( + self, key: Hashable, errors: ErrorOptions = "raise" + ) -> dict[Hashable, Variable]: + """Return all coordinates having the same index. + + Parameters + ---------- + key : hashable + Index key. + errors : {"raise", "ignore"}, default: "raise" + If "raise", raises a ValueError if `key` is not in indexes. + If "ignore", an empty tuple is returned instead. + + Returns + ------- + coords : dict + A dictionary of all coordinate variables having the same index. + + """ + if errors not in ["raise", "ignore"]: + raise ValueError('errors must be either "raise" or "ignore"') + + if key not in self._indexes: + if errors == "raise": + raise ValueError(f"no index found for {key!r} coordinate") + else: + return {} + + all_coord_names = self._id_coord_names[self._coord_name_id[key]] + return {k: self._variables[k] for k in all_coord_names} + + def get_all_dims( + self, key: Hashable, errors: ErrorOptions = "raise" + ) -> Mapping[Hashable, int]: + """Return all dimensions shared by an index. + + Parameters + ---------- + key : hashable + Index key. + errors : {"raise", "ignore"}, default: "raise" + If "raise", raises a ValueError if `key` is not in indexes. + If "ignore", an empty tuple is returned instead. + + Returns + ------- + dims : dict + A dictionary of all dimensions shared by an index. + + """ + from xarray.core.variable import calculate_dimensions + + return calculate_dimensions(self.get_all_coords(key, errors=errors)) + + def group_by_index( + self, + ) -> list[tuple[T_PandasOrXarrayIndex, dict[Hashable, Variable]]]: + """Returns a list of unique indexes and their corresponding coordinates.""" + + index_coords = [] + + for i in self._id_index: + index = self._id_index[i] + coords = {k: self._variables[k] for k in self._id_coord_names[i]} + index_coords.append((index, coords)) + + return index_coords + + def to_pandas_indexes(self) -> Indexes[pd.Index]: + """Returns an immutable proxy for Dataset or DataArrary pandas indexes. + + Raises an error if this proxy contains indexes that cannot be coerced to + pandas.Index objects. + + """ + indexes: dict[Hashable, pd.Index] = {} + + for k, idx in self._indexes.items(): + if isinstance(idx, pd.Index): + indexes[k] = idx + elif isinstance(idx, Index): + indexes[k] = idx.to_pandas_index() + + return Indexes(indexes, self._variables, index_type=pd.Index) + + def copy_indexes( + self, deep: bool = True, memo: dict[int, T_PandasOrXarrayIndex] | None = None + ) -> tuple[dict[Hashable, T_PandasOrXarrayIndex], dict[Hashable, Variable]]: + """Return a new dictionary with copies of indexes, preserving + unique indexes. + + Parameters + ---------- + deep : bool, default: True + Whether the indexes are deep or shallow copied onto the new object. + memo : dict if object id to copied objects or None, optional + To prevent infinite recursion deepcopy stores all copied elements + in this dict. + + """ + new_indexes = {} + new_index_vars = {} + + idx: T_PandasOrXarrayIndex + for idx, coords in self.group_by_index(): + if isinstance(idx, pd.Index): + convert_new_idx = True + dim = next(iter(coords.values())).dims[0] + if isinstance(idx, pd.MultiIndex): + idx = PandasMultiIndex(idx, dim) + else: + idx = PandasIndex(idx, dim) + else: + convert_new_idx = False + + new_idx = idx._copy(deep=deep, memo=memo) + idx_vars = idx.create_variables(coords) + + if convert_new_idx: + new_idx = cast(PandasIndex, new_idx).index + + new_indexes.update({k: new_idx for k in coords}) + new_index_vars.update(idx_vars) + + return new_indexes, new_index_vars + + def __iter__(self) -> Iterator[T_PandasOrXarrayIndex]: + return iter(self._indexes) + + def __len__(self) -> int: + return len(self._indexes) + + def __contains__(self, key) -> bool: + return key in self._indexes + + def __getitem__(self, key) -> T_PandasOrXarrayIndex: + return self._indexes[key] + + def __repr__(self): + indexes = formatting._get_indexes_dict(self) + return formatting.indexes_repr(indexes) + + +def default_indexes( + coords: Mapping[Any, Variable], dims: Iterable +) -> dict[Hashable, Index]: + """Default indexes for a Dataset/DataArray. + + Parameters + ---------- + coords : Mapping[Any, xarray.Variable] + Coordinate variables from which to draw default indexes. + dims : iterable + Iterable of dimension names. + + Returns + ------- + Mapping from indexing keys (levels/dimension names) to indexes used for + indexing along that dimension. + """ + indexes: dict[Hashable, Index] = {} + coord_names = set(coords) + + for name, var in coords.items(): + if name in dims and var.ndim == 1: + index, index_vars = create_default_index_implicit(var, coords) + if set(index_vars) <= coord_names: + indexes.update({k: index for k in index_vars}) + + return indexes + + +def indexes_equal( + index: Index, + other_index: Index, + variable: Variable, + other_variable: Variable, + cache: dict[tuple[int, int], bool | None] | None = None, +) -> bool: + """Check if two indexes are equal, possibly with cached results. + + If the two indexes are not of the same type or they do not implement + equality, fallback to coordinate labels equality check. + + """ + if cache is None: + # dummy cache + cache = {} + + key = (id(index), id(other_index)) + equal: bool | None = None + + if key not in cache: + if type(index) is type(other_index): + try: + equal = index.equals(other_index) + except NotImplementedError: + equal = None + else: + cache[key] = equal + else: + equal = None + else: + equal = cache[key] + + if equal is None: + equal = variable.equals(other_variable) + + return cast(bool, equal) + + +def indexes_all_equal( + elements: Sequence[tuple[Index, dict[Hashable, Variable]]] +) -> bool: + """Check if indexes are all equal. + + If they are not of the same type or they do not implement this check, check + if their coordinate variables are all equal instead. + + """ + + def check_variables(): + variables = [e[1] for e in elements] + return any( + not variables[0][k].equals(other_vars[k]) + for other_vars in variables[1:] + for k in variables[0] + ) + + indexes = [e[0] for e in elements] + + same_objects = all(indexes[0] is other_idx for other_idx in indexes[1:]) + if same_objects: + return True + + same_type = all(type(indexes[0]) is type(other_idx) for other_idx in indexes[1:]) + if same_type: + try: + not_equal = any( + not indexes[0].equals(other_idx) for other_idx in indexes[1:] + ) + except NotImplementedError: + not_equal = check_variables() + else: + not_equal = check_variables() + + return not not_equal + + +def _apply_indexes_fast(indexes: Indexes[Index], args: Mapping[Any, Any], func: str): + # This function avoids the call to indexes.group_by_index + # which is really slow when repeatidly iterating through + # an array. However, it fails to return the correct ID for + # multi-index arrays + indexes_fast, coords = indexes._indexes, indexes._variables + + new_indexes: dict[Hashable, Index] = {k: v for k, v in indexes_fast.items()} + new_index_variables: dict[Hashable, Variable] = {} + for name, index in indexes_fast.items(): + coord = coords[name] + if hasattr(coord, "_indexes"): + index_vars = {n: coords[n] for n in coord._indexes} + else: + index_vars = {name: coord} + index_dims = {d for var in index_vars.values() for d in var.dims} + index_args = {k: v for k, v in args.items() if k in index_dims} + + if index_args: + new_index = getattr(index, func)(index_args) + if new_index is not None: + new_indexes.update({k: new_index for k in index_vars}) + new_index_vars = new_index.create_variables(index_vars) + new_index_variables.update(new_index_vars) + else: + for k in index_vars: + new_indexes.pop(k, None) + return new_indexes, new_index_variables + + +def _apply_indexes( + indexes: Indexes[Index], + args: Mapping[Any, Any], + func: str, +) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]: + new_indexes: dict[Hashable, Index] = {k: v for k, v in indexes.items()} + new_index_variables: dict[Hashable, Variable] = {} + + for index, index_vars in indexes.group_by_index(): + index_dims = {d for var in index_vars.values() for d in var.dims} + index_args = {k: v for k, v in args.items() if k in index_dims} + if index_args: + new_index = getattr(index, func)(index_args) + if new_index is not None: + new_indexes.update({k: new_index for k in index_vars}) + new_index_vars = new_index.create_variables(index_vars) + new_index_variables.update(new_index_vars) + else: + for k in index_vars: + new_indexes.pop(k, None) + + return new_indexes, new_index_variables + + +def isel_indexes( + indexes: Indexes[Index], + indexers: Mapping[Any, Any], +) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]: + # TODO: remove if clause in the future. It should be unnecessary. + # See failure introduced when removed + # https://github.com/pydata/xarray/pull/9002#discussion_r1590443756 + if any(isinstance(v, PandasMultiIndex) for v in indexes._indexes.values()): + return _apply_indexes(indexes, indexers, "isel") + else: + return _apply_indexes_fast(indexes, indexers, "isel") + + +def roll_indexes( + indexes: Indexes[Index], + shifts: Mapping[Any, int], +) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]: + return _apply_indexes(indexes, shifts, "roll") + + +def filter_indexes_from_coords( + indexes: Mapping[Any, Index], + filtered_coord_names: set, +) -> dict[Hashable, Index]: + """Filter index items given a (sub)set of coordinate names. + + Drop all multi-coordinate related index items for any key missing in the set + of coordinate names. + + """ + filtered_indexes: dict[Any, Index] = dict(indexes) + + index_coord_names: dict[Hashable, set[Hashable]] = defaultdict(set) + for name, idx in indexes.items(): + index_coord_names[id(idx)].add(name) + + for idx_coord_names in index_coord_names.values(): + if not idx_coord_names <= filtered_coord_names: + for k in idx_coord_names: + del filtered_indexes[k] + + return filtered_indexes + + +def assert_no_index_corrupted( + indexes: Indexes[Index], + coord_names: set[Hashable], + action: str = "remove coordinate(s)", +) -> None: + """Assert removing coordinates or indexes will not corrupt indexes.""" + + # An index may be corrupted when the set of its corresponding coordinate name(s) + # partially overlaps the set of coordinate names to remove + for index, index_coords in indexes.group_by_index(): + common_names = set(index_coords) & coord_names + if common_names and len(common_names) != len(index_coords): + common_names_str = ", ".join(f"{k!r}" for k in common_names) + index_names_str = ", ".join(f"{k!r}" for k in index_coords) + raise ValueError( + f"cannot {action} {common_names_str}, which would corrupt " + f"the following index built from coordinates {index_names_str}:\n" + f"{index}" + ) diff --git a/test/fixtures/whole_applications/xarray/xarray/core/indexing.py b/test/fixtures/whole_applications/xarray/xarray/core/indexing.py new file mode 100644 index 0000000..06e7efd --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/core/indexing.py @@ -0,0 +1,1916 @@ +from __future__ import annotations + +import enum +import functools +import operator +from collections import Counter, defaultdict +from collections.abc import Hashable, Iterable, Mapping +from contextlib import suppress +from dataclasses import dataclass, field +from datetime import timedelta +from html import escape +from typing import TYPE_CHECKING, Any, Callable, overload + +import numpy as np +import pandas as pd + +from xarray.core import duck_array_ops +from xarray.core.nputils import NumpyVIndexAdapter +from xarray.core.options import OPTIONS +from xarray.core.types import T_Xarray +from xarray.core.utils import ( + NDArrayMixin, + either_dict_or_kwargs, + get_valid_numpy_dtype, + is_duck_array, + is_duck_dask_array, + is_scalar, + to_0d_array, +) +from xarray.namedarray.parallelcompat import get_chunked_array_type +from xarray.namedarray.pycompat import array_type, integer_types, is_chunked_array + +if TYPE_CHECKING: + from numpy.typing import DTypeLike + + from xarray.core.indexes import Index + from xarray.core.variable import Variable + from xarray.namedarray._typing import _Shape, duckarray + from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint + + +@dataclass +class IndexSelResult: + """Index query results. + + Attributes + ---------- + dim_indexers: dict + A dictionary where keys are array dimensions and values are + location-based indexers. + indexes: dict, optional + New indexes to replace in the resulting DataArray or Dataset. + variables : dict, optional + New variables to replace in the resulting DataArray or Dataset. + drop_coords : list, optional + Coordinate(s) to drop in the resulting DataArray or Dataset. + drop_indexes : list, optional + Index(es) to drop in the resulting DataArray or Dataset. + rename_dims : dict, optional + A dictionary in the form ``{old_dim: new_dim}`` for dimension(s) to + rename in the resulting DataArray or Dataset. + + """ + + dim_indexers: dict[Any, Any] + indexes: dict[Any, Index] = field(default_factory=dict) + variables: dict[Any, Variable] = field(default_factory=dict) + drop_coords: list[Hashable] = field(default_factory=list) + drop_indexes: list[Hashable] = field(default_factory=list) + rename_dims: dict[Any, Hashable] = field(default_factory=dict) + + def as_tuple(self): + """Unlike ``dataclasses.astuple``, return a shallow copy. + + See https://stackoverflow.com/a/51802661 + + """ + return ( + self.dim_indexers, + self.indexes, + self.variables, + self.drop_coords, + self.drop_indexes, + self.rename_dims, + ) + + +def merge_sel_results(results: list[IndexSelResult]) -> IndexSelResult: + all_dims_count = Counter([dim for res in results for dim in res.dim_indexers]) + duplicate_dims = {k: v for k, v in all_dims_count.items() if v > 1} + + if duplicate_dims: + # TODO: this message is not right when combining indexe(s) queries with + # location-based indexing on a dimension with no dimension-coordinate (failback) + fmt_dims = [ + f"{dim!r}: {count} indexes involved" + for dim, count in duplicate_dims.items() + ] + raise ValueError( + "Xarray does not support label-based selection with more than one index " + "over the following dimension(s):\n" + + "\n".join(fmt_dims) + + "\nSuggestion: use a multi-index for each of those dimension(s)." + ) + + dim_indexers = {} + indexes = {} + variables = {} + drop_coords = [] + drop_indexes = [] + rename_dims = {} + + for res in results: + dim_indexers.update(res.dim_indexers) + indexes.update(res.indexes) + variables.update(res.variables) + drop_coords += res.drop_coords + drop_indexes += res.drop_indexes + rename_dims.update(res.rename_dims) + + return IndexSelResult( + dim_indexers, indexes, variables, drop_coords, drop_indexes, rename_dims + ) + + +def group_indexers_by_index( + obj: T_Xarray, + indexers: Mapping[Any, Any], + options: Mapping[str, Any], +) -> list[tuple[Index, dict[Any, Any]]]: + """Returns a list of unique indexes and their corresponding indexers.""" + unique_indexes = {} + grouped_indexers: Mapping[int | None, dict] = defaultdict(dict) + + for key, label in indexers.items(): + index: Index = obj.xindexes.get(key, None) + + if index is not None: + index_id = id(index) + unique_indexes[index_id] = index + grouped_indexers[index_id][key] = label + elif key in obj.coords: + raise KeyError(f"no index found for coordinate {key!r}") + elif key not in obj.dims: + raise KeyError( + f"{key!r} is not a valid dimension or coordinate for " + f"{obj.__class__.__name__} with dimensions {obj.dims!r}" + ) + elif len(options): + raise ValueError( + f"cannot supply selection options {options!r} for dimension {key!r}" + "that has no associated coordinate or index" + ) + else: + # key is a dimension without a "dimension-coordinate" + # failback to location-based selection + # TODO: depreciate this implicit behavior and suggest using isel instead? + unique_indexes[None] = None + grouped_indexers[None][key] = label + + return [(unique_indexes[k], grouped_indexers[k]) for k in unique_indexes] + + +def map_index_queries( + obj: T_Xarray, + indexers: Mapping[Any, Any], + method=None, + tolerance: int | float | Iterable[int | float] | None = None, + **indexers_kwargs: Any, +) -> IndexSelResult: + """Execute index queries from a DataArray / Dataset and label-based indexers + and return the (merged) query results. + + """ + from xarray.core.dataarray import DataArray + + # TODO benbovy - flexible indexes: remove when custom index options are available + if method is None and tolerance is None: + options = {} + else: + options = {"method": method, "tolerance": tolerance} + + indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "map_index_queries") + grouped_indexers = group_indexers_by_index(obj, indexers, options) + + results = [] + for index, labels in grouped_indexers: + if index is None: + # forward dimension indexers with no index/coordinate + results.append(IndexSelResult(labels)) + else: + results.append(index.sel(labels, **options)) + + merged = merge_sel_results(results) + + # drop dimension coordinates found in dimension indexers + # (also drop multi-index if any) + # (.sel() already ensures alignment) + for k, v in merged.dim_indexers.items(): + if isinstance(v, DataArray): + if k in v._indexes: + v = v.reset_index(k) + drop_coords = [name for name in v._coords if name in merged.dim_indexers] + merged.dim_indexers[k] = v.drop_vars(drop_coords) + + return merged + + +def expanded_indexer(key, ndim): + """Given a key for indexing an ndarray, return an equivalent key which is a + tuple with length equal to the number of dimensions. + + The expansion is done by replacing all `Ellipsis` items with the right + number of full slices and then padding the key with full slices so that it + reaches the appropriate dimensionality. + """ + if not isinstance(key, tuple): + # numpy treats non-tuple keys equivalent to tuples of length 1 + key = (key,) + new_key = [] + # handling Ellipsis right is a little tricky, see: + # https://numpy.org/doc/stable/reference/arrays.indexing.html#advanced-indexing + found_ellipsis = False + for k in key: + if k is Ellipsis: + if not found_ellipsis: + new_key.extend((ndim + 1 - len(key)) * [slice(None)]) + found_ellipsis = True + else: + new_key.append(slice(None)) + else: + new_key.append(k) + if len(new_key) > ndim: + raise IndexError("too many indices") + new_key.extend((ndim - len(new_key)) * [slice(None)]) + return tuple(new_key) + + +def _normalize_slice(sl: slice, size: int) -> slice: + """ + Ensure that given slice only contains positive start and stop values + (stop can be -1 for full-size slices with negative steps, e.g. [-10::-1]) + + Examples + -------- + >>> _normalize_slice(slice(0, 9), 10) + slice(0, 9, 1) + >>> _normalize_slice(slice(0, -1), 10) + slice(0, 9, 1) + """ + return slice(*sl.indices(size)) + + +def _expand_slice(slice_: slice, size: int) -> np.ndarray[Any, np.dtype[np.integer]]: + """ + Expand slice to an array containing only positive integers. + + Examples + -------- + >>> _expand_slice(slice(0, 9), 10) + array([0, 1, 2, 3, 4, 5, 6, 7, 8]) + >>> _expand_slice(slice(0, -1), 10) + array([0, 1, 2, 3, 4, 5, 6, 7, 8]) + """ + sl = _normalize_slice(slice_, size) + return np.arange(sl.start, sl.stop, sl.step) + + +def slice_slice(old_slice: slice, applied_slice: slice, size: int) -> slice: + """Given a slice and the size of the dimension to which it will be applied, + index it with another slice to return a new slice equivalent to applying + the slices sequentially + """ + old_slice = _normalize_slice(old_slice, size) + + size_after_old_slice = len(range(old_slice.start, old_slice.stop, old_slice.step)) + if size_after_old_slice == 0: + # nothing left after applying first slice + return slice(0) + + applied_slice = _normalize_slice(applied_slice, size_after_old_slice) + + start = old_slice.start + applied_slice.start * old_slice.step + if start < 0: + # nothing left after applying second slice + # (can only happen for old_slice.step < 0, e.g. [10::-1], [20:]) + return slice(0) + + stop = old_slice.start + applied_slice.stop * old_slice.step + if stop < 0: + stop = None + + step = old_slice.step * applied_slice.step + + return slice(start, stop, step) + + +def _index_indexer_1d(old_indexer, applied_indexer, size: int): + if isinstance(applied_indexer, slice) and applied_indexer == slice(None): + # shortcut for the usual case + return old_indexer + if isinstance(old_indexer, slice): + if isinstance(applied_indexer, slice): + indexer = slice_slice(old_indexer, applied_indexer, size) + elif isinstance(applied_indexer, integer_types): + indexer = range(*old_indexer.indices(size))[applied_indexer] # type: ignore[assignment] + else: + indexer = _expand_slice(old_indexer, size)[applied_indexer] + else: + indexer = old_indexer[applied_indexer] + return indexer + + +class ExplicitIndexer: + """Base class for explicit indexer objects. + + ExplicitIndexer objects wrap a tuple of values given by their ``tuple`` + property. These tuples should always have length equal to the number of + dimensions on the indexed array. + + Do not instantiate BaseIndexer objects directly: instead, use one of the + sub-classes BasicIndexer, OuterIndexer or VectorizedIndexer. + """ + + __slots__ = ("_key",) + + def __init__(self, key: tuple[Any, ...]): + if type(self) is ExplicitIndexer: + raise TypeError("cannot instantiate base ExplicitIndexer objects") + self._key = tuple(key) + + @property + def tuple(self) -> tuple[Any, ...]: + return self._key + + def __repr__(self) -> str: + return f"{type(self).__name__}({self.tuple})" + + +@overload +def as_integer_or_none(value: int) -> int: ... +@overload +def as_integer_or_none(value: None) -> None: ... +def as_integer_or_none(value: int | None) -> int | None: + return None if value is None else operator.index(value) + + +def as_integer_slice(value: slice) -> slice: + start = as_integer_or_none(value.start) + stop = as_integer_or_none(value.stop) + step = as_integer_or_none(value.step) + return slice(start, stop, step) + + +class IndexCallable: + """Provide getitem and setitem syntax for callable objects.""" + + __slots__ = ("getter", "setter") + + def __init__( + self, getter: Callable[..., Any], setter: Callable[..., Any] | None = None + ): + self.getter = getter + self.setter = setter + + def __getitem__(self, key: Any) -> Any: + return self.getter(key) + + def __setitem__(self, key: Any, value: Any) -> None: + if self.setter is None: + raise NotImplementedError( + "Setting values is not supported for this indexer." + ) + self.setter(key, value) + + +class BasicIndexer(ExplicitIndexer): + """Tuple for basic indexing. + + All elements should be int or slice objects. Indexing follows NumPy's + rules for basic indexing: each axis is independently sliced and axes + indexed with an integer are dropped from the result. + """ + + __slots__ = () + + def __init__(self, key: tuple[int | np.integer | slice, ...]): + if not isinstance(key, tuple): + raise TypeError(f"key must be a tuple: {key!r}") + + new_key = [] + for k in key: + if isinstance(k, integer_types): + k = int(k) + elif isinstance(k, slice): + k = as_integer_slice(k) + else: + raise TypeError( + f"unexpected indexer type for {type(self).__name__}: {k!r}" + ) + new_key.append(k) + + super().__init__(tuple(new_key)) + + +class OuterIndexer(ExplicitIndexer): + """Tuple for outer/orthogonal indexing. + + All elements should be int, slice or 1-dimensional np.ndarray objects with + an integer dtype. Indexing is applied independently along each axis, and + axes indexed with an integer are dropped from the result. This type of + indexing works like MATLAB/Fortran. + """ + + __slots__ = () + + def __init__( + self, + key: tuple[ + int | np.integer | slice | np.ndarray[Any, np.dtype[np.generic]], ... + ], + ): + if not isinstance(key, tuple): + raise TypeError(f"key must be a tuple: {key!r}") + + new_key = [] + for k in key: + if isinstance(k, integer_types): + k = int(k) + elif isinstance(k, slice): + k = as_integer_slice(k) + elif is_duck_array(k): + if not np.issubdtype(k.dtype, np.integer): + raise TypeError( + f"invalid indexer array, does not have integer dtype: {k!r}" + ) + if k.ndim > 1: # type: ignore[union-attr] + raise TypeError( + f"invalid indexer array for {type(self).__name__}; must be scalar " + f"or have 1 dimension: {k!r}" + ) + k = k.astype(np.int64) # type: ignore[union-attr] + else: + raise TypeError( + f"unexpected indexer type for {type(self).__name__}: {k!r}" + ) + new_key.append(k) + + super().__init__(tuple(new_key)) + + +class VectorizedIndexer(ExplicitIndexer): + """Tuple for vectorized indexing. + + All elements should be slice or N-dimensional np.ndarray objects with an + integer dtype and the same number of dimensions. Indexing follows proposed + rules for np.ndarray.vindex, which matches NumPy's advanced indexing rules + (including broadcasting) except sliced axes are always moved to the end: + https://github.com/numpy/numpy/pull/6256 + """ + + __slots__ = () + + def __init__(self, key: tuple[slice | np.ndarray[Any, np.dtype[np.generic]], ...]): + if not isinstance(key, tuple): + raise TypeError(f"key must be a tuple: {key!r}") + + new_key = [] + ndim = None + for k in key: + if isinstance(k, slice): + k = as_integer_slice(k) + elif is_duck_dask_array(k): + raise ValueError( + "Vectorized indexing with Dask arrays is not supported. " + "Please pass a numpy array by calling ``.compute``. " + "See https://github.com/dask/dask/issues/8958." + ) + elif is_duck_array(k): + if not np.issubdtype(k.dtype, np.integer): + raise TypeError( + f"invalid indexer array, does not have integer dtype: {k!r}" + ) + if ndim is None: + ndim = k.ndim # type: ignore[union-attr] + elif ndim != k.ndim: + ndims = [k.ndim for k in key if isinstance(k, np.ndarray)] + raise ValueError( + "invalid indexer key: ndarray arguments " + f"have different numbers of dimensions: {ndims}" + ) + k = k.astype(np.int64) # type: ignore[union-attr] + else: + raise TypeError( + f"unexpected indexer type for {type(self).__name__}: {k!r}" + ) + new_key.append(k) + + super().__init__(tuple(new_key)) + + +class ExplicitlyIndexed: + """Mixin to mark support for Indexer subclasses in indexing.""" + + __slots__ = () + + def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray: + # Leave casting to an array up to the underlying array type. + return np.asarray(self.get_duck_array(), dtype=dtype) + + def get_duck_array(self): + return self.array + + +class ExplicitlyIndexedNDArrayMixin(NDArrayMixin, ExplicitlyIndexed): + __slots__ = () + + def get_duck_array(self): + key = BasicIndexer((slice(None),) * self.ndim) + return self[key] + + def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray: + # This is necessary because we apply the indexing key in self.get_duck_array() + # Note this is the base class for all lazy indexing classes + return np.asarray(self.get_duck_array(), dtype=dtype) + + def _oindex_get(self, indexer: OuterIndexer): + raise NotImplementedError( + f"{self.__class__.__name__}._oindex_get method should be overridden" + ) + + def _vindex_get(self, indexer: VectorizedIndexer): + raise NotImplementedError( + f"{self.__class__.__name__}._vindex_get method should be overridden" + ) + + def _oindex_set(self, indexer: OuterIndexer, value: Any) -> None: + raise NotImplementedError( + f"{self.__class__.__name__}._oindex_set method should be overridden" + ) + + def _vindex_set(self, indexer: VectorizedIndexer, value: Any) -> None: + raise NotImplementedError( + f"{self.__class__.__name__}._vindex_set method should be overridden" + ) + + def _check_and_raise_if_non_basic_indexer(self, indexer: ExplicitIndexer) -> None: + if isinstance(indexer, (VectorizedIndexer, OuterIndexer)): + raise TypeError( + "Vectorized indexing with vectorized or outer indexers is not supported. " + "Please use .vindex and .oindex properties to index the array." + ) + + @property + def oindex(self) -> IndexCallable: + return IndexCallable(self._oindex_get, self._oindex_set) + + @property + def vindex(self) -> IndexCallable: + return IndexCallable(self._vindex_get, self._vindex_set) + + +class ImplicitToExplicitIndexingAdapter(NDArrayMixin): + """Wrap an array, converting tuples into the indicated explicit indexer.""" + + __slots__ = ("array", "indexer_cls") + + def __init__(self, array, indexer_cls: type[ExplicitIndexer] = BasicIndexer): + self.array = as_indexable(array) + self.indexer_cls = indexer_cls + + def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray: + return np.asarray(self.get_duck_array(), dtype=dtype) + + def get_duck_array(self): + return self.array.get_duck_array() + + def __getitem__(self, key: Any): + key = expanded_indexer(key, self.ndim) + indexer = self.indexer_cls(key) + + result = apply_indexer(self.array, indexer) + + if isinstance(result, ExplicitlyIndexed): + return type(self)(result, self.indexer_cls) + else: + # Sometimes explicitly indexed arrays return NumPy arrays or + # scalars. + return result + + +class LazilyIndexedArray(ExplicitlyIndexedNDArrayMixin): + """Wrap an array to make basic and outer indexing lazy.""" + + __slots__ = ("array", "key", "_shape") + + def __init__(self, array: Any, key: ExplicitIndexer | None = None): + """ + Parameters + ---------- + array : array_like + Array like object to index. + key : ExplicitIndexer, optional + Array indexer. If provided, it is assumed to already be in + canonical expanded form. + """ + if isinstance(array, type(self)) and key is None: + # unwrap + key = array.key # type: ignore[has-type] + array = array.array # type: ignore[has-type] + + if key is None: + key = BasicIndexer((slice(None),) * array.ndim) + + self.array = as_indexable(array) + self.key = key + + shape: _Shape = () + for size, k in zip(self.array.shape, self.key.tuple): + if isinstance(k, slice): + shape += (len(range(*k.indices(size))),) + elif isinstance(k, np.ndarray): + shape += (k.size,) + self._shape = shape + + def _updated_key(self, new_key: ExplicitIndexer) -> BasicIndexer | OuterIndexer: + iter_new_key = iter(expanded_indexer(new_key.tuple, self.ndim)) + full_key = [] + for size, k in zip(self.array.shape, self.key.tuple): + if isinstance(k, integer_types): + full_key.append(k) + else: + full_key.append(_index_indexer_1d(k, next(iter_new_key), size)) + full_key_tuple = tuple(full_key) + + if all(isinstance(k, integer_types + (slice,)) for k in full_key_tuple): + return BasicIndexer(full_key_tuple) + return OuterIndexer(full_key_tuple) + + @property + def shape(self) -> _Shape: + return self._shape + + def get_duck_array(self): + if isinstance(self.array, ExplicitlyIndexedNDArrayMixin): + array = apply_indexer(self.array, self.key) + else: + # If the array is not an ExplicitlyIndexedNDArrayMixin, + # it may wrap a BackendArray so use its __getitem__ + array = self.array[self.key] + + # self.array[self.key] is now a numpy array when + # self.array is a BackendArray subclass + # and self.key is BasicIndexer((slice(None, None, None),)) + # so we need the explicit check for ExplicitlyIndexed + if isinstance(array, ExplicitlyIndexed): + array = array.get_duck_array() + return _wrap_numpy_scalars(array) + + def transpose(self, order): + return LazilyVectorizedIndexedArray(self.array, self.key).transpose(order) + + def _oindex_get(self, indexer: OuterIndexer): + return type(self)(self.array, self._updated_key(indexer)) + + def _vindex_get(self, indexer: VectorizedIndexer): + array = LazilyVectorizedIndexedArray(self.array, self.key) + return array.vindex[indexer] + + def __getitem__(self, indexer: ExplicitIndexer): + self._check_and_raise_if_non_basic_indexer(indexer) + return type(self)(self.array, self._updated_key(indexer)) + + def _vindex_set(self, key: VectorizedIndexer, value: Any) -> None: + raise NotImplementedError( + "Lazy item assignment with the vectorized indexer is not yet " + "implemented. Load your data first by .load() or compute()." + ) + + def _oindex_set(self, key: OuterIndexer, value: Any) -> None: + full_key = self._updated_key(key) + self.array.oindex[full_key] = value + + def __setitem__(self, key: BasicIndexer, value: Any) -> None: + self._check_and_raise_if_non_basic_indexer(key) + full_key = self._updated_key(key) + self.array[full_key] = value + + def __repr__(self) -> str: + return f"{type(self).__name__}(array={self.array!r}, key={self.key!r})" + + +# keep an alias to the old name for external backends pydata/xarray#5111 +LazilyOuterIndexedArray = LazilyIndexedArray + + +class LazilyVectorizedIndexedArray(ExplicitlyIndexedNDArrayMixin): + """Wrap an array to make vectorized indexing lazy.""" + + __slots__ = ("array", "key") + + def __init__(self, array: duckarray[Any, Any], key: ExplicitIndexer): + """ + Parameters + ---------- + array : array_like + Array like object to index. + key : VectorizedIndexer + """ + if isinstance(key, (BasicIndexer, OuterIndexer)): + self.key = _outer_to_vectorized_indexer(key, array.shape) + elif isinstance(key, VectorizedIndexer): + self.key = _arrayize_vectorized_indexer(key, array.shape) + self.array = as_indexable(array) + + @property + def shape(self) -> _Shape: + return np.broadcast(*self.key.tuple).shape + + def get_duck_array(self): + if isinstance(self.array, ExplicitlyIndexedNDArrayMixin): + array = apply_indexer(self.array, self.key) + else: + # If the array is not an ExplicitlyIndexedNDArrayMixin, + # it may wrap a BackendArray so use its __getitem__ + array = self.array[self.key] + # self.array[self.key] is now a numpy array when + # self.array is a BackendArray subclass + # and self.key is BasicIndexer((slice(None, None, None),)) + # so we need the explicit check for ExplicitlyIndexed + if isinstance(array, ExplicitlyIndexed): + array = array.get_duck_array() + return _wrap_numpy_scalars(array) + + def _updated_key(self, new_key: ExplicitIndexer): + return _combine_indexers(self.key, self.shape, new_key) + + def _oindex_get(self, indexer: OuterIndexer): + return type(self)(self.array, self._updated_key(indexer)) + + def _vindex_get(self, indexer: VectorizedIndexer): + return type(self)(self.array, self._updated_key(indexer)) + + def __getitem__(self, indexer: ExplicitIndexer): + self._check_and_raise_if_non_basic_indexer(indexer) + # If the indexed array becomes a scalar, return LazilyIndexedArray + if all(isinstance(ind, integer_types) for ind in indexer.tuple): + key = BasicIndexer(tuple(k[indexer.tuple] for k in self.key.tuple)) + return LazilyIndexedArray(self.array, key) + return type(self)(self.array, self._updated_key(indexer)) + + def transpose(self, order): + key = VectorizedIndexer(tuple(k.transpose(order) for k in self.key.tuple)) + return type(self)(self.array, key) + + def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None: + raise NotImplementedError( + "Lazy item assignment with the vectorized indexer is not yet " + "implemented. Load your data first by .load() or compute()." + ) + + def __repr__(self) -> str: + return f"{type(self).__name__}(array={self.array!r}, key={self.key!r})" + + +def _wrap_numpy_scalars(array): + """Wrap NumPy scalars in 0d arrays.""" + if np.isscalar(array): + return np.array(array) + else: + return array + + +class CopyOnWriteArray(ExplicitlyIndexedNDArrayMixin): + __slots__ = ("array", "_copied") + + def __init__(self, array: duckarray[Any, Any]): + self.array = as_indexable(array) + self._copied = False + + def _ensure_copied(self): + if not self._copied: + self.array = as_indexable(np.array(self.array)) + self._copied = True + + def get_duck_array(self): + return self.array.get_duck_array() + + def _oindex_get(self, indexer: OuterIndexer): + return type(self)(_wrap_numpy_scalars(self.array.oindex[indexer])) + + def _vindex_get(self, indexer: VectorizedIndexer): + return type(self)(_wrap_numpy_scalars(self.array.vindex[indexer])) + + def __getitem__(self, indexer: ExplicitIndexer): + self._check_and_raise_if_non_basic_indexer(indexer) + return type(self)(_wrap_numpy_scalars(self.array[indexer])) + + def transpose(self, order): + return self.array.transpose(order) + + def _vindex_set(self, indexer: VectorizedIndexer, value: Any) -> None: + self._ensure_copied() + self.array.vindex[indexer] = value + + def _oindex_set(self, indexer: OuterIndexer, value: Any) -> None: + self._ensure_copied() + self.array.oindex[indexer] = value + + def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None: + self._check_and_raise_if_non_basic_indexer(indexer) + self._ensure_copied() + + self.array[indexer] = value + + def __deepcopy__(self, memo): + # CopyOnWriteArray is used to wrap backend array objects, which might + # point to files on disk, so we can't rely on the default deepcopy + # implementation. + return type(self)(self.array) + + +class MemoryCachedArray(ExplicitlyIndexedNDArrayMixin): + __slots__ = ("array",) + + def __init__(self, array): + self.array = _wrap_numpy_scalars(as_indexable(array)) + + def _ensure_cached(self): + self.array = as_indexable(self.array.get_duck_array()) + + def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray: + return np.asarray(self.get_duck_array(), dtype=dtype) + + def get_duck_array(self): + self._ensure_cached() + return self.array.get_duck_array() + + def _oindex_get(self, indexer: OuterIndexer): + return type(self)(_wrap_numpy_scalars(self.array.oindex[indexer])) + + def _vindex_get(self, indexer: VectorizedIndexer): + return type(self)(_wrap_numpy_scalars(self.array.vindex[indexer])) + + def __getitem__(self, indexer: ExplicitIndexer): + self._check_and_raise_if_non_basic_indexer(indexer) + return type(self)(_wrap_numpy_scalars(self.array[indexer])) + + def transpose(self, order): + return self.array.transpose(order) + + def _vindex_set(self, indexer: VectorizedIndexer, value: Any) -> None: + self.array.vindex[indexer] = value + + def _oindex_set(self, indexer: OuterIndexer, value: Any) -> None: + self.array.oindex[indexer] = value + + def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None: + self._check_and_raise_if_non_basic_indexer(indexer) + self.array[indexer] = value + + +def as_indexable(array): + """ + This function always returns a ExplicitlyIndexed subclass, + so that the vectorized indexing is always possible with the returned + object. + """ + if isinstance(array, ExplicitlyIndexed): + return array + if isinstance(array, np.ndarray): + return NumpyIndexingAdapter(array) + if isinstance(array, pd.Index): + return PandasIndexingAdapter(array) + if is_duck_dask_array(array): + return DaskIndexingAdapter(array) + if hasattr(array, "__array_function__"): + return NdArrayLikeIndexingAdapter(array) + if hasattr(array, "__array_namespace__"): + return ArrayApiIndexingAdapter(array) + + raise TypeError(f"Invalid array type: {type(array)}") + + +def _outer_to_vectorized_indexer( + indexer: BasicIndexer | OuterIndexer, shape: _Shape +) -> VectorizedIndexer: + """Convert an OuterIndexer into an vectorized indexer. + + Parameters + ---------- + indexer : Outer/Basic Indexer + An indexer to convert. + shape : tuple + Shape of the array subject to the indexing. + + Returns + ------- + VectorizedIndexer + Tuple suitable for use to index a NumPy array with vectorized indexing. + Each element is an array: broadcasting them together gives the shape + of the result. + """ + key = indexer.tuple + + n_dim = len([k for k in key if not isinstance(k, integer_types)]) + i_dim = 0 + new_key = [] + for k, size in zip(key, shape): + if isinstance(k, integer_types): + new_key.append(np.array(k).reshape((1,) * n_dim)) + else: # np.ndarray or slice + if isinstance(k, slice): + k = np.arange(*k.indices(size)) + assert k.dtype.kind in {"i", "u"} + new_shape = [(1,) * i_dim + (k.size,) + (1,) * (n_dim - i_dim - 1)] + new_key.append(k.reshape(*new_shape)) + i_dim += 1 + return VectorizedIndexer(tuple(new_key)) + + +def _outer_to_numpy_indexer(indexer: BasicIndexer | OuterIndexer, shape: _Shape): + """Convert an OuterIndexer into an indexer for NumPy. + + Parameters + ---------- + indexer : Basic/OuterIndexer + An indexer to convert. + shape : tuple + Shape of the array subject to the indexing. + + Returns + ------- + tuple + Tuple suitable for use to index a NumPy array. + """ + if len([k for k in indexer.tuple if not isinstance(k, slice)]) <= 1: + # If there is only one vector and all others are slice, + # it can be safely used in mixed basic/advanced indexing. + # Boolean index should already be converted to integer array. + return indexer.tuple + else: + return _outer_to_vectorized_indexer(indexer, shape).tuple + + +def _combine_indexers(old_key, shape: _Shape, new_key) -> VectorizedIndexer: + """Combine two indexers. + + Parameters + ---------- + old_key : ExplicitIndexer + The first indexer for the original array + shape : tuple of ints + Shape of the original array to be indexed by old_key + new_key + The second indexer for indexing original[old_key] + """ + if not isinstance(old_key, VectorizedIndexer): + old_key = _outer_to_vectorized_indexer(old_key, shape) + if len(old_key.tuple) == 0: + return new_key + + new_shape = np.broadcast(*old_key.tuple).shape + if isinstance(new_key, VectorizedIndexer): + new_key = _arrayize_vectorized_indexer(new_key, new_shape) + else: + new_key = _outer_to_vectorized_indexer(new_key, new_shape) + + return VectorizedIndexer( + tuple(o[new_key.tuple] for o in np.broadcast_arrays(*old_key.tuple)) + ) + + +@enum.unique +class IndexingSupport(enum.Enum): + # for backends that support only basic indexer + BASIC = 0 + # for backends that support basic / outer indexer + OUTER = 1 + # for backends that support outer indexer including at most 1 vector. + OUTER_1VECTOR = 2 + # for backends that support full vectorized indexer. + VECTORIZED = 3 + + +def explicit_indexing_adapter( + key: ExplicitIndexer, + shape: _Shape, + indexing_support: IndexingSupport, + raw_indexing_method: Callable[..., Any], +) -> Any: + """Support explicit indexing by delegating to a raw indexing method. + + Outer and/or vectorized indexers are supported by indexing a second time + with a NumPy array. + + Parameters + ---------- + key : ExplicitIndexer + Explicit indexing object. + shape : Tuple[int, ...] + Shape of the indexed array. + indexing_support : IndexingSupport enum + Form of indexing supported by raw_indexing_method. + raw_indexing_method : callable + Function (like ndarray.__getitem__) that when called with indexing key + in the form of a tuple returns an indexed array. + + Returns + ------- + Indexing result, in the form of a duck numpy-array. + """ + raw_key, numpy_indices = decompose_indexer(key, shape, indexing_support) + result = raw_indexing_method(raw_key.tuple) + if numpy_indices.tuple: + # index the loaded np.ndarray + indexable = NumpyIndexingAdapter(result) + result = apply_indexer(indexable, numpy_indices) + return result + + +def apply_indexer(indexable, indexer: ExplicitIndexer): + """Apply an indexer to an indexable object.""" + if isinstance(indexer, VectorizedIndexer): + return indexable.vindex[indexer] + elif isinstance(indexer, OuterIndexer): + return indexable.oindex[indexer] + else: + return indexable[indexer] + + +def set_with_indexer(indexable, indexer: ExplicitIndexer, value: Any) -> None: + """Set values in an indexable object using an indexer.""" + if isinstance(indexer, VectorizedIndexer): + indexable.vindex[indexer] = value + elif isinstance(indexer, OuterIndexer): + indexable.oindex[indexer] = value + else: + indexable[indexer] = value + + +def decompose_indexer( + indexer: ExplicitIndexer, shape: _Shape, indexing_support: IndexingSupport +) -> tuple[ExplicitIndexer, ExplicitIndexer]: + if isinstance(indexer, VectorizedIndexer): + return _decompose_vectorized_indexer(indexer, shape, indexing_support) + if isinstance(indexer, (BasicIndexer, OuterIndexer)): + return _decompose_outer_indexer(indexer, shape, indexing_support) + raise TypeError(f"unexpected key type: {indexer}") + + +def _decompose_slice(key: slice, size: int) -> tuple[slice, slice]: + """convert a slice to successive two slices. The first slice always has + a positive step. + + >>> _decompose_slice(slice(2, 98, 2), 99) + (slice(2, 98, 2), slice(None, None, None)) + + >>> _decompose_slice(slice(98, 2, -2), 99) + (slice(4, 99, 2), slice(None, None, -1)) + + >>> _decompose_slice(slice(98, 2, -2), 98) + (slice(3, 98, 2), slice(None, None, -1)) + + >>> _decompose_slice(slice(360, None, -10), 361) + (slice(0, 361, 10), slice(None, None, -1)) + """ + start, stop, step = key.indices(size) + if step > 0: + # If key already has a positive step, use it as is in the backend + return key, slice(None) + else: + # determine stop precisely for step > 1 case + # Use the range object to do the calculation + # e.g. [98:2:-2] -> [98:3:-2] + exact_stop = range(start, stop, step)[-1] + return slice(exact_stop, start + 1, -step), slice(None, None, -1) + + +def _decompose_vectorized_indexer( + indexer: VectorizedIndexer, + shape: _Shape, + indexing_support: IndexingSupport, +) -> tuple[ExplicitIndexer, ExplicitIndexer]: + """ + Decompose vectorized indexer to the successive two indexers, where the + first indexer will be used to index backend arrays, while the second one + is used to index loaded on-memory np.ndarray. + + Parameters + ---------- + indexer : VectorizedIndexer + indexing_support : one of IndexerSupport entries + + Returns + ------- + backend_indexer: OuterIndexer or BasicIndexer + np_indexers: an ExplicitIndexer (VectorizedIndexer / BasicIndexer) + + Notes + ----- + This function is used to realize the vectorized indexing for the backend + arrays that only support basic or outer indexing. + + As an example, let us consider to index a few elements from a backend array + with a vectorized indexer ([0, 3, 1], [2, 3, 2]). + Even if the backend array only supports outer indexing, it is more + efficient to load a subslice of the array than loading the entire array, + + >>> array = np.arange(36).reshape(6, 6) + >>> backend_indexer = OuterIndexer((np.array([0, 1, 3]), np.array([2, 3]))) + >>> # load subslice of the array + ... array = NumpyIndexingAdapter(array).oindex[backend_indexer] + >>> np_indexer = VectorizedIndexer((np.array([0, 2, 1]), np.array([0, 1, 0]))) + >>> # vectorized indexing for on-memory np.ndarray. + ... NumpyIndexingAdapter(array).vindex[np_indexer] + array([ 2, 21, 8]) + """ + assert isinstance(indexer, VectorizedIndexer) + + if indexing_support is IndexingSupport.VECTORIZED: + return indexer, BasicIndexer(()) + + backend_indexer_elems = [] + np_indexer_elems = [] + # convert negative indices + indexer_elems = [ + np.where(k < 0, k + s, k) if isinstance(k, np.ndarray) else k + for k, s in zip(indexer.tuple, shape) + ] + + for k, s in zip(indexer_elems, shape): + if isinstance(k, slice): + # If it is a slice, then we will slice it as-is + # (but make its step positive) in the backend, + # and then use all of it (slice(None)) for the in-memory portion. + bk_slice, np_slice = _decompose_slice(k, s) + backend_indexer_elems.append(bk_slice) + np_indexer_elems.append(np_slice) + else: + # If it is a (multidimensional) np.ndarray, just pickup the used + # keys without duplication and store them as a 1d-np.ndarray. + oind, vind = np.unique(k, return_inverse=True) + backend_indexer_elems.append(oind) + np_indexer_elems.append(vind.reshape(*k.shape)) + + backend_indexer = OuterIndexer(tuple(backend_indexer_elems)) + np_indexer = VectorizedIndexer(tuple(np_indexer_elems)) + + if indexing_support is IndexingSupport.OUTER: + return backend_indexer, np_indexer + + # If the backend does not support outer indexing, + # backend_indexer (OuterIndexer) is also decomposed. + backend_indexer1, np_indexer1 = _decompose_outer_indexer( + backend_indexer, shape, indexing_support + ) + np_indexer = _combine_indexers(np_indexer1, shape, np_indexer) + return backend_indexer1, np_indexer + + +def _decompose_outer_indexer( + indexer: BasicIndexer | OuterIndexer, + shape: _Shape, + indexing_support: IndexingSupport, +) -> tuple[ExplicitIndexer, ExplicitIndexer]: + """ + Decompose outer indexer to the successive two indexers, where the + first indexer will be used to index backend arrays, while the second one + is used to index the loaded on-memory np.ndarray. + + Parameters + ---------- + indexer : OuterIndexer or BasicIndexer + indexing_support : One of the entries of IndexingSupport + + Returns + ------- + backend_indexer: OuterIndexer or BasicIndexer + np_indexers: an ExplicitIndexer (OuterIndexer / BasicIndexer) + + Notes + ----- + This function is used to realize the vectorized indexing for the backend + arrays that only support basic or outer indexing. + + As an example, let us consider to index a few elements from a backend array + with a orthogonal indexer ([0, 3, 1], [2, 3, 2]). + Even if the backend array only supports basic indexing, it is more + efficient to load a subslice of the array than loading the entire array, + + >>> array = np.arange(36).reshape(6, 6) + >>> backend_indexer = BasicIndexer((slice(0, 3), slice(2, 4))) + >>> # load subslice of the array + ... array = NumpyIndexingAdapter(array)[backend_indexer] + >>> np_indexer = OuterIndexer((np.array([0, 2, 1]), np.array([0, 1, 0]))) + >>> # outer indexing for on-memory np.ndarray. + ... NumpyIndexingAdapter(array).oindex[np_indexer] + array([[ 2, 3, 2], + [14, 15, 14], + [ 8, 9, 8]]) + """ + backend_indexer: list[Any] = [] + np_indexer: list[Any] = [] + + assert isinstance(indexer, (OuterIndexer, BasicIndexer)) + + if indexing_support == IndexingSupport.VECTORIZED: + for k, s in zip(indexer.tuple, shape): + if isinstance(k, slice): + # If it is a slice, then we will slice it as-is + # (but make its step positive) in the backend, + bk_slice, np_slice = _decompose_slice(k, s) + backend_indexer.append(bk_slice) + np_indexer.append(np_slice) + else: + backend_indexer.append(k) + if not is_scalar(k): + np_indexer.append(slice(None)) + return type(indexer)(tuple(backend_indexer)), BasicIndexer(tuple(np_indexer)) + + # make indexer positive + pos_indexer: list[np.ndarray | int | np.number] = [] + for k, s in zip(indexer.tuple, shape): + if isinstance(k, np.ndarray): + pos_indexer.append(np.where(k < 0, k + s, k)) + elif isinstance(k, integer_types) and k < 0: + pos_indexer.append(k + s) + else: + pos_indexer.append(k) + indexer_elems = pos_indexer + + if indexing_support is IndexingSupport.OUTER_1VECTOR: + # some backends such as h5py supports only 1 vector in indexers + # We choose the most efficient axis + gains = [ + ( + (np.max(k) - np.min(k) + 1.0) / len(np.unique(k)) + if isinstance(k, np.ndarray) + else 0 + ) + for k in indexer_elems + ] + array_index = np.argmax(np.array(gains)) if len(gains) > 0 else None + + for i, (k, s) in enumerate(zip(indexer_elems, shape)): + if isinstance(k, np.ndarray) and i != array_index: + # np.ndarray key is converted to slice that covers the entire + # entries of this key. + backend_indexer.append(slice(np.min(k), np.max(k) + 1)) + np_indexer.append(k - np.min(k)) + elif isinstance(k, np.ndarray): + # Remove duplicates and sort them in the increasing order + pkey, ekey = np.unique(k, return_inverse=True) + backend_indexer.append(pkey) + np_indexer.append(ekey) + elif isinstance(k, integer_types): + backend_indexer.append(k) + else: # slice: convert positive step slice for backend + bk_slice, np_slice = _decompose_slice(k, s) + backend_indexer.append(bk_slice) + np_indexer.append(np_slice) + + return (OuterIndexer(tuple(backend_indexer)), OuterIndexer(tuple(np_indexer))) + + if indexing_support == IndexingSupport.OUTER: + for k, s in zip(indexer_elems, shape): + if isinstance(k, slice): + # slice: convert positive step slice for backend + bk_slice, np_slice = _decompose_slice(k, s) + backend_indexer.append(bk_slice) + np_indexer.append(np_slice) + elif isinstance(k, integer_types): + backend_indexer.append(k) + elif isinstance(k, np.ndarray) and (np.diff(k) >= 0).all(): + backend_indexer.append(k) + np_indexer.append(slice(None)) + else: + # Remove duplicates and sort them in the increasing order + oind, vind = np.unique(k, return_inverse=True) + backend_indexer.append(oind) + np_indexer.append(vind.reshape(*k.shape)) + + return (OuterIndexer(tuple(backend_indexer)), OuterIndexer(tuple(np_indexer))) + + # basic indexer + assert indexing_support == IndexingSupport.BASIC + + for k, s in zip(indexer_elems, shape): + if isinstance(k, np.ndarray): + # np.ndarray key is converted to slice that covers the entire + # entries of this key. + backend_indexer.append(slice(np.min(k), np.max(k) + 1)) + np_indexer.append(k - np.min(k)) + elif isinstance(k, integer_types): + backend_indexer.append(k) + else: # slice: convert positive step slice for backend + bk_slice, np_slice = _decompose_slice(k, s) + backend_indexer.append(bk_slice) + np_indexer.append(np_slice) + + return (BasicIndexer(tuple(backend_indexer)), OuterIndexer(tuple(np_indexer))) + + +def _arrayize_vectorized_indexer( + indexer: VectorizedIndexer, shape: _Shape +) -> VectorizedIndexer: + """Return an identical vindex but slices are replaced by arrays""" + slices = [v for v in indexer.tuple if isinstance(v, slice)] + if len(slices) == 0: + return indexer + + arrays = [v for v in indexer.tuple if isinstance(v, np.ndarray)] + n_dim = arrays[0].ndim if len(arrays) > 0 else 0 + i_dim = 0 + new_key = [] + for v, size in zip(indexer.tuple, shape): + if isinstance(v, np.ndarray): + new_key.append(np.reshape(v, v.shape + (1,) * len(slices))) + else: # slice + shape = (1,) * (n_dim + i_dim) + (-1,) + (1,) * (len(slices) - i_dim - 1) + new_key.append(np.arange(*v.indices(size)).reshape(shape)) + i_dim += 1 + return VectorizedIndexer(tuple(new_key)) + + +def _chunked_array_with_chunks_hint( + array, chunks, chunkmanager: ChunkManagerEntrypoint[Any] +): + """Create a chunked array using the chunks hint for dimensions of size > 1.""" + + if len(chunks) < array.ndim: + raise ValueError("not enough chunks in hint") + new_chunks = [] + for chunk, size in zip(chunks, array.shape): + new_chunks.append(chunk if size > 1 else (1,)) + return chunkmanager.from_array(array, new_chunks) # type: ignore[arg-type] + + +def _logical_any(args): + return functools.reduce(operator.or_, args) + + +def _masked_result_drop_slice(key, data: duckarray[Any, Any] | None = None): + key = (k for k in key if not isinstance(k, slice)) + chunks_hint = getattr(data, "chunks", None) + + new_keys = [] + for k in key: + if isinstance(k, np.ndarray): + if is_chunked_array(data): # type: ignore[arg-type] + chunkmanager = get_chunked_array_type(data) + new_keys.append( + _chunked_array_with_chunks_hint(k, chunks_hint, chunkmanager) + ) + elif isinstance(data, array_type("sparse")): + import sparse + + new_keys.append(sparse.COO.from_numpy(k)) + else: + new_keys.append(k) + else: + new_keys.append(k) + + mask = _logical_any(k == -1 for k in new_keys) + return mask + + +def create_mask( + indexer: ExplicitIndexer, shape: _Shape, data: duckarray[Any, Any] | None = None +): + """Create a mask for indexing with a fill-value. + + Parameters + ---------- + indexer : ExplicitIndexer + Indexer with -1 in integer or ndarray value to indicate locations in + the result that should be masked. + shape : tuple + Shape of the array being indexed. + data : optional + Data for which mask is being created. If data is a dask arrays, its chunks + are used as a hint for chunks on the resulting mask. If data is a sparse + array, the returned mask is also a sparse array. + + Returns + ------- + mask : bool, np.ndarray, SparseArray or dask.array.Array with dtype=bool + Same type as data. Has the same shape as the indexing result. + """ + if isinstance(indexer, OuterIndexer): + key = _outer_to_vectorized_indexer(indexer, shape).tuple + assert not any(isinstance(k, slice) for k in key) + mask = _masked_result_drop_slice(key, data) + + elif isinstance(indexer, VectorizedIndexer): + key = indexer.tuple + base_mask = _masked_result_drop_slice(key, data) + slice_shape = tuple( + np.arange(*k.indices(size)).size + for k, size in zip(key, shape) + if isinstance(k, slice) + ) + expanded_mask = base_mask[(Ellipsis,) + (np.newaxis,) * len(slice_shape)] + mask = duck_array_ops.broadcast_to(expanded_mask, base_mask.shape + slice_shape) + + elif isinstance(indexer, BasicIndexer): + mask = any(k == -1 for k in indexer.tuple) + + else: + raise TypeError(f"unexpected key type: {type(indexer)}") + + return mask + + +def _posify_mask_subindexer( + index: np.ndarray[Any, np.dtype[np.generic]], +) -> np.ndarray[Any, np.dtype[np.generic]]: + """Convert masked indices in a flat array to the nearest unmasked index. + + Parameters + ---------- + index : np.ndarray + One dimensional ndarray with dtype=int. + + Returns + ------- + np.ndarray + One dimensional ndarray with all values equal to -1 replaced by an + adjacent non-masked element. + """ + masked = index == -1 + unmasked_locs = np.flatnonzero(~masked) + if not unmasked_locs.size: + # indexing unmasked_locs is invalid + return np.zeros_like(index) + masked_locs = np.flatnonzero(masked) + prev_value = np.maximum(0, np.searchsorted(unmasked_locs, masked_locs) - 1) + new_index = index.copy() + new_index[masked_locs] = index[unmasked_locs[prev_value]] + return new_index + + +def posify_mask_indexer(indexer: ExplicitIndexer) -> ExplicitIndexer: + """Convert masked values (-1) in an indexer to nearest unmasked values. + + This routine is useful for dask, where it can be much faster to index + adjacent points than arbitrary points from the end of an array. + + Parameters + ---------- + indexer : ExplicitIndexer + Input indexer. + + Returns + ------- + ExplicitIndexer + Same type of input, with all values in ndarray keys equal to -1 + replaced by an adjacent non-masked element. + """ + key = tuple( + ( + _posify_mask_subindexer(k.ravel()).reshape(k.shape) + if isinstance(k, np.ndarray) + else k + ) + for k in indexer.tuple + ) + return type(indexer)(key) + + +def is_fancy_indexer(indexer: Any) -> bool: + """Return False if indexer is a int, slice, a 1-dimensional list, or a 0 or + 1-dimensional ndarray; in all other cases return True + """ + if isinstance(indexer, (int, slice)): + return False + if isinstance(indexer, np.ndarray): + return indexer.ndim > 1 + if isinstance(indexer, list): + return bool(indexer) and not isinstance(indexer[0], int) + return True + + +class NumpyIndexingAdapter(ExplicitlyIndexedNDArrayMixin): + """Wrap a NumPy array to use explicit indexing.""" + + __slots__ = ("array",) + + def __init__(self, array): + # In NumpyIndexingAdapter we only allow to store bare np.ndarray + if not isinstance(array, np.ndarray): + raise TypeError( + "NumpyIndexingAdapter only wraps np.ndarray. " + f"Trying to wrap {type(array)}" + ) + self.array = array + + def transpose(self, order): + return self.array.transpose(order) + + def _oindex_get(self, indexer: OuterIndexer): + key = _outer_to_numpy_indexer(indexer, self.array.shape) + return self.array[key] + + def _vindex_get(self, indexer: VectorizedIndexer): + array = NumpyVIndexAdapter(self.array) + return array[indexer.tuple] + + def __getitem__(self, indexer: ExplicitIndexer): + self._check_and_raise_if_non_basic_indexer(indexer) + + array = self.array + # We want 0d slices rather than scalars. This is achieved by + # appending an ellipsis (see + # https://numpy.org/doc/stable/reference/arrays.indexing.html#detailed-notes). + key = indexer.tuple + (Ellipsis,) + return array[key] + + def _safe_setitem(self, array, key: tuple[Any, ...], value: Any) -> None: + try: + array[key] = value + except ValueError as exc: + # More informative exception if read-only view + if not array.flags.writeable and not array.flags.owndata: + raise ValueError( + "Assignment destination is a view. " + "Do you want to .copy() array first?" + ) + else: + raise exc + + def _oindex_set(self, indexer: OuterIndexer, value: Any) -> None: + key = _outer_to_numpy_indexer(indexer, self.array.shape) + self._safe_setitem(self.array, key, value) + + def _vindex_set(self, indexer: VectorizedIndexer, value: Any) -> None: + array = NumpyVIndexAdapter(self.array) + self._safe_setitem(array, indexer.tuple, value) + + def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None: + self._check_and_raise_if_non_basic_indexer(indexer) + array = self.array + # We want 0d slices rather than scalars. This is achieved by + # appending an ellipsis (see + # https://numpy.org/doc/stable/reference/arrays.indexing.html#detailed-notes). + key = indexer.tuple + (Ellipsis,) + self._safe_setitem(array, key, value) + + +class NdArrayLikeIndexingAdapter(NumpyIndexingAdapter): + __slots__ = ("array",) + + def __init__(self, array): + if not hasattr(array, "__array_function__"): + raise TypeError( + "NdArrayLikeIndexingAdapter must wrap an object that " + "implements the __array_function__ protocol" + ) + self.array = array + + +class ArrayApiIndexingAdapter(ExplicitlyIndexedNDArrayMixin): + """Wrap an array API array to use explicit indexing.""" + + __slots__ = ("array",) + + def __init__(self, array): + if not hasattr(array, "__array_namespace__"): + raise TypeError( + "ArrayApiIndexingAdapter must wrap an object that " + "implements the __array_namespace__ protocol" + ) + self.array = array + + def _oindex_get(self, indexer: OuterIndexer): + # manual orthogonal indexing (implemented like DaskIndexingAdapter) + key = indexer.tuple + value = self.array + for axis, subkey in reversed(list(enumerate(key))): + value = value[(slice(None),) * axis + (subkey, Ellipsis)] + return value + + def _vindex_get(self, indexer: VectorizedIndexer): + raise TypeError("Vectorized indexing is not supported") + + def __getitem__(self, indexer: ExplicitIndexer): + self._check_and_raise_if_non_basic_indexer(indexer) + return self.array[indexer.tuple] + + def _oindex_set(self, indexer: OuterIndexer, value: Any) -> None: + self.array[indexer.tuple] = value + + def _vindex_set(self, indexer: VectorizedIndexer, value: Any) -> None: + raise TypeError("Vectorized indexing is not supported") + + def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None: + self._check_and_raise_if_non_basic_indexer(indexer) + self.array[indexer.tuple] = value + + def transpose(self, order): + xp = self.array.__array_namespace__() + return xp.permute_dims(self.array, order) + + +class DaskIndexingAdapter(ExplicitlyIndexedNDArrayMixin): + """Wrap a dask array to support explicit indexing.""" + + __slots__ = ("array",) + + def __init__(self, array): + """This adapter is created in Variable.__getitem__ in + Variable._broadcast_indexes. + """ + self.array = array + + def _oindex_get(self, indexer: OuterIndexer): + key = indexer.tuple + try: + return self.array[key] + except NotImplementedError: + # manual orthogonal indexing + value = self.array + for axis, subkey in reversed(list(enumerate(key))): + value = value[(slice(None),) * axis + (subkey,)] + return value + + def _vindex_get(self, indexer: VectorizedIndexer): + return self.array.vindex[indexer.tuple] + + def __getitem__(self, indexer: ExplicitIndexer): + self._check_and_raise_if_non_basic_indexer(indexer) + return self.array[indexer.tuple] + + def _oindex_set(self, indexer: OuterIndexer, value: Any) -> None: + num_non_slices = sum(0 if isinstance(k, slice) else 1 for k in indexer.tuple) + if num_non_slices > 1: + raise NotImplementedError( + "xarray can't set arrays with multiple " "array indices to dask yet." + ) + self.array[indexer.tuple] = value + + def _vindex_set(self, indexer: VectorizedIndexer, value: Any) -> None: + self.array.vindex[indexer.tuple] = value + + def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None: + self._check_and_raise_if_non_basic_indexer(indexer) + self.array[indexer.tuple] = value + + def transpose(self, order): + return self.array.transpose(order) + + +class PandasIndexingAdapter(ExplicitlyIndexedNDArrayMixin): + """Wrap a pandas.Index to preserve dtypes and handle explicit indexing.""" + + __slots__ = ("array", "_dtype") + + def __init__(self, array: pd.Index, dtype: DTypeLike = None): + from xarray.core.indexes import safe_cast_to_index + + self.array = safe_cast_to_index(array) + + if dtype is None: + self._dtype = get_valid_numpy_dtype(array) + else: + self._dtype = np.dtype(dtype) + + @property + def dtype(self) -> np.dtype: + return self._dtype + + def __array__(self, dtype: DTypeLike = None) -> np.ndarray: + if dtype is None: + dtype = self.dtype + array = self.array + if isinstance(array, pd.PeriodIndex): + with suppress(AttributeError): + # this might not be public API + array = array.astype("object") + return np.asarray(array.values, dtype=dtype) + + def get_duck_array(self) -> np.ndarray: + return np.asarray(self) + + @property + def shape(self) -> _Shape: + return (len(self.array),) + + def _convert_scalar(self, item): + if item is pd.NaT: + # work around the impossibility of casting NaT with asarray + # note: it probably would be better in general to return + # pd.Timestamp rather np.than datetime64 but this is easier + # (for now) + item = np.datetime64("NaT", "ns") + elif isinstance(item, timedelta): + item = np.timedelta64(getattr(item, "value", item), "ns") + elif isinstance(item, pd.Timestamp): + # Work around for GH: pydata/xarray#1932 and numpy/numpy#10668 + # numpy fails to convert pd.Timestamp to np.datetime64[ns] + item = np.asarray(item.to_datetime64()) + elif self.dtype != object: + item = np.asarray(item, dtype=self.dtype) + + # as for numpy.ndarray indexing, we always want the result to be + # a NumPy array. + return to_0d_array(item) + + def _prepare_key(self, key: tuple[Any, ...]) -> tuple[Any, ...]: + if isinstance(key, tuple) and len(key) == 1: + # unpack key so it can index a pandas.Index object (pandas.Index + # objects don't like tuples) + (key,) = key + + return key + + def _handle_result( + self, result: Any + ) -> ( + PandasIndexingAdapter + | NumpyIndexingAdapter + | np.ndarray + | np.datetime64 + | np.timedelta64 + ): + if isinstance(result, pd.Index): + return type(self)(result, dtype=self.dtype) + else: + return self._convert_scalar(result) + + def _oindex_get( + self, indexer: OuterIndexer + ) -> ( + PandasIndexingAdapter + | NumpyIndexingAdapter + | np.ndarray + | np.datetime64 + | np.timedelta64 + ): + key = self._prepare_key(indexer.tuple) + + if getattr(key, "ndim", 0) > 1: # Return np-array if multidimensional + indexable = NumpyIndexingAdapter(np.asarray(self)) + return indexable.oindex[indexer] + + result = self.array[key] + + return self._handle_result(result) + + def _vindex_get( + self, indexer: VectorizedIndexer + ) -> ( + PandasIndexingAdapter + | NumpyIndexingAdapter + | np.ndarray + | np.datetime64 + | np.timedelta64 + ): + key = self._prepare_key(indexer.tuple) + + if getattr(key, "ndim", 0) > 1: # Return np-array if multidimensional + indexable = NumpyIndexingAdapter(np.asarray(self)) + return indexable.vindex[indexer] + + result = self.array[key] + + return self._handle_result(result) + + def __getitem__( + self, indexer: ExplicitIndexer + ) -> ( + PandasIndexingAdapter + | NumpyIndexingAdapter + | np.ndarray + | np.datetime64 + | np.timedelta64 + ): + key = self._prepare_key(indexer.tuple) + + if getattr(key, "ndim", 0) > 1: # Return np-array if multidimensional + indexable = NumpyIndexingAdapter(np.asarray(self)) + return indexable[indexer] + + result = self.array[key] + + return self._handle_result(result) + + def transpose(self, order) -> pd.Index: + return self.array # self.array should be always one-dimensional + + def __repr__(self) -> str: + return f"{type(self).__name__}(array={self.array!r}, dtype={self.dtype!r})" + + def copy(self, deep: bool = True) -> PandasIndexingAdapter: + # Not the same as just writing `self.array.copy(deep=deep)`, as + # shallow copies of the underlying numpy.ndarrays become deep ones + # upon pickling + # >>> len(pickle.dumps((self.array, self.array))) + # 4000281 + # >>> len(pickle.dumps((self.array, self.array.copy(deep=False)))) + # 8000341 + array = self.array.copy(deep=True) if deep else self.array + return type(self)(array, self._dtype) + + +class PandasMultiIndexingAdapter(PandasIndexingAdapter): + """Handles explicit indexing for a pandas.MultiIndex. + + This allows creating one instance for each multi-index level while + preserving indexing efficiency (memoized + might reuse another instance with + the same multi-index). + + """ + + __slots__ = ("array", "_dtype", "level", "adapter") + + def __init__( + self, + array: pd.MultiIndex, + dtype: DTypeLike = None, + level: str | None = None, + ): + super().__init__(array, dtype) + self.level = level + + def __array__(self, dtype: DTypeLike = None) -> np.ndarray: + if dtype is None: + dtype = self.dtype + if self.level is not None: + return np.asarray( + self.array.get_level_values(self.level).values, dtype=dtype + ) + else: + return super().__array__(dtype) + + def _convert_scalar(self, item): + if isinstance(item, tuple) and self.level is not None: + idx = tuple(self.array.names).index(self.level) + item = item[idx] + return super()._convert_scalar(item) + + def _oindex_get( + self, indexer: OuterIndexer + ) -> ( + PandasIndexingAdapter + | NumpyIndexingAdapter + | np.ndarray + | np.datetime64 + | np.timedelta64 + ): + result = super()._oindex_get(indexer) + if isinstance(result, type(self)): + result.level = self.level + return result + + def _vindex_get( + self, indexer: VectorizedIndexer + ) -> ( + PandasIndexingAdapter + | NumpyIndexingAdapter + | np.ndarray + | np.datetime64 + | np.timedelta64 + ): + result = super()._vindex_get(indexer) + if isinstance(result, type(self)): + result.level = self.level + return result + + def __getitem__(self, indexer: ExplicitIndexer): + result = super().__getitem__(indexer) + if isinstance(result, type(self)): + result.level = self.level + + return result + + def __repr__(self) -> str: + if self.level is None: + return super().__repr__() + else: + props = ( + f"(array={self.array!r}, level={self.level!r}, dtype={self.dtype!r})" + ) + return f"{type(self).__name__}{props}" + + def _get_array_subset(self) -> np.ndarray: + # used to speed-up the repr for big multi-indexes + threshold = max(100, OPTIONS["display_values_threshold"] + 2) + if self.size > threshold: + pos = threshold // 2 + indices = np.concatenate([np.arange(0, pos), np.arange(-pos, 0)]) + subset = self[OuterIndexer((indices,))] + else: + subset = self + + return np.asarray(subset) + + def _repr_inline_(self, max_width: int) -> str: + from xarray.core.formatting import format_array_flat + + if self.level is None: + return "MultiIndex" + else: + return format_array_flat(self._get_array_subset(), max_width) + + def _repr_html_(self) -> str: + from xarray.core.formatting import short_array_repr + + array_repr = short_array_repr(self._get_array_subset()) + return f"
    {escape(array_repr)}
    " + + def copy(self, deep: bool = True) -> PandasMultiIndexingAdapter: + # see PandasIndexingAdapter.copy + array = self.array.copy(deep=True) if deep else self.array + return type(self)(array, self._dtype, self.level) diff --git a/test/fixtures/whole_applications/xarray/xarray/core/iterators.py b/test/fixtures/whole_applications/xarray/xarray/core/iterators.py new file mode 100644 index 0000000..dd5fa7e --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/core/iterators.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +from collections.abc import Iterator +from typing import Callable + +from xarray.core.treenode import Tree + +"""These iterators are copied from anytree.iterators, with minor modifications.""" + + +class LevelOrderIter(Iterator): + """Iterate over tree applying level-order strategy starting at `node`. + This is the iterator used by `DataTree` to traverse nodes. + + Parameters + ---------- + node : Tree + Node in a tree to begin iteration at. + filter_ : Callable, optional + Function called with every `node` as argument, `node` is returned if `True`. + Default is to iterate through all ``node`` objects in the tree. + stop : Callable, optional + Function that will cause iteration to stop if ``stop`` returns ``True`` + for ``node``. + maxlevel : int, optional + Maximum level to descend in the node hierarchy. + + Examples + -------- + >>> from xarray.core.datatree import DataTree + >>> from xarray.core.iterators import LevelOrderIter + >>> f = DataTree(name="f") + >>> b = DataTree(name="b", parent=f) + >>> a = DataTree(name="a", parent=b) + >>> d = DataTree(name="d", parent=b) + >>> c = DataTree(name="c", parent=d) + >>> e = DataTree(name="e", parent=d) + >>> g = DataTree(name="g", parent=f) + >>> i = DataTree(name="i", parent=g) + >>> h = DataTree(name="h", parent=i) + >>> print(f) + DataTree('f', parent=None) + ├── DataTree('b') + │ ├── DataTree('a') + │ └── DataTree('d') + │ ├── DataTree('c') + │ └── DataTree('e') + └── DataTree('g') + └── DataTree('i') + └── DataTree('h') + >>> [node.name for node in LevelOrderIter(f)] + ['f', 'b', 'g', 'a', 'd', 'i', 'c', 'e', 'h'] + >>> [node.name for node in LevelOrderIter(f, maxlevel=3)] + ['f', 'b', 'g', 'a', 'd', 'i'] + >>> [ + ... node.name + ... for node in LevelOrderIter(f, filter_=lambda n: n.name not in ("e", "g")) + ... ] + ['f', 'b', 'a', 'd', 'i', 'c', 'h'] + >>> [node.name for node in LevelOrderIter(f, stop=lambda n: n.name == "d")] + ['f', 'b', 'g', 'a', 'i', 'h'] + """ + + def __init__( + self, + node: Tree, + filter_: Callable | None = None, + stop: Callable | None = None, + maxlevel: int | None = None, + ): + self.node = node + self.filter_ = filter_ + self.stop = stop + self.maxlevel = maxlevel + self.__iter = None + + def __init(self): + node = self.node + maxlevel = self.maxlevel + filter_ = self.filter_ or LevelOrderIter.__default_filter + stop = self.stop or LevelOrderIter.__default_stop + children = ( + [] + if LevelOrderIter._abort_at_level(1, maxlevel) + else LevelOrderIter._get_children([node], stop) + ) + return self._iter(children, filter_, stop, maxlevel) + + @staticmethod + def __default_filter(node: Tree) -> bool: + return True + + @staticmethod + def __default_stop(node: Tree) -> bool: + return False + + def __iter__(self) -> Iterator[Tree]: + return self + + def __next__(self) -> Iterator[Tree]: + if self.__iter is None: + self.__iter = self.__init() + item = next(self.__iter) # type: ignore[call-overload] + return item + + @staticmethod + def _abort_at_level(level: int, maxlevel: int | None) -> bool: + return maxlevel is not None and level > maxlevel + + @staticmethod + def _get_children(children: list[Tree], stop: Callable) -> list[Tree]: + return [child for child in children if not stop(child)] + + @staticmethod + def _iter( + children: list[Tree], filter_: Callable, stop: Callable, maxlevel: int | None + ) -> Iterator[Tree]: + level = 1 + while children: + next_children = [] + for child in children: + if filter_(child): + yield child + next_children += LevelOrderIter._get_children( + list(child.children.values()), stop + ) + children = next_children + level += 1 + if LevelOrderIter._abort_at_level(level, maxlevel): + break diff --git a/test/fixtures/whole_applications/xarray/xarray/core/merge.py b/test/fixtures/whole_applications/xarray/xarray/core/merge.py new file mode 100644 index 0000000..a90e59e --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/core/merge.py @@ -0,0 +1,1060 @@ +from __future__ import annotations + +from collections import defaultdict +from collections.abc import Hashable, Iterable, Mapping, Sequence, Set +from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union + +import pandas as pd + +from xarray.core import dtypes +from xarray.core.alignment import deep_align +from xarray.core.duck_array_ops import lazy_array_equiv +from xarray.core.indexes import ( + Index, + create_default_index_implicit, + filter_indexes_from_coords, + indexes_equal, +) +from xarray.core.utils import Frozen, compat_dict_union, dict_equiv, equivalent +from xarray.core.variable import Variable, as_variable, calculate_dimensions + +if TYPE_CHECKING: + from xarray.core.coordinates import Coordinates + from xarray.core.dataarray import DataArray + from xarray.core.dataset import Dataset + from xarray.core.types import CombineAttrsOptions, CompatOptions, JoinOptions + + DimsLike = Union[Hashable, Sequence[Hashable]] + ArrayLike = Any + VariableLike = Union[ + ArrayLike, + tuple[DimsLike, ArrayLike], + tuple[DimsLike, ArrayLike, Mapping], + tuple[DimsLike, ArrayLike, Mapping, Mapping], + ] + XarrayValue = Union[DataArray, Variable, VariableLike] + DatasetLike = Union[Dataset, Coordinates, Mapping[Any, XarrayValue]] + CoercibleValue = Union[XarrayValue, pd.Series, pd.DataFrame] + CoercibleMapping = Union[Dataset, Mapping[Any, CoercibleValue]] + + +PANDAS_TYPES = (pd.Series, pd.DataFrame) + +_VALID_COMPAT = Frozen( + { + "identical": 0, + "equals": 1, + "broadcast_equals": 2, + "minimal": 3, + "no_conflicts": 4, + "override": 5, + } +) + + +class Context: + """object carrying the information of a call""" + + def __init__(self, func): + self.func = func + + +def broadcast_dimension_size(variables: list[Variable]) -> dict[Hashable, int]: + """Extract dimension sizes from a dictionary of variables. + + Raises ValueError if any dimensions have different sizes. + """ + dims: dict[Hashable, int] = {} + for var in variables: + for dim, size in zip(var.dims, var.shape): + if dim in dims and size != dims[dim]: + raise ValueError(f"index {dim!r} not aligned") + dims[dim] = size + return dims + + +class MergeError(ValueError): + """Error class for merge failures due to incompatible arguments.""" + + # inherits from ValueError for backward compatibility + # TODO: move this to an xarray.exceptions module? + + +def unique_variable( + name: Hashable, + variables: list[Variable], + compat: CompatOptions = "broadcast_equals", + equals: bool | None = None, +) -> Variable: + """Return the unique variable from a list of variables or raise MergeError. + + Parameters + ---------- + name : hashable + Name for this variable. + variables : list of Variable + List of Variable objects, all of which go by the same name in different + inputs. + compat : {"identical", "equals", "broadcast_equals", "no_conflicts", "override"}, optional + Type of equality check to use. + equals : None or bool, optional + corresponding to result of compat test + + Returns + ------- + Variable to use in the result. + + Raises + ------ + MergeError: if any of the variables are not equal. + """ + out = variables[0] + + if len(variables) == 1 or compat == "override": + return out + + combine_method = None + + if compat == "minimal": + compat = "broadcast_equals" + + if compat == "broadcast_equals": + dim_lengths = broadcast_dimension_size(variables) + out = out.set_dims(dim_lengths) + + if compat == "no_conflicts": + combine_method = "fillna" + + if equals is None: + # first check without comparing values i.e. no computes + for var in variables[1:]: + equals = getattr(out, compat)(var, equiv=lazy_array_equiv) + if equals is not True: + break + + if equals is None: + # now compare values with minimum number of computes + out = out.compute() + for var in variables[1:]: + equals = getattr(out, compat)(var) + if not equals: + break + + if not equals: + raise MergeError( + f"conflicting values for variable {name!r} on objects to be combined. " + "You can skip this check by specifying compat='override'." + ) + + if combine_method: + for var in variables[1:]: + out = getattr(out, combine_method)(var) + + return out + + +def _assert_compat_valid(compat): + if compat not in _VALID_COMPAT: + raise ValueError(f"compat={compat!r} invalid: must be {set(_VALID_COMPAT)}") + + +MergeElement = tuple[Variable, Optional[Index]] + + +def _assert_prioritized_valid( + grouped: dict[Hashable, list[MergeElement]], + prioritized: Mapping[Any, MergeElement], +) -> None: + """Make sure that elements given in prioritized will not corrupt any + index given in grouped. + """ + prioritized_names = set(prioritized) + grouped_by_index: dict[int, list[Hashable]] = defaultdict(list) + indexes: dict[int, Index] = {} + + for name, elements_list in grouped.items(): + for _, index in elements_list: + if index is not None: + grouped_by_index[id(index)].append(name) + indexes[id(index)] = index + + # An index may be corrupted when the set of its corresponding coordinate name(s) + # partially overlaps the set of names given in prioritized + for index_id, index_coord_names in grouped_by_index.items(): + index_names = set(index_coord_names) + common_names = index_names & prioritized_names + if common_names and len(common_names) != len(index_names): + common_names_str = ", ".join(f"{k!r}" for k in common_names) + index_names_str = ", ".join(f"{k!r}" for k in index_coord_names) + raise ValueError( + f"cannot set or update variable(s) {common_names_str}, which would corrupt " + f"the following index built from coordinates {index_names_str}:\n" + f"{indexes[index_id]!r}" + ) + + +def merge_collected( + grouped: dict[Any, list[MergeElement]], + prioritized: Mapping[Any, MergeElement] | None = None, + compat: CompatOptions = "minimal", + combine_attrs: CombineAttrsOptions = "override", + equals: dict[Any, bool] | None = None, +) -> tuple[dict[Hashable, Variable], dict[Hashable, Index]]: + """Merge dicts of variables, while resolving conflicts appropriately. + + Parameters + ---------- + grouped : mapping + prioritized : mapping + compat : str + Type of equality check to use when checking for conflicts. + combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \ + "override"} or callable, default: "override" + A callable or a string indicating how to combine attrs of the objects being + merged: + + - "drop": empty attrs on returned Dataset. + - "identical": all attrs must be the same on every object. + - "no_conflicts": attrs from all objects are combined, any that have + the same name must also have the same value. + - "drop_conflicts": attrs from all objects are combined, any that have + the same name but different values are dropped. + - "override": skip comparing and copy attrs from the first dataset to + the result. + + If a callable, it must expect a sequence of ``attrs`` dicts and a context object + as its only parameters. + equals : mapping, optional + corresponding to result of compat test + + Returns + ------- + Dict with keys taken by the union of keys on list_of_mappings, + and Variable values corresponding to those that should be found on the + merged result. + """ + if prioritized is None: + prioritized = {} + if equals is None: + equals = {} + + _assert_compat_valid(compat) + _assert_prioritized_valid(grouped, prioritized) + + merged_vars: dict[Hashable, Variable] = {} + merged_indexes: dict[Hashable, Index] = {} + index_cmp_cache: dict[tuple[int, int], bool | None] = {} + + for name, elements_list in grouped.items(): + if name in prioritized: + variable, index = prioritized[name] + merged_vars[name] = variable + if index is not None: + merged_indexes[name] = index + else: + indexed_elements = [ + (variable, index) + for variable, index in elements_list + if index is not None + ] + if indexed_elements: + # TODO(shoyer): consider adjusting this logic. Are we really + # OK throwing away variable without an index in favor of + # indexed variables, without even checking if values match? + variable, index = indexed_elements[0] + for other_var, other_index in indexed_elements[1:]: + if not indexes_equal( + index, other_index, variable, other_var, index_cmp_cache + ): + raise MergeError( + f"conflicting values/indexes on objects to be combined fo coordinate {name!r}\n" + f"first index: {index!r}\nsecond index: {other_index!r}\n" + f"first variable: {variable!r}\nsecond variable: {other_var!r}\n" + ) + if compat == "identical": + for other_variable, _ in indexed_elements[1:]: + if not dict_equiv(variable.attrs, other_variable.attrs): + raise MergeError( + "conflicting attribute values on combined " + f"variable {name!r}:\nfirst value: {variable.attrs!r}\nsecond value: {other_variable.attrs!r}" + ) + merged_vars[name] = variable + merged_vars[name].attrs = merge_attrs( + [var.attrs for var, _ in indexed_elements], + combine_attrs=combine_attrs, + ) + merged_indexes[name] = index + else: + variables = [variable for variable, _ in elements_list] + try: + merged_vars[name] = unique_variable( + name, variables, compat, equals.get(name, None) + ) + except MergeError: + if compat != "minimal": + # we need more than "minimal" compatibility (for which + # we drop conflicting coordinates) + raise + + if name in merged_vars: + merged_vars[name].attrs = merge_attrs( + [var.attrs for var in variables], combine_attrs=combine_attrs + ) + + return merged_vars, merged_indexes + + +def collect_variables_and_indexes( + list_of_mappings: Iterable[DatasetLike], + indexes: Mapping[Any, Any] | None = None, +) -> dict[Hashable, list[MergeElement]]: + """Collect variables and indexes from list of mappings of xarray objects. + + Mappings can be Dataset or Coordinates objects, in which case both + variables and indexes are extracted from it. + + It can also have values of one of the following types: + - an xarray.Variable + - a tuple `(dims, data[, attrs[, encoding]])` that can be converted in + an xarray.Variable + - or an xarray.DataArray + + If a mapping of indexes is given, those indexes are assigned to all variables + with a matching key/name. For dimension variables with no matching index, a + default (pandas) index is assigned. DataArray indexes that don't match mapping + keys are also extracted. + + """ + from xarray.core.coordinates import Coordinates + from xarray.core.dataarray import DataArray + from xarray.core.dataset import Dataset + + if indexes is None: + indexes = {} + + grouped: dict[Hashable, list[MergeElement]] = defaultdict(list) + + def append(name, variable, index): + grouped[name].append((variable, index)) + + def append_all(variables, indexes): + for name, variable in variables.items(): + append(name, variable, indexes.get(name)) + + for mapping in list_of_mappings: + if isinstance(mapping, (Coordinates, Dataset)): + append_all(mapping.variables, mapping.xindexes) + continue + + for name, variable in mapping.items(): + if isinstance(variable, DataArray): + coords_ = variable._coords.copy() # use private API for speed + indexes_ = dict(variable._indexes) + # explicitly overwritten variables should take precedence + coords_.pop(name, None) + indexes_.pop(name, None) + append_all(coords_, indexes_) + + variable = as_variable(variable, name=name, auto_convert=False) + if name in indexes: + append(name, variable, indexes[name]) + elif variable.dims == (name,): + idx, idx_vars = create_default_index_implicit(variable) + append_all(idx_vars, {k: idx for k in idx_vars}) + else: + append(name, variable, None) + + return grouped + + +def collect_from_coordinates( + list_of_coords: list[Coordinates], +) -> dict[Hashable, list[MergeElement]]: + """Collect variables and indexes to be merged from Coordinate objects.""" + grouped: dict[Hashable, list[MergeElement]] = defaultdict(list) + + for coords in list_of_coords: + variables = coords.variables + indexes = coords.xindexes + for name, variable in variables.items(): + grouped[name].append((variable, indexes.get(name))) + + return grouped + + +def merge_coordinates_without_align( + objects: list[Coordinates], + prioritized: Mapping[Any, MergeElement] | None = None, + exclude_dims: Set = frozenset(), + combine_attrs: CombineAttrsOptions = "override", +) -> tuple[dict[Hashable, Variable], dict[Hashable, Index]]: + """Merge variables/indexes from coordinates without automatic alignments. + + This function is used for merging coordinate from pre-existing xarray + objects. + """ + collected = collect_from_coordinates(objects) + + if exclude_dims: + filtered: dict[Hashable, list[MergeElement]] = {} + for name, elements in collected.items(): + new_elements = [ + (variable, index) + for variable, index in elements + if exclude_dims.isdisjoint(variable.dims) + ] + if new_elements: + filtered[name] = new_elements + else: + filtered = collected + + # TODO: indexes should probably be filtered in collected elements + # before merging them + merged_coords, merged_indexes = merge_collected( + filtered, prioritized, combine_attrs=combine_attrs + ) + merged_indexes = filter_indexes_from_coords(merged_indexes, set(merged_coords)) + + return merged_coords, merged_indexes + + +def determine_coords( + list_of_mappings: Iterable[DatasetLike], +) -> tuple[set[Hashable], set[Hashable]]: + """Given a list of dicts with xarray object values, identify coordinates. + + Parameters + ---------- + list_of_mappings : list of dict or list of Dataset + Of the same form as the arguments to expand_variable_dicts. + + Returns + ------- + coord_names : set of variable names + noncoord_names : set of variable names + All variable found in the input should appear in either the set of + coordinate or non-coordinate names. + """ + from xarray.core.dataarray import DataArray + from xarray.core.dataset import Dataset + + coord_names: set[Hashable] = set() + noncoord_names: set[Hashable] = set() + + for mapping in list_of_mappings: + if isinstance(mapping, Dataset): + coord_names.update(mapping.coords) + noncoord_names.update(mapping.data_vars) + else: + for name, var in mapping.items(): + if isinstance(var, DataArray): + coords = set(var._coords) # use private API for speed + # explicitly overwritten variables should take precedence + coords.discard(name) + coord_names.update(coords) + + return coord_names, noncoord_names + + +def coerce_pandas_values(objects: Iterable[CoercibleMapping]) -> list[DatasetLike]: + """Convert pandas values found in a list of labeled objects. + + Parameters + ---------- + objects : list of Dataset or mapping + The mappings may contain any sort of objects coercible to + xarray.Variables as keys, including pandas objects. + + Returns + ------- + List of Dataset or dictionary objects. Any inputs or values in the inputs + that were pandas objects have been converted into native xarray objects. + """ + from xarray.core.coordinates import Coordinates + from xarray.core.dataarray import DataArray + from xarray.core.dataset import Dataset + + out: list[DatasetLike] = [] + for obj in objects: + variables: DatasetLike + if isinstance(obj, (Dataset, Coordinates)): + variables = obj + else: + variables = {} + if isinstance(obj, PANDAS_TYPES): + obj = dict(obj.items()) + for k, v in obj.items(): + if isinstance(v, PANDAS_TYPES): + v = DataArray(v) + variables[k] = v + out.append(variables) + return out + + +def _get_priority_vars_and_indexes( + objects: Sequence[DatasetLike], + priority_arg: int | None, + compat: CompatOptions = "equals", +) -> dict[Hashable, MergeElement]: + """Extract the priority variable from a list of mappings. + + We need this method because in some cases the priority argument itself + might have conflicting values (e.g., if it is a dict with two DataArray + values with conflicting coordinate values). + + Parameters + ---------- + objects : sequence of dict-like of Variable + Dictionaries in which to find the priority variables. + priority_arg : int or None + Integer object whose variable should take priority. + compat : {"identical", "equals", "broadcast_equals", "no_conflicts", "override"}, optional + String indicating how to compare non-concatenated variables of the same name for + potential conflicts. This is passed down to merge. + + - "broadcast_equals": all values must be equal when variables are + broadcast against each other to ensure common dimensions. + - "equals": all values and dimensions must be the same. + - "identical": all values, dimensions and attributes must be the + same. + - "no_conflicts": only values which are not null in both datasets + must be equal. The returned dataset then contains the combination + of all non-null values. + - "override": skip comparing and pick variable from first dataset + + Returns + ------- + A dictionary of variables and associated indexes (if any) to prioritize. + """ + if priority_arg is None: + return {} + + collected = collect_variables_and_indexes([objects[priority_arg]]) + variables, indexes = merge_collected(collected, compat=compat) + grouped: dict[Hashable, MergeElement] = {} + for name, variable in variables.items(): + grouped[name] = (variable, indexes.get(name)) + return grouped + + +def merge_coords( + objects: Iterable[CoercibleMapping], + compat: CompatOptions = "minimal", + join: JoinOptions = "outer", + priority_arg: int | None = None, + indexes: Mapping[Any, Index] | None = None, + fill_value: object = dtypes.NA, +) -> tuple[dict[Hashable, Variable], dict[Hashable, Index]]: + """Merge coordinate variables. + + See merge_core below for argument descriptions. This works similarly to + merge_core, except everything we don't worry about whether variables are + coordinates or not. + """ + _assert_compat_valid(compat) + coerced = coerce_pandas_values(objects) + aligned = deep_align( + coerced, join=join, copy=False, indexes=indexes, fill_value=fill_value + ) + collected = collect_variables_and_indexes(aligned, indexes=indexes) + prioritized = _get_priority_vars_and_indexes(aligned, priority_arg, compat=compat) + variables, out_indexes = merge_collected(collected, prioritized, compat=compat) + return variables, out_indexes + + +def merge_attrs(variable_attrs, combine_attrs, context=None): + """Combine attributes from different variables according to combine_attrs""" + if not variable_attrs: + # no attributes to merge + return None + + if callable(combine_attrs): + return combine_attrs(variable_attrs, context=context) + elif combine_attrs == "drop": + return {} + elif combine_attrs == "override": + return dict(variable_attrs[0]) + elif combine_attrs == "no_conflicts": + result = dict(variable_attrs[0]) + for attrs in variable_attrs[1:]: + try: + result = compat_dict_union(result, attrs) + except ValueError as e: + raise MergeError( + "combine_attrs='no_conflicts', but some values are not " + f"the same. Merging {str(result)} with {str(attrs)}" + ) from e + return result + elif combine_attrs == "drop_conflicts": + result = {} + dropped_keys = set() + for attrs in variable_attrs: + result.update( + { + key: value + for key, value in attrs.items() + if key not in result and key not in dropped_keys + } + ) + result = { + key: value + for key, value in result.items() + if key not in attrs or equivalent(attrs[key], value) + } + dropped_keys |= {key for key in attrs if key not in result} + return result + elif combine_attrs == "identical": + result = dict(variable_attrs[0]) + for attrs in variable_attrs[1:]: + if not dict_equiv(result, attrs): + raise MergeError( + f"combine_attrs='identical', but attrs differ. First is {str(result)} " + f", other is {str(attrs)}." + ) + return result + else: + raise ValueError(f"Unrecognised value for combine_attrs={combine_attrs}") + + +class _MergeResult(NamedTuple): + variables: dict[Hashable, Variable] + coord_names: set[Hashable] + dims: dict[Hashable, int] + indexes: dict[Hashable, Index] + attrs: dict[Hashable, Any] + + +def merge_core( + objects: Iterable[CoercibleMapping], + compat: CompatOptions = "broadcast_equals", + join: JoinOptions = "outer", + combine_attrs: CombineAttrsOptions = "override", + priority_arg: int | None = None, + explicit_coords: Iterable[Hashable] | None = None, + indexes: Mapping[Any, Any] | None = None, + fill_value: object = dtypes.NA, + skip_align_args: list[int] | None = None, +) -> _MergeResult: + """Core logic for merging labeled objects. + + This is not public API. + + Parameters + ---------- + objects : list of mapping + All values must be convertible to labeled arrays. + compat : {"identical", "equals", "broadcast_equals", "no_conflicts", "override"}, optional + Compatibility checks to use when merging variables. + join : {"outer", "inner", "left", "right"}, optional + How to combine objects with different indexes. + combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \ + "override"} or callable, default: "override" + How to combine attributes of objects + priority_arg : int, optional + Optional argument in `objects` that takes precedence over the others. + explicit_coords : set, optional + An explicit list of variables from `objects` that are coordinates. + indexes : dict, optional + Dictionary with values given by xarray.Index objects or anything that + may be cast to pandas.Index objects. + fill_value : scalar, optional + Value to use for newly missing values + skip_align_args : list of int, optional + Optional arguments in `objects` that are not included in alignment. + + Returns + ------- + variables : dict + Dictionary of Variable objects. + coord_names : set + Set of coordinate names. + dims : dict + Dictionary mapping from dimension names to sizes. + attrs : dict + Dictionary of attributes + + Raises + ------ + MergeError if the merge cannot be done successfully. + """ + from xarray.core.dataarray import DataArray + from xarray.core.dataset import Dataset + + _assert_compat_valid(compat) + + objects = list(objects) + if skip_align_args is None: + skip_align_args = [] + + skip_align_objs = [(pos, objects.pop(pos)) for pos in skip_align_args] + + coerced = coerce_pandas_values(objects) + aligned = deep_align( + coerced, join=join, copy=False, indexes=indexes, fill_value=fill_value + ) + + for pos, obj in skip_align_objs: + aligned.insert(pos, obj) + + collected = collect_variables_and_indexes(aligned, indexes=indexes) + prioritized = _get_priority_vars_and_indexes(aligned, priority_arg, compat=compat) + variables, out_indexes = merge_collected( + collected, prioritized, compat=compat, combine_attrs=combine_attrs + ) + + dims = calculate_dimensions(variables) + + coord_names, noncoord_names = determine_coords(coerced) + if compat == "minimal": + # coordinates may be dropped in merged results + coord_names.intersection_update(variables) + if explicit_coords is not None: + coord_names.update(explicit_coords) + for dim, size in dims.items(): + if dim in variables: + coord_names.add(dim) + ambiguous_coords = coord_names.intersection(noncoord_names) + if ambiguous_coords: + raise MergeError( + "unable to determine if these variables should be " + f"coordinates or not in the merged result: {ambiguous_coords}" + ) + + attrs = merge_attrs( + [var.attrs for var in coerced if isinstance(var, (Dataset, DataArray))], + combine_attrs, + ) + + return _MergeResult(variables, coord_names, dims, out_indexes, attrs) + + +def merge( + objects: Iterable[DataArray | CoercibleMapping], + compat: CompatOptions = "no_conflicts", + join: JoinOptions = "outer", + fill_value: object = dtypes.NA, + combine_attrs: CombineAttrsOptions = "override", +) -> Dataset: + """Merge any number of xarray objects into a single Dataset as variables. + + Parameters + ---------- + objects : iterable of Dataset or iterable of DataArray or iterable of dict-like + Merge together all variables from these objects. If any of them are + DataArray objects, they must have a name. + compat : {"identical", "equals", "broadcast_equals", "no_conflicts", \ + "override", "minimal"}, default: "no_conflicts" + String indicating how to compare variables of the same name for + potential conflicts: + + - "identical": all values, dimensions and attributes must be the + same. + - "equals": all values and dimensions must be the same. + - "broadcast_equals": all values must be equal when variables are + broadcast against each other to ensure common dimensions. + - "no_conflicts": only values which are not null in both datasets + must be equal. The returned dataset then contains the combination + of all non-null values. + - "override": skip comparing and pick variable from first dataset + - "minimal": drop conflicting coordinates + + join : {"outer", "inner", "left", "right", "exact", "override"}, default: "outer" + String indicating how to combine differing indexes in objects. + + - "outer": use the union of object indexes + - "inner": use the intersection of object indexes + - "left": use indexes from the first object with each dimension + - "right": use indexes from the last object with each dimension + - "exact": instead of aligning, raise `ValueError` when indexes to be + aligned are not equal + - "override": if indexes are of same size, rewrite indexes to be + those of the first object with that dimension. Indexes for the same + dimension must have the same size in all objects. + + fill_value : scalar or dict-like, optional + Value to use for newly missing values. If a dict-like, maps + variable names to fill values. Use a data array's name to + refer to its values. + combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \ + "override"} or callable, default: "override" + A callable or a string indicating how to combine attrs of the objects being + merged: + + - "drop": empty attrs on returned Dataset. + - "identical": all attrs must be the same on every object. + - "no_conflicts": attrs from all objects are combined, any that have + the same name must also have the same value. + - "drop_conflicts": attrs from all objects are combined, any that have + the same name but different values are dropped. + - "override": skip comparing and copy attrs from the first dataset to + the result. + + If a callable, it must expect a sequence of ``attrs`` dicts and a context object + as its only parameters. + + Returns + ------- + Dataset + Dataset with combined variables from each object. + + Examples + -------- + >>> x = xr.DataArray( + ... [[1.0, 2.0], [3.0, 5.0]], + ... dims=("lat", "lon"), + ... coords={"lat": [35.0, 40.0], "lon": [100.0, 120.0]}, + ... name="var1", + ... ) + >>> y = xr.DataArray( + ... [[5.0, 6.0], [7.0, 8.0]], + ... dims=("lat", "lon"), + ... coords={"lat": [35.0, 42.0], "lon": [100.0, 150.0]}, + ... name="var2", + ... ) + >>> z = xr.DataArray( + ... [[0.0, 3.0], [4.0, 9.0]], + ... dims=("time", "lon"), + ... coords={"time": [30.0, 60.0], "lon": [100.0, 150.0]}, + ... name="var3", + ... ) + + >>> x + Size: 32B + array([[1., 2.], + [3., 5.]]) + Coordinates: + * lat (lat) float64 16B 35.0 40.0 + * lon (lon) float64 16B 100.0 120.0 + + >>> y + Size: 32B + array([[5., 6.], + [7., 8.]]) + Coordinates: + * lat (lat) float64 16B 35.0 42.0 + * lon (lon) float64 16B 100.0 150.0 + + >>> z + Size: 32B + array([[0., 3.], + [4., 9.]]) + Coordinates: + * time (time) float64 16B 30.0 60.0 + * lon (lon) float64 16B 100.0 150.0 + + >>> xr.merge([x, y, z]) + Size: 256B + Dimensions: (lat: 3, lon: 3, time: 2) + Coordinates: + * lat (lat) float64 24B 35.0 40.0 42.0 + * lon (lon) float64 24B 100.0 120.0 150.0 + * time (time) float64 16B 30.0 60.0 + Data variables: + var1 (lat, lon) float64 72B 1.0 2.0 nan 3.0 5.0 nan nan nan nan + var2 (lat, lon) float64 72B 5.0 nan 6.0 nan nan nan 7.0 nan 8.0 + var3 (time, lon) float64 48B 0.0 nan 3.0 4.0 nan 9.0 + + >>> xr.merge([x, y, z], compat="identical") + Size: 256B + Dimensions: (lat: 3, lon: 3, time: 2) + Coordinates: + * lat (lat) float64 24B 35.0 40.0 42.0 + * lon (lon) float64 24B 100.0 120.0 150.0 + * time (time) float64 16B 30.0 60.0 + Data variables: + var1 (lat, lon) float64 72B 1.0 2.0 nan 3.0 5.0 nan nan nan nan + var2 (lat, lon) float64 72B 5.0 nan 6.0 nan nan nan 7.0 nan 8.0 + var3 (time, lon) float64 48B 0.0 nan 3.0 4.0 nan 9.0 + + >>> xr.merge([x, y, z], compat="equals") + Size: 256B + Dimensions: (lat: 3, lon: 3, time: 2) + Coordinates: + * lat (lat) float64 24B 35.0 40.0 42.0 + * lon (lon) float64 24B 100.0 120.0 150.0 + * time (time) float64 16B 30.0 60.0 + Data variables: + var1 (lat, lon) float64 72B 1.0 2.0 nan 3.0 5.0 nan nan nan nan + var2 (lat, lon) float64 72B 5.0 nan 6.0 nan nan nan 7.0 nan 8.0 + var3 (time, lon) float64 48B 0.0 nan 3.0 4.0 nan 9.0 + + >>> xr.merge([x, y, z], compat="equals", fill_value=-999.0) + Size: 256B + Dimensions: (lat: 3, lon: 3, time: 2) + Coordinates: + * lat (lat) float64 24B 35.0 40.0 42.0 + * lon (lon) float64 24B 100.0 120.0 150.0 + * time (time) float64 16B 30.0 60.0 + Data variables: + var1 (lat, lon) float64 72B 1.0 2.0 -999.0 3.0 ... -999.0 -999.0 -999.0 + var2 (lat, lon) float64 72B 5.0 -999.0 6.0 -999.0 ... 7.0 -999.0 8.0 + var3 (time, lon) float64 48B 0.0 -999.0 3.0 4.0 -999.0 9.0 + + >>> xr.merge([x, y, z], join="override") + Size: 144B + Dimensions: (lat: 2, lon: 2, time: 2) + Coordinates: + * lat (lat) float64 16B 35.0 40.0 + * lon (lon) float64 16B 100.0 120.0 + * time (time) float64 16B 30.0 60.0 + Data variables: + var1 (lat, lon) float64 32B 1.0 2.0 3.0 5.0 + var2 (lat, lon) float64 32B 5.0 6.0 7.0 8.0 + var3 (time, lon) float64 32B 0.0 3.0 4.0 9.0 + + >>> xr.merge([x, y, z], join="inner") + Size: 64B + Dimensions: (lat: 1, lon: 1, time: 2) + Coordinates: + * lat (lat) float64 8B 35.0 + * lon (lon) float64 8B 100.0 + * time (time) float64 16B 30.0 60.0 + Data variables: + var1 (lat, lon) float64 8B 1.0 + var2 (lat, lon) float64 8B 5.0 + var3 (time, lon) float64 16B 0.0 4.0 + + >>> xr.merge([x, y, z], compat="identical", join="inner") + Size: 64B + Dimensions: (lat: 1, lon: 1, time: 2) + Coordinates: + * lat (lat) float64 8B 35.0 + * lon (lon) float64 8B 100.0 + * time (time) float64 16B 30.0 60.0 + Data variables: + var1 (lat, lon) float64 8B 1.0 + var2 (lat, lon) float64 8B 5.0 + var3 (time, lon) float64 16B 0.0 4.0 + + >>> xr.merge([x, y, z], compat="broadcast_equals", join="outer") + Size: 256B + Dimensions: (lat: 3, lon: 3, time: 2) + Coordinates: + * lat (lat) float64 24B 35.0 40.0 42.0 + * lon (lon) float64 24B 100.0 120.0 150.0 + * time (time) float64 16B 30.0 60.0 + Data variables: + var1 (lat, lon) float64 72B 1.0 2.0 nan 3.0 5.0 nan nan nan nan + var2 (lat, lon) float64 72B 5.0 nan 6.0 nan nan nan 7.0 nan 8.0 + var3 (time, lon) float64 48B 0.0 nan 3.0 4.0 nan 9.0 + + >>> xr.merge([x, y, z], join="exact") + Traceback (most recent call last): + ... + ValueError: cannot align objects with join='exact' where ... + + Raises + ------ + xarray.MergeError + If any variables with the same name have conflicting values. + + See also + -------- + concat + combine_nested + combine_by_coords + """ + + from xarray.core.coordinates import Coordinates + from xarray.core.dataarray import DataArray + from xarray.core.dataset import Dataset + + dict_like_objects = [] + for obj in objects: + if not isinstance(obj, (DataArray, Dataset, Coordinates, dict)): + raise TypeError( + "objects must be an iterable containing only " + "Dataset(s), DataArray(s), and dictionaries." + ) + + if isinstance(obj, DataArray): + obj = obj.to_dataset(promote_attrs=True) + elif isinstance(obj, Coordinates): + obj = obj.to_dataset() + dict_like_objects.append(obj) + + merge_result = merge_core( + dict_like_objects, + compat, + join, + combine_attrs=combine_attrs, + fill_value=fill_value, + ) + return Dataset._construct_direct(**merge_result._asdict()) + + +def dataset_merge_method( + dataset: Dataset, + other: CoercibleMapping, + overwrite_vars: Hashable | Iterable[Hashable], + compat: CompatOptions, + join: JoinOptions, + fill_value: Any, + combine_attrs: CombineAttrsOptions, +) -> _MergeResult: + """Guts of the Dataset.merge method.""" + # we are locked into supporting overwrite_vars for the Dataset.merge + # method due for backwards compatibility + # TODO: consider deprecating it? + + if not isinstance(overwrite_vars, str) and isinstance(overwrite_vars, Iterable): + overwrite_vars = set(overwrite_vars) + else: + overwrite_vars = {overwrite_vars} + + if not overwrite_vars: + objs = [dataset, other] + priority_arg = None + elif overwrite_vars == set(other): + objs = [dataset, other] + priority_arg = 1 + else: + other_overwrite: dict[Hashable, CoercibleValue] = {} + other_no_overwrite: dict[Hashable, CoercibleValue] = {} + for k, v in other.items(): + if k in overwrite_vars: + other_overwrite[k] = v + else: + other_no_overwrite[k] = v + objs = [dataset, other_no_overwrite, other_overwrite] + priority_arg = 2 + + return merge_core( + objs, + compat, + join, + priority_arg=priority_arg, + fill_value=fill_value, + combine_attrs=combine_attrs, + ) + + +def dataset_update_method(dataset: Dataset, other: CoercibleMapping) -> _MergeResult: + """Guts of the Dataset.update method. + + This drops a duplicated coordinates from `other` if `other` is not an + `xarray.Dataset`, e.g., if it's a dict with DataArray values (GH2068, + GH2180). + """ + from xarray.core.dataarray import DataArray + from xarray.core.dataset import Dataset + + if not isinstance(other, Dataset): + other = dict(other) + for key, value in other.items(): + if isinstance(value, DataArray): + # drop conflicting coordinates + coord_names = [ + c + for c in value.coords + if c not in value.dims and c in dataset.coords + ] + if coord_names: + other[key] = value.drop_vars(coord_names) + + return merge_core( + [dataset, other], + priority_arg=1, + indexes=dataset.xindexes, + combine_attrs="override", + ) diff --git a/test/fixtures/whole_applications/xarray/xarray/core/missing.py b/test/fixtures/whole_applications/xarray/xarray/core/missing.py new file mode 100644 index 0000000..45abc70 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/core/missing.py @@ -0,0 +1,840 @@ +from __future__ import annotations + +import datetime as dt +import warnings +from collections.abc import Hashable, Sequence +from functools import partial +from numbers import Number +from typing import TYPE_CHECKING, Any, Callable, get_args + +import numpy as np +import pandas as pd + +from xarray.core import utils +from xarray.core.common import _contains_datetime_like_objects, ones_like +from xarray.core.computation import apply_ufunc +from xarray.core.duck_array_ops import ( + datetime_to_numeric, + push, + reshape, + timedelta_to_numeric, +) +from xarray.core.options import _get_keep_attrs +from xarray.core.types import Interp1dOptions, InterpOptions +from xarray.core.utils import OrderedSet, is_scalar +from xarray.core.variable import Variable, broadcast_variables +from xarray.namedarray.parallelcompat import get_chunked_array_type +from xarray.namedarray.pycompat import is_chunked_array + +if TYPE_CHECKING: + from xarray.core.dataarray import DataArray + from xarray.core.dataset import Dataset + + +def _get_nan_block_lengths( + obj: Dataset | DataArray | Variable, dim: Hashable, index: Variable +): + """ + Return an object where each NaN element in 'obj' is replaced by the + length of the gap the element is in. + """ + + # make variable so that we get broadcasting for free + index = Variable([dim], index) + + # algorithm from https://github.com/pydata/xarray/pull/3302#discussion_r324707072 + arange = ones_like(obj) * index + valid = obj.notnull() + valid_arange = arange.where(valid) + cumulative_nans = valid_arange.ffill(dim=dim).fillna(index[0]) + + nan_block_lengths = ( + cumulative_nans.diff(dim=dim, label="upper") + .reindex({dim: obj[dim]}) + .where(valid) + .bfill(dim=dim) + .where(~valid, 0) + .fillna(index[-1] - valid_arange.max(dim=[dim])) + ) + + return nan_block_lengths + + +class BaseInterpolator: + """Generic interpolator class for normalizing interpolation methods""" + + cons_kwargs: dict[str, Any] + call_kwargs: dict[str, Any] + f: Callable + method: str + + def __call__(self, x): + return self.f(x, **self.call_kwargs) + + def __repr__(self): + return f"{self.__class__.__name__}: method={self.method}" + + +class NumpyInterpolator(BaseInterpolator): + """One-dimensional linear interpolation. + + See Also + -------- + numpy.interp + """ + + def __init__(self, xi, yi, method="linear", fill_value=None, period=None): + if method != "linear": + raise ValueError("only method `linear` is valid for the NumpyInterpolator") + + self.method = method + self.f = np.interp + self.cons_kwargs = {} + self.call_kwargs = {"period": period} + + self._xi = xi + self._yi = yi + + nan = np.nan if yi.dtype.kind != "c" else np.nan + np.nan * 1j + + if fill_value is None: + self._left = nan + self._right = nan + elif isinstance(fill_value, Sequence) and len(fill_value) == 2: + self._left = fill_value[0] + self._right = fill_value[1] + elif is_scalar(fill_value): + self._left = fill_value + self._right = fill_value + else: + raise ValueError(f"{fill_value} is not a valid fill_value") + + def __call__(self, x): + return self.f( + x, + self._xi, + self._yi, + left=self._left, + right=self._right, + **self.call_kwargs, + ) + + +class ScipyInterpolator(BaseInterpolator): + """Interpolate a 1-D function using Scipy interp1d + + See Also + -------- + scipy.interpolate.interp1d + """ + + def __init__( + self, + xi, + yi, + method=None, + fill_value=None, + assume_sorted=True, + copy=False, + bounds_error=False, + order=None, + **kwargs, + ): + from scipy.interpolate import interp1d + + if method is None: + raise ValueError( + "method is a required argument, please supply a " + "valid scipy.inter1d method (kind)" + ) + + if method == "polynomial": + if order is None: + raise ValueError("order is required when method=polynomial") + method = order + + self.method = method + + self.cons_kwargs = kwargs + self.call_kwargs = {} + + nan = np.nan if yi.dtype.kind != "c" else np.nan + np.nan * 1j + + if fill_value is None and method == "linear": + fill_value = nan, nan + elif fill_value is None: + fill_value = nan + + self.f = interp1d( + xi, + yi, + kind=self.method, + fill_value=fill_value, + bounds_error=bounds_error, + assume_sorted=assume_sorted, + copy=copy, + **self.cons_kwargs, + ) + + +class SplineInterpolator(BaseInterpolator): + """One-dimensional smoothing spline fit to a given set of data points. + + See Also + -------- + scipy.interpolate.UnivariateSpline + """ + + def __init__( + self, + xi, + yi, + method="spline", + fill_value=None, + order=3, + nu=0, + ext=None, + **kwargs, + ): + from scipy.interpolate import UnivariateSpline + + if method != "spline": + raise ValueError("only method `spline` is valid for the SplineInterpolator") + + self.method = method + self.cons_kwargs = kwargs + self.call_kwargs = {"nu": nu, "ext": ext} + + if fill_value is not None: + raise ValueError("SplineInterpolator does not support fill_value") + + self.f = UnivariateSpline(xi, yi, k=order, **self.cons_kwargs) + + +def _apply_over_vars_with_dim(func, self, dim=None, **kwargs): + """Wrapper for datasets""" + ds = type(self)(coords=self.coords, attrs=self.attrs) + + for name, var in self.data_vars.items(): + if dim in var.dims: + ds[name] = func(var, dim=dim, **kwargs) + else: + ds[name] = var + + return ds + + +def get_clean_interp_index( + arr, dim: Hashable, use_coordinate: str | bool = True, strict: bool = True +): + """Return index to use for x values in interpolation or curve fitting. + + Parameters + ---------- + arr : DataArray + Array to interpolate or fit to a curve. + dim : str + Name of dimension along which to fit. + use_coordinate : str or bool + If use_coordinate is True, the coordinate that shares the name of the + dimension along which interpolation is being performed will be used as the + x values. If False, the x values are set as an equally spaced sequence. + strict : bool + Whether to raise errors if the index is either non-unique or non-monotonic (default). + + Returns + ------- + Variable + Numerical values for the x-coordinates. + + Notes + ----- + If indexing is along the time dimension, datetime coordinates are converted + to time deltas with respect to 1970-01-01. + """ + + # Question: If use_coordinate is a string, what role does `dim` play? + from xarray.coding.cftimeindex import CFTimeIndex + + if use_coordinate is False: + axis = arr.get_axis_num(dim) + return np.arange(arr.shape[axis], dtype=np.float64) + + if use_coordinate is True: + index = arr.get_index(dim) + + else: # string + index = arr.coords[use_coordinate] + if index.ndim != 1: + raise ValueError( + f"Coordinates used for interpolation must be 1D, " + f"{use_coordinate} is {index.ndim}D." + ) + index = index.to_index() + + # TODO: index.name is None for multiindexes + # set name for nice error messages below + if isinstance(index, pd.MultiIndex): + index.name = dim + + if strict: + if not index.is_monotonic_increasing: + raise ValueError(f"Index {index.name!r} must be monotonically increasing") + + if not index.is_unique: + raise ValueError(f"Index {index.name!r} has duplicate values") + + # Special case for non-standard calendar indexes + # Numerical datetime values are defined with respect to 1970-01-01T00:00:00 in units of nanoseconds + if isinstance(index, (CFTimeIndex, pd.DatetimeIndex)): + offset = type(index[0])(1970, 1, 1) + if isinstance(index, CFTimeIndex): + index = index.values + index = Variable( + data=datetime_to_numeric(index, offset=offset, datetime_unit="ns"), + dims=(dim,), + ) + + # raise if index cannot be cast to a float (e.g. MultiIndex) + try: + index = index.values.astype(np.float64) + except (TypeError, ValueError): + # pandas raises a TypeError + # xarray/numpy raise a ValueError + raise TypeError( + f"Index {index.name!r} must be castable to float64 to support " + f"interpolation or curve fitting, got {type(index).__name__}." + ) + + return index + + +def interp_na( + self, + dim: Hashable | None = None, + use_coordinate: bool | str = True, + method: InterpOptions = "linear", + limit: int | None = None, + max_gap: int | float | str | pd.Timedelta | np.timedelta64 | dt.timedelta = None, + keep_attrs: bool | None = None, + **kwargs, +): + """Interpolate values according to different methods.""" + from xarray.coding.cftimeindex import CFTimeIndex + + if dim is None: + raise NotImplementedError("dim is a required argument") + + if limit is not None: + valids = _get_valid_fill_mask(self, dim, limit) + + if max_gap is not None: + max_type = type(max_gap).__name__ + if not is_scalar(max_gap): + raise ValueError("max_gap must be a scalar.") + + if ( + dim in self._indexes + and isinstance( + self._indexes[dim].to_pandas_index(), (pd.DatetimeIndex, CFTimeIndex) + ) + and use_coordinate + ): + # Convert to float + max_gap = timedelta_to_numeric(max_gap) + + if not use_coordinate: + if not isinstance(max_gap, (Number, np.number)): + raise TypeError( + f"Expected integer or floating point max_gap since use_coordinate=False. Received {max_type}." + ) + + # method + index = get_clean_interp_index(self, dim, use_coordinate=use_coordinate) + interp_class, kwargs = _get_interpolator(method, **kwargs) + interpolator = partial(func_interpolate_na, interp_class, **kwargs) + + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=True) + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "overflow", RuntimeWarning) + warnings.filterwarnings("ignore", "invalid value", RuntimeWarning) + arr = apply_ufunc( + interpolator, + self, + index, + input_core_dims=[[dim], [dim]], + output_core_dims=[[dim]], + output_dtypes=[self.dtype], + dask="parallelized", + vectorize=True, + keep_attrs=keep_attrs, + ).transpose(*self.dims) + + if limit is not None: + arr = arr.where(valids) + + if max_gap is not None: + if dim not in self.coords: + raise NotImplementedError( + "max_gap not implemented for unlabeled coordinates yet." + ) + nan_block_lengths = _get_nan_block_lengths(self, dim, index) + arr = arr.where(nan_block_lengths <= max_gap) + + return arr + + +def func_interpolate_na(interpolator, y, x, **kwargs): + """helper function to apply interpolation along 1 dimension""" + # reversed arguments are so that attrs are preserved from da, not index + # it would be nice if this wasn't necessary, works around: + # "ValueError: assignment destination is read-only" in assignment below + out = y.copy() + + nans = pd.isnull(y) + nonans = ~nans + + # fast track for no-nans, all nan but one, and all-nans cases + n_nans = nans.sum() + if n_nans == 0 or n_nans >= len(y) - 1: + return y + + f = interpolator(x[nonans], y[nonans], **kwargs) + out[nans] = f(x[nans]) + return out + + +def _bfill(arr, n=None, axis=-1): + """inverse of ffill""" + arr = np.flip(arr, axis=axis) + + # fill + arr = push(arr, axis=axis, n=n) + + # reverse back to original + return np.flip(arr, axis=axis) + + +def ffill(arr, dim=None, limit=None): + """forward fill missing values""" + + axis = arr.get_axis_num(dim) + + # work around for bottleneck 178 + _limit = limit if limit is not None else arr.shape[axis] + + return apply_ufunc( + push, + arr, + dask="allowed", + keep_attrs=True, + output_dtypes=[arr.dtype], + kwargs=dict(n=_limit, axis=axis), + ).transpose(*arr.dims) + + +def bfill(arr, dim=None, limit=None): + """backfill missing values""" + + axis = arr.get_axis_num(dim) + + # work around for bottleneck 178 + _limit = limit if limit is not None else arr.shape[axis] + + return apply_ufunc( + _bfill, + arr, + dask="allowed", + keep_attrs=True, + output_dtypes=[arr.dtype], + kwargs=dict(n=_limit, axis=axis), + ).transpose(*arr.dims) + + +def _import_interpolant(interpolant, method): + """Import interpolant from scipy.interpolate.""" + try: + from scipy import interpolate + + return getattr(interpolate, interpolant) + except ImportError as e: + raise ImportError(f"Interpolation with method {method} requires scipy.") from e + + +def _get_interpolator( + method: InterpOptions, vectorizeable_only: bool = False, **kwargs +): + """helper function to select the appropriate interpolator class + + returns interpolator class and keyword arguments for the class + """ + interp_class: ( + type[NumpyInterpolator] | type[ScipyInterpolator] | type[SplineInterpolator] + ) + + interp1d_methods = get_args(Interp1dOptions) + valid_methods = tuple(vv for v in get_args(InterpOptions) for vv in get_args(v)) + + # prioritize scipy.interpolate + if ( + method == "linear" + and not kwargs.get("fill_value", None) == "extrapolate" + and not vectorizeable_only + ): + kwargs.update(method=method) + interp_class = NumpyInterpolator + + elif method in valid_methods: + if method in interp1d_methods: + kwargs.update(method=method) + interp_class = ScipyInterpolator + elif vectorizeable_only: + raise ValueError( + f"{method} is not a vectorizeable interpolator. " + f"Available methods are {interp1d_methods}" + ) + elif method == "barycentric": + interp_class = _import_interpolant("BarycentricInterpolator", method) + elif method in ["krogh", "krog"]: + interp_class = _import_interpolant("KroghInterpolator", method) + elif method == "pchip": + interp_class = _import_interpolant("PchipInterpolator", method) + elif method == "spline": + kwargs.update(method=method) + interp_class = SplineInterpolator + elif method == "akima": + interp_class = _import_interpolant("Akima1DInterpolator", method) + else: + raise ValueError(f"{method} is not a valid scipy interpolator") + else: + raise ValueError(f"{method} is not a valid interpolator") + + return interp_class, kwargs + + +def _get_interpolator_nd(method, **kwargs): + """helper function to select the appropriate interpolator class + + returns interpolator class and keyword arguments for the class + """ + valid_methods = ["linear", "nearest"] + + if method in valid_methods: + kwargs.update(method=method) + interp_class = _import_interpolant("interpn", method) + else: + raise ValueError( + f"{method} is not a valid interpolator for interpolating " + "over multiple dimensions." + ) + + return interp_class, kwargs + + +def _get_valid_fill_mask(arr, dim, limit): + """helper function to determine values that can be filled when limit is not + None""" + kw = {dim: limit + 1} + # we explicitly use construct method to avoid copy. + new_dim = utils.get_temp_dimname(arr.dims, "_window") + return ( + arr.isnull() + .rolling(min_periods=1, **kw) + .construct(new_dim, fill_value=False) + .sum(new_dim, skipna=False) + ) <= limit + + +def _localize(var, indexes_coords): + """Speed up for linear and nearest neighbor method. + Only consider a subspace that is needed for the interpolation + """ + indexes = {} + for dim, [x, new_x] in indexes_coords.items(): + new_x_loaded = new_x.values + minval = np.nanmin(new_x_loaded) + maxval = np.nanmax(new_x_loaded) + index = x.to_index() + imin, imax = index.get_indexer([minval, maxval], method="nearest") + indexes[dim] = slice(max(imin - 2, 0), imax + 2) + indexes_coords[dim] = (x[indexes[dim]], new_x) + return var.isel(**indexes), indexes_coords + + +def _floatize_x(x, new_x): + """Make x and new_x float. + This is particularly useful for datetime dtype. + x, new_x: tuple of np.ndarray + """ + x = list(x) + new_x = list(new_x) + for i in range(len(x)): + if _contains_datetime_like_objects(x[i]): + # Scipy casts coordinates to np.float64, which is not accurate + # enough for datetime64 (uses 64bit integer). + # We assume that the most of the bits are used to represent the + # offset (min(x)) and the variation (x - min(x)) can be + # represented by float. + xmin = x[i].values.min() + x[i] = x[i]._to_numeric(offset=xmin, dtype=np.float64) + new_x[i] = new_x[i]._to_numeric(offset=xmin, dtype=np.float64) + return x, new_x + + +def interp(var, indexes_coords, method: InterpOptions, **kwargs): + """Make an interpolation of Variable + + Parameters + ---------- + var : Variable + indexes_coords + Mapping from dimension name to a pair of original and new coordinates. + Original coordinates should be sorted in strictly ascending order. + Note that all the coordinates should be Variable objects. + method : string + One of {'linear', 'nearest', 'zero', 'slinear', 'quadratic', + 'cubic'}. For multidimensional interpolation, only + {'linear', 'nearest'} can be used. + **kwargs + keyword arguments to be passed to scipy.interpolate + + Returns + ------- + Interpolated Variable + + See Also + -------- + DataArray.interp + Dataset.interp + """ + if not indexes_coords: + return var.copy() + + # default behavior + kwargs["bounds_error"] = kwargs.get("bounds_error", False) + + result = var + # decompose the interpolation into a succession of independent interpolation + for indexes_coords in decompose_interp(indexes_coords): + var = result + + # target dimensions + dims = list(indexes_coords) + x, new_x = zip(*[indexes_coords[d] for d in dims]) + destination = broadcast_variables(*new_x) + + # transpose to make the interpolated axis to the last position + broadcast_dims = [d for d in var.dims if d not in dims] + original_dims = broadcast_dims + dims + new_dims = broadcast_dims + list(destination[0].dims) + interped = interp_func( + var.transpose(*original_dims).data, x, destination, method, kwargs + ) + + result = Variable(new_dims, interped, attrs=var.attrs, fastpath=True) + + # dimension of the output array + out_dims: OrderedSet = OrderedSet() + for d in var.dims: + if d in dims: + out_dims.update(indexes_coords[d][1].dims) + else: + out_dims.add(d) + if len(out_dims) > 1: + result = result.transpose(*out_dims) + return result + + +def interp_func(var, x, new_x, method: InterpOptions, kwargs): + """ + multi-dimensional interpolation for array-like. Interpolated axes should be + located in the last position. + + Parameters + ---------- + var : np.ndarray or dask.array.Array + Array to be interpolated. The final dimension is interpolated. + x : a list of 1d array. + Original coordinates. Should not contain NaN. + new_x : a list of 1d array + New coordinates. Should not contain NaN. + method : string + {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic'} for + 1-dimensional interpolation. + {'linear', 'nearest'} for multidimensional interpolation + **kwargs + Optional keyword arguments to be passed to scipy.interpolator + + Returns + ------- + interpolated: array + Interpolated array + + Notes + ----- + This requires scipy installed. + + See Also + -------- + scipy.interpolate.interp1d + """ + if not x: + return var.copy() + + if len(x) == 1: + func, kwargs = _get_interpolator(method, vectorizeable_only=True, **kwargs) + else: + func, kwargs = _get_interpolator_nd(method, **kwargs) + + if is_chunked_array(var): + chunkmanager = get_chunked_array_type(var) + + ndim = var.ndim + nconst = ndim - len(x) + + out_ind = list(range(nconst)) + list(range(ndim, ndim + new_x[0].ndim)) + + # blockwise args format + x_arginds = [[_x, (nconst + index,)] for index, _x in enumerate(x)] + x_arginds = [item for pair in x_arginds for item in pair] + new_x_arginds = [ + [_x, [ndim + index for index in range(_x.ndim)]] for _x in new_x + ] + new_x_arginds = [item for pair in new_x_arginds for item in pair] + + args = (var, range(ndim), *x_arginds, *new_x_arginds) + + _, rechunked = chunkmanager.unify_chunks(*args) + + args = tuple(elem for pair in zip(rechunked, args[1::2]) for elem in pair) + + new_x = rechunked[1 + (len(rechunked) - 1) // 2 :] + + new_x0_chunks = new_x[0].chunks + new_x0_shape = new_x[0].shape + new_x0_chunks_is_not_none = new_x0_chunks is not None + new_axes = { + ndim + i: new_x0_chunks[i] if new_x0_chunks_is_not_none else new_x0_shape[i] + for i in range(new_x[0].ndim) + } + + # if useful, reuse localize for each chunk of new_x + localize = (method in ["linear", "nearest"]) and new_x0_chunks_is_not_none + + # scipy.interpolate.interp1d always forces to float. + # Use the same check for blockwise as well: + if not issubclass(var.dtype.type, np.inexact): + dtype = float + else: + dtype = var.dtype + + meta = var._meta + + return chunkmanager.blockwise( + _chunked_aware_interpnd, + out_ind, + *args, + interp_func=func, + interp_kwargs=kwargs, + localize=localize, + concatenate=True, + dtype=dtype, + new_axes=new_axes, + meta=meta, + align_arrays=False, + ) + + return _interpnd(var, x, new_x, func, kwargs) + + +def _interp1d(var, x, new_x, func, kwargs): + # x, new_x are tuples of size 1. + x, new_x = x[0], new_x[0] + rslt = func(x, var, assume_sorted=True, **kwargs)(np.ravel(new_x)) + if new_x.ndim > 1: + return reshape(rslt, (var.shape[:-1] + new_x.shape)) + if new_x.ndim == 0: + return rslt[..., -1] + return rslt + + +def _interpnd(var, x, new_x, func, kwargs): + x, new_x = _floatize_x(x, new_x) + + if len(x) == 1: + return _interp1d(var, x, new_x, func, kwargs) + + # move the interpolation axes to the start position + var = var.transpose(range(-len(x), var.ndim - len(x))) + # stack new_x to 1 vector, with reshape + xi = np.stack([x1.values.ravel() for x1 in new_x], axis=-1) + rslt = func(x, var, xi, **kwargs) + # move back the interpolation axes to the last position + rslt = rslt.transpose(range(-rslt.ndim + 1, 1)) + return reshape(rslt, rslt.shape[:-1] + new_x[0].shape) + + +def _chunked_aware_interpnd(var, *coords, interp_func, interp_kwargs, localize=True): + """Wrapper for `_interpnd` through `blockwise` for chunked arrays. + + The first half arrays in `coords` are original coordinates, + the other half are destination coordinates + """ + n_x = len(coords) // 2 + nconst = len(var.shape) - n_x + + # _interpnd expect coords to be Variables + x = [Variable([f"dim_{nconst + dim}"], _x) for dim, _x in enumerate(coords[:n_x])] + new_x = [ + Variable([f"dim_{len(var.shape) + dim}" for dim in range(len(_x.shape))], _x) + for _x in coords[n_x:] + ] + + if localize: + # _localize expect var to be a Variable + var = Variable([f"dim_{dim}" for dim in range(len(var.shape))], var) + + indexes_coords = {_x.dims[0]: (_x, _new_x) for _x, _new_x in zip(x, new_x)} + + # simple speed up for the local interpolation + var, indexes_coords = _localize(var, indexes_coords) + x, new_x = zip(*[indexes_coords[d] for d in indexes_coords]) + + # put var back as a ndarray + var = var.data + + return _interpnd(var, x, new_x, interp_func, interp_kwargs) + + +def decompose_interp(indexes_coords): + """Decompose the interpolation into a succession of independent interpolation keeping the order""" + + dest_dims = [ + dest[1].dims if dest[1].ndim > 0 else [dim] + for dim, dest in indexes_coords.items() + ] + partial_dest_dims = [] + partial_indexes_coords = {} + for i, index_coords in enumerate(indexes_coords.items()): + partial_indexes_coords.update([index_coords]) + + if i == len(dest_dims) - 1: + break + + partial_dest_dims += [dest_dims[i]] + other_dims = dest_dims[i + 1 :] + + s_partial_dest_dims = {dim for dims in partial_dest_dims for dim in dims} + s_other_dims = {dim for dims in other_dims for dim in dims} + + if not s_partial_dest_dims.intersection(s_other_dims): + # this interpolation is orthogonal to the rest + + yield partial_indexes_coords + + partial_dest_dims = [] + partial_indexes_coords = {} + + yield partial_indexes_coords diff --git a/test/fixtures/whole_applications/xarray/xarray/core/nanops.py b/test/fixtures/whole_applications/xarray/xarray/core/nanops.py new file mode 100644 index 0000000..fc72401 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/core/nanops.py @@ -0,0 +1,177 @@ +from __future__ import annotations + +import warnings + +import numpy as np + +from xarray.core import dtypes, duck_array_ops, nputils, utils +from xarray.core.duck_array_ops import ( + astype, + count, + fillna, + isnull, + sum_where, + where, + where_method, +) + + +def _maybe_null_out(result, axis, mask, min_count=1): + """ + xarray version of pandas.core.nanops._maybe_null_out + """ + if axis is not None and getattr(result, "ndim", False): + null_mask = ( + np.take(mask.shape, axis).prod() + - duck_array_ops.sum(mask, axis) + - min_count + ) < 0 + dtype, fill_value = dtypes.maybe_promote(result.dtype) + result = where(null_mask, fill_value, astype(result, dtype)) + + elif getattr(result, "dtype", None) not in dtypes.NAT_TYPES: + null_mask = mask.size - duck_array_ops.sum(mask) + result = where(null_mask < min_count, np.nan, result) + + return result + + +def _nan_argminmax_object(func, fill_value, value, axis=None, **kwargs): + """In house nanargmin, nanargmax for object arrays. Always return integer + type + """ + valid_count = count(value, axis=axis) + value = fillna(value, fill_value) + data = getattr(np, func)(value, axis=axis, **kwargs) + + # TODO This will evaluate dask arrays and might be costly. + if (valid_count == 0).any(): + raise ValueError("All-NaN slice encountered") + + return data + + +def _nan_minmax_object(func, fill_value, value, axis=None, **kwargs): + """In house nanmin and nanmax for object array""" + valid_count = count(value, axis=axis) + filled_value = fillna(value, fill_value) + data = getattr(np, func)(filled_value, axis=axis, **kwargs) + if not hasattr(data, "dtype"): # scalar case + data = fill_value if valid_count == 0 else data + # we've computed a single min, max value of type object. + # don't let np.array turn a tuple back into an array + return utils.to_0d_object_array(data) + return where_method(data, valid_count != 0) + + +def nanmin(a, axis=None, out=None): + if a.dtype.kind == "O": + return _nan_minmax_object("min", dtypes.get_pos_infinity(a.dtype), a, axis) + + return nputils.nanmin(a, axis=axis) + + +def nanmax(a, axis=None, out=None): + if a.dtype.kind == "O": + return _nan_minmax_object("max", dtypes.get_neg_infinity(a.dtype), a, axis) + + return nputils.nanmax(a, axis=axis) + + +def nanargmin(a, axis=None): + if a.dtype.kind == "O": + fill_value = dtypes.get_pos_infinity(a.dtype) + return _nan_argminmax_object("argmin", fill_value, a, axis=axis) + + return nputils.nanargmin(a, axis=axis) + + +def nanargmax(a, axis=None): + if a.dtype.kind == "O": + fill_value = dtypes.get_neg_infinity(a.dtype) + return _nan_argminmax_object("argmax", fill_value, a, axis=axis) + + return nputils.nanargmax(a, axis=axis) + + +def nansum(a, axis=None, dtype=None, out=None, min_count=None): + mask = isnull(a) + result = sum_where(a, axis=axis, dtype=dtype, where=mask) + if min_count is not None: + return _maybe_null_out(result, axis, mask, min_count) + else: + return result + + +def _nanmean_ddof_object(ddof, value, axis=None, dtype=None, **kwargs): + """In house nanmean. ddof argument will be used in _nanvar method""" + from xarray.core.duck_array_ops import count, fillna, where_method + + valid_count = count(value, axis=axis) + value = fillna(value, 0) + # As dtype inference is impossible for object dtype, we assume float + # https://github.com/dask/dask/issues/3162 + if dtype is None and value.dtype.kind == "O": + dtype = value.dtype if value.dtype.kind in ["cf"] else float + + data = np.sum(value, axis=axis, dtype=dtype, **kwargs) + data = data / (valid_count - ddof) + return where_method(data, valid_count != 0) + + +def nanmean(a, axis=None, dtype=None, out=None): + if a.dtype.kind == "O": + return _nanmean_ddof_object(0, a, axis=axis, dtype=dtype) + + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", r"Mean of empty slice", category=RuntimeWarning + ) + + return np.nanmean(a, axis=axis, dtype=dtype) + + +def nanmedian(a, axis=None, out=None): + # The dask algorithm works by rechunking to one chunk along axis + # Make sure we trigger the dask error when passing all dimensions + # so that we don't rechunk the entire array to one chunk and + # possibly blow memory + if axis is not None and len(np.atleast_1d(axis)) == a.ndim: + axis = None + return nputils.nanmedian(a, axis=axis) + + +def _nanvar_object(value, axis=None, ddof=0, keepdims=False, **kwargs): + value_mean = _nanmean_ddof_object( + ddof=0, value=value, axis=axis, keepdims=True, **kwargs + ) + squared = (astype(value, value_mean.dtype) - value_mean) ** 2 + return _nanmean_ddof_object(ddof, squared, axis=axis, keepdims=keepdims, **kwargs) + + +def nanvar(a, axis=None, dtype=None, out=None, ddof=0): + if a.dtype.kind == "O": + return _nanvar_object(a, axis=axis, dtype=dtype, ddof=ddof) + + return nputils.nanvar(a, axis=axis, dtype=dtype, ddof=ddof) + + +def nanstd(a, axis=None, dtype=None, out=None, ddof=0): + return nputils.nanstd(a, axis=axis, dtype=dtype, ddof=ddof) + + +def nanprod(a, axis=None, dtype=None, out=None, min_count=None): + mask = isnull(a) + result = nputils.nanprod(a, axis=axis, dtype=dtype, out=out) + if min_count is not None: + return _maybe_null_out(result, axis, mask, min_count) + else: + return result + + +def nancumsum(a, axis=None, dtype=None, out=None): + return nputils.nancumsum(a, axis=axis, dtype=dtype) + + +def nancumprod(a, axis=None, dtype=None, out=None): + return nputils.nancumprod(a, axis=axis, dtype=dtype) diff --git a/test/fixtures/whole_applications/xarray/xarray/core/npcompat.py b/test/fixtures/whole_applications/xarray/xarray/core/npcompat.py new file mode 100644 index 0000000..2493c08 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/core/npcompat.py @@ -0,0 +1,60 @@ +# Copyright (c) 2005-2011, NumPy Developers. +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: + +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. + +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided +# with the distribution. + +# * Neither the name of the NumPy Developers nor the names of any +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +try: + # requires numpy>=2.0 + from numpy import isdtype # type: ignore[attr-defined,unused-ignore] +except ImportError: + import numpy as np + + dtype_kinds = { + "bool": np.bool_, + "signed integer": np.signedinteger, + "unsigned integer": np.unsignedinteger, + "integral": np.integer, + "real floating": np.floating, + "complex floating": np.complexfloating, + "numeric": np.number, + } + + def isdtype(dtype, kind): + kinds = kind if isinstance(kind, tuple) else (kind,) + + unknown_dtypes = [kind for kind in kinds if kind not in dtype_kinds] + if unknown_dtypes: + raise ValueError(f"unknown dtype kinds: {unknown_dtypes}") + + # verified the dtypes already, no need to check again + translated_kinds = [dtype_kinds[kind] for kind in kinds] + if isinstance(dtype, np.generic): + return any(isinstance(dtype, kind) for kind in translated_kinds) + else: + return any(np.issubdtype(dtype, kind) for kind in translated_kinds) diff --git a/test/fixtures/whole_applications/xarray/xarray/core/nputils.py b/test/fixtures/whole_applications/xarray/xarray/core/nputils.py new file mode 100644 index 0000000..6970d37 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/core/nputils.py @@ -0,0 +1,298 @@ +from __future__ import annotations + +import warnings +from typing import Callable + +import numpy as np +import pandas as pd +from packaging.version import Version + +from xarray.core.utils import is_duck_array, module_available +from xarray.namedarray import pycompat + +# remove once numpy 2.0 is the oldest supported version +if module_available("numpy", minversion="2.0.0.dev0"): + from numpy.lib.array_utils import ( # type: ignore[import-not-found,unused-ignore] + normalize_axis_index, + ) +else: + from numpy.core.multiarray import ( # type: ignore[attr-defined,no-redef,unused-ignore] + normalize_axis_index, + ) + +# remove once numpy 2.0 is the oldest supported version +try: + from numpy.exceptions import RankWarning # type: ignore[attr-defined,unused-ignore] +except ImportError: + from numpy import RankWarning # type: ignore[attr-defined,no-redef,unused-ignore] + +from xarray.core.options import OPTIONS + +try: + import bottleneck as bn + + _BOTTLENECK_AVAILABLE = True +except ImportError: + # use numpy methods instead + bn = np + _BOTTLENECK_AVAILABLE = False + + +def _select_along_axis(values, idx, axis): + other_ind = np.ix_(*[np.arange(s) for s in idx.shape]) + sl = other_ind[:axis] + (idx,) + other_ind[axis:] + return values[sl] + + +def nanfirst(values, axis, keepdims=False): + if isinstance(axis, tuple): + (axis,) = axis + axis = normalize_axis_index(axis, values.ndim) + idx_first = np.argmax(~pd.isnull(values), axis=axis) + result = _select_along_axis(values, idx_first, axis) + if keepdims: + return np.expand_dims(result, axis=axis) + else: + return result + + +def nanlast(values, axis, keepdims=False): + if isinstance(axis, tuple): + (axis,) = axis + axis = normalize_axis_index(axis, values.ndim) + rev = (slice(None),) * axis + (slice(None, None, -1),) + idx_last = -1 - np.argmax(~pd.isnull(values)[rev], axis=axis) + result = _select_along_axis(values, idx_last, axis) + if keepdims: + return np.expand_dims(result, axis=axis) + else: + return result + + +def inverse_permutation(indices: np.ndarray, N: int | None = None) -> np.ndarray: + """Return indices for an inverse permutation. + + Parameters + ---------- + indices : 1D np.ndarray with dtype=int + Integer positions to assign elements to. + N : int, optional + Size of the array + + Returns + ------- + inverse_permutation : 1D np.ndarray with dtype=int + Integer indices to take from the original array to create the + permutation. + """ + if N is None: + N = len(indices) + # use intp instead of int64 because of windows :( + inverse_permutation = np.full(N, -1, dtype=np.intp) + inverse_permutation[indices] = np.arange(len(indices), dtype=np.intp) + return inverse_permutation + + +def _ensure_bool_is_ndarray(result, *args): + # numpy will sometimes return a scalar value from binary comparisons if it + # can't handle the comparison instead of broadcasting, e.g., + # In [10]: 1 == np.array(['a', 'b']) + # Out[10]: False + # This function ensures that the result is the appropriate shape in these + # cases + if isinstance(result, bool): + shape = np.broadcast(*args).shape + constructor = np.ones if result else np.zeros + result = constructor(shape, dtype=bool) + return result + + +def array_eq(self, other): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", r"elementwise comparison failed") + return _ensure_bool_is_ndarray(self == other, self, other) + + +def array_ne(self, other): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", r"elementwise comparison failed") + return _ensure_bool_is_ndarray(self != other, self, other) + + +def _is_contiguous(positions): + """Given a non-empty list, does it consist of contiguous integers?""" + previous = positions[0] + for current in positions[1:]: + if current != previous + 1: + return False + previous = current + return True + + +def _advanced_indexer_subspaces(key): + """Indices of the advanced indexes subspaces for mixed indexing and vindex.""" + if not isinstance(key, tuple): + key = (key,) + advanced_index_positions = [ + i for i, k in enumerate(key) if not isinstance(k, slice) + ] + + if not advanced_index_positions or not _is_contiguous(advanced_index_positions): + # Nothing to reorder: dimensions on the indexing result are already + # ordered like vindex. See NumPy's rule for "Combining advanced and + # basic indexing": + # https://numpy.org/doc/stable/reference/arrays.indexing.html#combining-advanced-and-basic-indexing + return (), () + + non_slices = [k for k in key if not isinstance(k, slice)] + broadcasted_shape = np.broadcast_shapes( + *[item.shape if is_duck_array(item) else (0,) for item in non_slices] + ) + ndim = len(broadcasted_shape) + mixed_positions = advanced_index_positions[0] + np.arange(ndim) + vindex_positions = np.arange(ndim) + return mixed_positions, vindex_positions + + +class NumpyVIndexAdapter: + """Object that implements indexing like vindex on a np.ndarray. + + This is a pure Python implementation of (some of) the logic in this NumPy + proposal: https://github.com/numpy/numpy/pull/6256 + """ + + def __init__(self, array): + self._array = array + + def __getitem__(self, key): + mixed_positions, vindex_positions = _advanced_indexer_subspaces(key) + return np.moveaxis(self._array[key], mixed_positions, vindex_positions) + + def __setitem__(self, key, value): + """Value must have dimensionality matching the key.""" + mixed_positions, vindex_positions = _advanced_indexer_subspaces(key) + self._array[key] = np.moveaxis(value, vindex_positions, mixed_positions) + + +def _create_method(name, npmodule=np) -> Callable: + def f(values, axis=None, **kwargs): + dtype = kwargs.get("dtype", None) + bn_func = getattr(bn, name, None) + + if ( + module_available("numbagg") + and pycompat.mod_version("numbagg") >= Version("0.5.0") + and OPTIONS["use_numbagg"] + and isinstance(values, np.ndarray) + # numbagg<0.7.0 uses ddof=1 only, but numpy uses ddof=0 by default + and ( + pycompat.mod_version("numbagg") >= Version("0.7.0") + or ("var" not in name and "std" not in name) + or kwargs.get("ddof", 0) == 1 + ) + # TODO: bool? + and values.dtype.kind in "uifc" + # and values.dtype.isnative + and (dtype is None or np.dtype(dtype) == values.dtype) + # numbagg.nanquantile only available after 0.8.0 and with linear method + and ( + name != "nanquantile" + or ( + pycompat.mod_version("numbagg") >= Version("0.8.0") + and kwargs.get("method", "linear") == "linear" + ) + ) + ): + import numbagg + + nba_func = getattr(numbagg, name, None) + if nba_func is not None: + # numbagg does not use dtype + kwargs.pop("dtype", None) + # prior to 0.7.0, numbagg did not support ddof; we ensure it's limited + # to ddof=1 above. + if pycompat.mod_version("numbagg") < Version("0.7.0"): + kwargs.pop("ddof", None) + if name == "nanquantile": + kwargs["quantiles"] = kwargs.pop("q") + kwargs.pop("method", None) + return nba_func(values, axis=axis, **kwargs) + if ( + _BOTTLENECK_AVAILABLE + and OPTIONS["use_bottleneck"] + and isinstance(values, np.ndarray) + and bn_func is not None + and not isinstance(axis, tuple) + and values.dtype.kind in "uifc" + and values.dtype.isnative + and (dtype is None or np.dtype(dtype) == values.dtype) + ): + # bottleneck does not take care dtype, min_count + kwargs.pop("dtype", None) + result = bn_func(values, axis=axis, **kwargs) + else: + result = getattr(npmodule, name)(values, axis=axis, **kwargs) + + return result + + f.__name__ = name + return f + + +def _nanpolyfit_1d(arr, x, rcond=None): + out = np.full((x.shape[1] + 1,), np.nan) + mask = np.isnan(arr) + if not np.all(mask): + out[:-1], resid, rank, _ = np.linalg.lstsq(x[~mask, :], arr[~mask], rcond=rcond) + out[-1] = resid[0] if resid.size > 0 else np.nan + warn_on_deficient_rank(rank, x.shape[1]) + return out + + +def warn_on_deficient_rank(rank, order): + if rank != order: + warnings.warn("Polyfit may be poorly conditioned", RankWarning, stacklevel=2) + + +def least_squares(lhs, rhs, rcond=None, skipna=False): + if skipna: + added_dim = rhs.ndim == 1 + if added_dim: + rhs = rhs.reshape(rhs.shape[0], 1) + nan_cols = np.any(np.isnan(rhs), axis=0) + out = np.empty((lhs.shape[1] + 1, rhs.shape[1])) + if np.any(nan_cols): + out[:, nan_cols] = np.apply_along_axis( + _nanpolyfit_1d, 0, rhs[:, nan_cols], lhs + ) + if np.any(~nan_cols): + out[:-1, ~nan_cols], resids, rank, _ = np.linalg.lstsq( + lhs, rhs[:, ~nan_cols], rcond=rcond + ) + out[-1, ~nan_cols] = resids if resids.size > 0 else np.nan + warn_on_deficient_rank(rank, lhs.shape[1]) + coeffs = out[:-1, :] + residuals = out[-1, :] + if added_dim: + coeffs = coeffs.reshape(coeffs.shape[0]) + residuals = residuals.reshape(residuals.shape[0]) + else: + coeffs, residuals, rank, _ = np.linalg.lstsq(lhs, rhs, rcond=rcond) + if residuals.size == 0: + residuals = coeffs[0] * np.nan + warn_on_deficient_rank(rank, lhs.shape[1]) + return coeffs, residuals + + +nanmin = _create_method("nanmin") +nanmax = _create_method("nanmax") +nanmean = _create_method("nanmean") +nanmedian = _create_method("nanmedian") +nanvar = _create_method("nanvar") +nanstd = _create_method("nanstd") +nanprod = _create_method("nanprod") +nancumsum = _create_method("nancumsum") +nancumprod = _create_method("nancumprod") +nanargmin = _create_method("nanargmin") +nanargmax = _create_method("nanargmax") +nanquantile = _create_method("nanquantile") diff --git a/test/fixtures/whole_applications/xarray/xarray/core/ops.py b/test/fixtures/whole_applications/xarray/xarray/core/ops.py new file mode 100644 index 0000000..c67b466 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/core/ops.py @@ -0,0 +1,311 @@ +"""Define core operations for xarray objects. + +TODO(shoyer): rewrite this module, making use of xarray.core.computation, +NumPy's __array_ufunc__ and mixin classes instead of the unintuitive "inject" +functions. +""" + +from __future__ import annotations + +import operator + +import numpy as np + +from xarray.core import dtypes, duck_array_ops + +try: + import bottleneck as bn + + has_bottleneck = True +except ImportError: + # use numpy methods instead + bn = np + has_bottleneck = False + + +NUM_BINARY_OPS = [ + "add", + "sub", + "mul", + "truediv", + "floordiv", + "mod", + "pow", + "and", + "xor", + "or", + "lshift", + "rshift", +] + +# methods which pass on the numpy return value unchanged +# be careful not to list methods that we would want to wrap later +NUMPY_SAME_METHODS = ["item", "searchsorted"] + +# methods which remove an axis +REDUCE_METHODS = ["all", "any"] +NAN_REDUCE_METHODS = [ + "max", + "min", + "mean", + "prod", + "sum", + "std", + "var", + "median", +] +# TODO: wrap take, dot, sort + + +_CUM_DOCSTRING_TEMPLATE = """\ +Apply `{name}` along some dimension of {cls}. + +Parameters +---------- +{extra_args} +skipna : bool, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). +keep_attrs : bool, optional + If True, the attributes (`attrs`) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. +**kwargs : dict + Additional keyword arguments passed on to `{name}`. + +Returns +------- +cumvalue : {cls} + New {cls} object with `{name}` applied to its data along the + indicated dimension. +""" + +_REDUCE_DOCSTRING_TEMPLATE = """\ +Reduce this {cls}'s data by applying `{name}` along some dimension(s). + +Parameters +---------- +{extra_args}{skip_na_docs}{min_count_docs} +keep_attrs : bool, optional + If True, the attributes (`attrs`) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. +**kwargs : dict + Additional keyword arguments passed on to the appropriate array + function for calculating `{name}` on this object's data. + +Returns +------- +reduced : {cls} + New {cls} object with `{name}` applied to its data and the + indicated dimension(s) removed. +""" + +_SKIPNA_DOCSTRING = """ +skipna : bool, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64).""" + +_MINCOUNT_DOCSTRING = """ +min_count : int, default: None + The required number of valid values to perform the operation. If + fewer than min_count non-NA values are present the result will be + NA. Only used if skipna is set to True or defaults to True for the + array's dtype. New in version 0.10.8: Added with the default being + None. Changed in version 0.17.0: if specified on an integer array + and skipna=True, the result will be a float array.""" + + +def fillna(data, other, join="left", dataset_join="left"): + """Fill missing values in this object with data from the other object. + Follows normal broadcasting and alignment rules. + + Parameters + ---------- + join : {"outer", "inner", "left", "right"}, optional + Method for joining the indexes of the passed objects along each + dimension + - "outer": use the union of object indexes + - "inner": use the intersection of object indexes + - "left": use indexes from the first object with each dimension + - "right": use indexes from the last object with each dimension + - "exact": raise `ValueError` instead of aligning when indexes to be + aligned are not equal + dataset_join : {"outer", "inner", "left", "right"}, optional + Method for joining variables of Dataset objects with mismatched + data variables. + - "outer": take variables from both Dataset objects + - "inner": take only overlapped variables + - "left": take only variables from the first object + - "right": take only variables from the last object + """ + from xarray.core.computation import apply_ufunc + + return apply_ufunc( + duck_array_ops.fillna, + data, + other, + join=join, + dask="allowed", + dataset_join=dataset_join, + dataset_fill_value=np.nan, + keep_attrs=True, + ) + + +def where_method(self, cond, other=dtypes.NA): + """Return elements from `self` or `other` depending on `cond`. + + Parameters + ---------- + cond : DataArray or Dataset with boolean dtype + Locations at which to preserve this objects values. + other : scalar, DataArray or Dataset, optional + Value to use for locations in this object where ``cond`` is False. + By default, inserts missing values. + + Returns + ------- + Same type as caller. + """ + from xarray.core.computation import apply_ufunc + + # alignment for three arguments is complicated, so don't support it yet + join = "inner" if other is dtypes.NA else "exact" + return apply_ufunc( + duck_array_ops.where_method, + self, + cond, + other, + join=join, + dataset_join=join, + dask="allowed", + keep_attrs=True, + ) + + +def _call_possibly_missing_method(arg, name, args, kwargs): + try: + method = getattr(arg, name) + except AttributeError: + duck_array_ops.fail_on_dask_array_input(arg, func_name=name) + if hasattr(arg, "data"): + duck_array_ops.fail_on_dask_array_input(arg.data, func_name=name) + raise + else: + return method(*args, **kwargs) + + +def _values_method_wrapper(name): + def func(self, *args, **kwargs): + return _call_possibly_missing_method(self.data, name, args, kwargs) + + func.__name__ = name + func.__doc__ = getattr(np.ndarray, name).__doc__ + return func + + +def _method_wrapper(name): + def func(self, *args, **kwargs): + return _call_possibly_missing_method(self, name, args, kwargs) + + func.__name__ = name + func.__doc__ = getattr(np.ndarray, name).__doc__ + return func + + +def _func_slash_method_wrapper(f, name=None): + # try to wrap a method, but if not found use the function + # this is useful when patching in a function as both a DataArray and + # Dataset method + if name is None: + name = f.__name__ + + def func(self, *args, **kwargs): + try: + return getattr(self, name)(*args, **kwargs) + except AttributeError: + return f(self, *args, **kwargs) + + func.__name__ = name + func.__doc__ = f.__doc__ + return func + + +def inject_reduce_methods(cls): + methods = ( + [ + (name, getattr(duck_array_ops, f"array_{name}"), False) + for name in REDUCE_METHODS + ] + + [(name, getattr(duck_array_ops, name), True) for name in NAN_REDUCE_METHODS] + + [("count", duck_array_ops.count, False)] + ) + for name, f, include_skipna in methods: + numeric_only = getattr(f, "numeric_only", False) + available_min_count = getattr(f, "available_min_count", False) + skip_na_docs = _SKIPNA_DOCSTRING if include_skipna else "" + min_count_docs = _MINCOUNT_DOCSTRING if available_min_count else "" + + func = cls._reduce_method(f, include_skipna, numeric_only) + func.__name__ = name + func.__doc__ = _REDUCE_DOCSTRING_TEMPLATE.format( + name=name, + cls=cls.__name__, + extra_args=cls._reduce_extra_args_docstring.format(name=name), + skip_na_docs=skip_na_docs, + min_count_docs=min_count_docs, + ) + setattr(cls, name, func) + + +def op_str(name): + return f"__{name}__" + + +def get_op(name): + return getattr(operator, op_str(name)) + + +NON_INPLACE_OP = {get_op("i" + name): get_op(name) for name in NUM_BINARY_OPS} + + +def inplace_to_noninplace_op(f): + return NON_INPLACE_OP[f] + + +# _typed_ops.py uses the following wrapped functions as a kind of unary operator +argsort = _method_wrapper("argsort") +conj = _method_wrapper("conj") +conjugate = _method_wrapper("conjugate") +round_ = _func_slash_method_wrapper(duck_array_ops.around, name="round") + + +def inject_numpy_same(cls): + # these methods don't return arrays of the same shape as the input, so + # don't try to patch these in for Dataset objects + for name in NUMPY_SAME_METHODS: + setattr(cls, name, _values_method_wrapper(name)) + + +class IncludeReduceMethods: + __slots__ = () + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + + if getattr(cls, "_reduce_method", None): + inject_reduce_methods(cls) + + +class IncludeNumpySameMethods: + __slots__ = () + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + + inject_numpy_same(cls) # some methods not applicable to Dataset objects diff --git a/test/fixtures/whole_applications/xarray/xarray/core/options.py b/test/fixtures/whole_applications/xarray/xarray/core/options.py new file mode 100644 index 0000000..f561410 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/core/options.py @@ -0,0 +1,319 @@ +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING, Literal, TypedDict + +from xarray.core.utils import FrozenDict + +if TYPE_CHECKING: + from matplotlib.colors import Colormap + + Options = Literal[ + "arithmetic_join", + "cmap_divergent", + "cmap_sequential", + "display_max_rows", + "display_values_threshold", + "display_style", + "display_width", + "display_expand_attrs", + "display_expand_coords", + "display_expand_data_vars", + "display_expand_data", + "display_expand_groups", + "display_expand_indexes", + "display_default_indexes", + "enable_cftimeindex", + "file_cache_maxsize", + "keep_attrs", + "warn_for_unclosed_files", + "use_bottleneck", + "use_numbagg", + "use_opt_einsum", + "use_flox", + ] + + class T_Options(TypedDict): + arithmetic_broadcast: bool + arithmetic_join: Literal["inner", "outer", "left", "right", "exact"] + cmap_divergent: str | Colormap + cmap_sequential: str | Colormap + display_max_rows: int + display_values_threshold: int + display_style: Literal["text", "html"] + display_width: int + display_expand_attrs: Literal["default", True, False] + display_expand_coords: Literal["default", True, False] + display_expand_data_vars: Literal["default", True, False] + display_expand_data: Literal["default", True, False] + display_expand_groups: Literal["default", True, False] + display_expand_indexes: Literal["default", True, False] + display_default_indexes: Literal["default", True, False] + enable_cftimeindex: bool + file_cache_maxsize: int + keep_attrs: Literal["default", True, False] + warn_for_unclosed_files: bool + use_bottleneck: bool + use_flox: bool + use_numbagg: bool + use_opt_einsum: bool + + +OPTIONS: T_Options = { + "arithmetic_broadcast": True, + "arithmetic_join": "inner", + "cmap_divergent": "RdBu_r", + "cmap_sequential": "viridis", + "display_max_rows": 12, + "display_values_threshold": 200, + "display_style": "html", + "display_width": 80, + "display_expand_attrs": "default", + "display_expand_coords": "default", + "display_expand_data_vars": "default", + "display_expand_data": "default", + "display_expand_groups": "default", + "display_expand_indexes": "default", + "display_default_indexes": False, + "enable_cftimeindex": True, + "file_cache_maxsize": 128, + "keep_attrs": "default", + "warn_for_unclosed_files": False, + "use_bottleneck": True, + "use_flox": True, + "use_numbagg": True, + "use_opt_einsum": True, +} + +_JOIN_OPTIONS = frozenset(["inner", "outer", "left", "right", "exact"]) +_DISPLAY_OPTIONS = frozenset(["text", "html"]) + + +def _positive_integer(value: int) -> bool: + return isinstance(value, int) and value > 0 + + +_VALIDATORS = { + "arithmetic_broadcast": lambda value: isinstance(value, bool), + "arithmetic_join": _JOIN_OPTIONS.__contains__, + "display_max_rows": _positive_integer, + "display_values_threshold": _positive_integer, + "display_style": _DISPLAY_OPTIONS.__contains__, + "display_width": _positive_integer, + "display_expand_attrs": lambda choice: choice in [True, False, "default"], + "display_expand_coords": lambda choice: choice in [True, False, "default"], + "display_expand_data_vars": lambda choice: choice in [True, False, "default"], + "display_expand_data": lambda choice: choice in [True, False, "default"], + "display_expand_indexes": lambda choice: choice in [True, False, "default"], + "display_default_indexes": lambda choice: choice in [True, False, "default"], + "enable_cftimeindex": lambda value: isinstance(value, bool), + "file_cache_maxsize": _positive_integer, + "keep_attrs": lambda choice: choice in [True, False, "default"], + "use_bottleneck": lambda value: isinstance(value, bool), + "use_numbagg": lambda value: isinstance(value, bool), + "use_opt_einsum": lambda value: isinstance(value, bool), + "use_flox": lambda value: isinstance(value, bool), + "warn_for_unclosed_files": lambda value: isinstance(value, bool), +} + + +def _set_file_cache_maxsize(value) -> None: + from xarray.backends.file_manager import FILE_CACHE + + FILE_CACHE.maxsize = value + + +def _warn_on_setting_enable_cftimeindex(enable_cftimeindex): + warnings.warn( + "The enable_cftimeindex option is now a no-op " + "and will be removed in a future version of xarray.", + FutureWarning, + ) + + +_SETTERS = { + "enable_cftimeindex": _warn_on_setting_enable_cftimeindex, + "file_cache_maxsize": _set_file_cache_maxsize, +} + + +def _get_boolean_with_default(option: Options, default: bool) -> bool: + global_choice = OPTIONS[option] + + if global_choice == "default": + return default + elif isinstance(global_choice, bool): + return global_choice + else: + raise ValueError( + f"The global option {option} must be one of True, False or 'default'." + ) + + +def _get_keep_attrs(default: bool) -> bool: + return _get_boolean_with_default("keep_attrs", default) + + +class set_options: + """ + Set options for xarray in a controlled context. + + Parameters + ---------- + arithmetic_join : {"inner", "outer", "left", "right", "exact"}, default: "inner" + DataArray/Dataset alignment in binary operations: + + - "outer": use the union of object indexes + - "inner": use the intersection of object indexes + - "left": use indexes from the first object with each dimension + - "right": use indexes from the last object with each dimension + - "exact": instead of aligning, raise `ValueError` when indexes to be + aligned are not equal + - "override": if indexes are of same size, rewrite indexes to be + those of the first object with that dimension. Indexes for the same + dimension must have the same size in all objects. + + cmap_divergent : str or matplotlib.colors.Colormap, default: "RdBu_r" + Colormap to use for divergent data plots. If string, must be + matplotlib built-in colormap. Can also be a Colormap object + (e.g. mpl.colormaps["magma"]) + cmap_sequential : str or matplotlib.colors.Colormap, default: "viridis" + Colormap to use for nondivergent data plots. If string, must be + matplotlib built-in colormap. Can also be a Colormap object + (e.g. mpl.colormaps["magma"]) + display_expand_attrs : {"default", True, False} + Whether to expand the attributes section for display of + ``DataArray`` or ``Dataset`` objects. Can be + + * ``True`` : to always expand attrs + * ``False`` : to always collapse attrs + * ``default`` : to expand unless over a pre-defined limit + display_expand_coords : {"default", True, False} + Whether to expand the coordinates section for display of + ``DataArray`` or ``Dataset`` objects. Can be + + * ``True`` : to always expand coordinates + * ``False`` : to always collapse coordinates + * ``default`` : to expand unless over a pre-defined limit + display_expand_data : {"default", True, False} + Whether to expand the data section for display of ``DataArray`` + objects. Can be + + * ``True`` : to always expand data + * ``False`` : to always collapse data + * ``default`` : to expand unless over a pre-defined limit + display_expand_data_vars : {"default", True, False} + Whether to expand the data variables section for display of + ``Dataset`` objects. Can be + + * ``True`` : to always expand data variables + * ``False`` : to always collapse data variables + * ``default`` : to expand unless over a pre-defined limit + display_expand_indexes : {"default", True, False} + Whether to expand the indexes section for display of + ``DataArray`` or ``Dataset``. Can be + + * ``True`` : to always expand indexes + * ``False`` : to always collapse indexes + * ``default`` : to expand unless over a pre-defined limit (always collapse for html style) + display_max_rows : int, default: 12 + Maximum display rows. + display_values_threshold : int, default: 200 + Total number of array elements which trigger summarization rather + than full repr for variable data views (numpy arrays). + display_style : {"text", "html"}, default: "html" + Display style to use in jupyter for xarray objects. + display_width : int, default: 80 + Maximum display width for ``repr`` on xarray objects. + file_cache_maxsize : int, default: 128 + Maximum number of open files to hold in xarray's + global least-recently-usage cached. This should be smaller than + your system's per-process file descriptor limit, e.g., + ``ulimit -n`` on Linux. + keep_attrs : {"default", True, False} + Whether to keep attributes on xarray Datasets/dataarrays after + operations. Can be + + * ``True`` : to always keep attrs + * ``False`` : to always discard attrs + * ``default`` : to use original logic that attrs should only + be kept in unambiguous circumstances + use_bottleneck : bool, default: True + Whether to use ``bottleneck`` to accelerate 1D reductions and + 1D rolling reduction operations. + use_flox : bool, default: True + Whether to use ``numpy_groupies`` and `flox`` to + accelerate groupby and resampling reductions. + use_numbagg : bool, default: True + Whether to use ``numbagg`` to accelerate reductions. + Takes precedence over ``use_bottleneck`` when both are True. + use_opt_einsum : bool, default: True + Whether to use ``opt_einsum`` to accelerate dot products. + warn_for_unclosed_files : bool, default: False + Whether or not to issue a warning when unclosed files are + deallocated. This is mostly useful for debugging. + + Examples + -------- + It is possible to use ``set_options`` either as a context manager: + + >>> ds = xr.Dataset({"x": np.arange(1000)}) + >>> with xr.set_options(display_width=40): + ... print(ds) + ... + Size: 8kB + Dimensions: (x: 1000) + Coordinates: + * x (x) int64 8kB 0 1 ... 999 + Data variables: + *empty* + + Or to set global options: + + >>> xr.set_options(display_width=80) # doctest: +ELLIPSIS + + """ + + def __init__(self, **kwargs): + self.old = {} + for k, v in kwargs.items(): + if k not in OPTIONS: + raise ValueError( + f"argument name {k!r} is not in the set of valid options {set(OPTIONS)!r}" + ) + if k in _VALIDATORS and not _VALIDATORS[k](v): + if k == "arithmetic_join": + expected = f"Expected one of {_JOIN_OPTIONS!r}" + elif k == "display_style": + expected = f"Expected one of {_DISPLAY_OPTIONS!r}" + else: + expected = "" + raise ValueError( + f"option {k!r} given an invalid value: {v!r}. " + expected + ) + self.old[k] = OPTIONS[k] + self._apply_update(kwargs) + + def _apply_update(self, options_dict): + for k, v in options_dict.items(): + if k in _SETTERS: + _SETTERS[k](v) + OPTIONS.update(options_dict) + + def __enter__(self): + return + + def __exit__(self, type, value, traceback): + self._apply_update(self.old) + + +def get_options(): + """ + Get options for xarray. + + See Also + ---------- + set_options + + """ + return FrozenDict(OPTIONS) diff --git a/test/fixtures/whole_applications/xarray/xarray/core/parallel.py b/test/fixtures/whole_applications/xarray/xarray/core/parallel.py new file mode 100644 index 0000000..4131149 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/core/parallel.py @@ -0,0 +1,639 @@ +from __future__ import annotations + +import collections +import itertools +import operator +from collections.abc import Hashable, Iterable, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict + +import numpy as np + +from xarray.core.alignment import align +from xarray.core.coordinates import Coordinates +from xarray.core.dataarray import DataArray +from xarray.core.dataset import Dataset +from xarray.core.indexes import Index +from xarray.core.merge import merge +from xarray.core.utils import is_dask_collection +from xarray.core.variable import Variable + +if TYPE_CHECKING: + from xarray.core.types import T_Xarray + + +class ExpectedDict(TypedDict): + shapes: dict[Hashable, int] + coords: set[Hashable] + data_vars: set[Hashable] + indexes: dict[Hashable, Index] + + +def unzip(iterable): + return zip(*iterable) + + +def assert_chunks_compatible(a: Dataset, b: Dataset): + a = a.unify_chunks() + b = b.unify_chunks() + + for dim in set(a.chunks).intersection(set(b.chunks)): + if a.chunks[dim] != b.chunks[dim]: + raise ValueError(f"Chunk sizes along dimension {dim!r} are not equal.") + + +def check_result_variables( + result: DataArray | Dataset, + expected: ExpectedDict, + kind: Literal["coords", "data_vars"], +): + if kind == "coords": + nice_str = "coordinate" + elif kind == "data_vars": + nice_str = "data" + + # check that coords and data variables are as expected + missing = expected[kind] - set(getattr(result, kind)) + if missing: + raise ValueError( + "Result from applying user function does not contain " + f"{nice_str} variables {missing}." + ) + extra = set(getattr(result, kind)) - expected[kind] + if extra: + raise ValueError( + "Result from applying user function has unexpected " + f"{nice_str} variables {extra}." + ) + + +def dataset_to_dataarray(obj: Dataset) -> DataArray: + if not isinstance(obj, Dataset): + raise TypeError(f"Expected Dataset, got {type(obj)}") + + if len(obj.data_vars) > 1: + raise TypeError( + "Trying to convert Dataset with more than one data variable to DataArray" + ) + + return next(iter(obj.data_vars.values())) + + +def dataarray_to_dataset(obj: DataArray) -> Dataset: + # only using _to_temp_dataset would break + # func = lambda x: x.to_dataset() + # since that relies on preserving name. + if obj.name is None: + dataset = obj._to_temp_dataset() + else: + dataset = obj.to_dataset() + return dataset + + +def make_meta(obj): + """If obj is a DataArray or Dataset, return a new object of the same type and with + the same variables and dtypes, but where all variables have size 0 and numpy + backend. + If obj is neither a DataArray nor Dataset, return it unaltered. + """ + if isinstance(obj, DataArray): + obj_array = obj + obj = dataarray_to_dataset(obj) + elif isinstance(obj, Dataset): + obj_array = None + else: + return obj + + from dask.array.utils import meta_from_array + + meta = Dataset() + for name, variable in obj.variables.items(): + meta_obj = meta_from_array(variable.data, ndim=variable.ndim) + meta[name] = (variable.dims, meta_obj, variable.attrs) + meta.attrs = obj.attrs + meta = meta.set_coords(obj.coords) + + if obj_array is not None: + return dataset_to_dataarray(meta) + return meta + + +def infer_template( + func: Callable[..., T_Xarray], obj: DataArray | Dataset, *args, **kwargs +) -> T_Xarray: + """Infer return object by running the function on meta objects.""" + meta_args = [make_meta(arg) for arg in (obj,) + args] + + try: + template = func(*meta_args, **kwargs) + except Exception as e: + raise Exception( + "Cannot infer object returned from running user provided function. " + "Please supply the 'template' kwarg to map_blocks." + ) from e + + if not isinstance(template, (Dataset, DataArray)): + raise TypeError( + "Function must return an xarray DataArray or Dataset. Instead it returned " + f"{type(template)}" + ) + + return template + + +def make_dict(x: DataArray | Dataset) -> dict[Hashable, Any]: + """Map variable name to numpy(-like) data + (Dataset.to_dict() is too complicated). + """ + if isinstance(x, DataArray): + x = x._to_temp_dataset() + + return {k: v.data for k, v in x.variables.items()} + + +def _get_chunk_slicer(dim: Hashable, chunk_index: Mapping, chunk_bounds: Mapping): + if dim in chunk_index: + which_chunk = chunk_index[dim] + return slice(chunk_bounds[dim][which_chunk], chunk_bounds[dim][which_chunk + 1]) + return slice(None) + + +def subset_dataset_to_block( + graph: dict, gname: str, dataset: Dataset, input_chunk_bounds, chunk_index +): + """ + Creates a task that subsets an xarray dataset to a block determined by chunk_index. + Block extents are determined by input_chunk_bounds. + Also subtasks that subset the constituent variables of a dataset. + """ + import dask + + # this will become [[name1, variable1], + # [name2, variable2], + # ...] + # which is passed to dict and then to Dataset + data_vars = [] + coords = [] + + chunk_tuple = tuple(chunk_index.values()) + chunk_dims_set = set(chunk_index) + variable: Variable + for name, variable in dataset.variables.items(): + # make a task that creates tuple of (dims, chunk) + if dask.is_dask_collection(variable.data): + # get task name for chunk + chunk = ( + variable.data.name, + *tuple(chunk_index[dim] for dim in variable.dims), + ) + + chunk_variable_task = (f"{name}-{gname}-{chunk[0]!r}",) + chunk_tuple + graph[chunk_variable_task] = ( + tuple, + [variable.dims, chunk, variable.attrs], + ) + else: + assert name in dataset.dims or variable.ndim == 0 + + # non-dask array possibly with dimensions chunked on other variables + # index into variable appropriately + subsetter = { + dim: _get_chunk_slicer(dim, chunk_index, input_chunk_bounds) + for dim in variable.dims + } + if set(variable.dims) < chunk_dims_set: + this_var_chunk_tuple = tuple(chunk_index[dim] for dim in variable.dims) + else: + this_var_chunk_tuple = chunk_tuple + + chunk_variable_task = ( + f"{name}-{gname}-{dask.base.tokenize(subsetter)}", + ) + this_var_chunk_tuple + # We are including a dimension coordinate, + # minimize duplication by not copying it in the graph for every chunk. + if variable.ndim == 0 or chunk_variable_task not in graph: + subset = variable.isel(subsetter) + graph[chunk_variable_task] = ( + tuple, + [subset.dims, subset._data, subset.attrs], + ) + + # this task creates dict mapping variable name to above tuple + if name in dataset._coord_names: + coords.append([name, chunk_variable_task]) + else: + data_vars.append([name, chunk_variable_task]) + + return (Dataset, (dict, data_vars), (dict, coords), dataset.attrs) + + +def map_blocks( + func: Callable[..., T_Xarray], + obj: DataArray | Dataset, + args: Sequence[Any] = (), + kwargs: Mapping[str, Any] | None = None, + template: DataArray | Dataset | None = None, +) -> T_Xarray: + """Apply a function to each block of a DataArray or Dataset. + + .. warning:: + This function is experimental and its signature may change. + + Parameters + ---------- + func : callable + User-provided function that accepts a DataArray or Dataset as its first + parameter ``obj``. The function will receive a subset or 'block' of ``obj`` (see below), + corresponding to one chunk along each chunked dimension. ``func`` will be + executed as ``func(subset_obj, *subset_args, **kwargs)``. + + This function must return either a single DataArray or a single Dataset. + + This function cannot add a new chunked dimension. + obj : DataArray, Dataset + Passed to the function as its first argument, one block at a time. + args : sequence + Passed to func after unpacking and subsetting any xarray objects by blocks. + xarray objects in args must be aligned with obj, otherwise an error is raised. + kwargs : mapping + Passed verbatim to func after unpacking. xarray objects, if any, will not be + subset to blocks. Passing dask collections in kwargs is not allowed. + template : DataArray or Dataset, optional + xarray object representing the final result after compute is called. If not provided, + the function will be first run on mocked-up data, that looks like ``obj`` but + has sizes 0, to determine properties of the returned object such as dtype, + variable names, attributes, new dimensions and new indexes (if any). + ``template`` must be provided if the function changes the size of existing dimensions. + When provided, ``attrs`` on variables in `template` are copied over to the result. Any + ``attrs`` set by ``func`` will be ignored. + + Returns + ------- + obj : same as obj + A single DataArray or Dataset with dask backend, reassembled from the outputs of the + function. + + Notes + ----- + This function is designed for when ``func`` needs to manipulate a whole xarray object + subset to each block. Each block is loaded into memory. In the more common case where + ``func`` can work on numpy arrays, it is recommended to use ``apply_ufunc``. + + If none of the variables in ``obj`` is backed by dask arrays, calling this function is + equivalent to calling ``func(obj, *args, **kwargs)``. + + See Also + -------- + dask.array.map_blocks, xarray.apply_ufunc, xarray.Dataset.map_blocks + xarray.DataArray.map_blocks + + Examples + -------- + Calculate an anomaly from climatology using ``.groupby()``. Using + ``xr.map_blocks()`` allows for parallel operations with knowledge of ``xarray``, + its indices, and its methods like ``.groupby()``. + + >>> def calculate_anomaly(da, groupby_type="time.month"): + ... gb = da.groupby(groupby_type) + ... clim = gb.mean(dim="time") + ... return gb - clim + ... + >>> time = xr.cftime_range("1990-01", "1992-01", freq="ME") + >>> month = xr.DataArray(time.month, coords={"time": time}, dims=["time"]) + >>> np.random.seed(123) + >>> array = xr.DataArray( + ... np.random.rand(len(time)), + ... dims=["time"], + ... coords={"time": time, "month": month}, + ... ).chunk() + >>> array.map_blocks(calculate_anomaly, template=array).compute() + Size: 192B + array([ 0.12894847, 0.11323072, -0.0855964 , -0.09334032, 0.26848862, + 0.12382735, 0.22460641, 0.07650108, -0.07673453, -0.22865714, + -0.19063865, 0.0590131 , -0.12894847, -0.11323072, 0.0855964 , + 0.09334032, -0.26848862, -0.12382735, -0.22460641, -0.07650108, + 0.07673453, 0.22865714, 0.19063865, -0.0590131 ]) + Coordinates: + * time (time) object 192B 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 + month (time) int64 192B 1 2 3 4 5 6 7 8 9 10 ... 3 4 5 6 7 8 9 10 11 12 + + Note that one must explicitly use ``args=[]`` and ``kwargs={}`` to pass arguments + to the function being applied in ``xr.map_blocks()``: + + >>> array.map_blocks( + ... calculate_anomaly, + ... kwargs={"groupby_type": "time.year"}, + ... template=array, + ... ) # doctest: +ELLIPSIS + Size: 192B + dask.array<-calculate_anomaly, shape=(24,), dtype=float64, chunksize=(24,), chunktype=numpy.ndarray> + Coordinates: + * time (time) object 192B 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 + month (time) int64 192B dask.array + """ + + def _wrapper( + func: Callable, + args: list, + kwargs: dict, + arg_is_array: Iterable[bool], + expected: ExpectedDict, + ): + """ + Wrapper function that receives datasets in args; converts to dataarrays when necessary; + passes these to the user function `func` and checks returned objects for expected shapes/sizes/etc. + """ + + converted_args = [ + dataset_to_dataarray(arg) if is_array else arg + for is_array, arg in zip(arg_is_array, args) + ] + + result = func(*converted_args, **kwargs) + + merged_coordinates = merge( + [arg.coords for arg in args if isinstance(arg, (Dataset, DataArray))] + ).coords + + # check all dims are present + missing_dimensions = set(expected["shapes"]) - set(result.sizes) + if missing_dimensions: + raise ValueError( + f"Dimensions {missing_dimensions} missing on returned object." + ) + + # check that index lengths and values are as expected + for name, index in result._indexes.items(): + if name in expected["shapes"]: + if result.sizes[name] != expected["shapes"][name]: + raise ValueError( + f"Received dimension {name!r} of length {result.sizes[name]}. " + f"Expected length {expected['shapes'][name]}." + ) + + # ChainMap wants MutableMapping, but xindexes is Mapping + merged_indexes = collections.ChainMap( + expected["indexes"], merged_coordinates.xindexes # type: ignore[arg-type] + ) + expected_index = merged_indexes.get(name, None) + if expected_index is not None and not index.equals(expected_index): + raise ValueError( + f"Expected index {name!r} to be {expected_index!r}. Received {index!r} instead." + ) + + # check that all expected variables were returned + check_result_variables(result, expected, "coords") + if isinstance(result, Dataset): + check_result_variables(result, expected, "data_vars") + + return make_dict(result) + + if template is not None and not isinstance(template, (DataArray, Dataset)): + raise TypeError( + f"template must be a DataArray or Dataset. Received {type(template).__name__} instead." + ) + if not isinstance(args, Sequence): + raise TypeError("args must be a sequence (for example, a list or tuple).") + if kwargs is None: + kwargs = {} + elif not isinstance(kwargs, Mapping): + raise TypeError("kwargs must be a mapping (for example, a dict)") + + for value in kwargs.values(): + if is_dask_collection(value): + raise TypeError( + "Cannot pass dask collections in kwargs yet. Please compute or " + "load values before passing to map_blocks." + ) + + if not is_dask_collection(obj): + return func(obj, *args, **kwargs) + + try: + import dask + import dask.array + from dask.highlevelgraph import HighLevelGraph + + except ImportError: + pass + + all_args = [obj] + list(args) + is_xarray = [isinstance(arg, (Dataset, DataArray)) for arg in all_args] + is_array = [isinstance(arg, DataArray) for arg in all_args] + + # there should be a better way to group this. partition? + xarray_indices, xarray_objs = unzip( + (index, arg) for index, arg in enumerate(all_args) if is_xarray[index] + ) + others = [ + (index, arg) for index, arg in enumerate(all_args) if not is_xarray[index] + ] + + # all xarray objects must be aligned. This is consistent with apply_ufunc. + aligned = align(*xarray_objs, join="exact") + xarray_objs = tuple( + dataarray_to_dataset(arg) if isinstance(arg, DataArray) else arg + for arg in aligned + ) + # rechunk any numpy variables appropriately + xarray_objs = tuple(arg.chunk(arg.chunksizes) for arg in xarray_objs) + + merged_coordinates = merge([arg.coords for arg in aligned]).coords + + _, npargs = unzip( + sorted(list(zip(xarray_indices, xarray_objs)) + others, key=lambda x: x[0]) + ) + + # check that chunk sizes are compatible + input_chunks = dict(npargs[0].chunks) + for arg in xarray_objs[1:]: + assert_chunks_compatible(npargs[0], arg) + input_chunks.update(arg.chunks) + + coordinates: Coordinates + if template is None: + # infer template by providing zero-shaped arrays + template = infer_template(func, aligned[0], *args, **kwargs) + template_coords = set(template.coords) + preserved_coord_vars = template_coords & set(merged_coordinates) + new_coord_vars = template_coords - set(merged_coordinates) + + preserved_coords = merged_coordinates.to_dataset()[preserved_coord_vars] + # preserved_coords contains all coordinates variables that share a dimension + # with any index variable in preserved_indexes + # Drop any unneeded vars in a second pass, this is required for e.g. + # if the mapped function were to drop a non-dimension coordinate variable. + preserved_coords = preserved_coords.drop_vars( + tuple(k for k in preserved_coords.variables if k not in template_coords) + ) + + coordinates = merge( + (preserved_coords, template.coords.to_dataset()[new_coord_vars]) + ).coords + output_chunks: Mapping[Hashable, tuple[int, ...]] = { + dim: input_chunks[dim] for dim in template.dims if dim in input_chunks + } + + else: + # template xarray object has been provided with proper sizes and chunk shapes + coordinates = template.coords + output_chunks = template.chunksizes + if not output_chunks: + raise ValueError( + "Provided template has no dask arrays. " + " Please construct a template with appropriately chunked dask arrays." + ) + + new_indexes = set(template.xindexes) - set(merged_coordinates) + modified_indexes = set( + name + for name, xindex in coordinates.xindexes.items() + if not xindex.equals(merged_coordinates.xindexes.get(name, None)) + ) + + for dim in output_chunks: + if dim in input_chunks and len(input_chunks[dim]) != len(output_chunks[dim]): + raise ValueError( + "map_blocks requires that one block of the input maps to one block of output. " + f"Expected number of output chunks along dimension {dim!r} to be {len(input_chunks[dim])}. " + f"Received {len(output_chunks[dim])} instead. Please provide template if not provided, or " + "fix the provided template." + ) + + if isinstance(template, DataArray): + result_is_array = True + template_name = template.name + template = template._to_temp_dataset() + elif isinstance(template, Dataset): + result_is_array = False + else: + raise TypeError( + f"func output must be DataArray or Dataset; got {type(template)}" + ) + + # We're building a new HighLevelGraph hlg. We'll have one new layer + # for each variable in the dataset, which is the result of the + # func applied to the values. + + graph: dict[Any, Any] = {} + new_layers: collections.defaultdict[str, dict[Any, Any]] = collections.defaultdict( + dict + ) + gname = f"{dask.utils.funcname(func)}-{dask.base.tokenize(npargs[0], args, kwargs)}" + + # map dims to list of chunk indexes + ichunk = {dim: range(len(chunks_v)) for dim, chunks_v in input_chunks.items()} + # mapping from chunk index to slice bounds + input_chunk_bounds = { + dim: np.cumsum((0,) + chunks_v) for dim, chunks_v in input_chunks.items() + } + output_chunk_bounds = { + dim: np.cumsum((0,) + chunks_v) for dim, chunks_v in output_chunks.items() + } + + computed_variables = set(template.variables) - set(coordinates.indexes) + # iterate over all possible chunk combinations + for chunk_tuple in itertools.product(*ichunk.values()): + # mapping from dimension name to chunk index + chunk_index = dict(zip(ichunk.keys(), chunk_tuple)) + + blocked_args = [ + ( + subset_dataset_to_block( + graph, gname, arg, input_chunk_bounds, chunk_index + ) + if isxr + else arg + ) + for isxr, arg in zip(is_xarray, npargs) + ] + + # raise nice error messages in _wrapper + expected: ExpectedDict = { + # input chunk 0 along a dimension maps to output chunk 0 along the same dimension + # even if length of dimension is changed by the applied function + "shapes": { + k: output_chunks[k][v] + for k, v in chunk_index.items() + if k in output_chunks + }, + "data_vars": set(template.data_vars.keys()), + "coords": set(template.coords.keys()), + # only include new or modified indexes to minimize duplication of data, and graph size. + "indexes": { + dim: coordinates.xindexes[dim][ + _get_chunk_slicer(dim, chunk_index, output_chunk_bounds) + ] + for dim in (new_indexes | modified_indexes) + }, + } + + from_wrapper = (gname,) + chunk_tuple + graph[from_wrapper] = (_wrapper, func, blocked_args, kwargs, is_array, expected) + + # mapping from variable name to dask graph key + var_key_map: dict[Hashable, str] = {} + for name in computed_variables: + variable = template.variables[name] + gname_l = f"{name}-{gname}" + var_key_map[name] = gname_l + + # unchunked dimensions in the input have one chunk in the result + # output can have new dimensions with exactly one chunk + key: tuple[Any, ...] = (gname_l,) + tuple( + chunk_index[dim] if dim in chunk_index else 0 for dim in variable.dims + ) + + # We're adding multiple new layers to the graph: + # The first new layer is the result of the computation on + # the array. + # Then we add one layer per variable, which extracts the + # result for that variable, and depends on just the first new + # layer. + new_layers[gname_l][key] = (operator.getitem, from_wrapper, name) + + hlg = HighLevelGraph.from_collections( + gname, + graph, + dependencies=[arg for arg in npargs if dask.is_dask_collection(arg)], + ) + + # This adds in the getitems for each variable in the dataset. + hlg = HighLevelGraph( + {**hlg.layers, **new_layers}, + dependencies={ + **hlg.dependencies, + **{name: {gname} for name in new_layers.keys()}, + }, + ) + + result = Dataset(coords=coordinates, attrs=template.attrs) + + for index in result._indexes: + result[index].attrs = template[index].attrs + result[index].encoding = template[index].encoding + + for name, gname_l in var_key_map.items(): + dims = template[name].dims + var_chunks = [] + for dim in dims: + if dim in output_chunks: + var_chunks.append(output_chunks[dim]) + elif dim in result._indexes: + var_chunks.append((result.sizes[dim],)) + elif dim in template.dims: + # new unindexed dimension + var_chunks.append((template.sizes[dim],)) + + data = dask.array.Array( + hlg, name=gname_l, chunks=var_chunks, dtype=template[name].dtype + ) + result[name] = (dims, data, template[name].attrs) + result[name].encoding = template[name].encoding + + result = result.set_coords(template._coord_names) + + if result_is_array: + da = dataset_to_dataarray(result) + da.name = template_name + return da # type: ignore[return-value] + return result # type: ignore[return-value] diff --git a/test/fixtures/whole_applications/xarray/xarray/core/pdcompat.py b/test/fixtures/whole_applications/xarray/xarray/core/pdcompat.py new file mode 100644 index 0000000..c09dd82 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/core/pdcompat.py @@ -0,0 +1,107 @@ +# For reference, here is a copy of the pandas copyright notice: + +# (c) 2011-2012, Lambda Foundry, Inc. and PyData Development Team +# All rights reserved. + +# Copyright (c) 2008-2011 AQR Capital Management, LLC +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: + +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. + +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided +# with the distribution. + +# * Neither the name of the copyright holder nor the names of any +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDER AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from __future__ import annotations + +from enum import Enum +from typing import Literal + +import pandas as pd +from packaging.version import Version + +from xarray.coding import cftime_offsets + + +def count_not_none(*args) -> int: + """Compute the number of non-None arguments. + + Copied from pandas.core.common.count_not_none (not part of the public API) + """ + return sum(arg is not None for arg in args) + + +class _NoDefault(Enum): + """Used by pandas to specify a default value for a deprecated argument. + Copied from pandas._libs.lib._NoDefault. + + See also: + - pandas-dev/pandas#30788 + - pandas-dev/pandas#40684 + - pandas-dev/pandas#40715 + - pandas-dev/pandas#47045 + """ + + no_default = "NO_DEFAULT" + + def __repr__(self) -> str: + return "" + + +no_default = ( + _NoDefault.no_default +) # Sentinel indicating the default value following pandas +NoDefault = Literal[_NoDefault.no_default] # For typing following pandas + + +def _convert_base_to_offset(base, freq, index): + """Required until we officially deprecate the base argument to resample. This + translates a provided `base` argument to an `offset` argument, following logic + from pandas. + """ + from xarray.coding.cftimeindex import CFTimeIndex + + if isinstance(index, pd.DatetimeIndex): + freq = cftime_offsets._new_to_legacy_freq(freq) + freq = pd.tseries.frequencies.to_offset(freq) + if isinstance(freq, pd.offsets.Tick): + return pd.Timedelta(base * freq.nanos // freq.n) + elif isinstance(index, CFTimeIndex): + freq = cftime_offsets.to_offset(freq) + if isinstance(freq, cftime_offsets.Tick): + return base * freq.as_timedelta() // freq.n + else: + raise ValueError("Can only resample using a DatetimeIndex or CFTimeIndex.") + + +def nanosecond_precision_timestamp(*args, **kwargs) -> pd.Timestamp: + """Return a nanosecond-precision Timestamp object. + + Note this function should no longer be needed after addressing GitHub issue + #7493. + """ + if Version(pd.__version__) >= Version("2.0.0"): + return pd.Timestamp(*args, **kwargs).as_unit("ns") + else: + return pd.Timestamp(*args, **kwargs) diff --git a/test/fixtures/whole_applications/xarray/xarray/core/resample.py b/test/fixtures/whole_applications/xarray/xarray/core/resample.py new file mode 100644 index 0000000..9181bb4 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/core/resample.py @@ -0,0 +1,495 @@ +from __future__ import annotations + +import warnings +from collections.abc import Hashable, Iterable, Sequence +from typing import TYPE_CHECKING, Any, Callable + +from xarray.core._aggregations import ( + DataArrayResampleAggregations, + DatasetResampleAggregations, +) +from xarray.core.groupby import DataArrayGroupByBase, DatasetGroupByBase, GroupBy +from xarray.core.types import Dims, InterpOptions, T_Xarray + +if TYPE_CHECKING: + from xarray.core.dataarray import DataArray + from xarray.core.dataset import Dataset + +RESAMPLE_DIM = "__resample_dim__" + + +class Resample(GroupBy[T_Xarray]): + """An object that extends the `GroupBy` object with additional logic + for handling specialized re-sampling operations. + + You should create a `Resample` object by using the `DataArray.resample` or + `Dataset.resample` methods. The dimension along re-sampling + + See Also + -------- + DataArray.resample + Dataset.resample + + """ + + def __init__( + self, + *args, + dim: Hashable | None = None, + resample_dim: Hashable | None = None, + **kwargs, + ) -> None: + if dim == resample_dim: + raise ValueError( + f"Proxy resampling dimension ('{resample_dim}') " + f"cannot have the same name as actual dimension ('{dim}')!" + ) + self._dim = dim + + super().__init__(*args, **kwargs) + + def _flox_reduce( + self, + dim: Dims, + keep_attrs: bool | None = None, + **kwargs, + ) -> T_Xarray: + result = super()._flox_reduce(dim=dim, keep_attrs=keep_attrs, **kwargs) + result = result.rename({RESAMPLE_DIM: self._group_dim}) + return result + + def _drop_coords(self) -> T_Xarray: + """Drop non-dimension coordinates along the resampled dimension.""" + obj = self._obj + for k, v in obj.coords.items(): + if k != self._dim and self._dim in v.dims: + obj = obj.drop_vars([k]) + return obj + + def pad(self, tolerance: float | Iterable[float] | None = None) -> T_Xarray: + """Forward fill new values at up-sampled frequency. + + Parameters + ---------- + tolerance : float | Iterable[float] | None, default: None + Maximum distance between original and new labels to limit + the up-sampling method. + Up-sampled data with indices that satisfy the equation + ``abs(index[indexer] - target) <= tolerance`` are filled by + new values. Data with indices that are outside the given + tolerance are filled with ``NaN`` s. + + Returns + ------- + padded : DataArray or Dataset + """ + obj = self._drop_coords() + (grouper,) = self.groupers + return obj.reindex( + {self._dim: grouper.full_index}, method="pad", tolerance=tolerance + ) + + ffill = pad + + def backfill(self, tolerance: float | Iterable[float] | None = None) -> T_Xarray: + """Backward fill new values at up-sampled frequency. + + Parameters + ---------- + tolerance : float | Iterable[float] | None, default: None + Maximum distance between original and new labels to limit + the up-sampling method. + Up-sampled data with indices that satisfy the equation + ``abs(index[indexer] - target) <= tolerance`` are filled by + new values. Data with indices that are outside the given + tolerance are filled with ``NaN`` s. + + Returns + ------- + backfilled : DataArray or Dataset + """ + obj = self._drop_coords() + (grouper,) = self.groupers + return obj.reindex( + {self._dim: grouper.full_index}, method="backfill", tolerance=tolerance + ) + + bfill = backfill + + def nearest(self, tolerance: float | Iterable[float] | None = None) -> T_Xarray: + """Take new values from nearest original coordinate to up-sampled + frequency coordinates. + + Parameters + ---------- + tolerance : float | Iterable[float] | None, default: None + Maximum distance between original and new labels to limit + the up-sampling method. + Up-sampled data with indices that satisfy the equation + ``abs(index[indexer] - target) <= tolerance`` are filled by + new values. Data with indices that are outside the given + tolerance are filled with ``NaN`` s. + + Returns + ------- + upsampled : DataArray or Dataset + """ + obj = self._drop_coords() + (grouper,) = self.groupers + return obj.reindex( + {self._dim: grouper.full_index}, method="nearest", tolerance=tolerance + ) + + def interpolate(self, kind: InterpOptions = "linear", **kwargs) -> T_Xarray: + """Interpolate up-sampled data using the original data as knots. + + Parameters + ---------- + kind : {"linear", "nearest", "zero", "slinear", \ + "quadratic", "cubic", "polynomial"}, default: "linear" + The method used to interpolate. The method should be supported by + the scipy interpolator: + + - ``interp1d``: {"linear", "nearest", "zero", "slinear", + "quadratic", "cubic", "polynomial"} + - ``interpn``: {"linear", "nearest"} + + If ``"polynomial"`` is passed, the ``order`` keyword argument must + also be provided. + + Returns + ------- + interpolated : DataArray or Dataset + + See Also + -------- + DataArray.interp + Dataset.interp + scipy.interpolate.interp1d + + """ + return self._interpolate(kind=kind, **kwargs) + + def _interpolate(self, kind="linear", **kwargs) -> T_Xarray: + """Apply scipy.interpolate.interp1d along resampling dimension.""" + obj = self._drop_coords() + (grouper,) = self.groupers + kwargs.setdefault("bounds_error", False) + return obj.interp( + coords={self._dim: grouper.full_index}, + assume_sorted=True, + method=kind, + kwargs=kwargs, + ) + + +# https://github.com/python/mypy/issues/9031 +class DataArrayResample(Resample["DataArray"], DataArrayGroupByBase, DataArrayResampleAggregations): # type: ignore[misc] + """DataArrayGroupBy object specialized to time resampling operations over a + specified dimension + """ + + def reduce( + self, + func: Callable[..., Any], + dim: Dims = None, + *, + axis: int | Sequence[int] | None = None, + keep_attrs: bool | None = None, + keepdims: bool = False, + shortcut: bool = True, + **kwargs: Any, + ) -> DataArray: + """Reduce the items in this group by applying `func` along the + pre-defined resampling dimension. + + Parameters + ---------- + func : callable + Function which can be called in the form + `func(x, axis=axis, **kwargs)` to return the result of collapsing + an np.ndarray over an integer valued axis. + dim : "...", str, Iterable of Hashable or None, optional + Dimension(s) over which to apply `func`. + keep_attrs : bool, optional + If True, the datasets's attributes (`attrs`) will be copied from + the original object to the new one. If False (default), the new + object will be returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to `func`. + + Returns + ------- + reduced : DataArray + Array with summarized data and the indicated dimension(s) + removed. + """ + return super().reduce( + func=func, + dim=dim, + axis=axis, + keep_attrs=keep_attrs, + keepdims=keepdims, + shortcut=shortcut, + **kwargs, + ) + + def map( + self, + func: Callable[..., Any], + args: tuple[Any, ...] = (), + shortcut: bool | None = False, + **kwargs: Any, + ) -> DataArray: + """Apply a function to each array in the group and concatenate them + together into a new array. + + `func` is called like `func(ar, *args, **kwargs)` for each array `ar` + in this group. + + Apply uses heuristics (like `pandas.GroupBy.apply`) to figure out how + to stack together the array. The rule is: + + 1. If the dimension along which the group coordinate is defined is + still in the first grouped array after applying `func`, then stack + over this dimension. + 2. Otherwise, stack over the new dimension given by name of this + grouping (the argument to the `groupby` function). + + Parameters + ---------- + func : callable + Callable to apply to each array. + shortcut : bool, optional + Whether or not to shortcut evaluation under the assumptions that: + + (1) The action of `func` does not depend on any of the array + metadata (attributes or coordinates) but only on the data and + dimensions. + (2) The action of `func` creates arrays with homogeneous metadata, + that is, with the same dimensions and attributes. + + If these conditions are satisfied `shortcut` provides significant + speedup. This should be the case for many common groupby operations + (e.g., applying numpy ufuncs). + args : tuple, optional + Positional arguments passed on to `func`. + **kwargs + Used to call `func(ar, **kwargs)` for each array `ar`. + + Returns + ------- + applied : DataArray + The result of splitting, applying and combining this array. + """ + return self._map_maybe_warn(func, args, shortcut, warn_squeeze=True, **kwargs) + + def _map_maybe_warn( + self, + func: Callable[..., Any], + args: tuple[Any, ...] = (), + shortcut: bool | None = False, + warn_squeeze: bool = True, + **kwargs: Any, + ) -> DataArray: + # TODO: the argument order for Resample doesn't match that for its parent, + # GroupBy + combined = super()._map_maybe_warn( + func, shortcut=shortcut, args=args, warn_squeeze=warn_squeeze, **kwargs + ) + + # If the aggregation function didn't drop the original resampling + # dimension, then we need to do so before we can rename the proxy + # dimension we used. + if self._dim in combined.coords: + combined = combined.drop_vars([self._dim]) + + if RESAMPLE_DIM in combined.dims: + combined = combined.rename({RESAMPLE_DIM: self._dim}) + + return combined + + def apply(self, func, args=(), shortcut=None, **kwargs): + """ + Backward compatible implementation of ``map`` + + See Also + -------- + DataArrayResample.map + """ + warnings.warn( + "Resample.apply may be deprecated in the future. Using Resample.map is encouraged", + PendingDeprecationWarning, + stacklevel=2, + ) + return self.map(func=func, shortcut=shortcut, args=args, **kwargs) + + def asfreq(self) -> DataArray: + """Return values of original object at the new up-sampling frequency; + essentially a re-index with new times set to NaN. + + Returns + ------- + resampled : DataArray + """ + self._obj = self._drop_coords() + return self.mean(None if self._dim is None else [self._dim]) + + +# https://github.com/python/mypy/issues/9031 +class DatasetResample(Resample["Dataset"], DatasetGroupByBase, DatasetResampleAggregations): # type: ignore[misc] + """DatasetGroupBy object specialized to resampling a specified dimension""" + + def map( + self, + func: Callable[..., Any], + args: tuple[Any, ...] = (), + shortcut: bool | None = None, + **kwargs: Any, + ) -> Dataset: + """Apply a function over each Dataset in the groups generated for + resampling and concatenate them together into a new Dataset. + + `func` is called like `func(ds, *args, **kwargs)` for each dataset `ds` + in this group. + + Apply uses heuristics (like `pandas.GroupBy.apply`) to figure out how + to stack together the datasets. The rule is: + + 1. If the dimension along which the group coordinate is defined is + still in the first grouped item after applying `func`, then stack + over this dimension. + 2. Otherwise, stack over the new dimension given by name of this + grouping (the argument to the `groupby` function). + + Parameters + ---------- + func : callable + Callable to apply to each sub-dataset. + args : tuple, optional + Positional arguments passed on to `func`. + **kwargs + Used to call `func(ds, **kwargs)` for each sub-dataset `ar`. + + Returns + ------- + applied : Dataset + The result of splitting, applying and combining this dataset. + """ + return self._map_maybe_warn(func, args, shortcut, warn_squeeze=True, **kwargs) + + def _map_maybe_warn( + self, + func: Callable[..., Any], + args: tuple[Any, ...] = (), + shortcut: bool | None = None, + warn_squeeze: bool = True, + **kwargs: Any, + ) -> Dataset: + # ignore shortcut if set (for now) + applied = (func(ds, *args, **kwargs) for ds in self._iter_grouped(warn_squeeze)) + combined = self._combine(applied) + + # If the aggregation function didn't drop the original resampling + # dimension, then we need to do so before we can rename the proxy + # dimension we used. + if self._dim in combined.coords: + combined = combined.drop_vars(self._dim) + + if RESAMPLE_DIM in combined.dims: + combined = combined.rename({RESAMPLE_DIM: self._dim}) + + return combined + + def apply(self, func, args=(), shortcut=None, **kwargs): + """ + Backward compatible implementation of ``map`` + + See Also + -------- + DataSetResample.map + """ + + warnings.warn( + "Resample.apply may be deprecated in the future. Using Resample.map is encouraged", + PendingDeprecationWarning, + stacklevel=2, + ) + return self.map(func=func, shortcut=shortcut, args=args, **kwargs) + + def reduce( + self, + func: Callable[..., Any], + dim: Dims = None, + *, + axis: int | Sequence[int] | None = None, + keep_attrs: bool | None = None, + keepdims: bool = False, + shortcut: bool = True, + **kwargs: Any, + ) -> Dataset: + """Reduce the items in this group by applying `func` along the + pre-defined resampling dimension. + + Parameters + ---------- + func : callable + Function which can be called in the form + `func(x, axis=axis, **kwargs)` to return the result of collapsing + an np.ndarray over an integer valued axis. + dim : "...", str, Iterable of Hashable or None, optional + Dimension(s) over which to apply `func`. + keep_attrs : bool, optional + If True, the datasets's attributes (`attrs`) will be copied from + the original object to the new one. If False (default), the new + object will be returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to `func`. + + Returns + ------- + reduced : Dataset + Array with summarized data and the indicated dimension(s) + removed. + """ + return super().reduce( + func=func, + dim=dim, + axis=axis, + keep_attrs=keep_attrs, + keepdims=keepdims, + shortcut=shortcut, + **kwargs, + ) + + def _reduce_without_squeeze_warn( + self, + func: Callable[..., Any], + dim: Dims = None, + *, + axis: int | Sequence[int] | None = None, + keep_attrs: bool | None = None, + keepdims: bool = False, + shortcut: bool = True, + **kwargs: Any, + ) -> Dataset: + return super()._reduce_without_squeeze_warn( + func=func, + dim=dim, + axis=axis, + keep_attrs=keep_attrs, + keepdims=keepdims, + shortcut=shortcut, + **kwargs, + ) + + def asfreq(self) -> Dataset: + """Return values of original object at the new up-sampling frequency; + essentially a re-index with new times set to NaN. + + Returns + ------- + resampled : Dataset + """ + self._obj = self._drop_coords() + return self.mean(None if self._dim is None else [self._dim]) diff --git a/test/fixtures/whole_applications/xarray/xarray/core/resample_cftime.py b/test/fixtures/whole_applications/xarray/xarray/core/resample_cftime.py new file mode 100644 index 0000000..216bd8f --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/core/resample_cftime.py @@ -0,0 +1,510 @@ +"""Resampling for CFTimeIndex. Does not support non-integer freq.""" + +# The mechanisms for resampling CFTimeIndex was copied and adapted from +# the source code defined in pandas.core.resample +# +# For reference, here is a copy of the pandas copyright notice: +# +# BSD 3-Clause License +# +# Copyright (c) 2008-2012, AQR Capital Management, LLC, Lambda Foundry, Inc. +# and PyData Development Team +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +from __future__ import annotations + +import datetime +import typing + +import numpy as np +import pandas as pd + +from xarray.coding.cftime_offsets import ( + BaseCFTimeOffset, + MonthEnd, + QuarterEnd, + Tick, + YearEnd, + cftime_range, + normalize_date, + to_offset, +) +from xarray.coding.cftimeindex import CFTimeIndex +from xarray.core.types import SideOptions + +if typing.TYPE_CHECKING: + from xarray.core.types import CFTimeDatetime + + +class CFTimeGrouper: + """This is a simple container for the grouping parameters that implements a + single method, the only one required for resampling in xarray. It cannot + be used in a call to groupby like a pandas.Grouper object can.""" + + def __init__( + self, + freq: str | BaseCFTimeOffset, + closed: SideOptions | None = None, + label: SideOptions | None = None, + loffset: str | datetime.timedelta | BaseCFTimeOffset | None = None, + origin: str | CFTimeDatetime = "start_day", + offset: str | datetime.timedelta | None = None, + ): + self.offset: datetime.timedelta | None + self.closed: SideOptions + self.label: SideOptions + self.freq = to_offset(freq) + self.loffset = loffset + self.origin = origin + + if isinstance(self.freq, (MonthEnd, QuarterEnd, YearEnd)): + if closed is None: + self.closed = "right" + else: + self.closed = closed + if label is None: + self.label = "right" + else: + self.label = label + else: + # The backward resample sets ``closed`` to ``'right'`` by default + # since the last value should be considered as the edge point for + # the last bin. When origin in "end" or "end_day", the value for a + # specific ``cftime.datetime`` index stands for the resample result + # from the current ``cftime.datetime`` minus ``freq`` to the current + # ``cftime.datetime`` with a right close. + if self.origin in ["end", "end_day"]: + if closed is None: + self.closed = "right" + else: + self.closed = closed + if label is None: + self.label = "right" + else: + self.label = label + else: + if closed is None: + self.closed = "left" + else: + self.closed = closed + if label is None: + self.label = "left" + else: + self.label = label + + if offset is not None: + try: + self.offset = _convert_offset_to_timedelta(offset) + except (ValueError, AttributeError) as error: + raise ValueError( + f"offset must be a datetime.timedelta object or an offset string " + f"that can be converted to a timedelta. Got {offset} instead." + ) from error + else: + self.offset = None + + def first_items(self, index: CFTimeIndex): + """Meant to reproduce the results of the following + + grouper = pandas.Grouper(...) + first_items = pd.Series(np.arange(len(index)), + index).groupby(grouper).first() + + with index being a CFTimeIndex instead of a DatetimeIndex. + """ + + datetime_bins, labels = _get_time_bins( + index, self.freq, self.closed, self.label, self.origin, self.offset + ) + if self.loffset is not None: + if not isinstance( + self.loffset, (str, datetime.timedelta, BaseCFTimeOffset) + ): + # BaseCFTimeOffset is not public API so we do not include it in + # the error message for now. + raise ValueError( + f"`loffset` must be a str or datetime.timedelta object. " + f"Got {self.loffset}." + ) + + if isinstance(self.loffset, datetime.timedelta): + labels = labels + self.loffset + else: + labels = labels + to_offset(self.loffset) + + # check binner fits data + if index[0] < datetime_bins[0]: + raise ValueError("Value falls before first bin") + if index[-1] > datetime_bins[-1]: + raise ValueError("Value falls after last bin") + + integer_bins = np.searchsorted(index, datetime_bins, side=self.closed) + counts = np.diff(integer_bins) + codes = np.repeat(np.arange(len(labels)), counts) + first_items = pd.Series(integer_bins[:-1], labels, copy=False) + + # Mask duplicate values with NaNs, preserving the last values + non_duplicate = ~first_items.duplicated("last") + return first_items.where(non_duplicate), codes + + +def _get_time_bins( + index: CFTimeIndex, + freq: BaseCFTimeOffset, + closed: SideOptions, + label: SideOptions, + origin: str | CFTimeDatetime, + offset: datetime.timedelta | None, +): + """Obtain the bins and their respective labels for resampling operations. + + Parameters + ---------- + index : CFTimeIndex + Index object to be resampled (e.g., CFTimeIndex named 'time'). + freq : xarray.coding.cftime_offsets.BaseCFTimeOffset + The offset object representing target conversion a.k.a. resampling + frequency (e.g., 'MS', '2D', 'H', or '3T' with + coding.cftime_offsets.to_offset() applied to it). + closed : 'left' or 'right' + Which side of bin interval is closed. + The default is 'left' for all frequency offsets except for 'M' and 'A', + which have a default of 'right'. + label : 'left' or 'right' + Which bin edge label to label bucket with. + The default is 'left' for all frequency offsets except for 'M' and 'A', + which have a default of 'right'. + origin : {'epoch', 'start', 'start_day', 'end', 'end_day'} or cftime.datetime, default 'start_day' + The datetime on which to adjust the grouping. The timezone of origin + must match the timezone of the index. + + If a datetime is not used, these values are also supported: + - 'epoch': `origin` is 1970-01-01 + - 'start': `origin` is the first value of the timeseries + - 'start_day': `origin` is the first day at midnight of the timeseries + - 'end': `origin` is the last value of the timeseries + - 'end_day': `origin` is the ceiling midnight of the last day + offset : datetime.timedelta, default is None + An offset timedelta added to the origin. + + Returns + ------- + datetime_bins : CFTimeIndex + Defines the edge of resampling bins by which original index values will + be grouped into. + labels : CFTimeIndex + Define what the user actually sees the bins labeled as. + """ + + if not isinstance(index, CFTimeIndex): + raise TypeError( + "index must be a CFTimeIndex, but got " + f"an instance of {type(index).__name__!r}" + ) + if len(index) == 0: + datetime_bins = labels = CFTimeIndex(data=[], name=index.name) + return datetime_bins, labels + + first, last = _get_range_edges( + index.min(), index.max(), freq, closed=closed, origin=origin, offset=offset + ) + datetime_bins = labels = cftime_range( + freq=freq, start=first, end=last, name=index.name + ) + + datetime_bins, labels = _adjust_bin_edges( + datetime_bins, freq, closed, index, labels + ) + + labels = labels[1:] if label == "right" else labels[:-1] + # TODO: when CFTimeIndex supports missing values, if the reference index + # contains missing values, insert the appropriate NaN value at the + # beginning of the datetime_bins and labels indexes. + + return datetime_bins, labels + + +def _adjust_bin_edges( + datetime_bins: np.ndarray, + freq: BaseCFTimeOffset, + closed: SideOptions, + index: CFTimeIndex, + labels: np.ndarray, +): + """This is required for determining the bin edges resampling with + month end, quarter end, and year end frequencies. + + Consider the following example. Let's say you want to downsample the + time series with the following coordinates to month end frequency: + + CFTimeIndex([2000-01-01 12:00:00, 2000-01-31 12:00:00, + 2000-02-01 12:00:00], dtype='object') + + Without this adjustment, _get_time_bins with month-end frequency will + return the following index for the bin edges (default closed='right' and + label='right' in this case): + + CFTimeIndex([1999-12-31 00:00:00, 2000-01-31 00:00:00, + 2000-02-29 00:00:00], dtype='object') + + If 2000-01-31 is used as a bound for a bin, the value on + 2000-01-31T12:00:00 (at noon on January 31st), will not be included in the + month of January. To account for this, pandas adds a day minus one worth + of microseconds to the bin edges generated by cftime range, so that we do + bin the value at noon on January 31st in the January bin. This results in + an index with bin edges like the following: + + CFTimeIndex([1999-12-31 23:59:59, 2000-01-31 23:59:59, + 2000-02-29 23:59:59], dtype='object') + + The labels are still: + + CFTimeIndex([2000-01-31 00:00:00, 2000-02-29 00:00:00], dtype='object') + """ + if isinstance(freq, (MonthEnd, QuarterEnd, YearEnd)): + if closed == "right": + datetime_bins = datetime_bins + datetime.timedelta(days=1, microseconds=-1) + if datetime_bins[-2] > index.max(): + datetime_bins = datetime_bins[:-1] + labels = labels[:-1] + + return datetime_bins, labels + + +def _get_range_edges( + first: CFTimeDatetime, + last: CFTimeDatetime, + freq: BaseCFTimeOffset, + closed: SideOptions = "left", + origin: str | CFTimeDatetime = "start_day", + offset: datetime.timedelta | None = None, +): + """Get the correct starting and ending datetimes for the resampled + CFTimeIndex range. + + Parameters + ---------- + first : cftime.datetime + Uncorrected starting datetime object for resampled CFTimeIndex range. + Usually the min of the original CFTimeIndex. + last : cftime.datetime + Uncorrected ending datetime object for resampled CFTimeIndex range. + Usually the max of the original CFTimeIndex. + freq : xarray.coding.cftime_offsets.BaseCFTimeOffset + The offset object representing target conversion a.k.a. resampling + frequency. Contains information on offset type (e.g. Day or 'D') and + offset magnitude (e.g., n = 3). + closed : 'left' or 'right' + Which side of bin interval is closed. Defaults to 'left'. + origin : {'epoch', 'start', 'start_day', 'end', 'end_day'} or cftime.datetime, default 'start_day' + The datetime on which to adjust the grouping. The timezone of origin + must match the timezone of the index. + + If a datetime is not used, these values are also supported: + - 'epoch': `origin` is 1970-01-01 + - 'start': `origin` is the first value of the timeseries + - 'start_day': `origin` is the first day at midnight of the timeseries + - 'end': `origin` is the last value of the timeseries + - 'end_day': `origin` is the ceiling midnight of the last day + offset : datetime.timedelta, default is None + An offset timedelta added to the origin. + + Returns + ------- + first : cftime.datetime + Corrected starting datetime object for resampled CFTimeIndex range. + last : cftime.datetime + Corrected ending datetime object for resampled CFTimeIndex range. + """ + if isinstance(freq, Tick): + first, last = _adjust_dates_anchored( + first, last, freq, closed=closed, origin=origin, offset=offset + ) + return first, last + else: + first = normalize_date(first) + last = normalize_date(last) + + first = freq.rollback(first) if closed == "left" else first - freq + last = last + freq + return first, last + + +def _adjust_dates_anchored( + first: CFTimeDatetime, + last: CFTimeDatetime, + freq: Tick, + closed: SideOptions = "right", + origin: str | CFTimeDatetime = "start_day", + offset: datetime.timedelta | None = None, +): + """First and last offsets should be calculated from the start day to fix + an error cause by resampling across multiple days when a one day period is + not a multiple of the frequency. + See https://github.com/pandas-dev/pandas/issues/8683 + + Parameters + ---------- + first : cftime.datetime + A datetime object representing the start of a CFTimeIndex range. + last : cftime.datetime + A datetime object representing the end of a CFTimeIndex range. + freq : xarray.coding.cftime_offsets.BaseCFTimeOffset + The offset object representing target conversion a.k.a. resampling + frequency. Contains information on offset type (e.g. Day or 'D') and + offset magnitude (e.g., n = 3). + closed : 'left' or 'right' + Which side of bin interval is closed. Defaults to 'right'. + origin : {'epoch', 'start', 'start_day', 'end', 'end_day'} or cftime.datetime, default 'start_day' + The datetime on which to adjust the grouping. The timezone of origin + must match the timezone of the index. + + If a datetime is not used, these values are also supported: + - 'epoch': `origin` is 1970-01-01 + - 'start': `origin` is the first value of the timeseries + - 'start_day': `origin` is the first day at midnight of the timeseries + - 'end': `origin` is the last value of the timeseries + - 'end_day': `origin` is the ceiling midnight of the last day + offset : datetime.timedelta, default is None + An offset timedelta added to the origin. + + Returns + ------- + fresult : cftime.datetime + A datetime object representing the start of a date range that has been + adjusted to fix resampling errors. + lresult : cftime.datetime + A datetime object representing the end of a date range that has been + adjusted to fix resampling errors. + """ + import cftime + + if origin == "start_day": + origin_date = normalize_date(first) + elif origin == "start": + origin_date = first + elif origin == "epoch": + origin_date = type(first)(1970, 1, 1) + elif origin in ["end", "end_day"]: + origin_last = last if origin == "end" else _ceil_via_cftimeindex(last, "D") + sub_freq_times = (origin_last - first) // freq.as_timedelta() + if closed == "left": + sub_freq_times += 1 + first = origin_last - sub_freq_times * freq + origin_date = first + elif isinstance(origin, cftime.datetime): + origin_date = origin + else: + raise ValueError( + f"origin must be one of {{'epoch', 'start_day', 'start', 'end', 'end_day'}} " + f"or a cftime.datetime object. Got {origin}." + ) + + if offset is not None: + origin_date = origin_date + offset + + foffset = (first - origin_date) % freq.as_timedelta() + loffset = (last - origin_date) % freq.as_timedelta() + + if closed == "right": + if foffset.total_seconds() > 0: + fresult = first - foffset + else: + fresult = first - freq.as_timedelta() + + if loffset.total_seconds() > 0: + lresult = last + (freq.as_timedelta() - loffset) + else: + lresult = last + else: + if foffset.total_seconds() > 0: + fresult = first - foffset + else: + fresult = first + + if loffset.total_seconds() > 0: + lresult = last + (freq.as_timedelta() - loffset) + else: + lresult = last + freq + return fresult, lresult + + +def exact_cftime_datetime_difference(a: CFTimeDatetime, b: CFTimeDatetime): + """Exact computation of b - a + + Assumes: + + a = a_0 + a_m + b = b_0 + b_m + + Here a_0, and b_0 represent the input dates rounded + down to the nearest second, and a_m, and b_m represent + the remaining microseconds associated with date a and + date b. + + We can then express the value of b - a as: + + b - a = (b_0 + b_m) - (a_0 + a_m) = b_0 - a_0 + b_m - a_m + + By construction, we know that b_0 - a_0 must be a round number + of seconds. Therefore we can take the result of b_0 - a_0 using + ordinary cftime.datetime arithmetic and round to the nearest + second. b_m - a_m is the remainder, in microseconds, and we + can simply add this to the rounded timedelta. + + Parameters + ---------- + a : cftime.datetime + Input datetime + b : cftime.datetime + Input datetime + + Returns + ------- + datetime.timedelta + """ + seconds = b.replace(microsecond=0) - a.replace(microsecond=0) + seconds = int(round(seconds.total_seconds())) + microseconds = b.microsecond - a.microsecond + return datetime.timedelta(seconds=seconds, microseconds=microseconds) + + +def _convert_offset_to_timedelta( + offset: datetime.timedelta | str | BaseCFTimeOffset, +) -> datetime.timedelta: + if isinstance(offset, datetime.timedelta): + return offset + elif isinstance(offset, (str, Tick)): + return to_offset(offset).as_timedelta() + else: + raise ValueError + + +def _ceil_via_cftimeindex(date: CFTimeDatetime, freq: str | BaseCFTimeOffset): + index = CFTimeIndex([date]) + return index.ceil(freq).item() diff --git a/test/fixtures/whole_applications/xarray/xarray/core/rolling.py b/test/fixtures/whole_applications/xarray/xarray/core/rolling.py new file mode 100644 index 0000000..6cf49fc --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/core/rolling.py @@ -0,0 +1,1267 @@ +from __future__ import annotations + +import functools +import itertools +import math +import warnings +from collections.abc import Hashable, Iterator, Mapping +from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar + +import numpy as np +from packaging.version import Version + +from xarray.core import dtypes, duck_array_ops, utils +from xarray.core.arithmetic import CoarsenArithmetic +from xarray.core.options import OPTIONS, _get_keep_attrs +from xarray.core.types import CoarsenBoundaryOptions, SideOptions, T_Xarray +from xarray.core.utils import ( + either_dict_or_kwargs, + is_duck_dask_array, + module_available, +) +from xarray.namedarray import pycompat + +try: + import bottleneck +except ImportError: + # use numpy methods instead + bottleneck = None + +if TYPE_CHECKING: + from xarray.core.dataarray import DataArray + from xarray.core.dataset import Dataset + + RollingKey = Any + _T = TypeVar("_T") + +_ROLLING_REDUCE_DOCSTRING_TEMPLATE = """\ +Reduce this object's data windows by applying `{name}` along its dimension. + +Parameters +---------- +keep_attrs : bool, default: None + If True, the attributes (``attrs``) will be copied from the original + object to the new one. If False, the new object will be returned + without attributes. If None uses the global default. +**kwargs : dict + Additional keyword arguments passed on to `{name}`. + +Returns +------- +reduced : same type as caller + New object with `{name}` applied along its rolling dimension. +""" + + +class Rolling(Generic[T_Xarray]): + """A object that implements the moving window pattern. + + See Also + -------- + xarray.Dataset.groupby + xarray.DataArray.groupby + xarray.Dataset.rolling + xarray.DataArray.rolling + """ + + __slots__ = ("obj", "window", "min_periods", "center", "dim") + _attributes = ("window", "min_periods", "center", "dim") + dim: list[Hashable] + window: list[int] + center: list[bool] + obj: T_Xarray + min_periods: int + + def __init__( + self, + obj: T_Xarray, + windows: Mapping[Any, int], + min_periods: int | None = None, + center: bool | Mapping[Any, bool] = False, + ) -> None: + """ + Moving window object. + + Parameters + ---------- + obj : Dataset or DataArray + Object to window. + windows : mapping of hashable to int + A mapping from the name of the dimension to create the rolling + window along (e.g. `time`) to the size of the moving window. + min_periods : int or None, default: None + Minimum number of observations in window required to have a value + (otherwise result is NA). The default, None, is equivalent to + setting min_periods equal to the size of the window. + center : bool or dict-like Hashable to bool, default: False + Set the labels at the center of the window. If dict-like, set this + property per rolling dimension. + + Returns + ------- + rolling : type of input argument + """ + self.dim = [] + self.window = [] + for d, w in windows.items(): + self.dim.append(d) + if w <= 0: + raise ValueError("window must be > 0") + self.window.append(w) + + self.center = self._mapping_to_list(center, default=False) + self.obj = obj + + missing_dims = tuple(dim for dim in self.dim if dim not in self.obj.dims) + if missing_dims: + # NOTE: we raise KeyError here but ValueError in Coarsen. + raise KeyError( + f"Window dimensions {missing_dims} not found in {self.obj.__class__.__name__} " + f"dimensions {tuple(self.obj.dims)}" + ) + + # attributes + if min_periods is not None and min_periods <= 0: + raise ValueError("min_periods must be greater than zero or None") + + self.min_periods = ( + math.prod(self.window) if min_periods is None else min_periods + ) + + def __repr__(self) -> str: + """provide a nice str repr of our rolling object""" + + attrs = [ + "{k}->{v}{c}".format(k=k, v=w, c="(center)" if c else "") + for k, w, c in zip(self.dim, self.window, self.center) + ] + return "{klass} [{attrs}]".format( + klass=self.__class__.__name__, attrs=",".join(attrs) + ) + + def __len__(self) -> int: + return math.prod(self.obj.sizes[d] for d in self.dim) + + @property + def ndim(self) -> int: + return len(self.dim) + + def _reduce_method( # type: ignore[misc] + name: str, fillna: Any, rolling_agg_func: Callable | None = None + ) -> Callable[..., T_Xarray]: + """Constructs reduction methods built on a numpy reduction function (e.g. sum), + a numbagg reduction function (e.g. move_sum), a bottleneck reduction function + (e.g. move_sum), or a Rolling reduction (_mean). + + The logic here for which function to run is quite diffuse, across this method & + _array_reduce. Arguably we could refactor this. But one constraint is that we + need context of xarray options, of the functions each library offers, of + the array (e.g. dtype). + """ + if rolling_agg_func: + array_agg_func = None + else: + array_agg_func = getattr(duck_array_ops, name) + + bottleneck_move_func = getattr(bottleneck, "move_" + name, None) + if module_available("numbagg"): + import numbagg + + numbagg_move_func = getattr(numbagg, "move_" + name, None) + else: + numbagg_move_func = None + + def method(self, keep_attrs=None, **kwargs): + keep_attrs = self._get_keep_attrs(keep_attrs) + + return self._array_reduce( + array_agg_func=array_agg_func, + bottleneck_move_func=bottleneck_move_func, + numbagg_move_func=numbagg_move_func, + rolling_agg_func=rolling_agg_func, + keep_attrs=keep_attrs, + fillna=fillna, + **kwargs, + ) + + method.__name__ = name + method.__doc__ = _ROLLING_REDUCE_DOCSTRING_TEMPLATE.format(name=name) + return method + + def _mean(self, keep_attrs, **kwargs): + result = self.sum(keep_attrs=False, **kwargs) / duck_array_ops.astype( + self.count(keep_attrs=False), dtype=self.obj.dtype, copy=False + ) + if keep_attrs: + result.attrs = self.obj.attrs + return result + + _mean.__doc__ = _ROLLING_REDUCE_DOCSTRING_TEMPLATE.format(name="mean") + + argmax = _reduce_method("argmax", dtypes.NINF) + argmin = _reduce_method("argmin", dtypes.INF) + max = _reduce_method("max", dtypes.NINF) + min = _reduce_method("min", dtypes.INF) + prod = _reduce_method("prod", 1) + sum = _reduce_method("sum", 0) + mean = _reduce_method("mean", None, _mean) + std = _reduce_method("std", None) + var = _reduce_method("var", None) + median = _reduce_method("median", None) + + def _counts(self, keep_attrs: bool | None) -> T_Xarray: + raise NotImplementedError() + + def count(self, keep_attrs: bool | None = None) -> T_Xarray: + keep_attrs = self._get_keep_attrs(keep_attrs) + rolling_count = self._counts(keep_attrs=keep_attrs) + enough_periods = rolling_count >= self.min_periods + return rolling_count.where(enough_periods) + + count.__doc__ = _ROLLING_REDUCE_DOCSTRING_TEMPLATE.format(name="count") + + def _mapping_to_list( + self, + arg: _T | Mapping[Any, _T], + default: _T | None = None, + allow_default: bool = True, + allow_allsame: bool = True, + ) -> list[_T]: + if utils.is_dict_like(arg): + if allow_default: + return [arg.get(d, default) for d in self.dim] + for d in self.dim: + if d not in arg: + raise KeyError(f"Argument has no dimension key {d}.") + return [arg[d] for d in self.dim] + if allow_allsame: # for single argument + return [arg] * self.ndim # type: ignore[list-item] # no check for negatives + if self.ndim == 1: + return [arg] # type: ignore[list-item] # no check for negatives + raise ValueError(f"Mapping argument is necessary for {self.ndim}d-rolling.") + + def _get_keep_attrs(self, keep_attrs): + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=True) + + return keep_attrs + + +class DataArrayRolling(Rolling["DataArray"]): + __slots__ = ("window_labels",) + + def __init__( + self, + obj: DataArray, + windows: Mapping[Any, int], + min_periods: int | None = None, + center: bool | Mapping[Any, bool] = False, + ) -> None: + """ + Moving window object for DataArray. + You should use DataArray.rolling() method to construct this object + instead of the class constructor. + + Parameters + ---------- + obj : DataArray + Object to window. + windows : mapping of hashable to int + A mapping from the name of the dimension to create the rolling + exponential window along (e.g. `time`) to the size of the moving window. + min_periods : int, default: None + Minimum number of observations in window required to have a value + (otherwise result is NA). The default, None, is equivalent to + setting min_periods equal to the size of the window. + center : bool, default: False + Set the labels at the center of the window. + + Returns + ------- + rolling : type of input argument + + See Also + -------- + xarray.DataArray.rolling + xarray.DataArray.groupby + xarray.Dataset.rolling + xarray.Dataset.groupby + """ + super().__init__(obj, windows, min_periods=min_periods, center=center) + + # TODO legacy attribute + self.window_labels = self.obj[self.dim[0]] + + def __iter__(self) -> Iterator[tuple[DataArray, DataArray]]: + if self.ndim > 1: + raise ValueError("__iter__ is only supported for 1d-rolling") + + dim0 = self.dim[0] + window0 = int(self.window[0]) + offset = (window0 + 1) // 2 if self.center[0] else 1 + stops = np.arange(offset, self.obj.sizes[dim0] + offset) + starts = stops - window0 + starts[: window0 - offset] = 0 + + for label, start, stop in zip(self.window_labels, starts, stops): + window = self.obj.isel({dim0: slice(start, stop)}) + + counts = window.count(dim=[dim0]) + window = window.where(counts >= self.min_periods) + + yield (label, window) + + def construct( + self, + window_dim: Hashable | Mapping[Any, Hashable] | None = None, + stride: int | Mapping[Any, int] = 1, + fill_value: Any = dtypes.NA, + keep_attrs: bool | None = None, + **window_dim_kwargs: Hashable, + ) -> DataArray: + """ + Convert this rolling object to xr.DataArray, + where the window dimension is stacked as a new dimension + + Parameters + ---------- + window_dim : Hashable or dict-like to Hashable, optional + A mapping from dimension name to the new window dimension names. + stride : int or mapping of int, default: 1 + Size of stride for the rolling window. + fill_value : default: dtypes.NA + Filling value to match the dimension size. + keep_attrs : bool, default: None + If True, the attributes (``attrs``) will be copied from the original + object to the new one. If False, the new object will be returned + without attributes. If None uses the global default. + **window_dim_kwargs : Hashable, optional + The keyword arguments form of ``window_dim`` {dim: new_name, ...}. + + Returns + ------- + DataArray that is a view of the original array. The returned array is + not writeable. + + Examples + -------- + >>> da = xr.DataArray(np.arange(8).reshape(2, 4), dims=("a", "b")) + + >>> rolling = da.rolling(b=3) + >>> rolling.construct("window_dim") + Size: 192B + array([[[nan, nan, 0.], + [nan, 0., 1.], + [ 0., 1., 2.], + [ 1., 2., 3.]], + + [[nan, nan, 4.], + [nan, 4., 5.], + [ 4., 5., 6.], + [ 5., 6., 7.]]]) + Dimensions without coordinates: a, b, window_dim + + >>> rolling = da.rolling(b=3, center=True) + >>> rolling.construct("window_dim") + Size: 192B + array([[[nan, 0., 1.], + [ 0., 1., 2.], + [ 1., 2., 3.], + [ 2., 3., nan]], + + [[nan, 4., 5.], + [ 4., 5., 6.], + [ 5., 6., 7.], + [ 6., 7., nan]]]) + Dimensions without coordinates: a, b, window_dim + + """ + + return self._construct( + self.obj, + window_dim=window_dim, + stride=stride, + fill_value=fill_value, + keep_attrs=keep_attrs, + **window_dim_kwargs, + ) + + def _construct( + self, + obj: DataArray, + window_dim: Hashable | Mapping[Any, Hashable] | None = None, + stride: int | Mapping[Any, int] = 1, + fill_value: Any = dtypes.NA, + keep_attrs: bool | None = None, + **window_dim_kwargs: Hashable, + ) -> DataArray: + from xarray.core.dataarray import DataArray + + keep_attrs = self._get_keep_attrs(keep_attrs) + + if window_dim is None: + if len(window_dim_kwargs) == 0: + raise ValueError( + "Either window_dim or window_dim_kwargs need to be specified." + ) + window_dim = {d: window_dim_kwargs[str(d)] for d in self.dim} + + window_dims = self._mapping_to_list( + window_dim, allow_default=False, allow_allsame=False + ) + strides = self._mapping_to_list(stride, default=1) + + window = obj.variable.rolling_window( + self.dim, self.window, window_dims, self.center, fill_value=fill_value + ) + + attrs = obj.attrs if keep_attrs else {} + + result = DataArray( + window, + dims=obj.dims + tuple(window_dims), + coords=obj.coords, + attrs=attrs, + name=obj.name, + ) + return result.isel({d: slice(None, None, s) for d, s in zip(self.dim, strides)}) + + def reduce( + self, func: Callable, keep_attrs: bool | None = None, **kwargs: Any + ) -> DataArray: + """Reduce the items in this group by applying `func` along some + dimension(s). + + Parameters + ---------- + func : callable + Function which can be called in the form + `func(x, **kwargs)` to return the result of collapsing an + np.ndarray over an the rolling dimension. + keep_attrs : bool, default: None + If True, the attributes (``attrs``) will be copied from the original + object to the new one. If False, the new object will be returned + without attributes. If None uses the global default. + **kwargs : dict + Additional keyword arguments passed on to `func`. + + Returns + ------- + reduced : DataArray + Array with summarized data. + + Examples + -------- + >>> da = xr.DataArray(np.arange(8).reshape(2, 4), dims=("a", "b")) + >>> rolling = da.rolling(b=3) + >>> rolling.construct("window_dim") + Size: 192B + array([[[nan, nan, 0.], + [nan, 0., 1.], + [ 0., 1., 2.], + [ 1., 2., 3.]], + + [[nan, nan, 4.], + [nan, 4., 5.], + [ 4., 5., 6.], + [ 5., 6., 7.]]]) + Dimensions without coordinates: a, b, window_dim + + >>> rolling.reduce(np.sum) + Size: 64B + array([[nan, nan, 3., 6.], + [nan, nan, 15., 18.]]) + Dimensions without coordinates: a, b + + >>> rolling = da.rolling(b=3, min_periods=1) + >>> rolling.reduce(np.nansum) + Size: 64B + array([[ 0., 1., 3., 6.], + [ 4., 9., 15., 18.]]) + Dimensions without coordinates: a, b + """ + + keep_attrs = self._get_keep_attrs(keep_attrs) + + rolling_dim = { + d: utils.get_temp_dimname(self.obj.dims, f"_rolling_dim_{d}") + for d in self.dim + } + + # save memory with reductions GH4325 + fillna = kwargs.pop("fillna", dtypes.NA) + if fillna is not dtypes.NA: + obj = self.obj.fillna(fillna) + else: + obj = self.obj + windows = self._construct( + obj, rolling_dim, keep_attrs=keep_attrs, fill_value=fillna + ) + + dim = list(rolling_dim.values()) + result = windows.reduce(func, dim=dim, keep_attrs=keep_attrs, **kwargs) + + # Find valid windows based on count. + counts = self._counts(keep_attrs=False) + return result.where(counts >= self.min_periods) + + def _counts(self, keep_attrs: bool | None) -> DataArray: + """Number of non-nan entries in each rolling window.""" + + rolling_dim = { + d: utils.get_temp_dimname(self.obj.dims, f"_rolling_dim_{d}") + for d in self.dim + } + # We use False as the fill_value instead of np.nan, since boolean + # array is faster to be reduced than object array. + # The use of skipna==False is also faster since it does not need to + # copy the strided array. + dim = list(rolling_dim.values()) + counts = ( + self.obj.notnull(keep_attrs=keep_attrs) + .rolling( + {d: w for d, w in zip(self.dim, self.window)}, + center={d: self.center[i] for i, d in enumerate(self.dim)}, + ) + .construct(rolling_dim, fill_value=False, keep_attrs=keep_attrs) + .sum(dim=dim, skipna=False, keep_attrs=keep_attrs) + ) + return counts + + def _numbagg_reduce(self, func, keep_attrs, **kwargs): + # Some of this is copied from `_bottleneck_reduce`, we could reduce this as part + # of a wider refactor. + + axis = self.obj.get_axis_num(self.dim[0]) + + padded = self.obj.variable + if self.center[0]: + if is_duck_dask_array(padded.data): + # workaround to make the padded chunk size larger than + # self.window - 1 + shift = -(self.window[0] + 1) // 2 + offset = (self.window[0] - 1) // 2 + valid = (slice(None),) * axis + ( + slice(offset, offset + self.obj.shape[axis]), + ) + else: + shift = (-self.window[0] // 2) + 1 + valid = (slice(None),) * axis + (slice(-shift, None),) + padded = padded.pad({self.dim[0]: (0, -shift)}, mode="constant") + + if is_duck_dask_array(padded.data) and False: + raise AssertionError("should not be reachable") + else: + values = func( + padded.data, + window=self.window[0], + min_count=self.min_periods, + axis=axis, + ) + + if self.center[0]: + values = values[valid] + + attrs = self.obj.attrs if keep_attrs else {} + + return self.obj.__class__( + values, self.obj.coords, attrs=attrs, name=self.obj.name + ) + + def _bottleneck_reduce(self, func, keep_attrs, **kwargs): + # bottleneck doesn't allow min_count to be 0, although it should + # work the same as if min_count = 1 + # Note bottleneck only works with 1d-rolling. + if self.min_periods is not None and self.min_periods == 0: + min_count = 1 + else: + min_count = self.min_periods + + axis = self.obj.get_axis_num(self.dim[0]) + + padded = self.obj.variable + if self.center[0]: + if is_duck_dask_array(padded.data): + # workaround to make the padded chunk size larger than + # self.window - 1 + shift = -(self.window[0] + 1) // 2 + offset = (self.window[0] - 1) // 2 + valid = (slice(None),) * axis + ( + slice(offset, offset + self.obj.shape[axis]), + ) + else: + shift = (-self.window[0] // 2) + 1 + valid = (slice(None),) * axis + (slice(-shift, None),) + padded = padded.pad({self.dim[0]: (0, -shift)}, mode="constant") + + if is_duck_dask_array(padded.data): + raise AssertionError("should not be reachable") + else: + values = func( + padded.data, window=self.window[0], min_count=min_count, axis=axis + ) + # index 0 is at the rightmost edge of the window + # need to reverse index here + # see GH #8541 + if func in [bottleneck.move_argmin, bottleneck.move_argmax]: + values = self.window[0] - 1 - values + + if self.center[0]: + values = values[valid] + + attrs = self.obj.attrs if keep_attrs else {} + + return self.obj.__class__( + values, self.obj.coords, attrs=attrs, name=self.obj.name + ) + + def _array_reduce( + self, + array_agg_func, + bottleneck_move_func, + numbagg_move_func, + rolling_agg_func, + keep_attrs, + fillna, + **kwargs, + ): + if "dim" in kwargs: + warnings.warn( + f"Reductions are applied along the rolling dimension(s) " + f"'{self.dim}'. Passing the 'dim' kwarg to reduction " + f"operations has no effect.", + DeprecationWarning, + stacklevel=3, + ) + del kwargs["dim"] + + if ( + OPTIONS["use_numbagg"] + and module_available("numbagg") + and pycompat.mod_version("numbagg") >= Version("0.6.3") + and numbagg_move_func is not None + # TODO: we could at least allow this for the equivalent of `apply_ufunc`'s + # "parallelized". `rolling_exp` does this, as an example (but rolling_exp is + # much simpler) + and not is_duck_dask_array(self.obj.data) + # Numbagg doesn't handle object arrays and generally has dtype consistency, + # so doesn't deal well with bool arrays which are expected to change type. + and self.obj.data.dtype.kind not in "ObMm" + # TODO: we could also allow this, probably as part of a refactoring of this + # module, so we can use the machinery in `self.reduce`. + and self.ndim == 1 + ): + import numbagg + + # Numbagg has a default ddof of 1. I (@max-sixty) think we should make + # this the default in xarray too, but until we do, don't use numbagg for + # std and var unless ddof is set to 1. + if ( + numbagg_move_func not in [numbagg.move_std, numbagg.move_var] + or kwargs.get("ddof") == 1 + ): + return self._numbagg_reduce( + numbagg_move_func, keep_attrs=keep_attrs, **kwargs + ) + + if ( + OPTIONS["use_bottleneck"] + and bottleneck_move_func is not None + and not is_duck_dask_array(self.obj.data) + and self.ndim == 1 + ): + # TODO: re-enable bottleneck with dask after the issues + # underlying https://github.com/pydata/xarray/issues/2940 are + # fixed. + return self._bottleneck_reduce( + bottleneck_move_func, keep_attrs=keep_attrs, **kwargs + ) + + if rolling_agg_func: + return rolling_agg_func(self, keep_attrs=self._get_keep_attrs(keep_attrs)) + + if fillna is not None: + if fillna is dtypes.INF: + fillna = dtypes.get_pos_infinity(self.obj.dtype, max_for_int=True) + elif fillna is dtypes.NINF: + fillna = dtypes.get_neg_infinity(self.obj.dtype, min_for_int=True) + kwargs.setdefault("skipna", False) + kwargs.setdefault("fillna", fillna) + + return self.reduce(array_agg_func, keep_attrs=keep_attrs, **kwargs) + + +class DatasetRolling(Rolling["Dataset"]): + __slots__ = ("rollings",) + + def __init__( + self, + obj: Dataset, + windows: Mapping[Any, int], + min_periods: int | None = None, + center: bool | Mapping[Any, bool] = False, + ) -> None: + """ + Moving window object for Dataset. + You should use Dataset.rolling() method to construct this object + instead of the class constructor. + + Parameters + ---------- + obj : Dataset + Object to window. + windows : mapping of hashable to int + A mapping from the name of the dimension to create the rolling + exponential window along (e.g. `time`) to the size of the moving window. + min_periods : int, default: None + Minimum number of observations in window required to have a value + (otherwise result is NA). The default, None, is equivalent to + setting min_periods equal to the size of the window. + center : bool or mapping of hashable to bool, default: False + Set the labels at the center of the window. + + Returns + ------- + rolling : type of input argument + + See Also + -------- + xarray.Dataset.rolling + xarray.DataArray.rolling + xarray.Dataset.groupby + xarray.DataArray.groupby + """ + super().__init__(obj, windows, min_periods, center) + + # Keep each Rolling object as a dictionary + self.rollings = {} + for key, da in self.obj.data_vars.items(): + # keeps rollings only for the dataset depending on self.dim + dims, center = [], {} + for i, d in enumerate(self.dim): + if d in da.dims: + dims.append(d) + center[d] = self.center[i] + + if dims: + w = {d: windows[d] for d in dims} + self.rollings[key] = DataArrayRolling(da, w, min_periods, center) + + def _dataset_implementation(self, func, keep_attrs, **kwargs): + from xarray.core.dataset import Dataset + + keep_attrs = self._get_keep_attrs(keep_attrs) + + reduced = {} + for key, da in self.obj.data_vars.items(): + if any(d in da.dims for d in self.dim): + reduced[key] = func(self.rollings[key], keep_attrs=keep_attrs, **kwargs) + else: + reduced[key] = self.obj[key].copy() + # we need to delete the attrs of the copied DataArray + if not keep_attrs: + reduced[key].attrs = {} + + attrs = self.obj.attrs if keep_attrs else {} + return Dataset(reduced, coords=self.obj.coords, attrs=attrs) + + def reduce( + self, func: Callable, keep_attrs: bool | None = None, **kwargs: Any + ) -> DataArray: + """Reduce the items in this group by applying `func` along some + dimension(s). + + Parameters + ---------- + func : callable + Function which can be called in the form + `func(x, **kwargs)` to return the result of collapsing an + np.ndarray over an the rolling dimension. + keep_attrs : bool, default: None + If True, the attributes (``attrs``) will be copied from the original + object to the new one. If False, the new object will be returned + without attributes. If None uses the global default. + **kwargs : dict + Additional keyword arguments passed on to `func`. + + Returns + ------- + reduced : DataArray + Array with summarized data. + """ + return self._dataset_implementation( + functools.partial(DataArrayRolling.reduce, func=func), + keep_attrs=keep_attrs, + **kwargs, + ) + + def _counts(self, keep_attrs: bool | None) -> Dataset: + return self._dataset_implementation( + DataArrayRolling._counts, keep_attrs=keep_attrs + ) + + def _array_reduce( + self, + array_agg_func, + bottleneck_move_func, + rolling_agg_func, + keep_attrs, + **kwargs, + ): + return self._dataset_implementation( + functools.partial( + DataArrayRolling._array_reduce, + array_agg_func=array_agg_func, + bottleneck_move_func=bottleneck_move_func, + rolling_agg_func=rolling_agg_func, + ), + keep_attrs=keep_attrs, + **kwargs, + ) + + def construct( + self, + window_dim: Hashable | Mapping[Any, Hashable] | None = None, + stride: int | Mapping[Any, int] = 1, + fill_value: Any = dtypes.NA, + keep_attrs: bool | None = None, + **window_dim_kwargs: Hashable, + ) -> Dataset: + """ + Convert this rolling object to xr.Dataset, + where the window dimension is stacked as a new dimension + + Parameters + ---------- + window_dim : str or mapping, optional + A mapping from dimension name to the new window dimension names. + Just a string can be used for 1d-rolling. + stride : int, optional + size of stride for the rolling window. + fill_value : Any, default: dtypes.NA + Filling value to match the dimension size. + **window_dim_kwargs : {dim: new_name, ...}, optional + The keyword arguments form of ``window_dim``. + + Returns + ------- + Dataset with variables converted from rolling object. + """ + + from xarray.core.dataset import Dataset + + keep_attrs = self._get_keep_attrs(keep_attrs) + + if window_dim is None: + if len(window_dim_kwargs) == 0: + raise ValueError( + "Either window_dim or window_dim_kwargs need to be specified." + ) + window_dim = {d: window_dim_kwargs[str(d)] for d in self.dim} + + window_dims = self._mapping_to_list( + window_dim, allow_default=False, allow_allsame=False + ) + strides = self._mapping_to_list(stride, default=1) + + dataset = {} + for key, da in self.obj.data_vars.items(): + # keeps rollings only for the dataset depending on self.dim + dims = [d for d in self.dim if d in da.dims] + if dims: + wi = {d: window_dims[i] for i, d in enumerate(self.dim) if d in da.dims} + st = {d: strides[i] for i, d in enumerate(self.dim) if d in da.dims} + + dataset[key] = self.rollings[key].construct( + window_dim=wi, + fill_value=fill_value, + stride=st, + keep_attrs=keep_attrs, + ) + else: + dataset[key] = da.copy() + + # as the DataArrays can be copied we need to delete the attrs + if not keep_attrs: + dataset[key].attrs = {} + + # Need to stride coords as well. TODO: is there a better way? + coords = self.obj.isel( + {d: slice(None, None, s) for d, s in zip(self.dim, strides)} + ).coords + + attrs = self.obj.attrs if keep_attrs else {} + + return Dataset(dataset, coords=coords, attrs=attrs) + + +class Coarsen(CoarsenArithmetic, Generic[T_Xarray]): + """A object that implements the coarsen. + + See Also + -------- + Dataset.coarsen + DataArray.coarsen + """ + + __slots__ = ( + "obj", + "boundary", + "coord_func", + "windows", + "side", + "trim_excess", + ) + _attributes = ("windows", "side", "trim_excess") + obj: T_Xarray + windows: Mapping[Hashable, int] + side: SideOptions | Mapping[Hashable, SideOptions] + boundary: CoarsenBoundaryOptions + coord_func: Mapping[Hashable, str | Callable] + + def __init__( + self, + obj: T_Xarray, + windows: Mapping[Any, int], + boundary: CoarsenBoundaryOptions, + side: SideOptions | Mapping[Any, SideOptions], + coord_func: str | Callable | Mapping[Any, str | Callable], + ) -> None: + """ + Moving window object. + + Parameters + ---------- + obj : Dataset or DataArray + Object to window. + windows : mapping of hashable to int + A mapping from the name of the dimension to create the rolling + exponential window along (e.g. `time`) to the size of the moving window. + boundary : {"exact", "trim", "pad"} + If 'exact', a ValueError will be raised if dimension size is not a + multiple of window size. If 'trim', the excess indexes are trimmed. + If 'pad', NA will be padded. + side : 'left' or 'right' or mapping from dimension to 'left' or 'right' + coord_func : function (name) or mapping from coordinate name to function (name). + + Returns + ------- + coarsen + + """ + self.obj = obj + self.windows = windows + self.side = side + self.boundary = boundary + + missing_dims = tuple(dim for dim in windows.keys() if dim not in self.obj.dims) + if missing_dims: + raise ValueError( + f"Window dimensions {missing_dims} not found in {self.obj.__class__.__name__} " + f"dimensions {tuple(self.obj.dims)}" + ) + + if utils.is_dict_like(coord_func): + coord_func_map = coord_func + else: + coord_func_map = {d: coord_func for d in self.obj.dims} + for c in self.obj.coords: + if c not in coord_func_map: + coord_func_map[c] = duck_array_ops.mean # type: ignore[index] + self.coord_func = coord_func_map + + def _get_keep_attrs(self, keep_attrs): + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=True) + + return keep_attrs + + def __repr__(self) -> str: + """provide a nice str repr of our coarsen object""" + + attrs = [ + f"{k}->{getattr(self, k)}" + for k in self._attributes + if getattr(self, k, None) is not None + ] + return "{klass} [{attrs}]".format( + klass=self.__class__.__name__, attrs=",".join(attrs) + ) + + def construct( + self, + window_dim=None, + keep_attrs=None, + **window_dim_kwargs, + ) -> T_Xarray: + """ + Convert this Coarsen object to a DataArray or Dataset, + where the coarsening dimension is split or reshaped to two + new dimensions. + + Parameters + ---------- + window_dim: mapping + A mapping from existing dimension name to new dimension names. + The size of the second dimension will be the length of the + coarsening window. + keep_attrs: bool, optional + Preserve attributes if True + **window_dim_kwargs : {dim: new_name, ...} + The keyword arguments form of ``window_dim``. + + Returns + ------- + Dataset or DataArray with reshaped dimensions + + Examples + -------- + >>> da = xr.DataArray(np.arange(24), dims="time") + >>> da.coarsen(time=12).construct(time=("year", "month")) + Size: 192B + array([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]]) + Dimensions without coordinates: year, month + + See Also + -------- + DataArrayRolling.construct + DatasetRolling.construct + """ + + from xarray.core.dataarray import DataArray + from xarray.core.dataset import Dataset + + window_dim = either_dict_or_kwargs( + window_dim, window_dim_kwargs, "Coarsen.construct" + ) + if not window_dim: + raise ValueError( + "Either window_dim or window_dim_kwargs need to be specified." + ) + + bad_new_dims = tuple( + win + for win, dims in window_dim.items() + if len(dims) != 2 or isinstance(dims, str) + ) + if bad_new_dims: + raise ValueError( + f"Please provide exactly two dimension names for the following coarsening dimensions: {bad_new_dims}" + ) + + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=True) + + missing_dims = set(window_dim) - set(self.windows) + if missing_dims: + raise ValueError( + f"'window_dim' must contain entries for all dimensions to coarsen. Missing {missing_dims}" + ) + extra_windows = set(self.windows) - set(window_dim) + if extra_windows: + raise ValueError( + f"'window_dim' includes dimensions that will not be coarsened: {extra_windows}" + ) + + reshaped = Dataset() + if isinstance(self.obj, DataArray): + obj = self.obj._to_temp_dataset() + else: + obj = self.obj + + reshaped.attrs = obj.attrs if keep_attrs else {} + + for key, var in obj.variables.items(): + reshaped_dims = tuple( + itertools.chain(*[window_dim.get(dim, [dim]) for dim in list(var.dims)]) + ) + if reshaped_dims != var.dims: + windows = {w: self.windows[w] for w in window_dim if w in var.dims} + reshaped_var, _ = var.coarsen_reshape(windows, self.boundary, self.side) + attrs = var.attrs if keep_attrs else {} + reshaped[key] = (reshaped_dims, reshaped_var, attrs) + else: + reshaped[key] = var + + # should handle window_dim being unindexed + should_be_coords = (set(window_dim) & set(self.obj.coords)) | set( + self.obj.coords + ) + result = reshaped.set_coords(should_be_coords) + if isinstance(self.obj, DataArray): + return self.obj._from_temp_dataset(result) + else: + return result + + +class DataArrayCoarsen(Coarsen["DataArray"]): + __slots__ = () + + _reduce_extra_args_docstring = """""" + + @classmethod + def _reduce_method( + cls, func: Callable, include_skipna: bool = False, numeric_only: bool = False + ) -> Callable[..., DataArray]: + """ + Return a wrapped function for injecting reduction methods. + see ops.inject_reduce_methods + """ + kwargs: dict[str, Any] = {} + if include_skipna: + kwargs["skipna"] = None + + def wrapped_func( + self: DataArrayCoarsen, keep_attrs: bool | None = None, **kwargs + ) -> DataArray: + from xarray.core.dataarray import DataArray + + keep_attrs = self._get_keep_attrs(keep_attrs) + + reduced = self.obj.variable.coarsen( + self.windows, func, self.boundary, self.side, keep_attrs, **kwargs + ) + coords = {} + for c, v in self.obj.coords.items(): + if c == self.obj.name: + coords[c] = reduced + else: + if any(d in self.windows for d in v.dims): + coords[c] = v.variable.coarsen( + self.windows, + self.coord_func[c], + self.boundary, + self.side, + keep_attrs, + **kwargs, + ) + else: + coords[c] = v + return DataArray( + reduced, dims=self.obj.dims, coords=coords, name=self.obj.name + ) + + return wrapped_func + + def reduce( + self, func: Callable, keep_attrs: bool | None = None, **kwargs + ) -> DataArray: + """Reduce the items in this group by applying `func` along some + dimension(s). + + Parameters + ---------- + func : callable + Function which can be called in the form `func(x, axis, **kwargs)` + to return the result of collapsing an np.ndarray over the coarsening + dimensions. It must be possible to provide the `axis` argument + with a tuple of integers. + keep_attrs : bool, default: None + If True, the attributes (``attrs``) will be copied from the original + object to the new one. If False, the new object will be returned + without attributes. If None uses the global default. + **kwargs : dict + Additional keyword arguments passed on to `func`. + + Returns + ------- + reduced : DataArray + Array with summarized data. + + Examples + -------- + >>> da = xr.DataArray(np.arange(8).reshape(2, 4), dims=("a", "b")) + >>> coarsen = da.coarsen(b=2) + >>> coarsen.reduce(np.sum) + Size: 32B + array([[ 1, 5], + [ 9, 13]]) + Dimensions without coordinates: a, b + """ + wrapped_func = self._reduce_method(func) + return wrapped_func(self, keep_attrs=keep_attrs, **kwargs) + + +class DatasetCoarsen(Coarsen["Dataset"]): + __slots__ = () + + _reduce_extra_args_docstring = """""" + + @classmethod + def _reduce_method( + cls, func: Callable, include_skipna: bool = False, numeric_only: bool = False + ) -> Callable[..., Dataset]: + """ + Return a wrapped function for injecting reduction methods. + see ops.inject_reduce_methods + """ + kwargs: dict[str, Any] = {} + if include_skipna: + kwargs["skipna"] = None + + def wrapped_func( + self: DatasetCoarsen, keep_attrs: bool | None = None, **kwargs + ) -> Dataset: + from xarray.core.dataset import Dataset + + keep_attrs = self._get_keep_attrs(keep_attrs) + + if keep_attrs: + attrs = self.obj.attrs + else: + attrs = {} + + reduced = {} + for key, da in self.obj.data_vars.items(): + reduced[key] = da.variable.coarsen( + self.windows, + func, + self.boundary, + self.side, + keep_attrs=keep_attrs, + **kwargs, + ) + + coords = {} + for c, v in self.obj.coords.items(): + # variable.coarsen returns variables not containing the window dims + # unchanged (maybe removes attrs) + coords[c] = v.variable.coarsen( + self.windows, + self.coord_func[c], + self.boundary, + self.side, + keep_attrs=keep_attrs, + **kwargs, + ) + + return Dataset(reduced, coords=coords, attrs=attrs) + + return wrapped_func + + def reduce(self, func: Callable, keep_attrs=None, **kwargs) -> Dataset: + """Reduce the items in this group by applying `func` along some + dimension(s). + + Parameters + ---------- + func : callable + Function which can be called in the form `func(x, axis, **kwargs)` + to return the result of collapsing an np.ndarray over the coarsening + dimensions. It must be possible to provide the `axis` argument with + a tuple of integers. + keep_attrs : bool, default: None + If True, the attributes (``attrs``) will be copied from the original + object to the new one. If False, the new object will be returned + without attributes. If None uses the global default. + **kwargs : dict + Additional keyword arguments passed on to `func`. + + Returns + ------- + reduced : Dataset + Arrays with summarized data. + """ + wrapped_func = self._reduce_method(func) + return wrapped_func(self, keep_attrs=keep_attrs, **kwargs) diff --git a/test/fixtures/whole_applications/xarray/xarray/core/rolling_exp.py b/test/fixtures/whole_applications/xarray/xarray/core/rolling_exp.py new file mode 100644 index 0000000..4e085a0 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/core/rolling_exp.py @@ -0,0 +1,312 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any, Generic + +import numpy as np +from packaging.version import Version + +from xarray.core.computation import apply_ufunc +from xarray.core.options import _get_keep_attrs +from xarray.core.pdcompat import count_not_none +from xarray.core.types import T_DataWithCoords +from xarray.core.utils import module_available +from xarray.namedarray import pycompat + + +def _get_alpha( + com: float | None = None, + span: float | None = None, + halflife: float | None = None, + alpha: float | None = None, +) -> float: + """ + Convert com, span, halflife to alpha. + """ + valid_count = count_not_none(com, span, halflife, alpha) + if valid_count > 1: + raise ValueError("com, span, halflife, and alpha are mutually exclusive") + + # Convert to alpha + if com is not None: + if com < 0: + raise ValueError("commust satisfy: com>= 0") + return 1 / (com + 1) + elif span is not None: + if span < 1: + raise ValueError("span must satisfy: span >= 1") + return 2 / (span + 1) + elif halflife is not None: + if halflife <= 0: + raise ValueError("halflife must satisfy: halflife > 0") + return 1 - np.exp(np.log(0.5) / halflife) + elif alpha is not None: + if not 0 < alpha <= 1: + raise ValueError("alpha must satisfy: 0 < alpha <= 1") + return alpha + else: + raise ValueError("Must pass one of comass, span, halflife, or alpha") + + +class RollingExp(Generic[T_DataWithCoords]): + """ + Exponentially-weighted moving window object. + Similar to EWM in pandas + + Parameters + ---------- + obj : Dataset or DataArray + Object to window. + windows : mapping of hashable to int (or float for alpha type) + A mapping from the name of the dimension to create the rolling + exponential window along (e.g. `time`) to the size of the moving window. + window_type : {"span", "com", "halflife", "alpha"}, default: "span" + The format of the previously supplied window. Each is a simple + numerical transformation of the others. Described in detail: + https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.ewm.html + + Returns + ------- + RollingExp : type of input argument + """ + + def __init__( + self, + obj: T_DataWithCoords, + windows: Mapping[Any, int | float], + window_type: str = "span", + min_weight: float = 0.0, + ): + if not module_available("numbagg"): + raise ImportError( + "numbagg >= 0.2.1 is required for rolling_exp but currently numbagg is not installed" + ) + elif pycompat.mod_version("numbagg") < Version("0.2.1"): + raise ImportError( + f"numbagg >= 0.2.1 is required for rolling_exp but currently version {pycompat.mod_version('numbagg')} is installed" + ) + elif pycompat.mod_version("numbagg") < Version("0.3.1") and min_weight > 0: + raise ImportError( + f"numbagg >= 0.3.1 is required for `min_weight > 0` within `.rolling_exp` but currently version {pycompat.mod_version('numbagg')} is installed" + ) + + self.obj: T_DataWithCoords = obj + dim, window = next(iter(windows.items())) + self.dim = dim + self.alpha = _get_alpha(**{window_type: window}) + self.min_weight = min_weight + # Don't pass min_weight=0 so we can support older versions of numbagg + kwargs = dict(alpha=self.alpha, axis=-1) + if min_weight > 0: + kwargs["min_weight"] = min_weight + self.kwargs = kwargs + + def mean(self, keep_attrs: bool | None = None) -> T_DataWithCoords: + """ + Exponentially weighted moving average. + + Parameters + ---------- + keep_attrs : bool, default: None + If True, the attributes (``attrs``) will be copied from the original + object to the new one. If False, the new object will be returned + without attributes. If None uses the global default. + + Examples + -------- + >>> da = xr.DataArray([1, 1, 2, 2, 2], dims="x") + >>> da.rolling_exp(x=2, window_type="span").mean() + Size: 40B + array([1. , 1. , 1.69230769, 1.9 , 1.96694215]) + Dimensions without coordinates: x + """ + + import numbagg + + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=True) + + dim_order = self.obj.dims + + return apply_ufunc( + numbagg.move_exp_nanmean, + self.obj, + input_core_dims=[[self.dim]], + kwargs=self.kwargs, + output_core_dims=[[self.dim]], + keep_attrs=keep_attrs, + on_missing_core_dim="copy", + dask="parallelized", + ).transpose(*dim_order) + + def sum(self, keep_attrs: bool | None = None) -> T_DataWithCoords: + """ + Exponentially weighted moving sum. + + Parameters + ---------- + keep_attrs : bool, default: None + If True, the attributes (``attrs``) will be copied from the original + object to the new one. If False, the new object will be returned + without attributes. If None uses the global default. + + Examples + -------- + >>> da = xr.DataArray([1, 1, 2, 2, 2], dims="x") + >>> da.rolling_exp(x=2, window_type="span").sum() + Size: 40B + array([1. , 1.33333333, 2.44444444, 2.81481481, 2.9382716 ]) + Dimensions without coordinates: x + """ + + import numbagg + + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=True) + + dim_order = self.obj.dims + + return apply_ufunc( + numbagg.move_exp_nansum, + self.obj, + input_core_dims=[[self.dim]], + kwargs=self.kwargs, + output_core_dims=[[self.dim]], + keep_attrs=keep_attrs, + on_missing_core_dim="copy", + dask="parallelized", + ).transpose(*dim_order) + + def std(self) -> T_DataWithCoords: + """ + Exponentially weighted moving standard deviation. + + `keep_attrs` is always True for this method. Drop attrs separately to remove attrs. + + Examples + -------- + >>> da = xr.DataArray([1, 1, 2, 2, 2], dims="x") + >>> da.rolling_exp(x=2, window_type="span").std() + Size: 40B + array([ nan, 0. , 0.67936622, 0.42966892, 0.25389527]) + Dimensions without coordinates: x + """ + + if pycompat.mod_version("numbagg") < Version("0.4.0"): + raise ImportError( + f"numbagg >= 0.4.0 is required for rolling_exp().std(), currently {pycompat.mod_version('numbagg')} is installed" + ) + import numbagg + + dim_order = self.obj.dims + + return apply_ufunc( + numbagg.move_exp_nanstd, + self.obj, + input_core_dims=[[self.dim]], + kwargs=self.kwargs, + output_core_dims=[[self.dim]], + keep_attrs=True, + on_missing_core_dim="copy", + dask="parallelized", + ).transpose(*dim_order) + + def var(self) -> T_DataWithCoords: + """ + Exponentially weighted moving variance. + + `keep_attrs` is always True for this method. Drop attrs separately to remove attrs. + + Examples + -------- + >>> da = xr.DataArray([1, 1, 2, 2, 2], dims="x") + >>> da.rolling_exp(x=2, window_type="span").var() + Size: 40B + array([ nan, 0. , 0.46153846, 0.18461538, 0.06446281]) + Dimensions without coordinates: x + """ + if pycompat.mod_version("numbagg") < Version("0.4.0"): + raise ImportError( + f"numbagg >= 0.4.0 is required for rolling_exp().var(), currently {pycompat.mod_version('numbagg')} is installed" + ) + dim_order = self.obj.dims + import numbagg + + return apply_ufunc( + numbagg.move_exp_nanvar, + self.obj, + input_core_dims=[[self.dim]], + kwargs=self.kwargs, + output_core_dims=[[self.dim]], + keep_attrs=True, + on_missing_core_dim="copy", + dask="parallelized", + ).transpose(*dim_order) + + def cov(self, other: T_DataWithCoords) -> T_DataWithCoords: + """ + Exponentially weighted moving covariance. + + `keep_attrs` is always True for this method. Drop attrs separately to remove attrs. + + Examples + -------- + >>> da = xr.DataArray([1, 1, 2, 2, 2], dims="x") + >>> da.rolling_exp(x=2, window_type="span").cov(da**2) + Size: 40B + array([ nan, 0. , 1.38461538, 0.55384615, 0.19338843]) + Dimensions without coordinates: x + """ + + if pycompat.mod_version("numbagg") < Version("0.4.0"): + raise ImportError( + f"numbagg >= 0.4.0 is required for rolling_exp().cov(), currently {pycompat.mod_version('numbagg')} is installed" + ) + dim_order = self.obj.dims + import numbagg + + return apply_ufunc( + numbagg.move_exp_nancov, + self.obj, + other, + input_core_dims=[[self.dim], [self.dim]], + kwargs=self.kwargs, + output_core_dims=[[self.dim]], + keep_attrs=True, + on_missing_core_dim="copy", + dask="parallelized", + ).transpose(*dim_order) + + def corr(self, other: T_DataWithCoords) -> T_DataWithCoords: + """ + Exponentially weighted moving correlation. + + `keep_attrs` is always True for this method. Drop attrs separately to remove attrs. + + Examples + -------- + >>> da = xr.DataArray([1, 1, 2, 2, 2], dims="x") + >>> da.rolling_exp(x=2, window_type="span").corr(da.shift(x=1)) + Size: 40B + array([ nan, nan, nan, 0.4330127 , 0.48038446]) + Dimensions without coordinates: x + """ + + if pycompat.mod_version("numbagg") < Version("0.4.0"): + raise ImportError( + f"numbagg >= 0.4.0 is required for rolling_exp().corr(), currently {pycompat.mod_version('numbagg')} is installed" + ) + dim_order = self.obj.dims + import numbagg + + return apply_ufunc( + numbagg.move_exp_nancorr, + self.obj, + other, + input_core_dims=[[self.dim], [self.dim]], + kwargs=self.kwargs, + output_core_dims=[[self.dim]], + keep_attrs=True, + on_missing_core_dim="copy", + dask="parallelized", + ).transpose(*dim_order) diff --git a/test/fixtures/whole_applications/xarray/xarray/core/treenode.py b/test/fixtures/whole_applications/xarray/xarray/core/treenode.py new file mode 100644 index 0000000..6f51e1f --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/core/treenode.py @@ -0,0 +1,679 @@ +from __future__ import annotations + +import sys +from collections.abc import Iterator, Mapping +from pathlib import PurePosixPath +from typing import ( + TYPE_CHECKING, + Generic, + TypeVar, +) + +from xarray.core.utils import Frozen, is_dict_like + +if TYPE_CHECKING: + from xarray.core.types import T_DataArray + + +class InvalidTreeError(Exception): + """Raised when user attempts to create an invalid tree in some way.""" + + +class NotFoundInTreeError(ValueError): + """Raised when operation can't be completed because one node is not part of the expected tree.""" + + +class NodePath(PurePosixPath): + """Represents a path from one node to another within a tree.""" + + def __init__(self, *pathsegments): + if sys.version_info >= (3, 12): + super().__init__(*pathsegments) + else: + super().__new__(PurePosixPath, *pathsegments) + if self.drive: + raise ValueError("NodePaths cannot have drives") + + if self.root not in ["/", ""]: + raise ValueError( + 'Root of NodePath can only be either "/" or "", with "" meaning the path is relative.' + ) + # TODO should we also forbid suffixes to avoid node names with dots in them? + + +Tree = TypeVar("Tree", bound="TreeNode") + + +class TreeNode(Generic[Tree]): + """ + Base class representing a node of a tree, with methods for traversing and altering the tree. + + This class stores no data, it has only parents and children attributes, and various methods. + + Stores child nodes in an dict, ensuring that equality checks between trees + and order of child nodes is preserved (since python 3.7). + + Nodes themselves are intrinsically unnamed (do not possess a ._name attribute), but if the node has a parent you can + find the key it is stored under via the .name property. + + The .parent attribute is read-only: to replace the parent using public API you must set this node as the child of a + new parent using `new_parent.children[name] = child_node`, or to instead detach from the current parent use + `child_node.orphan()`. + + This class is intended to be subclassed by DataTree, which will overwrite some of the inherited behaviour, + in particular to make names an inherent attribute, and allow setting parents directly. The intention is to mirror + the class structure of xarray.Variable & xarray.DataArray, where Variable is unnamed but DataArray is (optionally) + named. + + Also allows access to any other node in the tree via unix-like paths, including upwards referencing via '../'. + + (This class is heavily inspired by the anytree library's NodeMixin class.) + + """ + + _parent: Tree | None + _children: dict[str, Tree] + + def __init__(self, children: Mapping[str, Tree] | None = None): + """Create a parentless node.""" + self._parent = None + self._children = {} + if children is not None: + self.children = children + + @property + def parent(self) -> Tree | None: + """Parent of this node.""" + return self._parent + + def _set_parent( + self, new_parent: Tree | None, child_name: str | None = None + ) -> None: + # TODO is it possible to refactor in a way that removes this private method? + + if new_parent is not None and not isinstance(new_parent, TreeNode): + raise TypeError( + "Parent nodes must be of type DataTree or None, " + f"not type {type(new_parent)}" + ) + + old_parent = self._parent + if new_parent is not old_parent: + self._check_loop(new_parent) + self._detach(old_parent) + self._attach(new_parent, child_name) + + def _check_loop(self, new_parent: Tree | None) -> None: + """Checks that assignment of this new parent will not create a cycle.""" + if new_parent is not None: + if new_parent is self: + raise InvalidTreeError( + f"Cannot set parent, as node {self} cannot be a parent of itself." + ) + + if self._is_descendant_of(new_parent): + raise InvalidTreeError( + "Cannot set parent, as intended parent is already a descendant of this node." + ) + + def _is_descendant_of(self, node: Tree) -> bool: + return any(n is self for n in node.parents) + + def _detach(self, parent: Tree | None) -> None: + if parent is not None: + self._pre_detach(parent) + parents_children = parent.children + parent._children = { + name: child + for name, child in parents_children.items() + if child is not self + } + self._parent = None + self._post_detach(parent) + + def _attach(self, parent: Tree | None, child_name: str | None = None) -> None: + if parent is not None: + if child_name is None: + raise ValueError( + "To directly set parent, child needs a name, but child is unnamed" + ) + + self._pre_attach(parent) + parentchildren = parent._children + assert not any( + child is self for child in parentchildren + ), "Tree is corrupt." + parentchildren[child_name] = self + self._parent = parent + self._post_attach(parent) + else: + self._parent = None + + def orphan(self) -> None: + """Detach this node from its parent.""" + self._set_parent(new_parent=None) + + @property + def children(self: Tree) -> Mapping[str, Tree]: + """Child nodes of this node, stored under a mapping via their names.""" + return Frozen(self._children) + + @children.setter + def children(self: Tree, children: Mapping[str, Tree]) -> None: + self._check_children(children) + children = {**children} + + old_children = self.children + del self.children + try: + self._pre_attach_children(children) + for name, child in children.items(): + child._set_parent(new_parent=self, child_name=name) + self._post_attach_children(children) + assert len(self.children) == len(children) + except Exception: + # if something goes wrong then revert to previous children + self.children = old_children + raise + + @children.deleter + def children(self) -> None: + # TODO this just detaches all the children, it doesn't actually delete them... + children = self.children + self._pre_detach_children(children) + for child in self.children.values(): + child.orphan() + assert len(self.children) == 0 + self._post_detach_children(children) + + @staticmethod + def _check_children(children: Mapping[str, Tree]) -> None: + """Check children for correct types and for any duplicates.""" + if not is_dict_like(children): + raise TypeError( + "children must be a dict-like mapping from names to node objects" + ) + + seen = set() + for name, child in children.items(): + if not isinstance(child, TreeNode): + raise TypeError( + f"Cannot add object {name}. It is of type {type(child)}, " + "but can only add children of type DataTree" + ) + + childid = id(child) + if childid not in seen: + seen.add(childid) + else: + raise InvalidTreeError( + f"Cannot add same node {name} multiple times as different children." + ) + + def __repr__(self) -> str: + return f"TreeNode(children={dict(self._children)})" + + def _pre_detach_children(self: Tree, children: Mapping[str, Tree]) -> None: + """Method call before detaching `children`.""" + pass + + def _post_detach_children(self: Tree, children: Mapping[str, Tree]) -> None: + """Method call after detaching `children`.""" + pass + + def _pre_attach_children(self: Tree, children: Mapping[str, Tree]) -> None: + """Method call before attaching `children`.""" + pass + + def _post_attach_children(self: Tree, children: Mapping[str, Tree]) -> None: + """Method call after attaching `children`.""" + pass + + def _iter_parents(self: Tree) -> Iterator[Tree]: + """Iterate up the tree, starting from the current node's parent.""" + node: Tree | None = self.parent + while node is not None: + yield node + node = node.parent + + def iter_lineage(self: Tree) -> tuple[Tree, ...]: + """Iterate up the tree, starting from the current node.""" + from warnings import warn + + warn( + "`iter_lineage` has been deprecated, and in the future will raise an error." + "Please use `parents` from now on.", + DeprecationWarning, + ) + return tuple((self, *self.parents)) + + @property + def lineage(self: Tree) -> tuple[Tree, ...]: + """All parent nodes and their parent nodes, starting with the closest.""" + from warnings import warn + + warn( + "`lineage` has been deprecated, and in the future will raise an error." + "Please use `parents` from now on.", + DeprecationWarning, + ) + return self.iter_lineage() + + @property + def parents(self: Tree) -> tuple[Tree, ...]: + """All parent nodes and their parent nodes, starting with the closest.""" + return tuple(self._iter_parents()) + + @property + def ancestors(self: Tree) -> tuple[Tree, ...]: + """All parent nodes and their parent nodes, starting with the most distant.""" + + from warnings import warn + + warn( + "`ancestors` has been deprecated, and in the future will raise an error." + "Please use `parents`. Example: `tuple(reversed(node.parents))`", + DeprecationWarning, + ) + return tuple((*reversed(self.parents), self)) + + @property + def root(self: Tree) -> Tree: + """Root node of the tree""" + node = self + while node.parent is not None: + node = node.parent + return node + + @property + def is_root(self) -> bool: + """Whether this node is the tree root.""" + return self.parent is None + + @property + def is_leaf(self) -> bool: + """ + Whether this node is a leaf node. + + Leaf nodes are defined as nodes which have no children. + """ + return self.children == {} + + @property + def leaves(self: Tree) -> tuple[Tree, ...]: + """ + All leaf nodes. + + Leaf nodes are defined as nodes which have no children. + """ + return tuple([node for node in self.subtree if node.is_leaf]) + + @property + def siblings(self: Tree) -> dict[str, Tree]: + """ + Nodes with the same parent as this node. + """ + if self.parent: + return { + name: child + for name, child in self.parent.children.items() + if child is not self + } + else: + return {} + + @property + def subtree(self: Tree) -> Iterator[Tree]: + """ + An iterator over all nodes in this tree, including both self and all descendants. + + Iterates depth-first. + + See Also + -------- + DataTree.descendants + """ + from xarray.core.iterators import LevelOrderIter + + return LevelOrderIter(self) + + @property + def descendants(self: Tree) -> tuple[Tree, ...]: + """ + Child nodes and all their child nodes. + + Returned in depth-first order. + + See Also + -------- + DataTree.subtree + """ + all_nodes = tuple(self.subtree) + this_node, *descendants = all_nodes + return tuple(descendants) + + @property + def level(self: Tree) -> int: + """ + Level of this node. + + Level means number of parent nodes above this node before reaching the root. + The root node is at level 0. + + Returns + ------- + level : int + + See Also + -------- + depth + width + """ + return len(self.parents) + + @property + def depth(self: Tree) -> int: + """ + Maximum level of this tree. + + Measured from the root, which has a depth of 0. + + Returns + ------- + depth : int + + See Also + -------- + level + width + """ + return max(node.level for node in self.root.subtree) + + @property + def width(self: Tree) -> int: + """ + Number of nodes at this level in the tree. + + Includes number of immediate siblings, but also "cousins" in other branches and so-on. + + Returns + ------- + depth : int + + See Also + -------- + level + depth + """ + return len([node for node in self.root.subtree if node.level == self.level]) + + def _pre_detach(self: Tree, parent: Tree) -> None: + """Method call before detaching from `parent`.""" + pass + + def _post_detach(self: Tree, parent: Tree) -> None: + """Method call after detaching from `parent`.""" + pass + + def _pre_attach(self: Tree, parent: Tree) -> None: + """Method call before attaching to `parent`.""" + pass + + def _post_attach(self: Tree, parent: Tree) -> None: + """Method call after attaching to `parent`.""" + pass + + def get(self: Tree, key: str, default: Tree | None = None) -> Tree | None: + """ + Return the child node with the specified key. + + Only looks for the node within the immediate children of this node, + not in other nodes of the tree. + """ + if key in self.children: + return self.children[key] + else: + return default + + # TODO `._walk` method to be called by both `_get_item` and `_set_item` + + def _get_item(self: Tree, path: str | NodePath) -> Tree | T_DataArray: + """ + Returns the object lying at the given path. + + Raises a KeyError if there is no object at the given path. + """ + if isinstance(path, str): + path = NodePath(path) + + if path.root: + current_node = self.root + root, *parts = list(path.parts) + else: + current_node = self + parts = list(path.parts) + + for part in parts: + if part == "..": + if current_node.parent is None: + raise KeyError(f"Could not find node at {path}") + else: + current_node = current_node.parent + elif part in ("", "."): + pass + else: + if current_node.get(part) is None: + raise KeyError(f"Could not find node at {path}") + else: + current_node = current_node.get(part) + return current_node + + def _set(self: Tree, key: str, val: Tree) -> None: + """ + Set the child node with the specified key to value. + + Counterpart to the public .get method, and also only works on the immediate node, not other nodes in the tree. + """ + new_children = {**self.children, key: val} + self.children = new_children + + def _set_item( + self: Tree, + path: str | NodePath, + item: Tree | T_DataArray, + new_nodes_along_path: bool = False, + allow_overwrite: bool = True, + ) -> None: + """ + Set a new item in the tree, overwriting anything already present at that path. + + The given value either forms a new node of the tree or overwrites an + existing item at that location. + + Parameters + ---------- + path + item + new_nodes_along_path : bool + If true, then if necessary new nodes will be created along the + given path, until the tree can reach the specified location. + allow_overwrite : bool + Whether or not to overwrite any existing node at the location given + by path. + + Raises + ------ + KeyError + If node cannot be reached, and new_nodes_along_path=False. + Or if a node already exists at the specified path, and allow_overwrite=False. + """ + if isinstance(path, str): + path = NodePath(path) + + if not path.name: + raise ValueError("Can't set an item under a path which has no name") + + if path.root: + # absolute path + current_node = self.root + root, *parts, name = path.parts + else: + # relative path + current_node = self + *parts, name = path.parts + + if parts: + # Walk to location of new node, creating intermediate node objects as we go if necessary + for part in parts: + if part == "..": + if current_node.parent is None: + # We can't create a parent if `new_nodes_along_path=True` as we wouldn't know what to name it + raise KeyError(f"Could not reach node at path {path}") + else: + current_node = current_node.parent + elif part in ("", "."): + pass + else: + if part in current_node.children: + current_node = current_node.children[part] + elif new_nodes_along_path: + # Want child classes (i.e. DataTree) to populate tree with their own types + new_node = type(self)() + current_node._set(part, new_node) + current_node = current_node.children[part] + else: + raise KeyError(f"Could not reach node at path {path}") + + if name in current_node.children: + # Deal with anything already existing at this location + if allow_overwrite: + current_node._set(name, item) + else: + raise KeyError(f"Already a node object at path {path}") + else: + current_node._set(name, item) + + def __delitem__(self: Tree, key: str): + """Remove a child node from this tree object.""" + if key in self.children: + child = self._children[key] + del self._children[key] + child.orphan() + else: + raise KeyError("Cannot delete") + + def same_tree(self, other: Tree) -> bool: + """True if other node is in the same tree as this node.""" + return self.root is other.root + + +class NamedNode(TreeNode, Generic[Tree]): + """ + A TreeNode which knows its own name. + + Implements path-like relationships to other nodes in its tree. + """ + + _name: str | None + _parent: Tree | None + _children: dict[str, Tree] + + def __init__(self, name=None, children=None): + super().__init__(children=children) + self._name = None + self.name = name + + @property + def name(self) -> str | None: + """The name of this node.""" + return self._name + + @name.setter + def name(self, name: str | None) -> None: + if name is not None: + if not isinstance(name, str): + raise TypeError("node name must be a string or None") + if "/" in name: + raise ValueError("node names cannot contain forward slashes") + self._name = name + + def __repr__(self, level=0): + repr_value = "\t" * level + self.__str__() + "\n" + for child in self.children: + repr_value += self.get(child).__repr__(level + 1) + return repr_value + + def __str__(self) -> str: + return f"NamedNode('{self.name}')" if self.name else "NamedNode()" + + def _post_attach(self: NamedNode, parent: NamedNode) -> None: + """Ensures child has name attribute corresponding to key under which it has been stored.""" + key = next(k for k, v in parent.children.items() if v is self) + self.name = key + + @property + def path(self) -> str: + """Return the file-like path from the root to this node.""" + if self.is_root: + return "/" + else: + root, *ancestors = tuple(reversed(self.parents)) + # don't include name of root because (a) root might not have a name & (b) we want path relative to root. + names = [*(node.name for node in ancestors), self.name] + return "/" + "/".join(names) + + def relative_to(self: NamedNode, other: NamedNode) -> str: + """ + Compute the relative path from this node to node `other`. + + If other is not in this tree, or it's otherwise impossible, raise a ValueError. + """ + if not self.same_tree(other): + raise NotFoundInTreeError( + "Cannot find relative path because nodes do not lie within the same tree" + ) + + this_path = NodePath(self.path) + if other.path in list(parent.path for parent in (self, *self.parents)): + return str(this_path.relative_to(other.path)) + else: + common_ancestor = self.find_common_ancestor(other) + path_to_common_ancestor = other._path_to_ancestor(common_ancestor) + return str( + path_to_common_ancestor / this_path.relative_to(common_ancestor.path) + ) + + def find_common_ancestor(self, other: NamedNode) -> NamedNode: + """ + Find the first common ancestor of two nodes in the same tree. + + Raise ValueError if they are not in the same tree. + """ + if self is other: + return self + + other_paths = [op.path for op in other.parents] + for parent in (self, *self.parents): + if parent.path in other_paths: + return parent + + raise NotFoundInTreeError( + "Cannot find common ancestor because nodes do not lie within the same tree" + ) + + def _path_to_ancestor(self, ancestor: NamedNode) -> NodePath: + """Return the relative path from this node to the given ancestor node""" + + if not self.same_tree(ancestor): + raise NotFoundInTreeError( + "Cannot find relative path to ancestor because nodes do not lie within the same tree" + ) + if ancestor.path not in list(a.path for a in (self, *self.parents)): + raise NotFoundInTreeError( + "Cannot find relative path to ancestor because given node is not an ancestor of this node" + ) + + parents_paths = list(parent.path for parent in (self, *self.parents)) + generation_gap = list(parents_paths).index(ancestor.path) + path_upwards = "../" * generation_gap if generation_gap > 0 else "." + return NodePath(path_upwards) diff --git a/test/fixtures/whole_applications/xarray/xarray/core/types.py b/test/fixtures/whole_applications/xarray/xarray/core/types.py new file mode 100644 index 0000000..41078d2 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/core/types.py @@ -0,0 +1,285 @@ +from __future__ import annotations + +import datetime +import sys +from collections.abc import Collection, Hashable, Iterator, Mapping, Sequence +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Literal, + Protocol, + SupportsIndex, + TypeVar, + Union, +) + +import numpy as np +import pandas as pd + +try: + if sys.version_info >= (3, 11): + from typing import Self, TypeAlias + else: + from typing_extensions import Self, TypeAlias +except ImportError: + if TYPE_CHECKING: + raise + else: + Self: Any = None + +if TYPE_CHECKING: + from numpy._typing import _SupportsDType + from numpy.typing import ArrayLike + + from xarray.backends.common import BackendEntrypoint + from xarray.core.alignment import Aligner + from xarray.core.common import AbstractArray, DataWithCoords + from xarray.core.coordinates import Coordinates + from xarray.core.dataarray import DataArray + from xarray.core.dataset import Dataset + from xarray.core.indexes import Index, Indexes + from xarray.core.utils import Frozen + from xarray.core.variable import Variable + + try: + from dask.array import Array as DaskArray + except ImportError: + DaskArray = np.ndarray # type: ignore + + try: + from cubed import Array as CubedArray + except ImportError: + CubedArray = np.ndarray + + try: + from zarr.core import Array as ZarrArray + except ImportError: + ZarrArray = np.ndarray + + # Anything that can be coerced to a shape tuple + _ShapeLike = Union[SupportsIndex, Sequence[SupportsIndex]] + _DTypeLikeNested = Any # TODO: wait for support for recursive types + + # Xarray requires a Mapping[Hashable, dtype] in many places which + # conflicts with numpys own DTypeLike (with dtypes for fields). + # https://numpy.org/devdocs/reference/typing.html#numpy.typing.DTypeLike + # This is a copy of this DTypeLike that allows only non-Mapping dtypes. + DTypeLikeSave = Union[ + np.dtype[Any], + # default data type (float64) + None, + # array-scalar types and generic types + type[Any], + # character codes, type strings or comma-separated fields, e.g., 'float64' + str, + # (flexible_dtype, itemsize) + tuple[_DTypeLikeNested, int], + # (fixed_dtype, shape) + tuple[_DTypeLikeNested, _ShapeLike], + # (base_dtype, new_dtype) + tuple[_DTypeLikeNested, _DTypeLikeNested], + # because numpy does the same? + list[Any], + # anything with a dtype attribute + _SupportsDType[np.dtype[Any]], + ] + try: + from cftime import datetime as CFTimeDatetime + except ImportError: + CFTimeDatetime = Any + DatetimeLike = Union[pd.Timestamp, datetime.datetime, np.datetime64, CFTimeDatetime] +else: + DTypeLikeSave: Any = None + + +class Alignable(Protocol): + """Represents any Xarray type that supports alignment. + + It may be ``Dataset``, ``DataArray`` or ``Coordinates``. This protocol class + is needed since those types do not all have a common base class. + + """ + + @property + def dims(self) -> Frozen[Hashable, int] | tuple[Hashable, ...]: ... + + @property + def sizes(self) -> Mapping[Hashable, int]: ... + + @property + def xindexes(self) -> Indexes[Index]: ... + + def _reindex_callback( + self, + aligner: Aligner, + dim_pos_indexers: dict[Hashable, Any], + variables: dict[Hashable, Variable], + indexes: dict[Hashable, Index], + fill_value: Any, + exclude_dims: frozenset[Hashable], + exclude_vars: frozenset[Hashable], + ) -> Self: ... + + def _overwrite_indexes( + self, + indexes: Mapping[Any, Index], + variables: Mapping[Any, Variable] | None = None, + ) -> Self: ... + + def __len__(self) -> int: ... + + def __iter__(self) -> Iterator[Hashable]: ... + + def copy( + self, + deep: bool = False, + ) -> Self: ... + + +T_Alignable = TypeVar("T_Alignable", bound="Alignable") + +T_Backend = TypeVar("T_Backend", bound="BackendEntrypoint") +T_Dataset = TypeVar("T_Dataset", bound="Dataset") +T_DataArray = TypeVar("T_DataArray", bound="DataArray") +T_Variable = TypeVar("T_Variable", bound="Variable") +T_Coordinates = TypeVar("T_Coordinates", bound="Coordinates") +T_Array = TypeVar("T_Array", bound="AbstractArray") +T_Index = TypeVar("T_Index", bound="Index") + +# `T_Xarray` is a type variable that can be either "DataArray" or "Dataset". When used +# in a function definition, all inputs and outputs annotated with `T_Xarray` must be of +# the same concrete type, either "DataArray" or "Dataset". This is generally preferred +# over `T_DataArrayOrSet`, given the type system can determine the exact type. +T_Xarray = TypeVar("T_Xarray", "DataArray", "Dataset") + +# `T_DataArrayOrSet` is a type variable that is bounded to either "DataArray" or +# "Dataset". Use it for functions that might return either type, but where the exact +# type cannot be determined statically using the type system. +T_DataArrayOrSet = TypeVar("T_DataArrayOrSet", bound=Union["Dataset", "DataArray"]) + +# For working directly with `DataWithCoords`. It will only allow using methods defined +# on `DataWithCoords`. +T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords") + + +# Temporary placeholder for indicating an array api compliant type. +# hopefully in the future we can narrow this down more: +T_DuckArray = TypeVar("T_DuckArray", bound=Any, covariant=True) + +# For typing pandas extension arrays. +T_ExtensionArray = TypeVar("T_ExtensionArray", bound=pd.api.extensions.ExtensionArray) + + +ScalarOrArray = Union["ArrayLike", np.generic, np.ndarray, "DaskArray"] +VarCompatible = Union["Variable", "ScalarOrArray"] +DaCompatible = Union["DataArray", "VarCompatible"] +DsCompatible = Union["Dataset", "DaCompatible"] +GroupByCompatible = Union["Dataset", "DataArray"] + +# Don't change to Hashable | Collection[Hashable] +# Read: https://github.com/pydata/xarray/issues/6142 +Dims = Union[str, Collection[Hashable], "ellipsis", None] + +# FYI in some cases we don't allow `None`, which this doesn't take account of. +T_ChunkDim: TypeAlias = Union[int, Literal["auto"], None, tuple[int, ...]] +# We allow the tuple form of this (though arguably we could transition to named dims only) +T_Chunks: TypeAlias = Union[T_ChunkDim, Mapping[Any, T_ChunkDim]] +T_NormalizedChunks = tuple[tuple[int, ...], ...] + +DataVars = Mapping[Any, Any] + + +ErrorOptions = Literal["raise", "ignore"] +ErrorOptionsWithWarn = Literal["raise", "warn", "ignore"] + +CompatOptions = Literal[ + "identical", "equals", "broadcast_equals", "no_conflicts", "override", "minimal" +] +ConcatOptions = Literal["all", "minimal", "different"] +CombineAttrsOptions = Union[ + Literal["drop", "identical", "no_conflicts", "drop_conflicts", "override"], + Callable[..., Any], +] +JoinOptions = Literal["outer", "inner", "left", "right", "exact", "override"] + +Interp1dOptions = Literal[ + "linear", "nearest", "zero", "slinear", "quadratic", "cubic", "polynomial" +] +InterpolantOptions = Literal["barycentric", "krogh", "pchip", "spline", "akima"] +InterpOptions = Union[Interp1dOptions, InterpolantOptions] + +DatetimeUnitOptions = Literal[ + "Y", "M", "W", "D", "h", "m", "s", "ms", "us", "μs", "ns", "ps", "fs", "as", None +] + +QueryEngineOptions = Literal["python", "numexpr", None] +QueryParserOptions = Literal["pandas", "python"] + +ReindexMethodOptions = Literal["nearest", "pad", "ffill", "backfill", "bfill", None] + +PadModeOptions = Literal[ + "constant", + "edge", + "linear_ramp", + "maximum", + "mean", + "median", + "minimum", + "reflect", + "symmetric", + "wrap", +] +PadReflectOptions = Literal["even", "odd", None] + +CFCalendar = Literal[ + "standard", + "gregorian", + "proleptic_gregorian", + "noleap", + "365_day", + "360_day", + "julian", + "all_leap", + "366_day", +] + +CoarsenBoundaryOptions = Literal["exact", "trim", "pad"] +SideOptions = Literal["left", "right"] +InclusiveOptions = Literal["both", "neither", "left", "right"] + +ScaleOptions = Literal["linear", "symlog", "log", "logit", None] +HueStyleOptions = Literal["continuous", "discrete", None] +AspectOptions = Union[Literal["auto", "equal"], float, None] +ExtendOptions = Literal["neither", "both", "min", "max", None] + +# TODO: Wait until mypy supports recursive objects in combination with typevars +_T = TypeVar("_T") +NestedSequence = Union[ + _T, + Sequence[_T], + Sequence[Sequence[_T]], + Sequence[Sequence[Sequence[_T]]], + Sequence[Sequence[Sequence[Sequence[_T]]]], +] + + +QuantileMethods = Literal[ + "inverted_cdf", + "averaged_inverted_cdf", + "closest_observation", + "interpolated_inverted_cdf", + "hazen", + "weibull", + "linear", + "median_unbiased", + "normal_unbiased", + "lower", + "higher", + "midpoint", + "nearest", +] + + +NetcdfWriteModes = Literal["w", "a"] +ZarrWriteModes = Literal["w", "w-", "a", "a-", "r+", "r"] diff --git a/test/fixtures/whole_applications/xarray/xarray/core/utils.py b/test/fixtures/whole_applications/xarray/xarray/core/utils.py new file mode 100644 index 0000000..5cb52cb --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/core/utils.py @@ -0,0 +1,1175 @@ +"""Internal utilities; not for external use""" + +# Some functions in this module are derived from functions in pandas. For +# reference, here is a copy of the pandas copyright notice: + +# BSD 3-Clause License + +# Copyright (c) 2008-2011, AQR Capital Management, LLC, Lambda Foundry, Inc. and PyData Development Team +# All rights reserved. + +# Copyright (c) 2011-2022, Open source contributors. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from __future__ import annotations + +import contextlib +import functools +import inspect +import io +import itertools +import math +import os +import re +import sys +import warnings +from collections.abc import ( + Collection, + Container, + Hashable, + ItemsView, + Iterable, + Iterator, + KeysView, + Mapping, + MutableMapping, + MutableSet, + ValuesView, +) +from enum import Enum +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generic, + Literal, + TypeVar, + overload, +) + +import numpy as np +import pandas as pd + +from xarray.namedarray.utils import ( # noqa: F401 + ReprObject, + drop_missing_dims, + either_dict_or_kwargs, + infix_dims, + is_dask_collection, + is_dict_like, + is_duck_array, + is_duck_dask_array, + module_available, + to_0d_object_array, +) + +if TYPE_CHECKING: + from xarray.core.types import Dims, ErrorOptionsWithWarn + +K = TypeVar("K") +V = TypeVar("V") +T = TypeVar("T") + + +def alias_message(old_name: str, new_name: str) -> str: + return f"{old_name} has been deprecated. Use {new_name} instead." + + +def alias_warning(old_name: str, new_name: str, stacklevel: int = 3) -> None: + warnings.warn( + alias_message(old_name, new_name), FutureWarning, stacklevel=stacklevel + ) + + +def alias(obj: Callable[..., T], old_name: str) -> Callable[..., T]: + assert isinstance(old_name, str) + + @functools.wraps(obj) + def wrapper(*args, **kwargs): + alias_warning(old_name, obj.__name__) + return obj(*args, **kwargs) + + wrapper.__doc__ = alias_message(old_name, obj.__name__) + return wrapper + + +def get_valid_numpy_dtype(array: np.ndarray | pd.Index): + """Return a numpy compatible dtype from either + a numpy array or a pandas.Index. + + Used for wrapping a pandas.Index as an xarray,Variable. + + """ + if isinstance(array, pd.PeriodIndex): + dtype = np.dtype("O") + elif hasattr(array, "categories"): + # category isn't a real numpy dtype + dtype = array.categories.dtype + if not is_valid_numpy_dtype(dtype): + dtype = np.dtype("O") + elif not is_valid_numpy_dtype(array.dtype): + dtype = np.dtype("O") + else: + dtype = array.dtype + + return dtype + + +def maybe_coerce_to_str(index, original_coords): + """maybe coerce a pandas Index back to a nunpy array of type str + + pd.Index uses object-dtype to store str - try to avoid this for coords + """ + from xarray.core import dtypes + + try: + result_type = dtypes.result_type(*original_coords) + except TypeError: + pass + else: + if result_type.kind in "SU": + index = np.asarray(index, dtype=result_type.type) + + return index + + +def maybe_wrap_array(original, new_array): + """Wrap a transformed array with __array_wrap__ if it can be done safely. + + This lets us treat arbitrary functions that take and return ndarray objects + like ufuncs, as long as they return an array with the same shape. + """ + # in case func lost array's metadata + if isinstance(new_array, np.ndarray) and new_array.shape == original.shape: + return original.__array_wrap__(new_array) + else: + return new_array + + +def equivalent(first: T, second: T) -> bool: + """Compare two objects for equivalence (identity or equality), using + array_equiv if either object is an ndarray. If both objects are lists, + equivalent is sequentially called on all the elements. + """ + # TODO: refactor to avoid circular import + from xarray.core import duck_array_ops + + if first is second: + return True + if isinstance(first, np.ndarray) or isinstance(second, np.ndarray): + return duck_array_ops.array_equiv(first, second) + if isinstance(first, list) or isinstance(second, list): + return list_equiv(first, second) + return (first == second) or (pd.isnull(first) and pd.isnull(second)) + + +def list_equiv(first, second): + equiv = True + if len(first) != len(second): + return False + else: + for f, s in zip(first, second): + equiv = equiv and equivalent(f, s) + return equiv + + +def peek_at(iterable: Iterable[T]) -> tuple[T, Iterator[T]]: + """Returns the first value from iterable, as well as a new iterator with + the same content as the original iterable + """ + gen = iter(iterable) + peek = next(gen) + return peek, itertools.chain([peek], gen) + + +def update_safety_check( + first_dict: Mapping[K, V], + second_dict: Mapping[K, V], + compat: Callable[[V, V], bool] = equivalent, +) -> None: + """Check the safety of updating one dictionary with another. + + Raises ValueError if dictionaries have non-compatible values for any key, + where compatibility is determined by identity (they are the same item) or + the `compat` function. + + Parameters + ---------- + first_dict, second_dict : dict-like + All items in the second dictionary are checked against for conflicts + against items in the first dictionary. + compat : function, optional + Binary operator to determine if two values are compatible. By default, + checks for equivalence. + """ + for k, v in second_dict.items(): + if k in first_dict and not compat(v, first_dict[k]): + raise ValueError( + "unsafe to merge dictionaries without " + f"overriding values; conflicting key {k!r}" + ) + + +def remove_incompatible_items( + first_dict: MutableMapping[K, V], + second_dict: Mapping[K, V], + compat: Callable[[V, V], bool] = equivalent, +) -> None: + """Remove incompatible items from the first dictionary in-place. + + Items are retained if their keys are found in both dictionaries and the + values are compatible. + + Parameters + ---------- + first_dict, second_dict : dict-like + Mappings to merge. + compat : function, optional + Binary operator to determine if two values are compatible. By default, + checks for equivalence. + """ + for k in list(first_dict): + if k not in second_dict or not compat(first_dict[k], second_dict[k]): + del first_dict[k] + + +def is_full_slice(value: Any) -> bool: + return isinstance(value, slice) and value == slice(None) + + +def is_list_like(value: Any) -> TypeGuard[list | tuple]: + return isinstance(value, (list, tuple)) + + +def _is_scalar(value, include_0d): + from xarray.core.variable import NON_NUMPY_SUPPORTED_ARRAY_TYPES + + if include_0d: + include_0d = getattr(value, "ndim", None) == 0 + return ( + include_0d + or isinstance(value, (str, bytes)) + or not ( + isinstance(value, (Iterable,) + NON_NUMPY_SUPPORTED_ARRAY_TYPES) + or hasattr(value, "__array_function__") + or hasattr(value, "__array_namespace__") + ) + ) + + +# See GH5624, this is a convoluted way to allow type-checking to use `TypeGuard` without +# requiring typing_extensions as a required dependency to _run_ the code (it is required +# to type-check). +try: + if sys.version_info >= (3, 10): + from typing import TypeGuard + else: + from typing_extensions import TypeGuard +except ImportError: + if TYPE_CHECKING: + raise + else: + + def is_scalar(value: Any, include_0d: bool = True) -> bool: + """Whether to treat a value as a scalar. + + Any non-iterable, string, or 0-D array + """ + return _is_scalar(value, include_0d) + +else: + + def is_scalar(value: Any, include_0d: bool = True) -> TypeGuard[Hashable]: + """Whether to treat a value as a scalar. + + Any non-iterable, string, or 0-D array + """ + return _is_scalar(value, include_0d) + + +def is_valid_numpy_dtype(dtype: Any) -> bool: + try: + np.dtype(dtype) + except (TypeError, ValueError): + return False + else: + return True + + +def to_0d_array(value: Any) -> np.ndarray: + """Given a value, wrap it in a 0-D numpy.ndarray.""" + if np.isscalar(value) or (isinstance(value, np.ndarray) and value.ndim == 0): + return np.array(value) + else: + return to_0d_object_array(value) + + +def dict_equiv( + first: Mapping[K, V], + second: Mapping[K, V], + compat: Callable[[V, V], bool] = equivalent, +) -> bool: + """Test equivalence of two dict-like objects. If any of the values are + numpy arrays, compare them correctly. + + Parameters + ---------- + first, second : dict-like + Dictionaries to compare for equality + compat : function, optional + Binary operator to determine if two values are compatible. By default, + checks for equivalence. + + Returns + ------- + equals : bool + True if the dictionaries are equal + """ + for k in first: + if k not in second or not compat(first[k], second[k]): + return False + return all(k in first for k in second) + + +def compat_dict_intersection( + first_dict: Mapping[K, V], + second_dict: Mapping[K, V], + compat: Callable[[V, V], bool] = equivalent, +) -> MutableMapping[K, V]: + """Return the intersection of two dictionaries as a new dictionary. + + Items are retained if their keys are found in both dictionaries and the + values are compatible. + + Parameters + ---------- + first_dict, second_dict : dict-like + Mappings to merge. + compat : function, optional + Binary operator to determine if two values are compatible. By default, + checks for equivalence. + + Returns + ------- + intersection : dict + Intersection of the contents. + """ + new_dict = dict(first_dict) + remove_incompatible_items(new_dict, second_dict, compat) + return new_dict + + +def compat_dict_union( + first_dict: Mapping[K, V], + second_dict: Mapping[K, V], + compat: Callable[[V, V], bool] = equivalent, +) -> MutableMapping[K, V]: + """Return the union of two dictionaries as a new dictionary. + + An exception is raised if any keys are found in both dictionaries and the + values are not compatible. + + Parameters + ---------- + first_dict, second_dict : dict-like + Mappings to merge. + compat : function, optional + Binary operator to determine if two values are compatible. By default, + checks for equivalence. + + Returns + ------- + union : dict + union of the contents. + """ + new_dict = dict(first_dict) + update_safety_check(first_dict, second_dict, compat) + new_dict.update(second_dict) + return new_dict + + +class Frozen(Mapping[K, V]): + """Wrapper around an object implementing the mapping interface to make it + immutable. If you really want to modify the mapping, the mutable version is + saved under the `mapping` attribute. + """ + + __slots__ = ("mapping",) + + def __init__(self, mapping: Mapping[K, V]): + self.mapping = mapping + + def __getitem__(self, key: K) -> V: + return self.mapping[key] + + def __iter__(self) -> Iterator[K]: + return iter(self.mapping) + + def __len__(self) -> int: + return len(self.mapping) + + def __contains__(self, key: object) -> bool: + return key in self.mapping + + def __repr__(self) -> str: + return f"{type(self).__name__}({self.mapping!r})" + + +def FrozenDict(*args, **kwargs) -> Frozen: + return Frozen(dict(*args, **kwargs)) + + +class FrozenMappingWarningOnValuesAccess(Frozen[K, V]): + """ + Class which behaves like a Mapping but warns if the values are accessed. + + Temporary object to aid in deprecation cycle of `Dataset.dims` (see GH issue #8496). + `Dataset.dims` is being changed from returning a mapping of dimension names to lengths to just + returning a frozen set of dimension names (to increase consistency with `DataArray.dims`). + This class retains backwards compatibility but raises a warning only if the return value + of ds.dims is used like a dictionary (i.e. it doesn't raise a warning if used in a way that + would also be valid for a FrozenSet, e.g. iteration). + """ + + __slots__ = ("mapping",) + + def _warn(self) -> None: + emit_user_level_warning( + "The return type of `Dataset.dims` will be changed to return a set of dimension names in future, " + "in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, " + "please use `Dataset.sizes`.", + FutureWarning, + ) + + def __getitem__(self, key: K) -> V: + self._warn() + return super().__getitem__(key) + + @overload + def get(self, key: K, /) -> V | None: ... + + @overload + def get(self, key: K, /, default: V | T) -> V | T: ... + + def get(self, key: K, default: T | None = None) -> V | T | None: + self._warn() + return super().get(key, default) + + def keys(self) -> KeysView[K]: + self._warn() + return super().keys() + + def items(self) -> ItemsView[K, V]: + self._warn() + return super().items() + + def values(self) -> ValuesView[V]: + self._warn() + return super().values() + + +class HybridMappingProxy(Mapping[K, V]): + """Implements the Mapping interface. Uses the wrapped mapping for item lookup + and a separate wrapped keys collection for iteration. + + Can be used to construct a mapping object from another dict-like object without + eagerly accessing its items or when a mapping object is expected but only + iteration over keys is actually used. + + Note: HybridMappingProxy does not validate consistency of the provided `keys` + and `mapping`. It is the caller's responsibility to ensure that they are + suitable for the task at hand. + """ + + __slots__ = ("_keys", "mapping") + + def __init__(self, keys: Collection[K], mapping: Mapping[K, V]): + self._keys = keys + self.mapping = mapping + + def __getitem__(self, key: K) -> V: + return self.mapping[key] + + def __iter__(self) -> Iterator[K]: + return iter(self._keys) + + def __len__(self) -> int: + return len(self._keys) + + +class OrderedSet(MutableSet[T]): + """A simple ordered set. + + The API matches the builtin set, but it preserves insertion order of elements, like + a dict. Note that, unlike in an OrderedDict, equality tests are not order-sensitive. + """ + + _d: dict[T, None] + + __slots__ = ("_d",) + + def __init__(self, values: Iterable[T] | None = None): + self._d = {} + if values is not None: + self.update(values) + + # Required methods for MutableSet + + def __contains__(self, value: Hashable) -> bool: + return value in self._d + + def __iter__(self) -> Iterator[T]: + return iter(self._d) + + def __len__(self) -> int: + return len(self._d) + + def add(self, value: T) -> None: + self._d[value] = None + + def discard(self, value: T) -> None: + del self._d[value] + + # Additional methods + + def update(self, values: Iterable[T]) -> None: + self._d.update(dict.fromkeys(values)) + + def __repr__(self) -> str: + return f"{type(self).__name__}({list(self)!r})" + + +class NdimSizeLenMixin: + """Mixin class that extends a class that defines a ``shape`` property to + one that also defines ``ndim``, ``size`` and ``__len__``. + """ + + __slots__ = () + + @property + def ndim(self: Any) -> int: + """ + Number of array dimensions. + + See Also + -------- + numpy.ndarray.ndim + """ + return len(self.shape) + + @property + def size(self: Any) -> int: + """ + Number of elements in the array. + + Equal to ``np.prod(a.shape)``, i.e., the product of the array’s dimensions. + + See Also + -------- + numpy.ndarray.size + """ + return math.prod(self.shape) + + def __len__(self: Any) -> int: + try: + return self.shape[0] + except IndexError: + raise TypeError("len() of unsized object") + + +class NDArrayMixin(NdimSizeLenMixin): + """Mixin class for making wrappers of N-dimensional arrays that conform to + the ndarray interface required for the data argument to Variable objects. + + A subclass should set the `array` property and override one or more of + `dtype`, `shape` and `__getitem__`. + """ + + __slots__ = () + + @property + def dtype(self: Any) -> np.dtype: + return self.array.dtype + + @property + def shape(self: Any) -> tuple[int, ...]: + return self.array.shape + + def __getitem__(self: Any, key): + return self.array[key] + + def __repr__(self: Any) -> str: + return f"{type(self).__name__}(array={self.array!r})" + + +@contextlib.contextmanager +def close_on_error(f): + """Context manager to ensure that a file opened by xarray is closed if an + exception is raised before the user sees the file object. + """ + try: + yield + except Exception: + f.close() + raise + + +def is_remote_uri(path: str) -> bool: + """Finds URLs of the form protocol:// or protocol:: + + This also matches for http[s]://, which were the only remote URLs + supported in <=v0.16.2. + """ + return bool(re.search(r"^[a-z][a-z0-9]*(\://|\:\:)", path)) + + +def read_magic_number_from_file(filename_or_obj, count=8) -> bytes: + # check byte header to determine file type + if isinstance(filename_or_obj, bytes): + magic_number = filename_or_obj[:count] + elif isinstance(filename_or_obj, io.IOBase): + if filename_or_obj.tell() != 0: + filename_or_obj.seek(0) + magic_number = filename_or_obj.read(count) + filename_or_obj.seek(0) + else: + raise TypeError(f"cannot read the magic number from {type(filename_or_obj)}") + return magic_number + + +def try_read_magic_number_from_path(pathlike, count=8) -> bytes | None: + if isinstance(pathlike, str) or hasattr(pathlike, "__fspath__"): + path = os.fspath(pathlike) + try: + with open(path, "rb") as f: + return read_magic_number_from_file(f, count) + except (FileNotFoundError, TypeError): + pass + return None + + +def try_read_magic_number_from_file_or_path(filename_or_obj, count=8) -> bytes | None: + magic_number = try_read_magic_number_from_path(filename_or_obj, count) + if magic_number is None: + try: + magic_number = read_magic_number_from_file(filename_or_obj, count) + except TypeError: + pass + return magic_number + + +def is_uniform_spaced(arr, **kwargs) -> bool: + """Return True if values of an array are uniformly spaced and sorted. + + >>> is_uniform_spaced(range(5)) + True + >>> is_uniform_spaced([-4, 0, 100]) + False + + kwargs are additional arguments to ``np.isclose`` + """ + arr = np.array(arr, dtype=float) + diffs = np.diff(arr) + return bool(np.isclose(diffs.min(), diffs.max(), **kwargs)) + + +def hashable(v: Any) -> TypeGuard[Hashable]: + """Determine whether `v` can be hashed.""" + try: + hash(v) + except TypeError: + return False + return True + + +def iterable(v: Any) -> TypeGuard[Iterable[Any]]: + """Determine whether `v` is iterable.""" + try: + iter(v) + except TypeError: + return False + return True + + +def iterable_of_hashable(v: Any) -> TypeGuard[Iterable[Hashable]]: + """Determine whether `v` is an Iterable of Hashables.""" + try: + it = iter(v) + except TypeError: + return False + return all(hashable(elm) for elm in it) + + +def decode_numpy_dict_values(attrs: Mapping[K, V]) -> dict[K, V]: + """Convert attribute values from numpy objects to native Python objects, + for use in to_dict + """ + attrs = dict(attrs) + for k, v in attrs.items(): + if isinstance(v, np.ndarray): + attrs[k] = v.tolist() + elif isinstance(v, np.generic): + attrs[k] = v.item() + return attrs + + +def ensure_us_time_resolution(val): + """Convert val out of numpy time, for use in to_dict. + Needed because of numpy bug GH#7619""" + if np.issubdtype(val.dtype, np.datetime64): + val = val.astype("datetime64[us]") + elif np.issubdtype(val.dtype, np.timedelta64): + val = val.astype("timedelta64[us]") + return val + + +class HiddenKeyDict(MutableMapping[K, V]): + """Acts like a normal dictionary, but hides certain keys.""" + + __slots__ = ("_data", "_hidden_keys") + + # ``__init__`` method required to create instance from class. + + def __init__(self, data: MutableMapping[K, V], hidden_keys: Iterable[K]): + self._data = data + self._hidden_keys = frozenset(hidden_keys) + + def _raise_if_hidden(self, key: K) -> None: + if key in self._hidden_keys: + raise KeyError(f"Key `{key!r}` is hidden.") + + # The next five methods are requirements of the ABC. + def __setitem__(self, key: K, value: V) -> None: + self._raise_if_hidden(key) + self._data[key] = value + + def __getitem__(self, key: K) -> V: + self._raise_if_hidden(key) + return self._data[key] + + def __delitem__(self, key: K) -> None: + self._raise_if_hidden(key) + del self._data[key] + + def __iter__(self) -> Iterator[K]: + for k in self._data: + if k not in self._hidden_keys: + yield k + + def __len__(self) -> int: + num_hidden = len(self._hidden_keys & self._data.keys()) + return len(self._data) - num_hidden + + +def get_temp_dimname(dims: Container[Hashable], new_dim: Hashable) -> Hashable: + """Get an new dimension name based on new_dim, that is not used in dims. + If the same name exists, we add an underscore(s) in the head. + + Example1: + dims: ['a', 'b', 'c'] + new_dim: ['_rolling'] + -> ['_rolling'] + Example2: + dims: ['a', 'b', 'c', '_rolling'] + new_dim: ['_rolling'] + -> ['__rolling'] + """ + while new_dim in dims: + new_dim = "_" + str(new_dim) + return new_dim + + +def drop_dims_from_indexers( + indexers: Mapping[Any, Any], + dims: Iterable[Hashable] | Mapping[Any, int], + missing_dims: ErrorOptionsWithWarn, +) -> Mapping[Hashable, Any]: + """Depending on the setting of missing_dims, drop any dimensions from indexers that + are not present in dims. + + Parameters + ---------- + indexers : dict + dims : sequence + missing_dims : {"raise", "warn", "ignore"} + """ + + if missing_dims == "raise": + invalid = indexers.keys() - set(dims) + if invalid: + raise ValueError( + f"Dimensions {invalid} do not exist. Expected one or more of {dims}" + ) + + return indexers + + elif missing_dims == "warn": + # don't modify input + indexers = dict(indexers) + + invalid = indexers.keys() - set(dims) + if invalid: + warnings.warn( + f"Dimensions {invalid} do not exist. Expected one or more of {dims}" + ) + for key in invalid: + indexers.pop(key) + + return indexers + + elif missing_dims == "ignore": + return {key: val for key, val in indexers.items() if key in dims} + + else: + raise ValueError( + f"Unrecognised option {missing_dims} for missing_dims argument" + ) + + +@overload +def parse_dims( + dim: Dims, + all_dims: tuple[Hashable, ...], + *, + check_exists: bool = True, + replace_none: Literal[True] = True, +) -> tuple[Hashable, ...]: ... + + +@overload +def parse_dims( + dim: Dims, + all_dims: tuple[Hashable, ...], + *, + check_exists: bool = True, + replace_none: Literal[False], +) -> tuple[Hashable, ...] | None | ellipsis: ... + + +def parse_dims( + dim: Dims, + all_dims: tuple[Hashable, ...], + *, + check_exists: bool = True, + replace_none: bool = True, +) -> tuple[Hashable, ...] | None | ellipsis: + """Parse one or more dimensions. + + A single dimension must be always a str, multiple dimensions + can be Hashables. This supports e.g. using a tuple as a dimension. + If you supply e.g. a set of dimensions the order cannot be + conserved, but for sequences it will be. + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None + Dimension(s) to parse. + all_dims : tuple of Hashable + All possible dimensions. + check_exists: bool, default: True + if True, check if dim is a subset of all_dims. + replace_none : bool, default: True + If True, return all_dims if dim is None or "...". + + Returns + ------- + parsed_dims : tuple of Hashable + Input dimensions as a tuple. + """ + if dim is None or dim is ...: + if replace_none: + return all_dims + return dim + if isinstance(dim, str): + dim = (dim,) + if check_exists: + _check_dims(set(dim), set(all_dims)) + return tuple(dim) + + +@overload +def parse_ordered_dims( + dim: Dims, + all_dims: tuple[Hashable, ...], + *, + check_exists: bool = True, + replace_none: Literal[True] = True, +) -> tuple[Hashable, ...]: ... + + +@overload +def parse_ordered_dims( + dim: Dims, + all_dims: tuple[Hashable, ...], + *, + check_exists: bool = True, + replace_none: Literal[False], +) -> tuple[Hashable, ...] | None | ellipsis: ... + + +def parse_ordered_dims( + dim: Dims, + all_dims: tuple[Hashable, ...], + *, + check_exists: bool = True, + replace_none: bool = True, +) -> tuple[Hashable, ...] | None | ellipsis: + """Parse one or more dimensions. + + A single dimension must be always a str, multiple dimensions + can be Hashables. This supports e.g. using a tuple as a dimension. + An ellipsis ("...") in a sequence of dimensions will be + replaced with all remaining dimensions. This only makes sense when + the input is a sequence and not e.g. a set. + + Parameters + ---------- + dim : str, Sequence of Hashable or "...", "..." or None + Dimension(s) to parse. If "..." appears in a Sequence + it always gets replaced with all remaining dims + all_dims : tuple of Hashable + All possible dimensions. + check_exists: bool, default: True + if True, check if dim is a subset of all_dims. + replace_none : bool, default: True + If True, return all_dims if dim is None. + + Returns + ------- + parsed_dims : tuple of Hashable + Input dimensions as a tuple. + """ + if dim is not None and dim is not ... and not isinstance(dim, str) and ... in dim: + dims_set: set[Hashable | ellipsis] = set(dim) + all_dims_set = set(all_dims) + if check_exists: + _check_dims(dims_set, all_dims_set) + if len(all_dims_set) != len(all_dims): + raise ValueError("Cannot use ellipsis with repeated dims") + dims = tuple(dim) + if dims.count(...) > 1: + raise ValueError("More than one ellipsis supplied") + other_dims = tuple(d for d in all_dims if d not in dims_set) + idx = dims.index(...) + return dims[:idx] + other_dims + dims[idx + 1 :] + else: + # mypy cannot resolve that the sequence cannot contain "..." + return parse_dims( # type: ignore[call-overload] + dim=dim, + all_dims=all_dims, + check_exists=check_exists, + replace_none=replace_none, + ) + + +def _check_dims(dim: set[Hashable], all_dims: set[Hashable]) -> None: + wrong_dims = (dim - all_dims) - {...} + if wrong_dims: + wrong_dims_str = ", ".join(f"'{d!s}'" for d in wrong_dims) + raise ValueError( + f"Dimension(s) {wrong_dims_str} do not exist. Expected one or more of {all_dims}" + ) + + +_Accessor = TypeVar("_Accessor") + + +class UncachedAccessor(Generic[_Accessor]): + """Acts like a property, but on both classes and class instances + + This class is necessary because some tools (e.g. pydoc and sphinx) + inspect classes for which property returns itself and not the + accessor. + """ + + def __init__(self, accessor: type[_Accessor]) -> None: + self._accessor = accessor + + @overload + def __get__(self, obj: None, cls) -> type[_Accessor]: ... + + @overload + def __get__(self, obj: object, cls) -> _Accessor: ... + + def __get__(self, obj: None | object, cls) -> type[_Accessor] | _Accessor: + if obj is None: + return self._accessor + + return self._accessor(obj) # type: ignore # assume it is a valid accessor! + + +# Singleton type, as per https://github.com/python/typing/pull/240 +class Default(Enum): + token = 0 + + +_default = Default.token + + +def iterate_nested(nested_list): + for item in nested_list: + if isinstance(item, list): + yield from iterate_nested(item) + else: + yield item + + +def contains_only_chunked_or_numpy(obj) -> bool: + """Returns True if xarray object contains only numpy arrays or chunked arrays (i.e. pure dask or cubed). + + Expects obj to be Dataset or DataArray""" + from xarray.core.dataarray import DataArray + from xarray.namedarray.pycompat import is_chunked_array + + if isinstance(obj, DataArray): + obj = obj._to_temp_dataset() + + return all( + [ + isinstance(var.data, np.ndarray) or is_chunked_array(var.data) + for var in obj.variables.values() + ] + ) + + +def find_stack_level(test_mode=False) -> int: + """Find the first place in the stack that is not inside xarray or the Python standard library. + + This is unless the code emanates from a test, in which case we would prefer + to see the xarray source. + + This function is taken from pandas and modified to exclude standard library paths. + + Parameters + ---------- + test_mode : bool + Flag used for testing purposes to switch off the detection of test + directories in the stack trace. + + Returns + ------- + stacklevel : int + First level in the stack that is not part of xarray or the Python standard library. + """ + import xarray as xr + + pkg_dir = Path(xr.__file__).parent + test_dir = pkg_dir / "tests" + + std_lib_init = sys.modules["os"].__file__ + # Mostly to appease mypy; I don't think this can happen... + if std_lib_init is None: + return 0 + + std_lib_dir = Path(std_lib_init).parent + + frame = inspect.currentframe() + n = 0 + while frame: + fname = inspect.getfile(frame) + if ( + fname.startswith(str(pkg_dir)) + and (not fname.startswith(str(test_dir)) or test_mode) + ) or ( + fname.startswith(str(std_lib_dir)) + and "site-packages" not in fname + and "dist-packages" not in fname + ): + frame = frame.f_back + n += 1 + else: + break + return n + + +def emit_user_level_warning(message, category=None) -> None: + """Emit a warning at the user level by inspecting the stack trace.""" + stacklevel = find_stack_level() + return warnings.warn(message, category=category, stacklevel=stacklevel) + + +def consolidate_dask_from_array_kwargs( + from_array_kwargs: dict[Any, Any], + name: str | None = None, + lock: bool | None = None, + inline_array: bool | None = None, +) -> dict[Any, Any]: + """ + Merge dask-specific kwargs with arbitrary from_array_kwargs dict. + + Temporary function, to be deleted once explicitly passing dask-specific kwargs to .chunk() is deprecated. + """ + + from_array_kwargs = _resolve_doubly_passed_kwarg( + from_array_kwargs, + kwarg_name="name", + passed_kwarg_value=name, + default=None, + err_msg_dict_name="from_array_kwargs", + ) + from_array_kwargs = _resolve_doubly_passed_kwarg( + from_array_kwargs, + kwarg_name="lock", + passed_kwarg_value=lock, + default=False, + err_msg_dict_name="from_array_kwargs", + ) + from_array_kwargs = _resolve_doubly_passed_kwarg( + from_array_kwargs, + kwarg_name="inline_array", + passed_kwarg_value=inline_array, + default=False, + err_msg_dict_name="from_array_kwargs", + ) + + return from_array_kwargs + + +def _resolve_doubly_passed_kwarg( + kwargs_dict: dict[Any, Any], + kwarg_name: str, + passed_kwarg_value: str | bool | None, + default: bool | None, + err_msg_dict_name: str, +) -> dict[Any, Any]: + # if in kwargs_dict but not passed explicitly then just pass kwargs_dict through unaltered + if kwarg_name in kwargs_dict and passed_kwarg_value is None: + pass + # if passed explicitly but not in kwargs_dict then use that + elif kwarg_name not in kwargs_dict and passed_kwarg_value is not None: + kwargs_dict[kwarg_name] = passed_kwarg_value + # if in neither then use default + elif kwarg_name not in kwargs_dict and passed_kwarg_value is None: + kwargs_dict[kwarg_name] = default + # if in both then raise + else: + raise ValueError( + f"argument {kwarg_name} cannot be passed both as a keyword argument and within " + f"the {err_msg_dict_name} dictionary" + ) + + return kwargs_dict diff --git a/test/fixtures/whole_applications/xarray/xarray/core/variable.py b/test/fixtures/whole_applications/xarray/xarray/core/variable.py new file mode 100644 index 0000000..f068588 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/core/variable.py @@ -0,0 +1,3012 @@ +from __future__ import annotations + +import copy +import itertools +import math +import numbers +import warnings +from collections.abc import Hashable, Mapping, Sequence +from datetime import timedelta +from functools import partial +from typing import TYPE_CHECKING, Any, Callable, Literal, NoReturn, cast + +import numpy as np +import pandas as pd +from numpy.typing import ArrayLike +from pandas.api.types import is_extension_array_dtype + +import xarray as xr # only for Dataset and DataArray +from xarray.core import common, dtypes, duck_array_ops, indexing, nputils, ops, utils +from xarray.core.arithmetic import VariableArithmetic +from xarray.core.common import AbstractArray +from xarray.core.extension_array import PandasExtensionArray +from xarray.core.indexing import ( + BasicIndexer, + OuterIndexer, + PandasIndexingAdapter, + VectorizedIndexer, + as_indexable, +) +from xarray.core.options import OPTIONS, _get_keep_attrs +from xarray.core.utils import ( + OrderedSet, + _default, + consolidate_dask_from_array_kwargs, + decode_numpy_dict_values, + drop_dims_from_indexers, + either_dict_or_kwargs, + emit_user_level_warning, + ensure_us_time_resolution, + infix_dims, + is_dict_like, + is_duck_array, + is_duck_dask_array, + maybe_coerce_to_str, +) +from xarray.namedarray.core import NamedArray, _raise_if_any_duplicate_dimensions +from xarray.namedarray.pycompat import integer_types, is_0d_dask_array, to_duck_array +from xarray.util.deprecation_helpers import deprecate_dims + +NON_NUMPY_SUPPORTED_ARRAY_TYPES = ( + indexing.ExplicitlyIndexed, + pd.Index, + pd.api.extensions.ExtensionArray, +) +# https://github.com/python/mypy/issues/224 +BASIC_INDEXING_TYPES = integer_types + (slice,) + +if TYPE_CHECKING: + from xarray.core.types import ( + Dims, + ErrorOptionsWithWarn, + PadModeOptions, + PadReflectOptions, + QuantileMethods, + Self, + T_DuckArray, + ) + from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint + + +NON_NANOSECOND_WARNING = ( + "Converting non-nanosecond precision {case} values to nanosecond precision. " + "This behavior can eventually be relaxed in xarray, as it is an artifact from " + "pandas which is now beginning to support non-nanosecond precision values. " + "This warning is caused by passing non-nanosecond np.datetime64 or " + "np.timedelta64 values to the DataArray or Variable constructor; it can be " + "silenced by converting the values to nanosecond precision ahead of time." +) + + +class MissingDimensionsError(ValueError): + """Error class used when we can't safely guess a dimension name.""" + + # inherits from ValueError for backward compatibility + # TODO: move this to an xarray.exceptions module? + + +def as_variable( + obj: T_DuckArray | Any, name=None, auto_convert: bool = True +) -> Variable | IndexVariable: + """Convert an object into a Variable. + + Parameters + ---------- + obj : object + Object to convert into a Variable. + + - If the object is already a Variable, return a shallow copy. + - Otherwise, if the object has 'dims' and 'data' attributes, convert + it into a new Variable. + - If all else fails, attempt to convert the object into a Variable by + unpacking it into the arguments for creating a new Variable. + name : str, optional + If provided: + + - `obj` can be a 1D array, which is assumed to label coordinate values + along a dimension of this given name. + - Variables with name matching one of their dimensions are converted + into `IndexVariable` objects. + auto_convert : bool, optional + For internal use only! If True, convert a "dimension" variable into + an IndexVariable object (deprecated). + + Returns + ------- + var : Variable + The newly created variable. + + """ + from xarray.core.dataarray import DataArray + + # TODO: consider extending this method to automatically handle Iris and + if isinstance(obj, DataArray): + # extract the primary Variable from DataArrays + obj = obj.variable + + if isinstance(obj, Variable): + obj = obj.copy(deep=False) + elif isinstance(obj, tuple): + try: + dims_, data_, *attrs = obj + except ValueError: + raise ValueError(f"Tuple {obj} is not in the form (dims, data[, attrs])") + + if isinstance(data_, DataArray): + raise TypeError( + f"Variable {name!r}: Using a DataArray object to construct a variable is" + " ambiguous, please extract the data using the .data property." + ) + try: + obj = Variable(dims_, data_, *attrs) + except (TypeError, ValueError) as error: + raise error.__class__( + f"Variable {name!r}: Could not convert tuple of form " + f"(dims, data[, attrs, encoding]): {obj} to Variable." + ) + elif utils.is_scalar(obj): + obj = Variable([], obj) + elif isinstance(obj, (pd.Index, IndexVariable)) and obj.name is not None: + obj = Variable(obj.name, obj) + elif isinstance(obj, (set, dict)): + raise TypeError(f"variable {name!r} has invalid type {type(obj)!r}") + elif name is not None: + data: T_DuckArray = as_compatible_data(obj) + if data.ndim != 1: + raise MissingDimensionsError( + f"cannot set variable {name!r} with {data.ndim!r}-dimensional data " + "without explicit dimension names. Pass a tuple of " + "(dims, data) instead." + ) + obj = Variable(name, data, fastpath=True) + else: + raise TypeError( + f"Variable {name!r}: unable to convert object into a variable without an " + f"explicit list of dimensions: {obj!r}" + ) + + if auto_convert: + if name is not None and name in obj.dims and obj.ndim == 1: + # automatically convert the Variable into an Index + emit_user_level_warning( + f"variable {name!r} with name matching its dimension will not be " + "automatically converted into an `IndexVariable` object in the future.", + FutureWarning, + ) + obj = obj.to_index_variable() + + return obj + + +def _maybe_wrap_data(data): + """ + Put pandas.Index and numpy.ndarray arguments in adapter objects to ensure + they can be indexed properly. + + NumpyArrayAdapter, PandasIndexingAdapter and LazilyIndexedArray should + all pass through unmodified. + """ + if isinstance(data, pd.Index): + return PandasIndexingAdapter(data) + if isinstance(data, pd.api.extensions.ExtensionArray): + return PandasExtensionArray[type(data)](data) + return data + + +def _as_nanosecond_precision(data): + dtype = data.dtype + non_ns_datetime64 = ( + dtype.kind == "M" + and isinstance(dtype, np.dtype) + and dtype != np.dtype("datetime64[ns]") + ) + non_ns_datetime_tz_dtype = ( + isinstance(dtype, pd.DatetimeTZDtype) and dtype.unit != "ns" + ) + if non_ns_datetime64 or non_ns_datetime_tz_dtype: + utils.emit_user_level_warning(NON_NANOSECOND_WARNING.format(case="datetime")) + if isinstance(dtype, pd.DatetimeTZDtype): + nanosecond_precision_dtype = pd.DatetimeTZDtype("ns", dtype.tz) + else: + nanosecond_precision_dtype = "datetime64[ns]" + return duck_array_ops.astype(data, nanosecond_precision_dtype) + elif dtype.kind == "m" and dtype != np.dtype("timedelta64[ns]"): + utils.emit_user_level_warning(NON_NANOSECOND_WARNING.format(case="timedelta")) + return duck_array_ops.astype(data, "timedelta64[ns]") + else: + return data + + +def _possibly_convert_objects(values): + """Convert arrays of datetime.datetime and datetime.timedelta objects into + datetime64 and timedelta64, according to the pandas convention. For the time + being, convert any non-nanosecond precision DatetimeIndex or TimedeltaIndex + objects to nanosecond precision. While pandas is relaxing this in version + 2.0.0, in xarray we will need to make sure we are ready to handle + non-nanosecond precision datetimes or timedeltas in our code before allowing + such values to pass through unchanged. Converting to nanosecond precision + through pandas.Series objects ensures that datetimes and timedeltas are + within the valid date range for ns precision, as pandas will raise an error + if they are not. + """ + as_series = pd.Series(values.ravel(), copy=False) + if as_series.dtype.kind in "mM": + as_series = _as_nanosecond_precision(as_series) + result = np.asarray(as_series).reshape(values.shape) + if not result.flags.writeable: + # GH8843, pandas copy-on-write mode creates read-only arrays by default + try: + result.flags.writeable = True + except ValueError: + result = result.copy() + return result + + +def _possibly_convert_datetime_or_timedelta_index(data): + """For the time being, convert any non-nanosecond precision DatetimeIndex or + TimedeltaIndex objects to nanosecond precision. While pandas is relaxing + this in version 2.0.0, in xarray we will need to make sure we are ready to + handle non-nanosecond precision datetimes or timedeltas in our code + before allowing such values to pass through unchanged.""" + if isinstance(data, PandasIndexingAdapter): + if isinstance(data.array, (pd.DatetimeIndex, pd.TimedeltaIndex)): + data = PandasIndexingAdapter(_as_nanosecond_precision(data.array)) + elif isinstance(data, (pd.DatetimeIndex, pd.TimedeltaIndex)): + data = _as_nanosecond_precision(data) + return data + + +def as_compatible_data( + data: T_DuckArray | ArrayLike, fastpath: bool = False +) -> T_DuckArray: + """Prepare and wrap data to put in a Variable. + + - If data does not have the necessary attributes, convert it to ndarray. + - If data has dtype=datetime64, ensure that it has ns precision. If it's a + pandas.Timestamp, convert it to datetime64. + - If data is already a pandas or xarray object (other than an Index), just + use the values. + + Finally, wrap it up with an adapter if necessary. + """ + if fastpath and getattr(data, "ndim", None) is not None: + return cast("T_DuckArray", data) + + from xarray.core.dataarray import DataArray + + # TODO: do this uwrapping in the Variable/NamedArray constructor instead. + if isinstance(data, Variable): + return cast("T_DuckArray", data._data) + + # TODO: do this uwrapping in the DataArray constructor instead. + if isinstance(data, DataArray): + return cast("T_DuckArray", data._variable._data) + + if isinstance(data, NON_NUMPY_SUPPORTED_ARRAY_TYPES): + data = _possibly_convert_datetime_or_timedelta_index(data) + return cast("T_DuckArray", _maybe_wrap_data(data)) + + if isinstance(data, tuple): + data = utils.to_0d_object_array(data) + + if isinstance(data, pd.Timestamp): + # TODO: convert, handle datetime objects, too + data = np.datetime64(data.value, "ns") + + if isinstance(data, timedelta): + data = np.timedelta64(getattr(data, "value", data), "ns") + + # we don't want nested self-described arrays + if isinstance(data, (pd.Series, pd.DataFrame)): + data = data.values + + if isinstance(data, np.ma.MaskedArray): + mask = np.ma.getmaskarray(data) + if mask.any(): + dtype, fill_value = dtypes.maybe_promote(data.dtype) + data = duck_array_ops.where_method(data, ~mask, fill_value) + else: + data = np.asarray(data) + + if not isinstance(data, np.ndarray) and ( + hasattr(data, "__array_function__") or hasattr(data, "__array_namespace__") + ): + return cast("T_DuckArray", data) + + # validate whether the data is valid data types. + data = np.asarray(data) + + if isinstance(data, np.ndarray) and data.dtype.kind in "OMm": + data = _possibly_convert_objects(data) + return _maybe_wrap_data(data) + + +def _as_array_or_item(data): + """Return the given values as a numpy array, or as an individual item if + it's a 0d datetime64 or timedelta64 array. + + Importantly, this function does not copy data if it is already an ndarray - + otherwise, it will not be possible to update Variable values in place. + + This function mostly exists because 0-dimensional ndarrays with + dtype=datetime64 are broken :( + https://github.com/numpy/numpy/issues/4337 + https://github.com/numpy/numpy/issues/7619 + + TODO: remove this (replace with np.asarray) once these issues are fixed + """ + data = np.asarray(data) + if data.ndim == 0: + if data.dtype.kind == "M": + data = np.datetime64(data, "ns") + elif data.dtype.kind == "m": + data = np.timedelta64(data, "ns") + return data + + +class Variable(NamedArray, AbstractArray, VariableArithmetic): + """A netcdf-like variable consisting of dimensions, data and attributes + which describe a single Array. A single Variable object is not fully + described outside the context of its parent Dataset (if you want such a + fully described object, use a DataArray instead). + + The main functional difference between Variables and numpy arrays is that + numerical operations on Variables implement array broadcasting by dimension + name. For example, adding an Variable with dimensions `('time',)` to + another Variable with dimensions `('space',)` results in a new Variable + with dimensions `('time', 'space')`. Furthermore, numpy reduce operations + like ``mean`` or ``sum`` are overwritten to take a "dimension" argument + instead of an "axis". + + Variables are light-weight objects used as the building block for datasets. + They are more primitive objects, so operations with them provide marginally + higher performance than using DataArrays. However, manipulating data in the + form of a Dataset or DataArray should almost always be preferred, because + they can use more complete metadata in context of coordinate labels. + """ + + __slots__ = ("_dims", "_data", "_attrs", "_encoding") + + def __init__( + self, + dims, + data: T_DuckArray | ArrayLike, + attrs=None, + encoding=None, + fastpath=False, + ): + """ + Parameters + ---------- + dims : str or sequence of str + Name(s) of the the data dimension(s). Must be either a string (only + for 1D data) or a sequence of strings with length equal to the + number of dimensions. + data : array_like + Data array which supports numpy-like data access. + attrs : dict_like or None, optional + Attributes to assign to the new variable. If None (default), an + empty attribute dictionary is initialized. + encoding : dict_like or None, optional + Dictionary specifying how to encode this array's data into a + serialized format like netCDF4. Currently used keys (for netCDF) + include '_FillValue', 'scale_factor', 'add_offset' and 'dtype'. + Well-behaved code to serialize a Variable should ignore + unrecognized encoding items. + """ + super().__init__( + dims=dims, data=as_compatible_data(data, fastpath=fastpath), attrs=attrs + ) + + self._encoding = None + if encoding is not None: + self.encoding = encoding + + def _new( + self, + dims=_default, + data=_default, + attrs=_default, + ): + dims_ = copy.copy(self._dims) if dims is _default else dims + + if attrs is _default: + attrs_ = None if self._attrs is None else self._attrs.copy() + else: + attrs_ = attrs + + if data is _default: + return type(self)(dims_, copy.copy(self._data), attrs_) + else: + cls_ = type(self) + return cls_(dims_, data, attrs_) + + @property + def _in_memory(self): + return isinstance( + self._data, (np.ndarray, np.number, PandasIndexingAdapter) + ) or ( + isinstance(self._data, indexing.MemoryCachedArray) + and isinstance(self._data.array, indexing.NumpyIndexingAdapter) + ) + + @property + def data(self): + """ + The Variable's data as an array. The underlying array type + (e.g. dask, sparse, pint) is preserved. + + See Also + -------- + Variable.to_numpy + Variable.as_numpy + Variable.values + """ + if is_duck_array(self._data): + return self._data + elif isinstance(self._data, indexing.ExplicitlyIndexed): + return self._data.get_duck_array() + else: + return self.values + + @data.setter + def data(self, data: T_DuckArray | ArrayLike) -> None: + data = as_compatible_data(data) + self._check_shape(data) + self._data = data + + def astype( + self, + dtype, + *, + order=None, + casting=None, + subok=None, + copy=None, + keep_attrs=True, + ) -> Self: + """ + Copy of the Variable object, with data cast to a specified type. + + Parameters + ---------- + dtype : str or dtype + Typecode or data-type to which the array is cast. + order : {'C', 'F', 'A', 'K'}, optional + Controls the memory layout order of the result. ‘C’ means C order, + ‘F’ means Fortran order, ‘A’ means ‘F’ order if all the arrays are + Fortran contiguous, ‘C’ order otherwise, and ‘K’ means as close to + the order the array elements appear in memory as possible. + casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional + Controls what kind of data casting may occur. + + * 'no' means the data types should not be cast at all. + * 'equiv' means only byte-order changes are allowed. + * 'safe' means only casts which can preserve values are allowed. + * 'same_kind' means only safe casts or casts within a kind, + like float64 to float32, are allowed. + * 'unsafe' means any data conversions may be done. + subok : bool, optional + If True, then sub-classes will be passed-through, otherwise the + returned array will be forced to be a base-class array. + copy : bool, optional + By default, astype always returns a newly allocated array. If this + is set to False and the `dtype` requirement is satisfied, the input + array is returned instead of a copy. + keep_attrs : bool, optional + By default, astype keeps attributes. Set to False to remove + attributes in the returned object. + + Returns + ------- + out : same as object + New object with data cast to the specified type. + + Notes + ----- + The ``order``, ``casting``, ``subok`` and ``copy`` arguments are only passed + through to the ``astype`` method of the underlying array when a value + different than ``None`` is supplied. + Make sure to only supply these arguments if the underlying array class + supports them. + + See Also + -------- + numpy.ndarray.astype + dask.array.Array.astype + sparse.COO.astype + """ + from xarray.core.computation import apply_ufunc + + kwargs = dict(order=order, casting=casting, subok=subok, copy=copy) + kwargs = {k: v for k, v in kwargs.items() if v is not None} + + return apply_ufunc( + duck_array_ops.astype, + self, + dtype, + kwargs=kwargs, + keep_attrs=keep_attrs, + dask="allowed", + ) + + def _dask_finalize(self, results, array_func, *args, **kwargs): + data = array_func(results, *args, **kwargs) + return Variable(self._dims, data, attrs=self._attrs, encoding=self._encoding) + + @property + def values(self): + """The variable's data as a numpy.ndarray""" + return _as_array_or_item(self._data) + + @values.setter + def values(self, values): + self.data = values + + def to_base_variable(self) -> Variable: + """Return this variable as a base xarray.Variable""" + return Variable( + self._dims, self._data, self._attrs, encoding=self._encoding, fastpath=True + ) + + to_variable = utils.alias(to_base_variable, "to_variable") + + def to_index_variable(self) -> IndexVariable: + """Return this variable as an xarray.IndexVariable""" + return IndexVariable( + self._dims, self._data, self._attrs, encoding=self._encoding, fastpath=True + ) + + to_coord = utils.alias(to_index_variable, "to_coord") + + def _to_index(self) -> pd.Index: + return self.to_index_variable()._to_index() + + def to_index(self) -> pd.Index: + """Convert this variable to a pandas.Index""" + return self.to_index_variable().to_index() + + def to_dict( + self, data: bool | str = "list", encoding: bool = False + ) -> dict[str, Any]: + """Dictionary representation of variable.""" + item: dict[str, Any] = { + "dims": self.dims, + "attrs": decode_numpy_dict_values(self.attrs), + } + if data is not False: + if data in [True, "list"]: + item["data"] = ensure_us_time_resolution(self.to_numpy()).tolist() + elif data == "array": + item["data"] = ensure_us_time_resolution(self.data) + else: + msg = 'data argument must be bool, "list", or "array"' + raise ValueError(msg) + + else: + item.update({"dtype": str(self.dtype), "shape": self.shape}) + + if encoding: + item["encoding"] = dict(self.encoding) + + return item + + def _item_key_to_tuple(self, key): + if is_dict_like(key): + return tuple(key.get(dim, slice(None)) for dim in self.dims) + else: + return key + + def _broadcast_indexes(self, key): + """Prepare an indexing key for an indexing operation. + + Parameters + ---------- + key : int, slice, array-like, dict or tuple of integer, slice and array-like + Any valid input for indexing. + + Returns + ------- + dims : tuple + Dimension of the resultant variable. + indexers : IndexingTuple subclass + Tuple of integer, array-like, or slices to use when indexing + self._data. The type of this argument indicates the type of + indexing to perform, either basic, outer or vectorized. + new_order : Optional[Sequence[int]] + Optional reordering to do on the result of indexing. If not None, + the first len(new_order) indexing should be moved to these + positions. + """ + key = self._item_key_to_tuple(key) # key is a tuple + # key is a tuple of full size + key = indexing.expanded_indexer(key, self.ndim) + # Convert a scalar Variable to a 0d-array + key = tuple( + k.data if isinstance(k, Variable) and k.ndim == 0 else k for k in key + ) + # Convert a 0d numpy arrays to an integer + # dask 0d arrays are passed through + key = tuple( + k.item() if isinstance(k, np.ndarray) and k.ndim == 0 else k for k in key + ) + + if all(isinstance(k, BASIC_INDEXING_TYPES) for k in key): + return self._broadcast_indexes_basic(key) + + self._validate_indexers(key) + # Detect it can be mapped as an outer indexer + # If all key is unlabeled, or + # key can be mapped as an OuterIndexer. + if all(not isinstance(k, Variable) for k in key): + return self._broadcast_indexes_outer(key) + + # If all key is 1-dimensional and there are no duplicate labels, + # key can be mapped as an OuterIndexer. + dims = [] + for k, d in zip(key, self.dims): + if isinstance(k, Variable): + if len(k.dims) > 1: + return self._broadcast_indexes_vectorized(key) + dims.append(k.dims[0]) + elif not isinstance(k, integer_types): + dims.append(d) + if len(set(dims)) == len(dims): + return self._broadcast_indexes_outer(key) + + return self._broadcast_indexes_vectorized(key) + + def _broadcast_indexes_basic(self, key): + dims = tuple( + dim for k, dim in zip(key, self.dims) if not isinstance(k, integer_types) + ) + return dims, BasicIndexer(key), None + + def _validate_indexers(self, key): + """Make sanity checks""" + for dim, k in zip(self.dims, key): + if not isinstance(k, BASIC_INDEXING_TYPES): + if not isinstance(k, Variable): + if not is_duck_array(k): + k = np.asarray(k) + if k.ndim > 1: + raise IndexError( + "Unlabeled multi-dimensional array cannot be " + f"used for indexing: {k}" + ) + if k.dtype.kind == "b": + if self.shape[self.get_axis_num(dim)] != len(k): + raise IndexError( + f"Boolean array size {len(k):d} is used to index array " + f"with shape {str(self.shape):s}." + ) + if k.ndim > 1: + raise IndexError( + f"{k.ndim}-dimensional boolean indexing is " + "not supported. " + ) + if is_duck_dask_array(k.data): + raise KeyError( + "Indexing with a boolean dask array is not allowed. " + "This will result in a dask array of unknown shape. " + "Such arrays are unsupported by Xarray." + "Please compute the indexer first using .compute()" + ) + if getattr(k, "dims", (dim,)) != (dim,): + raise IndexError( + "Boolean indexer should be unlabeled or on the " + "same dimension to the indexed array. Indexer is " + f"on {str(k.dims):s} but the target dimension is {dim:s}." + ) + + def _broadcast_indexes_outer(self, key): + # drop dim if k is integer or if k is a 0d dask array + dims = tuple( + k.dims[0] if isinstance(k, Variable) else dim + for k, dim in zip(key, self.dims) + if (not isinstance(k, integer_types) and not is_0d_dask_array(k)) + ) + + new_key = [] + for k in key: + if isinstance(k, Variable): + k = k.data + if not isinstance(k, BASIC_INDEXING_TYPES): + if not is_duck_array(k): + k = np.asarray(k) + if k.size == 0: + # Slice by empty list; numpy could not infer the dtype + k = k.astype(int) + elif k.dtype.kind == "b": + (k,) = np.nonzero(k) + new_key.append(k) + + return dims, OuterIndexer(tuple(new_key)), None + + def _broadcast_indexes_vectorized(self, key): + variables = [] + out_dims_set = OrderedSet() + for dim, value in zip(self.dims, key): + if isinstance(value, slice): + out_dims_set.add(dim) + else: + variable = ( + value + if isinstance(value, Variable) + else as_variable(value, name=dim, auto_convert=False) + ) + if variable.dims == (dim,): + variable = variable.to_index_variable() + if variable.dtype.kind == "b": # boolean indexing case + (variable,) = variable._nonzero() + + variables.append(variable) + out_dims_set.update(variable.dims) + + variable_dims = set() + for variable in variables: + variable_dims.update(variable.dims) + + slices = [] + for i, (dim, value) in enumerate(zip(self.dims, key)): + if isinstance(value, slice): + if dim in variable_dims: + # We only convert slice objects to variables if they share + # a dimension with at least one other variable. Otherwise, + # we can equivalently leave them as slices aknd transpose + # the result. This is significantly faster/more efficient + # for most array backends. + values = np.arange(*value.indices(self.sizes[dim])) + variables.insert(i - len(slices), Variable((dim,), values)) + else: + slices.append((i, value)) + + try: + variables = _broadcast_compat_variables(*variables) + except ValueError: + raise IndexError(f"Dimensions of indexers mismatch: {key}") + + out_key = [variable.data for variable in variables] + out_dims = tuple(out_dims_set) + slice_positions = set() + for i, value in slices: + out_key.insert(i, value) + new_position = out_dims.index(self.dims[i]) + slice_positions.add(new_position) + + if slice_positions: + new_order = [i for i in range(len(out_dims)) if i not in slice_positions] + else: + new_order = None + + return out_dims, VectorizedIndexer(tuple(out_key)), new_order + + def __getitem__(self, key) -> Self: + """Return a new Variable object whose contents are consistent with + getting the provided key from the underlying data. + + NB. __getitem__ and __setitem__ implement xarray-style indexing, + where if keys are unlabeled arrays, we index the array orthogonally + with them. If keys are labeled array (such as Variables), they are + broadcasted with our usual scheme and then the array is indexed with + the broadcasted key, like numpy's fancy indexing. + + If you really want to do indexing like `x[x > 0]`, manipulate the numpy + array `x.values` directly. + """ + dims, indexer, new_order = self._broadcast_indexes(key) + indexable = as_indexable(self._data) + + data = indexing.apply_indexer(indexable, indexer) + + if new_order: + data = np.moveaxis(data, range(len(new_order)), new_order) + return self._finalize_indexing_result(dims, data) + + def _finalize_indexing_result(self, dims, data) -> Self: + """Used by IndexVariable to return IndexVariable objects when possible.""" + return self._replace(dims=dims, data=data) + + def _getitem_with_mask(self, key, fill_value=dtypes.NA): + """Index this Variable with -1 remapped to fill_value.""" + # TODO(shoyer): expose this method in public API somewhere (isel?) and + # use it for reindex. + # TODO(shoyer): add a sanity check that all other integers are + # non-negative + # TODO(shoyer): add an optimization, remapping -1 to an adjacent value + # that is actually indexed rather than mapping it to the last value + # along each axis. + + if fill_value is dtypes.NA: + fill_value = dtypes.get_fill_value(self.dtype) + + dims, indexer, new_order = self._broadcast_indexes(key) + + if self.size: + + if is_duck_dask_array(self._data): + # dask's indexing is faster this way; also vindex does not + # support negative indices yet: + # https://github.com/dask/dask/pull/2967 + actual_indexer = indexing.posify_mask_indexer(indexer) + else: + actual_indexer = indexer + + indexable = as_indexable(self._data) + data = indexing.apply_indexer(indexable, actual_indexer) + + mask = indexing.create_mask(indexer, self.shape, data) + # we need to invert the mask in order to pass data first. This helps + # pint to choose the correct unit + # TODO: revert after https://github.com/hgrecco/pint/issues/1019 is fixed + data = duck_array_ops.where(np.logical_not(mask), data, fill_value) + else: + # array cannot be indexed along dimensions of size 0, so just + # build the mask directly instead. + mask = indexing.create_mask(indexer, self.shape) + data = np.broadcast_to(fill_value, getattr(mask, "shape", ())) + + if new_order: + data = duck_array_ops.moveaxis(data, range(len(new_order)), new_order) + return self._finalize_indexing_result(dims, data) + + def __setitem__(self, key, value): + """__setitem__ is overloaded to access the underlying numpy values with + orthogonal indexing. + + See __getitem__ for more details. + """ + dims, index_tuple, new_order = self._broadcast_indexes(key) + + if not isinstance(value, Variable): + value = as_compatible_data(value) + if value.ndim > len(dims): + raise ValueError( + f"shape mismatch: value array of shape {value.shape} could not be " + f"broadcast to indexing result with {len(dims)} dimensions" + ) + if value.ndim == 0: + value = Variable((), value) + else: + value = Variable(dims[-value.ndim :], value) + # broadcast to become assignable + value = value.set_dims(dims).data + + if new_order: + value = duck_array_ops.asarray(value) + value = value[(len(dims) - value.ndim) * (np.newaxis,) + (Ellipsis,)] + value = np.moveaxis(value, new_order, range(len(new_order))) + + indexable = as_indexable(self._data) + indexing.set_with_indexer(indexable, index_tuple, value) + + @property + def encoding(self) -> dict[Any, Any]: + """Dictionary of encodings on this variable.""" + if self._encoding is None: + self._encoding = {} + return self._encoding + + @encoding.setter + def encoding(self, value): + try: + self._encoding = dict(value) + except ValueError: + raise ValueError("encoding must be castable to a dictionary") + + def reset_encoding(self) -> Self: + warnings.warn( + "reset_encoding is deprecated since 2023.11, use `drop_encoding` instead" + ) + return self.drop_encoding() + + def drop_encoding(self) -> Self: + """Return a new Variable without encoding.""" + return self._replace(encoding={}) + + def _copy( + self, + deep: bool = True, + data: T_DuckArray | ArrayLike | None = None, + memo: dict[int, Any] | None = None, + ) -> Self: + if data is None: + data_old = self._data + + if not isinstance(data_old, indexing.MemoryCachedArray): + ndata = data_old + else: + # don't share caching between copies + # TODO: MemoryCachedArray doesn't match the array api: + ndata = indexing.MemoryCachedArray(data_old.array) # type: ignore[assignment] + + if deep: + ndata = copy.deepcopy(ndata, memo) + + else: + ndata = as_compatible_data(data) + if self.shape != ndata.shape: # type: ignore[attr-defined] + raise ValueError( + f"Data shape {ndata.shape} must match shape of object {self.shape}" # type: ignore[attr-defined] + ) + + attrs = copy.deepcopy(self._attrs, memo) if deep else copy.copy(self._attrs) + encoding = ( + copy.deepcopy(self._encoding, memo) if deep else copy.copy(self._encoding) + ) + + # note: dims is already an immutable tuple + return self._replace(data=ndata, attrs=attrs, encoding=encoding) + + def _replace( + self, + dims=_default, + data=_default, + attrs=_default, + encoding=_default, + ) -> Self: + if dims is _default: + dims = copy.copy(self._dims) + if data is _default: + data = copy.copy(self.data) + if attrs is _default: + attrs = copy.copy(self._attrs) + + if encoding is _default: + encoding = copy.copy(self._encoding) + return type(self)(dims, data, attrs, encoding, fastpath=True) + + def load(self, **kwargs): + """Manually trigger loading of this variable's data from disk or a + remote source into memory and return this variable. + + Normally, it should not be necessary to call this method in user code, + because all xarray functions should either work on deferred data or + load data automatically. + + Parameters + ---------- + **kwargs : dict + Additional keyword arguments passed on to ``dask.array.compute``. + + See Also + -------- + dask.array.compute + """ + self._data = to_duck_array(self._data, **kwargs) + return self + + def compute(self, **kwargs): + """Manually trigger loading of this variable's data from disk or a + remote source into memory and return a new variable. The original is + left unaltered. + + Normally, it should not be necessary to call this method in user code, + because all xarray functions should either work on deferred data or + load data automatically. + + Parameters + ---------- + **kwargs : dict + Additional keyword arguments passed on to ``dask.array.compute``. + + See Also + -------- + dask.array.compute + """ + new = self.copy(deep=False) + return new.load(**kwargs) + + def isel( + self, + indexers: Mapping[Any, Any] | None = None, + missing_dims: ErrorOptionsWithWarn = "raise", + **indexers_kwargs: Any, + ) -> Self: + """Return a new array indexed along the specified dimension(s). + + Parameters + ---------- + **indexers : {dim: indexer, ...} + Keyword arguments with names matching dimensions and values given + by integers, slice objects or arrays. + missing_dims : {"raise", "warn", "ignore"}, default: "raise" + What to do if dimensions that should be selected from are not present in the + DataArray: + - "raise": raise an exception + - "warn": raise a warning, and ignore the missing dimensions + - "ignore": ignore the missing dimensions + + Returns + ------- + obj : Array object + A new Array with the selected data and dimensions. In general, + the new variable's data will be a view of this variable's data, + unless numpy fancy indexing was triggered by using an array + indexer, in which case the data will be a copy. + """ + indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel") + + indexers = drop_dims_from_indexers(indexers, self.dims, missing_dims) + + key = tuple(indexers.get(dim, slice(None)) for dim in self.dims) + return self[key] + + def squeeze(self, dim=None): + """Return a new object with squeezed data. + + Parameters + ---------- + dim : None or str or tuple of str, optional + Selects a subset of the length one dimensions. If a dimension is + selected with length greater than one, an error is raised. If + None, all length one dimensions are squeezed. + + Returns + ------- + squeezed : same type as caller + This object, but with with all or a subset of the dimensions of + length 1 removed. + + See Also + -------- + numpy.squeeze + """ + dims = common.get_squeeze_dims(self, dim) + return self.isel({d: 0 for d in dims}) + + def _shift_one_dim(self, dim, count, fill_value=dtypes.NA): + axis = self.get_axis_num(dim) + + if count > 0: + keep = slice(None, -count) + elif count < 0: + keep = slice(-count, None) + else: + keep = slice(None) + + trimmed_data = self[(slice(None),) * axis + (keep,)].data + + if fill_value is dtypes.NA: + dtype, fill_value = dtypes.maybe_promote(self.dtype) + else: + dtype = self.dtype + + width = min(abs(count), self.shape[axis]) + dim_pad = (width, 0) if count >= 0 else (0, width) + pads = [(0, 0) if d != dim else dim_pad for d in self.dims] + + data = np.pad( + duck_array_ops.astype(trimmed_data, dtype), + pads, + mode="constant", + constant_values=fill_value, + ) + + if is_duck_dask_array(data): + # chunked data should come out with the same chunks; this makes + # it feasible to combine shifted and unshifted data + # TODO: remove this once dask.array automatically aligns chunks + data = data.rechunk(self.data.chunks) + + return self._replace(data=data) + + def shift(self, shifts=None, fill_value=dtypes.NA, **shifts_kwargs): + """ + Return a new Variable with shifted data. + + Parameters + ---------- + shifts : mapping of the form {dim: offset} + Integer offset to shift along each of the given dimensions. + Positive offsets shift to the right; negative offsets shift to the + left. + fill_value : scalar, optional + Value to use for newly missing values + **shifts_kwargs + The keyword arguments form of ``shifts``. + One of shifts or shifts_kwargs must be provided. + + Returns + ------- + shifted : Variable + Variable with the same dimensions and attributes but shifted data. + """ + shifts = either_dict_or_kwargs(shifts, shifts_kwargs, "shift") + result = self + for dim, count in shifts.items(): + result = result._shift_one_dim(dim, count, fill_value=fill_value) + return result + + def _pad_options_dim_to_index( + self, + pad_option: Mapping[Any, int | tuple[int, int]], + fill_with_shape=False, + ): + if fill_with_shape: + return [ + (n, n) if d not in pad_option else pad_option[d] + for d, n in zip(self.dims, self.data.shape) + ] + return [(0, 0) if d not in pad_option else pad_option[d] for d in self.dims] + + def pad( + self, + pad_width: Mapping[Any, int | tuple[int, int]] | None = None, + mode: PadModeOptions = "constant", + stat_length: ( + int | tuple[int, int] | Mapping[Any, tuple[int, int]] | None + ) = None, + constant_values: ( + float | tuple[float, float] | Mapping[Any, tuple[float, float]] | None + ) = None, + end_values: int | tuple[int, int] | Mapping[Any, tuple[int, int]] | None = None, + reflect_type: PadReflectOptions = None, + keep_attrs: bool | None = None, + **pad_width_kwargs: Any, + ): + """ + Return a new Variable with padded data. + + Parameters + ---------- + pad_width : mapping of hashable to tuple of int + Mapping with the form of {dim: (pad_before, pad_after)} + describing the number of values padded along each dimension. + {dim: pad} is a shortcut for pad_before = pad_after = pad + mode : str, default: "constant" + See numpy / Dask docs + stat_length : int, tuple or mapping of hashable to tuple + Used in 'maximum', 'mean', 'median', and 'minimum'. Number of + values at edge of each axis used to calculate the statistic value. + constant_values : scalar, tuple or mapping of hashable to tuple + Used in 'constant'. The values to set the padded values for each + axis. + end_values : scalar, tuple or mapping of hashable to tuple + Used in 'linear_ramp'. The values used for the ending value of the + linear_ramp and that will form the edge of the padded array. + reflect_type : {"even", "odd"}, optional + Used in "reflect", and "symmetric". The "even" style is the + default with an unaltered reflection around the edge value. For + the "odd" style, the extended part of the array is created by + subtracting the reflected values from two times the edge value. + keep_attrs : bool, optional + If True, the variable's attributes (`attrs`) will be copied from + the original object to the new one. If False (default), the new + object will be returned without attributes. + **pad_width_kwargs + One of pad_width or pad_width_kwargs must be provided. + + Returns + ------- + padded : Variable + Variable with the same dimensions and attributes but padded data. + """ + pad_width = either_dict_or_kwargs(pad_width, pad_width_kwargs, "pad") + + # change default behaviour of pad with mode constant + if mode == "constant" and ( + constant_values is None or constant_values is dtypes.NA + ): + dtype, constant_values = dtypes.maybe_promote(self.dtype) + else: + dtype = self.dtype + + # create pad_options_kwargs, numpy requires only relevant kwargs to be nonempty + if isinstance(stat_length, dict): + stat_length = self._pad_options_dim_to_index( + stat_length, fill_with_shape=True + ) + if isinstance(constant_values, dict): + constant_values = self._pad_options_dim_to_index(constant_values) + if isinstance(end_values, dict): + end_values = self._pad_options_dim_to_index(end_values) + + # workaround for bug in Dask's default value of stat_length https://github.com/dask/dask/issues/5303 + if stat_length is None and mode in ["maximum", "mean", "median", "minimum"]: + stat_length = [(n, n) for n in self.data.shape] # type: ignore[assignment] + + # change integer values to a tuple of two of those values and change pad_width to index + for k, v in pad_width.items(): + if isinstance(v, numbers.Number): + pad_width[k] = (v, v) + pad_width_by_index = self._pad_options_dim_to_index(pad_width) + + # create pad_options_kwargs, numpy/dask requires only relevant kwargs to be nonempty + pad_option_kwargs: dict[str, Any] = {} + if stat_length is not None: + pad_option_kwargs["stat_length"] = stat_length + if constant_values is not None: + pad_option_kwargs["constant_values"] = constant_values + if end_values is not None: + pad_option_kwargs["end_values"] = end_values + if reflect_type is not None: + pad_option_kwargs["reflect_type"] = reflect_type + + array = np.pad( + duck_array_ops.astype(self.data, dtype, copy=False), + pad_width_by_index, + mode=mode, + **pad_option_kwargs, + ) + + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=True) + attrs = self._attrs if keep_attrs else None + + return type(self)(self.dims, array, attrs=attrs) + + def _roll_one_dim(self, dim, count): + axis = self.get_axis_num(dim) + + count %= self.shape[axis] + if count != 0: + indices = [slice(-count, None), slice(None, -count)] + else: + indices = [slice(None)] + + arrays = [self[(slice(None),) * axis + (idx,)].data for idx in indices] + + data = duck_array_ops.concatenate(arrays, axis) + + if is_duck_dask_array(data): + # chunked data should come out with the same chunks; this makes + # it feasible to combine shifted and unshifted data + # TODO: remove this once dask.array automatically aligns chunks + data = data.rechunk(self.data.chunks) + + return self._replace(data=data) + + def roll(self, shifts=None, **shifts_kwargs): + """ + Return a new Variable with rolld data. + + Parameters + ---------- + shifts : mapping of hashable to int + Integer offset to roll along each of the given dimensions. + Positive offsets roll to the right; negative offsets roll to the + left. + **shifts_kwargs + The keyword arguments form of ``shifts``. + One of shifts or shifts_kwargs must be provided. + + Returns + ------- + shifted : Variable + Variable with the same dimensions and attributes but rolled data. + """ + shifts = either_dict_or_kwargs(shifts, shifts_kwargs, "roll") + + result = self + for dim, count in shifts.items(): + result = result._roll_one_dim(dim, count) + return result + + @deprecate_dims + def transpose( + self, + *dim: Hashable | ellipsis, + missing_dims: ErrorOptionsWithWarn = "raise", + ) -> Self: + """Return a new Variable object with transposed dimensions. + + Parameters + ---------- + *dim : Hashable, optional + By default, reverse the dimensions. Otherwise, reorder the + dimensions to this order. + missing_dims : {"raise", "warn", "ignore"}, default: "raise" + What to do if dimensions that should be selected from are not present in the + Variable: + - "raise": raise an exception + - "warn": raise a warning, and ignore the missing dimensions + - "ignore": ignore the missing dimensions + + Returns + ------- + transposed : Variable + The returned object has transposed data and dimensions with the + same attributes as the original. + + Notes + ----- + This operation returns a view of this variable's data. It is + lazy for dask-backed Variables but not for numpy-backed Variables. + + See Also + -------- + numpy.transpose + """ + if len(dim) == 0: + dim = self.dims[::-1] + else: + dim = tuple(infix_dims(dim, self.dims, missing_dims)) + + if len(dim) < 2 or dim == self.dims: + # no need to transpose if only one dimension + # or dims are in same order + return self.copy(deep=False) + + axes = self.get_axis_num(dim) + data = as_indexable(self._data).transpose(axes) + return self._replace(dims=dim, data=data) + + @property + def T(self) -> Self: + return self.transpose() + + @deprecate_dims + def set_dims(self, dim, shape=None): + """Return a new variable with given set of dimensions. + This method might be used to attach new dimension(s) to variable. + + When possible, this operation does not copy this variable's data. + + Parameters + ---------- + dim : str or sequence of str or dict + Dimensions to include on the new variable. If a dict, values are + used to provide the sizes of new dimensions; otherwise, new + dimensions are inserted with length 1. + + Returns + ------- + Variable + """ + if isinstance(dim, str): + dim = [dim] + + if shape is None and is_dict_like(dim): + shape = dim.values() + + missing_dims = set(self.dims) - set(dim) + if missing_dims: + raise ValueError( + f"new dimensions {dim!r} must be a superset of " + f"existing dimensions {self.dims!r}" + ) + + self_dims = set(self.dims) + expanded_dims = tuple(d for d in dim if d not in self_dims) + self.dims + + if self.dims == expanded_dims: + # don't use broadcast_to unless necessary so the result remains + # writeable if possible + expanded_data = self.data + elif shape is not None: + dims_map = dict(zip(dim, shape)) + tmp_shape = tuple(dims_map[d] for d in expanded_dims) + expanded_data = duck_array_ops.broadcast_to(self.data, tmp_shape) + else: + indexer = (None,) * (len(expanded_dims) - self.ndim) + (...,) + expanded_data = self.data[indexer] + + expanded_var = Variable( + expanded_dims, expanded_data, self._attrs, self._encoding, fastpath=True + ) + return expanded_var.transpose(*dim) + + def _stack_once(self, dim: list[Hashable], new_dim: Hashable): + if not set(dim) <= set(self.dims): + raise ValueError(f"invalid existing dimensions: {dim}") + + if new_dim in self.dims: + raise ValueError( + "cannot create a new dimension with the same " + "name as an existing dimension" + ) + + if len(dim) == 0: + # don't stack + return self.copy(deep=False) + + other_dims = [d for d in self.dims if d not in dim] + dim_order = other_dims + list(dim) + reordered = self.transpose(*dim_order) + + new_shape = reordered.shape[: len(other_dims)] + (-1,) + new_data = duck_array_ops.reshape(reordered.data, new_shape) + new_dims = reordered.dims[: len(other_dims)] + (new_dim,) + + return type(self)( + new_dims, new_data, self._attrs, self._encoding, fastpath=True + ) + + @partial(deprecate_dims, old_name="dimensions") + def stack(self, dim=None, **dim_kwargs): + """ + Stack any number of existing dim into a single new dimension. + + New dim will be added at the end, and the order of the data + along each new dimension will be in contiguous (C) order. + + Parameters + ---------- + dim : mapping of hashable to tuple of hashable + Mapping of form new_name=(dim1, dim2, ...) describing the + names of new dim, and the existing dim that + they replace. + **dim_kwargs + The keyword arguments form of ``dim``. + One of dim or dim_kwargs must be provided. + + Returns + ------- + stacked : Variable + Variable with the same attributes but stacked data. + + See Also + -------- + Variable.unstack + """ + dim = either_dict_or_kwargs(dim, dim_kwargs, "stack") + result = self + for new_dim, dims in dim.items(): + result = result._stack_once(dims, new_dim) + return result + + def _unstack_once_full(self, dim: Mapping[Any, int], old_dim: Hashable) -> Self: + """ + Unstacks the variable without needing an index. + + Unlike `_unstack_once`, this function requires the existing dimension to + contain the full product of the new dimensions. + """ + new_dim_names = tuple(dim.keys()) + new_dim_sizes = tuple(dim.values()) + + if old_dim not in self.dims: + raise ValueError(f"invalid existing dimension: {old_dim}") + + if set(new_dim_names).intersection(self.dims): + raise ValueError( + "cannot create a new dimension with the same " + "name as an existing dimension" + ) + + if math.prod(new_dim_sizes) != self.sizes[old_dim]: + raise ValueError( + "the product of the new dimension sizes must " + "equal the size of the old dimension" + ) + + other_dims = [d for d in self.dims if d != old_dim] + dim_order = other_dims + [old_dim] + reordered = self.transpose(*dim_order) + + new_shape = reordered.shape[: len(other_dims)] + new_dim_sizes + new_data = duck_array_ops.reshape(reordered.data, new_shape) + new_dims = reordered.dims[: len(other_dims)] + new_dim_names + + return type(self)( + new_dims, new_data, self._attrs, self._encoding, fastpath=True + ) + + def _unstack_once( + self, + index: pd.MultiIndex, + dim: Hashable, + fill_value=dtypes.NA, + sparse: bool = False, + ) -> Self: + """ + Unstacks this variable given an index to unstack and the name of the + dimension to which the index refers. + """ + + reordered = self.transpose(..., dim) + + new_dim_sizes = [lev.size for lev in index.levels] + new_dim_names = index.names + indexer = index.codes + + # Potentially we could replace `len(other_dims)` with just `-1` + other_dims = [d for d in self.dims if d != dim] + new_shape = tuple(list(reordered.shape[: len(other_dims)]) + new_dim_sizes) + new_dims = reordered.dims[: len(other_dims)] + new_dim_names + + create_template: Callable + if fill_value is dtypes.NA: + is_missing_values = math.prod(new_shape) > math.prod(self.shape) + if is_missing_values: + dtype, fill_value = dtypes.maybe_promote(self.dtype) + + create_template = partial(np.full_like, fill_value=fill_value) + else: + dtype = self.dtype + fill_value = dtypes.get_fill_value(dtype) + create_template = np.empty_like + else: + dtype = self.dtype + create_template = partial(np.full_like, fill_value=fill_value) + + if sparse: + # unstacking a dense multitindexed array to a sparse array + from sparse import COO + + codes = zip(*index.codes) + if reordered.ndim == 1: + indexes = codes + else: + sizes = itertools.product(*[range(s) for s in reordered.shape[:-1]]) + tuple_indexes = itertools.product(sizes, codes) + indexes = map(lambda x: list(itertools.chain(*x)), tuple_indexes) # type: ignore + + data = COO( + coords=np.array(list(indexes)).T, + data=self.data.astype(dtype).ravel(), + fill_value=fill_value, + shape=new_shape, + sorted=index.is_monotonic_increasing, + ) + + else: + data = create_template(self.data, shape=new_shape, dtype=dtype) + + # Indexer is a list of lists of locations. Each list is the locations + # on the new dimension. This is robust to the data being sparse; in that + # case the destinations will be NaN / zero. + data[(..., *indexer)] = reordered + + return self._replace(dims=new_dims, data=data) + + @partial(deprecate_dims, old_name="dimensions") + def unstack(self, dim=None, **dim_kwargs): + """ + Unstack an existing dimension into multiple new dimensions. + + New dimensions will be added at the end, and the order of the data + along each new dimension will be in contiguous (C) order. + + Note that unlike ``DataArray.unstack`` and ``Dataset.unstack``, this + method requires the existing dimension to contain the full product of + the new dimensions. + + Parameters + ---------- + dim : mapping of hashable to mapping of hashable to int + Mapping of the form old_dim={dim1: size1, ...} describing the + names of existing dimensions, and the new dimensions and sizes + that they map to. + **dim_kwargs + The keyword arguments form of ``dim``. + One of dim or dim_kwargs must be provided. + + Returns + ------- + unstacked : Variable + Variable with the same attributes but unstacked data. + + See Also + -------- + Variable.stack + DataArray.unstack + Dataset.unstack + """ + dim = either_dict_or_kwargs(dim, dim_kwargs, "unstack") + result = self + for old_dim, dims in dim.items(): + result = result._unstack_once_full(dims, old_dim) + return result + + def fillna(self, value): + return ops.fillna(self, value) + + def where(self, cond, other=dtypes.NA): + return ops.where_method(self, cond, other) + + def clip(self, min=None, max=None): + """ + Return an array whose values are limited to ``[min, max]``. + At least one of max or min must be given. + + Refer to `numpy.clip` for full documentation. + + See Also + -------- + numpy.clip : equivalent function + """ + from xarray.core.computation import apply_ufunc + + return apply_ufunc(np.clip, self, min, max, dask="allowed") + + def reduce( # type: ignore[override] + self, + func: Callable[..., Any], + dim: Dims = None, + axis: int | Sequence[int] | None = None, + keep_attrs: bool | None = None, + keepdims: bool = False, + **kwargs, + ) -> Variable: + """Reduce this array by applying `func` along some dimension(s). + + Parameters + ---------- + func : callable + Function which can be called in the form + `func(x, axis=axis, **kwargs)` to return the result of reducing an + np.ndarray over an integer valued axis. + dim : "...", str, Iterable of Hashable or None, optional + Dimension(s) over which to apply `func`. By default `func` is + applied over all dimensions. + axis : int or Sequence of int, optional + Axis(es) over which to apply `func`. Only one of the 'dim' + and 'axis' arguments can be supplied. If neither are supplied, then + the reduction is calculated over the flattened array (by calling + `func(x)` without an axis argument). + keep_attrs : bool, optional + If True, the variable's attributes (`attrs`) will be copied from + the original object to the new one. If False (default), the new + object will be returned without attributes. + keepdims : bool, default: False + If True, the dimensions which are reduced are left in the result + as dimensions of size one + **kwargs : dict + Additional keyword arguments passed on to `func`. + + Returns + ------- + reduced : Array + Array with summarized data and the indicated dimension(s) + removed. + """ + keep_attrs_ = ( + _get_keep_attrs(default=False) if keep_attrs is None else keep_attrs + ) + + # Noe that the call order for Variable.mean is + # Variable.mean -> NamedArray.mean -> Variable.reduce + # -> NamedArray.reduce + result = super().reduce( + func=func, dim=dim, axis=axis, keepdims=keepdims, **kwargs + ) + + # return Variable always to support IndexVariable + return Variable( + result.dims, result._data, attrs=result._attrs if keep_attrs_ else None + ) + + @classmethod + def concat( + cls, + variables, + dim="concat_dim", + positions=None, + shortcut=False, + combine_attrs="override", + ): + """Concatenate variables along a new or existing dimension. + + Parameters + ---------- + variables : iterable of Variable + Arrays to stack together. Each variable is expected to have + matching dimensions and shape except for along the stacked + dimension. + dim : str or DataArray, optional + Name of the dimension to stack along. This can either be a new + dimension name, in which case it is added along axis=0, or an + existing dimension name, in which case the location of the + dimension is unchanged. Where to insert the new dimension is + determined by the first variable. + positions : None or list of array-like, optional + List of integer arrays which specifies the integer positions to + which to assign each dataset along the concatenated dimension. + If not supplied, objects are concatenated in the provided order. + shortcut : bool, optional + This option is used internally to speed-up groupby operations. + If `shortcut` is True, some checks of internal consistency between + arrays to concatenate are skipped. + combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \ + "override"}, default: "override" + String indicating how to combine attrs of the objects being merged: + + - "drop": empty attrs on returned Dataset. + - "identical": all attrs must be the same on every object. + - "no_conflicts": attrs from all objects are combined, any that have + the same name must also have the same value. + - "drop_conflicts": attrs from all objects are combined, any that have + the same name but different values are dropped. + - "override": skip comparing and copy attrs from the first dataset to + the result. + + Returns + ------- + stacked : Variable + Concatenated Variable formed by stacking all the supplied variables + along the given dimension. + """ + from xarray.core.merge import merge_attrs + + if not isinstance(dim, str): + (dim,) = dim.dims + + # can't do this lazily: we need to loop through variables at least + # twice + variables = list(variables) + first_var = variables[0] + first_var_dims = first_var.dims + + arrays = [v._data for v in variables] + + if dim in first_var_dims: + axis = first_var.get_axis_num(dim) + dims = first_var_dims + data = duck_array_ops.concatenate(arrays, axis=axis) + if positions is not None: + # TODO: deprecate this option -- we don't need it for groupby + # any more. + indices = nputils.inverse_permutation(np.concatenate(positions)) + data = duck_array_ops.take(data, indices, axis=axis) + else: + axis = 0 + dims = (dim,) + first_var_dims + data = duck_array_ops.stack(arrays, axis=axis) + + attrs = merge_attrs( + [var.attrs for var in variables], combine_attrs=combine_attrs + ) + encoding = dict(first_var.encoding) + if not shortcut: + for var in variables: + if var.dims != first_var_dims: + raise ValueError( + f"Variable has dimensions {tuple(var.dims)} but first Variable has dimensions {tuple(first_var_dims)}" + ) + + return cls(dims, data, attrs, encoding, fastpath=True) + + def equals(self, other, equiv=duck_array_ops.array_equiv): + """True if two Variables have the same dimensions and values; + otherwise False. + + Variables can still be equal (like pandas objects) if they have NaN + values in the same locations. + + This method is necessary because `v1 == v2` for Variables + does element-wise comparisons (like numpy.ndarrays). + """ + other = getattr(other, "variable", other) + try: + return self.dims == other.dims and ( + self._data is other._data or equiv(self.data, other.data) + ) + except (TypeError, AttributeError): + return False + + def broadcast_equals(self, other, equiv=duck_array_ops.array_equiv): + """True if two Variables have the values after being broadcast against + each other; otherwise False. + + Variables can still be equal (like pandas objects) if they have NaN + values in the same locations. + """ + try: + self, other = broadcast_variables(self, other) + except (ValueError, AttributeError): + return False + return self.equals(other, equiv=equiv) + + def identical(self, other, equiv=duck_array_ops.array_equiv): + """Like equals, but also checks attributes.""" + try: + return utils.dict_equiv(self.attrs, other.attrs) and self.equals( + other, equiv=equiv + ) + except (TypeError, AttributeError): + return False + + def no_conflicts(self, other, equiv=duck_array_ops.array_notnull_equiv): + """True if the intersection of two Variable's non-null data is + equal; otherwise false. + + Variables can thus still be equal if there are locations where either, + or both, contain NaN values. + """ + return self.broadcast_equals(other, equiv=equiv) + + def quantile( + self, + q: ArrayLike, + dim: str | Sequence[Hashable] | None = None, + method: QuantileMethods = "linear", + keep_attrs: bool | None = None, + skipna: bool | None = None, + interpolation: QuantileMethods | None = None, + ) -> Self: + """Compute the qth quantile of the data along the specified dimension. + + Returns the qth quantiles(s) of the array elements. + + Parameters + ---------- + q : float or sequence of float + Quantile to compute, which must be between 0 and 1 + inclusive. + dim : str or sequence of str, optional + Dimension(s) over which to apply quantile. + method : str, default: "linear" + This optional parameter specifies the interpolation method to use when the + desired quantile lies between two data points. The options sorted by their R + type as summarized in the H&F paper [1]_ are: + + 1. "inverted_cdf" + 2. "averaged_inverted_cdf" + 3. "closest_observation" + 4. "interpolated_inverted_cdf" + 5. "hazen" + 6. "weibull" + 7. "linear" (default) + 8. "median_unbiased" + 9. "normal_unbiased" + + The first three methods are discontiuous. The following discontinuous + variations of the default "linear" (7.) option are also available: + + * "lower" + * "higher" + * "midpoint" + * "nearest" + + See :py:func:`numpy.quantile` or [1]_ for details. The "method" argument + was previously called "interpolation", renamed in accordance with numpy + version 1.22.0. + + keep_attrs : bool, optional + If True, the variable's attributes (`attrs`) will be copied from + the original object to the new one. If False (default), the new + object will be returned without attributes. + skipna : bool, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + + Returns + ------- + quantiles : Variable + If `q` is a single quantile, then the result + is a scalar. If multiple percentiles are given, first axis of + the result corresponds to the quantile and a quantile dimension + is added to the return array. The other dimensions are the + dimensions that remain after the reduction of the array. + + See Also + -------- + numpy.nanquantile, pandas.Series.quantile, Dataset.quantile + DataArray.quantile + + References + ---------- + .. [1] R. J. Hyndman and Y. Fan, + "Sample quantiles in statistical packages," + The American Statistician, 50(4), pp. 361-365, 1996 + """ + + from xarray.core.computation import apply_ufunc + + if interpolation is not None: + warnings.warn( + "The `interpolation` argument to quantile was renamed to `method`.", + FutureWarning, + ) + + if method != "linear": + raise TypeError("Cannot pass interpolation and method keywords!") + + method = interpolation + + if skipna or (skipna is None and self.dtype.kind in "cfO"): + _quantile_func = nputils.nanquantile + else: + _quantile_func = np.quantile + + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=False) + + scalar = utils.is_scalar(q) + q = np.atleast_1d(np.asarray(q, dtype=np.float64)) + + if dim is None: + dim = self.dims + + if utils.is_scalar(dim): + dim = [dim] + + def _wrapper(npa, **kwargs): + # move quantile axis to end. required for apply_ufunc + return np.moveaxis(_quantile_func(npa, **kwargs), 0, -1) + + axis = np.arange(-1, -1 * len(dim) - 1, -1) + + kwargs = {"q": q, "axis": axis, "method": method} + + result = apply_ufunc( + _wrapper, + self, + input_core_dims=[dim], + exclude_dims=set(dim), + output_core_dims=[["quantile"]], + output_dtypes=[np.float64], + dask_gufunc_kwargs=dict(output_sizes={"quantile": len(q)}), + dask="parallelized", + kwargs=kwargs, + ) + + # for backward compatibility + result = result.transpose("quantile", ...) + if scalar: + result = result.squeeze("quantile") + if keep_attrs: + result.attrs = self._attrs + return result + + def rank(self, dim, pct=False): + """Ranks the data. + + Equal values are assigned a rank that is the average of the ranks that + would have been otherwise assigned to all of the values within that + set. Ranks begin at 1, not 0. If `pct`, computes percentage ranks. + + NaNs in the input array are returned as NaNs. + + The `bottleneck` library is required. + + Parameters + ---------- + dim : str + Dimension over which to compute rank. + pct : bool, optional + If True, compute percentage ranks, otherwise compute integer ranks. + + Returns + ------- + ranked : Variable + + See Also + -------- + Dataset.rank, DataArray.rank + """ + # This could / should arguably be implemented at the DataArray & Dataset level + if not OPTIONS["use_bottleneck"]: + raise RuntimeError( + "rank requires bottleneck to be enabled." + " Call `xr.set_options(use_bottleneck=True)` to enable it." + ) + + import bottleneck as bn + + func = bn.nanrankdata if self.dtype.kind == "f" else bn.rankdata + ranked = xr.apply_ufunc( + func, + self, + input_core_dims=[[dim]], + output_core_dims=[[dim]], + dask="parallelized", + kwargs=dict(axis=-1), + ).transpose(*self.dims) + + if pct: + count = self.notnull().sum(dim) + ranked /= count + return ranked + + def rolling_window( + self, dim, window, window_dim, center=False, fill_value=dtypes.NA + ): + """ + Make a rolling_window along dim and add a new_dim to the last place. + + Parameters + ---------- + dim : str + Dimension over which to compute rolling_window. + For nd-rolling, should be list of dimensions. + window : int + Window size of the rolling + For nd-rolling, should be list of integers. + window_dim : str + New name of the window dimension. + For nd-rolling, should be list of strings. + center : bool, default: False + If True, pad fill_value for both ends. Otherwise, pad in the head + of the axis. + fill_value + value to be filled. + + Returns + ------- + Variable that is a view of the original array with a added dimension of + size w. + The return dim: self.dims + (window_dim, ) + The return shape: self.shape + (window, ) + + Examples + -------- + >>> v = Variable(("a", "b"), np.arange(8).reshape((2, 4))) + >>> v.rolling_window("b", 3, "window_dim") + Size: 192B + array([[[nan, nan, 0.], + [nan, 0., 1.], + [ 0., 1., 2.], + [ 1., 2., 3.]], + + [[nan, nan, 4.], + [nan, 4., 5.], + [ 4., 5., 6.], + [ 5., 6., 7.]]]) + + >>> v.rolling_window("b", 3, "window_dim", center=True) + Size: 192B + array([[[nan, 0., 1.], + [ 0., 1., 2.], + [ 1., 2., 3.], + [ 2., 3., nan]], + + [[nan, 4., 5.], + [ 4., 5., 6.], + [ 5., 6., 7.], + [ 6., 7., nan]]]) + """ + if fill_value is dtypes.NA: # np.nan is passed + dtype, fill_value = dtypes.maybe_promote(self.dtype) + var = duck_array_ops.astype(self, dtype, copy=False) + else: + dtype = self.dtype + var = self + + if utils.is_scalar(dim): + for name, arg in zip( + ["window", "window_dim", "center"], [window, window_dim, center] + ): + if not utils.is_scalar(arg): + raise ValueError( + f"Expected {name}={arg!r} to be a scalar like 'dim'." + ) + dim = (dim,) + + # dim is now a list + nroll = len(dim) + if utils.is_scalar(window): + window = [window] * nroll + if utils.is_scalar(window_dim): + window_dim = [window_dim] * nroll + if utils.is_scalar(center): + center = [center] * nroll + if ( + len(dim) != len(window) + or len(dim) != len(window_dim) + or len(dim) != len(center) + ): + raise ValueError( + "'dim', 'window', 'window_dim', and 'center' must be the same length. " + f"Received dim={dim!r}, window={window!r}, window_dim={window_dim!r}," + f" and center={center!r}." + ) + + pads = {} + for d, win, cent in zip(dim, window, center): + if cent: + start = win // 2 # 10 -> 5, 9 -> 4 + end = win - 1 - start + pads[d] = (start, end) + else: + pads[d] = (win - 1, 0) + + padded = var.pad(pads, mode="constant", constant_values=fill_value) + axis = self.get_axis_num(dim) + new_dims = self.dims + tuple(window_dim) + return Variable( + new_dims, + duck_array_ops.sliding_window_view( + padded.data, window_shape=window, axis=axis + ), + ) + + def coarsen( + self, windows, func, boundary="exact", side="left", keep_attrs=None, **kwargs + ): + """ + Apply reduction function. + """ + windows = {k: v for k, v in windows.items() if k in self.dims} + + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=True) + + if keep_attrs: + _attrs = self.attrs + else: + _attrs = None + + if not windows: + return self._replace(attrs=_attrs) + + reshaped, axes = self.coarsen_reshape(windows, boundary, side) + if isinstance(func, str): + name = func + func = getattr(duck_array_ops, name, None) + if func is None: + raise NameError(f"{name} is not a valid method.") + + return self._replace(data=func(reshaped, axis=axes, **kwargs), attrs=_attrs) + + def coarsen_reshape(self, windows, boundary, side): + """ + Construct a reshaped-array for coarsen + """ + if not is_dict_like(boundary): + boundary = {d: boundary for d in windows.keys()} + + if not is_dict_like(side): + side = {d: side for d in windows.keys()} + + # remove unrelated dimensions + boundary = {k: v for k, v in boundary.items() if k in windows} + side = {k: v for k, v in side.items() if k in windows} + + for d, window in windows.items(): + if window <= 0: + raise ValueError( + f"window must be > 0. Given {window} for dimension {d}" + ) + + variable = self + for d, window in windows.items(): + # trim or pad the object + size = variable.shape[self._get_axis_num(d)] + n = int(size / window) + if boundary[d] == "exact": + if n * window != size: + raise ValueError( + f"Could not coarsen a dimension of size {size} with " + f"window {window} and boundary='exact'. Try a different 'boundary' option." + ) + elif boundary[d] == "trim": + if side[d] == "left": + variable = variable.isel({d: slice(0, window * n)}) + else: + excess = size - window * n + variable = variable.isel({d: slice(excess, None)}) + elif boundary[d] == "pad": # pad + pad = window * n - size + if pad < 0: + pad += window + if side[d] == "left": + pad_width = {d: (0, pad)} + else: + pad_width = {d: (pad, 0)} + variable = variable.pad(pad_width, mode="constant") + else: + raise TypeError( + f"{boundary[d]} is invalid for boundary. Valid option is 'exact', " + "'trim' and 'pad'" + ) + + shape = [] + axes = [] + axis_count = 0 + for i, d in enumerate(variable.dims): + if d in windows: + size = variable.shape[i] + shape.append(int(size / windows[d])) + shape.append(windows[d]) + axis_count += 1 + axes.append(i + axis_count) + else: + shape.append(variable.shape[i]) + + return duck_array_ops.reshape(variable.data, shape), tuple(axes) + + def isnull(self, keep_attrs: bool | None = None): + """Test each value in the array for whether it is a missing value. + + Returns + ------- + isnull : Variable + Same type and shape as object, but the dtype of the data is bool. + + See Also + -------- + pandas.isnull + + Examples + -------- + >>> var = xr.Variable("x", [1, np.nan, 3]) + >>> var + Size: 24B + array([ 1., nan, 3.]) + >>> var.isnull() + Size: 3B + array([False, True, False]) + """ + from xarray.core.computation import apply_ufunc + + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=False) + + return apply_ufunc( + duck_array_ops.isnull, + self, + dask="allowed", + keep_attrs=keep_attrs, + ) + + def notnull(self, keep_attrs: bool | None = None): + """Test each value in the array for whether it is not a missing value. + + Returns + ------- + notnull : Variable + Same type and shape as object, but the dtype of the data is bool. + + See Also + -------- + pandas.notnull + + Examples + -------- + >>> var = xr.Variable("x", [1, np.nan, 3]) + >>> var + Size: 24B + array([ 1., nan, 3.]) + >>> var.notnull() + Size: 3B + array([ True, False, True]) + """ + from xarray.core.computation import apply_ufunc + + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=False) + + return apply_ufunc( + duck_array_ops.notnull, + self, + dask="allowed", + keep_attrs=keep_attrs, + ) + + @property + def imag(self) -> Variable: + """ + The imaginary part of the variable. + + See Also + -------- + numpy.ndarray.imag + """ + return self._new(data=self.data.imag) + + @property + def real(self) -> Variable: + """ + The real part of the variable. + + See Also + -------- + numpy.ndarray.real + """ + return self._new(data=self.data.real) + + def __array_wrap__(self, obj, context=None): + return Variable(self.dims, obj) + + def _unary_op(self, f, *args, **kwargs): + keep_attrs = kwargs.pop("keep_attrs", None) + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=True) + with np.errstate(all="ignore"): + result = self.__array_wrap__(f(self.data, *args, **kwargs)) + if keep_attrs: + result.attrs = self.attrs + return result + + def _binary_op(self, other, f, reflexive=False): + if isinstance(other, (xr.DataArray, xr.Dataset)): + return NotImplemented + if reflexive and issubclass(type(self), type(other)): + other_data, self_data, dims = _broadcast_compat_data(other, self) + else: + self_data, other_data, dims = _broadcast_compat_data(self, other) + keep_attrs = _get_keep_attrs(default=False) + attrs = self._attrs if keep_attrs else None + with np.errstate(all="ignore"): + new_data = ( + f(self_data, other_data) if not reflexive else f(other_data, self_data) + ) + result = Variable(dims, new_data, attrs=attrs) + return result + + def _inplace_binary_op(self, other, f): + if isinstance(other, xr.Dataset): + raise TypeError("cannot add a Dataset to a Variable in-place") + self_data, other_data, dims = _broadcast_compat_data(self, other) + if dims != self.dims: + raise ValueError("dimensions cannot change for in-place operations") + with np.errstate(all="ignore"): + self.values = f(self_data, other_data) + return self + + def _to_numeric(self, offset=None, datetime_unit=None, dtype=float): + """A (private) method to convert datetime array to numeric dtype + See duck_array_ops.datetime_to_numeric + """ + numeric_array = duck_array_ops.datetime_to_numeric( + self.data, offset, datetime_unit, dtype + ) + return type(self)(self.dims, numeric_array, self._attrs) + + def _unravel_argminmax( + self, + argminmax: str, + dim: Dims, + axis: int | None, + keep_attrs: bool | None, + skipna: bool | None, + ) -> Variable | dict[Hashable, Variable]: + """Apply argmin or argmax over one or more dimensions, returning the result as a + dict of DataArray that can be passed directly to isel. + """ + if dim is None and axis is None: + warnings.warn( + "Behaviour of argmin/argmax with neither dim nor axis argument will " + "change to return a dict of indices of each dimension. To get a " + "single, flat index, please use np.argmin(da.data) or " + "np.argmax(da.data) instead of da.argmin() or da.argmax().", + DeprecationWarning, + stacklevel=3, + ) + + argminmax_func = getattr(duck_array_ops, argminmax) + + if dim is ...: + # In future, should do this also when (dim is None and axis is None) + dim = self.dims + if ( + dim is None + or axis is not None + or not isinstance(dim, Sequence) + or isinstance(dim, str) + ): + # Return int index if single dimension is passed, and is not part of a + # sequence + return self.reduce( + argminmax_func, dim=dim, axis=axis, keep_attrs=keep_attrs, skipna=skipna + ) + + # Get a name for the new dimension that does not conflict with any existing + # dimension + newdimname = "_unravel_argminmax_dim_0" + count = 1 + while newdimname in self.dims: + newdimname = f"_unravel_argminmax_dim_{count}" + count += 1 + + stacked = self.stack({newdimname: dim}) + + result_dims = stacked.dims[:-1] + reduce_shape = tuple(self.sizes[d] for d in dim) + + result_flat_indices = stacked.reduce(argminmax_func, axis=-1, skipna=skipna) + + result_unravelled_indices = duck_array_ops.unravel_index( + result_flat_indices.data, reduce_shape + ) + + result = { + d: Variable(dims=result_dims, data=i) + for d, i in zip(dim, result_unravelled_indices) + } + + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=False) + if keep_attrs: + for v in result.values(): + v.attrs = self.attrs + + return result + + def argmin( + self, + dim: Dims = None, + axis: int | None = None, + keep_attrs: bool | None = None, + skipna: bool | None = None, + ) -> Variable | dict[Hashable, Variable]: + """Index or indices of the minimum of the Variable over one or more dimensions. + If a sequence is passed to 'dim', then result returned as dict of Variables, + which can be passed directly to isel(). If a single str is passed to 'dim' then + returns a Variable with dtype int. + + If there are multiple minima, the indices of the first one found will be + returned. + + Parameters + ---------- + dim : "...", str, Iterable of Hashable or None, optional + The dimensions over which to find the minimum. By default, finds minimum over + all dimensions - for now returning an int for backward compatibility, but + this is deprecated, in future will return a dict with indices for all + dimensions; to return a dict with all dimensions now, pass '...'. + axis : int, optional + Axis over which to apply `argmin`. Only one of the 'dim' and 'axis' arguments + can be supplied. + keep_attrs : bool, optional + If True, the attributes (`attrs`) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + skipna : bool, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + + Returns + ------- + result : Variable or dict of Variable + + See Also + -------- + DataArray.argmin, DataArray.idxmin + """ + return self._unravel_argminmax("argmin", dim, axis, keep_attrs, skipna) + + def argmax( + self, + dim: Dims = None, + axis: int | None = None, + keep_attrs: bool | None = None, + skipna: bool | None = None, + ) -> Variable | dict[Hashable, Variable]: + """Index or indices of the maximum of the Variable over one or more dimensions. + If a sequence is passed to 'dim', then result returned as dict of Variables, + which can be passed directly to isel(). If a single str is passed to 'dim' then + returns a Variable with dtype int. + + If there are multiple maxima, the indices of the first one found will be + returned. + + Parameters + ---------- + dim : "...", str, Iterable of Hashable or None, optional + The dimensions over which to find the maximum. By default, finds maximum over + all dimensions - for now returning an int for backward compatibility, but + this is deprecated, in future will return a dict with indices for all + dimensions; to return a dict with all dimensions now, pass '...'. + axis : int, optional + Axis over which to apply `argmin`. Only one of the 'dim' and 'axis' arguments + can be supplied. + keep_attrs : bool, optional + If True, the attributes (`attrs`) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + skipna : bool, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + + Returns + ------- + result : Variable or dict of Variable + + See Also + -------- + DataArray.argmax, DataArray.idxmax + """ + return self._unravel_argminmax("argmax", dim, axis, keep_attrs, skipna) + + def _as_sparse(self, sparse_format=_default, fill_value=_default) -> Variable: + """ + Use sparse-array as backend. + """ + from xarray.namedarray._typing import _default as _default_named + + if sparse_format is _default: + sparse_format = _default_named + + if fill_value is _default: + fill_value = _default_named + + out = super()._as_sparse(sparse_format, fill_value) + return cast("Variable", out) + + def _to_dense(self) -> Variable: + """ + Change backend from sparse to np.array. + """ + out = super()._to_dense() + return cast("Variable", out) + + def chunk( # type: ignore[override] + self, + chunks: int | Literal["auto"] | Mapping[Any, None | int | tuple[int, ...]] = {}, + name: str | None = None, + lock: bool | None = None, + inline_array: bool | None = None, + chunked_array_type: str | ChunkManagerEntrypoint[Any] | None = None, + from_array_kwargs: Any = None, + **chunks_kwargs: Any, + ) -> Self: + """Coerce this array's data into a dask array with the given chunks. + + If this variable is a non-dask array, it will be converted to dask + array. If it's a dask array, it will be rechunked to the given chunk + sizes. + + If neither chunks is not provided for one or more dimensions, chunk + sizes along that dimension will not be updated; non-dask arrays will be + converted into dask arrays with a single block. + + Parameters + ---------- + chunks : int, tuple or dict, optional + Chunk sizes along each dimension, e.g., ``5``, ``(5, 5)`` or + ``{'x': 5, 'y': 5}``. + name : str, optional + Used to generate the name for this array in the internal dask + graph. Does not need not be unique. + lock : bool, default: False + Passed on to :py:func:`dask.array.from_array`, if the array is not + already as dask array. + inline_array : bool, default: False + Passed on to :py:func:`dask.array.from_array`, if the array is not + already as dask array. + chunked_array_type: str, optional + Which chunked array type to coerce this datasets' arrays to. + Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEntrypoint` system. + Experimental API that should not be relied upon. + from_array_kwargs: dict, optional + Additional keyword arguments passed on to the `ChunkManagerEntrypoint.from_array` method used to create + chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg. + For example, with dask as the default chunked array type, this method would pass additional kwargs + to :py:func:`dask.array.from_array`. Experimental API that should not be relied upon. + **chunks_kwargs : {dim: chunks, ...}, optional + The keyword arguments form of ``chunks``. + One of chunks or chunks_kwargs must be provided. + + Returns + ------- + chunked : xarray.Variable + + See Also + -------- + Variable.chunks + Variable.chunksizes + xarray.unify_chunks + dask.array.from_array + """ + + if is_extension_array_dtype(self): + raise ValueError( + f"{self} was found to be a Pandas ExtensionArray. Please convert to numpy first." + ) + + if from_array_kwargs is None: + from_array_kwargs = {} + + # TODO deprecate passing these dask-specific arguments explicitly. In future just pass everything via from_array_kwargs + _from_array_kwargs = consolidate_dask_from_array_kwargs( + from_array_kwargs, + name=name, + lock=lock, + inline_array=inline_array, + ) + + return super().chunk( + chunks=chunks, + chunked_array_type=chunked_array_type, + from_array_kwargs=_from_array_kwargs, + **chunks_kwargs, + ) + + +class IndexVariable(Variable): + """Wrapper for accommodating a pandas.Index in an xarray.Variable. + + IndexVariable preserve loaded values in the form of a pandas.Index instead + of a NumPy array. Hence, their values are immutable and must always be one- + dimensional. + + They also have a name property, which is the name of their sole dimension + unless another name is given. + """ + + __slots__ = () + + # TODO: PandasIndexingAdapter doesn't match the array api: + _data: PandasIndexingAdapter # type: ignore[assignment] + + def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False): + super().__init__(dims, data, attrs, encoding, fastpath) + if self.ndim != 1: + raise ValueError(f"{type(self).__name__} objects must be 1-dimensional") + + # Unlike in Variable, always eagerly load values into memory + if not isinstance(self._data, PandasIndexingAdapter): + self._data = PandasIndexingAdapter(self._data) + + def __dask_tokenize__(self) -> object: + from dask.base import normalize_token + + # Don't waste time converting pd.Index to np.ndarray + return normalize_token( + (type(self), self._dims, self._data.array, self._attrs or None) + ) + + def load(self): + # data is already loaded into memory for IndexVariable + return self + + # https://github.com/python/mypy/issues/1465 + @Variable.data.setter # type: ignore[attr-defined] + def data(self, data): + raise ValueError( + f"Cannot assign to the .data attribute of dimension coordinate a.k.a IndexVariable {self.name!r}. " + f"Please use DataArray.assign_coords, Dataset.assign_coords or Dataset.assign as appropriate." + ) + + @Variable.values.setter # type: ignore[attr-defined] + def values(self, values): + raise ValueError( + f"Cannot assign to the .values attribute of dimension coordinate a.k.a IndexVariable {self.name!r}. " + f"Please use DataArray.assign_coords, Dataset.assign_coords or Dataset.assign as appropriate." + ) + + def chunk( + self, + chunks={}, + name=None, + lock=False, + inline_array=False, + chunked_array_type=None, + from_array_kwargs=None, + ): + # Dummy - do not chunk. This method is invoked e.g. by Dataset.chunk() + return self.copy(deep=False) + + def _as_sparse(self, sparse_format=_default, fill_value=_default): + # Dummy + return self.copy(deep=False) + + def _to_dense(self): + # Dummy + return self.copy(deep=False) + + def _finalize_indexing_result(self, dims, data): + if getattr(data, "ndim", 0) != 1: + # returns Variable rather than IndexVariable if multi-dimensional + return Variable(dims, data, self._attrs, self._encoding) + else: + return self._replace(dims=dims, data=data) + + def __setitem__(self, key, value): + raise TypeError(f"{type(self).__name__} values cannot be modified") + + @classmethod + def concat( + cls, + variables, + dim="concat_dim", + positions=None, + shortcut=False, + combine_attrs="override", + ): + """Specialized version of Variable.concat for IndexVariable objects. + + This exists because we want to avoid converting Index objects to NumPy + arrays, if possible. + """ + from xarray.core.merge import merge_attrs + + if not isinstance(dim, str): + (dim,) = dim.dims + + variables = list(variables) + first_var = variables[0] + + if any(not isinstance(v, cls) for v in variables): + raise TypeError( + "IndexVariable.concat requires that all input " + "variables be IndexVariable objects" + ) + + indexes = [v._data.array for v in variables] + + if not indexes: + data = [] + else: + data = indexes[0].append(indexes[1:]) + + if positions is not None: + indices = nputils.inverse_permutation(np.concatenate(positions)) + data = data.take(indices) + + # keep as str if possible as pandas.Index uses object (converts to numpy array) + data = maybe_coerce_to_str(data, variables) + + attrs = merge_attrs( + [var.attrs for var in variables], combine_attrs=combine_attrs + ) + if not shortcut: + for var in variables: + if var.dims != first_var.dims: + raise ValueError("inconsistent dimensions") + + return cls(first_var.dims, data, attrs) + + def copy(self, deep: bool = True, data: T_DuckArray | ArrayLike | None = None): + """Returns a copy of this object. + + `deep` is ignored since data is stored in the form of + pandas.Index, which is already immutable. Dimensions, attributes + and encodings are always copied. + + Use `data` to create a new object with the same structure as + original but entirely new data. + + Parameters + ---------- + deep : bool, default: True + Deep is ignored when data is given. Whether the data array is + loaded into memory and copied onto the new object. Default is True. + data : array_like, optional + Data to use in the new object. Must have same shape as original. + + Returns + ------- + object : Variable + New object with dimensions, attributes, encodings, and optionally + data copied from original. + """ + if data is None: + ndata = self._data + + if deep: + ndata = copy.deepcopy(ndata, None) + + else: + ndata = as_compatible_data(data) + if self.shape != ndata.shape: # type: ignore[attr-defined] + raise ValueError( + f"Data shape {ndata.shape} must match shape of object {self.shape}" # type: ignore[attr-defined] + ) + + attrs = copy.deepcopy(self._attrs) if deep else copy.copy(self._attrs) + encoding = copy.deepcopy(self._encoding) if deep else copy.copy(self._encoding) + + return self._replace(data=ndata, attrs=attrs, encoding=encoding) + + def equals(self, other, equiv=None): + # if equiv is specified, super up + if equiv is not None: + return super().equals(other, equiv) + + # otherwise use the native index equals, rather than looking at _data + other = getattr(other, "variable", other) + try: + return self.dims == other.dims and self._data_equals(other) + except (TypeError, AttributeError): + return False + + def _data_equals(self, other): + return self._to_index().equals(other._to_index()) + + def to_index_variable(self) -> IndexVariable: + """Return this variable as an xarray.IndexVariable""" + return self.copy(deep=False) + + to_coord = utils.alias(to_index_variable, "to_coord") + + def _to_index(self) -> pd.Index: + # n.b. creating a new pandas.Index from an old pandas.Index is + # basically free as pandas.Index objects are immutable. + # n.b.2. this method returns the multi-index instance for + # a pandas multi-index level variable. + assert self.ndim == 1 + index = self._data.array + if isinstance(index, pd.MultiIndex): + # set default names for multi-index unnamed levels so that + # we can safely rename dimension / coordinate later + valid_level_names = [ + name or f"{self.dims[0]}_level_{i}" + for i, name in enumerate(index.names) + ] + index = index.set_names(valid_level_names) + else: + index = index.set_names(self.name) + return index + + def to_index(self) -> pd.Index: + """Convert this variable to a pandas.Index""" + index = self._to_index() + level = getattr(self._data, "level", None) + if level is not None: + # return multi-index level converted to a single index + return index.get_level_values(level) + else: + return index + + @property + def level_names(self) -> list[str] | None: + """Return MultiIndex level names or None if this IndexVariable has no + MultiIndex. + """ + index = self.to_index() + if isinstance(index, pd.MultiIndex): + return index.names + else: + return None + + def get_level_variable(self, level): + """Return a new IndexVariable from a given MultiIndex level.""" + if self.level_names is None: + raise ValueError(f"IndexVariable {self.name!r} has no MultiIndex") + index = self.to_index() + return type(self)(self.dims, index.get_level_values(level)) + + @property + def name(self) -> Hashable: + return self.dims[0] + + @name.setter + def name(self, value) -> NoReturn: + raise AttributeError("cannot modify name of IndexVariable in-place") + + def _inplace_binary_op(self, other, f): + raise TypeError( + "Values of an IndexVariable are immutable and can not be modified inplace" + ) + + +def _unified_dims(variables): + # validate dimensions + all_dims = {} + for var in variables: + var_dims = var.dims + _raise_if_any_duplicate_dimensions(var_dims, err_context="Broadcasting") + + for d, s in zip(var_dims, var.shape): + if d not in all_dims: + all_dims[d] = s + elif all_dims[d] != s: + raise ValueError( + "operands cannot be broadcast together " + f"with mismatched lengths for dimension {d!r}: {(all_dims[d], s)}" + ) + return all_dims + + +def _broadcast_compat_variables(*variables): + """Create broadcast compatible variables, with the same dimensions. + + Unlike the result of broadcast_variables(), some variables may have + dimensions of size 1 instead of the size of the broadcast dimension. + """ + dims = tuple(_unified_dims(variables)) + return tuple(var.set_dims(dims) if var.dims != dims else var for var in variables) + + +def broadcast_variables(*variables: Variable) -> tuple[Variable, ...]: + """Given any number of variables, return variables with matching dimensions + and broadcast data. + + The data on the returned variables will be a view of the data on the + corresponding original arrays, but dimensions will be reordered and + inserted so that both broadcast arrays have the same dimensions. The new + dimensions are sorted in order of appearance in the first variable's + dimensions followed by the second variable's dimensions. + """ + dims_map = _unified_dims(variables) + dims_tuple = tuple(dims_map) + return tuple( + var.set_dims(dims_map) if var.dims != dims_tuple else var for var in variables + ) + + +def _broadcast_compat_data(self, other): + if not OPTIONS["arithmetic_broadcast"]: + if (isinstance(other, Variable) and self.dims != other.dims) or ( + is_duck_array(other) and self.ndim != other.ndim + ): + raise ValueError( + "Broadcasting is necessary but automatic broadcasting is disabled via " + "global option `'arithmetic_broadcast'`. " + "Use `xr.set_options(arithmetic_broadcast=True)` to enable automatic broadcasting." + ) + + if all(hasattr(other, attr) for attr in ["dims", "data", "shape", "encoding"]): + # `other` satisfies the necessary Variable API for broadcast_variables + new_self, new_other = _broadcast_compat_variables(self, other) + self_data = new_self.data + other_data = new_other.data + dims = new_self.dims + else: + # rely on numpy broadcasting rules + self_data = self.data + other_data = other + dims = self.dims + return self_data, other_data, dims + + +def concat( + variables, + dim="concat_dim", + positions=None, + shortcut=False, + combine_attrs="override", +): + """Concatenate variables along a new or existing dimension. + + Parameters + ---------- + variables : iterable of Variable + Arrays to stack together. Each variable is expected to have + matching dimensions and shape except for along the stacked + dimension. + dim : str or DataArray, optional + Name of the dimension to stack along. This can either be a new + dimension name, in which case it is added along axis=0, or an + existing dimension name, in which case the location of the + dimension is unchanged. Where to insert the new dimension is + determined by the first variable. + positions : None or list of array-like, optional + List of integer arrays which specifies the integer positions to which + to assign each dataset along the concatenated dimension. If not + supplied, objects are concatenated in the provided order. + shortcut : bool, optional + This option is used internally to speed-up groupby operations. + If `shortcut` is True, some checks of internal consistency between + arrays to concatenate are skipped. + combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \ + "override"}, default: "override" + String indicating how to combine attrs of the objects being merged: + + - "drop": empty attrs on returned Dataset. + - "identical": all attrs must be the same on every object. + - "no_conflicts": attrs from all objects are combined, any that have + the same name must also have the same value. + - "drop_conflicts": attrs from all objects are combined, any that have + the same name but different values are dropped. + - "override": skip comparing and copy attrs from the first dataset to + the result. + + Returns + ------- + stacked : Variable + Concatenated Variable formed by stacking all the supplied variables + along the given dimension. + """ + variables = list(variables) + if all(isinstance(v, IndexVariable) for v in variables): + return IndexVariable.concat(variables, dim, positions, shortcut, combine_attrs) + else: + return Variable.concat(variables, dim, positions, shortcut, combine_attrs) + + +def calculate_dimensions(variables: Mapping[Any, Variable]) -> dict[Hashable, int]: + """Calculate the dimensions corresponding to a set of variables. + + Returns dictionary mapping from dimension names to sizes. Raises ValueError + if any of the dimension sizes conflict. + """ + dims: dict[Hashable, int] = {} + last_used = {} + scalar_vars = {k for k, v in variables.items() if not v.dims} + for k, var in variables.items(): + for dim, size in zip(var.dims, var.shape): + if dim in scalar_vars: + raise ValueError( + f"dimension {dim!r} already exists as a scalar variable" + ) + if dim not in dims: + dims[dim] = size + last_used[dim] = k + elif dims[dim] != size: + raise ValueError( + f"conflicting sizes for dimension {dim!r}: " + f"length {size} on {k!r} and length {dims[dim]} on {last_used!r}" + ) + return dims diff --git a/test/fixtures/whole_applications/xarray/xarray/core/weighted.py b/test/fixtures/whole_applications/xarray/xarray/core/weighted.py new file mode 100644 index 0000000..8cb90ac --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/core/weighted.py @@ -0,0 +1,589 @@ +from __future__ import annotations + +from collections.abc import Hashable, Iterable, Sequence +from typing import TYPE_CHECKING, Generic, Literal, cast + +import numpy as np +from numpy.typing import ArrayLike + +from xarray.core import duck_array_ops, utils +from xarray.core.alignment import align, broadcast +from xarray.core.computation import apply_ufunc, dot +from xarray.core.types import Dims, T_DataArray, T_Xarray +from xarray.namedarray.utils import is_duck_dask_array +from xarray.util.deprecation_helpers import _deprecate_positional_args + +# Weighted quantile methods are a subset of the numpy supported quantile methods. +QUANTILE_METHODS = Literal[ + "linear", + "interpolated_inverted_cdf", + "hazen", + "weibull", + "median_unbiased", + "normal_unbiased", +] + +_WEIGHTED_REDUCE_DOCSTRING_TEMPLATE = """ + Reduce this {cls}'s data by a weighted ``{fcn}`` along some dimension(s). + + Parameters + ---------- + dim : Hashable or Iterable of Hashable, optional + Dimension(s) over which to apply the weighted ``{fcn}``. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + keep_attrs : bool or None, optional + If True, the attributes (``attrs``) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + + Returns + ------- + reduced : {cls} + New {cls} object with weighted ``{fcn}`` applied to its data and + the indicated dimension(s) removed. + + Notes + ----- + Returns {on_zero} if the ``weights`` sum to 0.0 along the reduced + dimension(s). + """ + +_SUM_OF_WEIGHTS_DOCSTRING = """ + Calculate the sum of weights, accounting for missing values in the data. + + Parameters + ---------- + dim : str or sequence of str, optional + Dimension(s) over which to sum the weights. + keep_attrs : bool, optional + If True, the attributes (``attrs``) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + + Returns + ------- + reduced : {cls} + New {cls} object with the sum of the weights over the given dimension. + """ + +_WEIGHTED_QUANTILE_DOCSTRING_TEMPLATE = """ + Apply a weighted ``quantile`` to this {cls}'s data along some dimension(s). + + Weights are interpreted as *sampling weights* (or probability weights) and + describe how a sample is scaled to the whole population [1]_. There are + other possible interpretations for weights, *precision weights* describing the + precision of observations, or *frequency weights* counting the number of identical + observations, however, they are not implemented here. + + For compatibility with NumPy's non-weighted ``quantile`` (which is used by + ``DataArray.quantile`` and ``Dataset.quantile``), the only interpolation + method supported by this weighted version corresponds to the default "linear" + option of ``numpy.quantile``. This is "Type 7" option, described in Hyndman + and Fan (1996) [2]_. The implementation is largely inspired by a blog post + from A. Akinshin's (2023) [3]_. + + Parameters + ---------- + q : float or sequence of float + Quantile to compute, which must be between 0 and 1 inclusive. + dim : str or sequence of str, optional + Dimension(s) over which to apply the weighted ``quantile``. + skipna : bool, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + keep_attrs : bool, optional + If True, the attributes (``attrs``) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + + Returns + ------- + quantiles : {cls} + New {cls} object with weighted ``quantile`` applied to its data and + the indicated dimension(s) removed. + + See Also + -------- + numpy.nanquantile, pandas.Series.quantile, Dataset.quantile, DataArray.quantile + + Notes + ----- + Returns NaN if the ``weights`` sum to 0.0 along the reduced + dimension(s). + + References + ---------- + .. [1] https://notstatschat.rbind.io/2020/08/04/weights-in-statistics/ + .. [2] Hyndman, R. J. & Fan, Y. (1996). Sample Quantiles in Statistical Packages. + The American Statistician, 50(4), 361–365. https://doi.org/10.2307/2684934 + .. [3] Akinshin, A. (2023) "Weighted quantile estimators" arXiv:2304.07265 [stat.ME] + https://arxiv.org/abs/2304.07265 + """ + + +if TYPE_CHECKING: + from xarray.core.dataarray import DataArray + from xarray.core.dataset import Dataset + + +class Weighted(Generic[T_Xarray]): + """An object that implements weighted operations. + + You should create a Weighted object by using the ``DataArray.weighted`` or + ``Dataset.weighted`` methods. + + See Also + -------- + Dataset.weighted + DataArray.weighted + """ + + __slots__ = ("obj", "weights") + + def __init__(self, obj: T_Xarray, weights: T_DataArray) -> None: + """ + Create a Weighted object + + Parameters + ---------- + obj : DataArray or Dataset + Object over which the weighted reduction operation is applied. + weights : DataArray + An array of weights associated with the values in the obj. + Each value in the obj contributes to the reduction operation + according to its associated weight. + + Notes + ----- + ``weights`` must be a ``DataArray`` and cannot contain missing values. + Missing values can be replaced by ``weights.fillna(0)``. + """ + + from xarray.core.dataarray import DataArray + + if not isinstance(weights, DataArray): + raise ValueError("`weights` must be a DataArray") + + def _weight_check(w): + # Ref https://github.com/pydata/xarray/pull/4559/files#r515968670 + if duck_array_ops.isnull(w).any(): + raise ValueError( + "`weights` cannot contain missing values. " + "Missing values can be replaced by `weights.fillna(0)`." + ) + return w + + if is_duck_dask_array(weights.data): + # assign to copy - else the check is not triggered + weights = weights.copy( + data=weights.data.map_blocks(_weight_check, dtype=weights.dtype), + deep=False, + ) + + else: + _weight_check(weights.data) + + self.obj: T_Xarray = obj + self.weights: T_DataArray = weights + + def _check_dim(self, dim: Dims): + """raise an error if any dimension is missing""" + + dims: list[Hashable] + if isinstance(dim, str) or not isinstance(dim, Iterable): + dims = [dim] if dim else [] + else: + dims = list(dim) + all_dims = set(self.obj.dims).union(set(self.weights.dims)) + missing_dims = set(dims) - all_dims + if missing_dims: + raise ValueError( + f"Dimensions {tuple(missing_dims)} not found in {self.__class__.__name__} dimensions {tuple(all_dims)}" + ) + + @staticmethod + def _reduce( + da: T_DataArray, + weights: T_DataArray, + dim: Dims = None, + skipna: bool | None = None, + ) -> T_DataArray: + """reduce using dot; equivalent to (da * weights).sum(dim, skipna) + + for internal use only + """ + + # need to infer dims as we use `dot` + if dim is None: + dim = ... + + # need to mask invalid values in da, as `dot` does not implement skipna + if skipna or (skipna is None and da.dtype.kind in "cfO"): + da = da.fillna(0.0) + + # `dot` does not broadcast arrays, so this avoids creating a large + # DataArray (if `weights` has additional dimensions) + return dot(da, weights, dim=dim) + + def _sum_of_weights(self, da: T_DataArray, dim: Dims = None) -> T_DataArray: + """Calculate the sum of weights, accounting for missing values""" + + # we need to mask data values that are nan; else the weights are wrong + mask = da.notnull() + + # bool -> int, because ``xr.dot([True, True], [True, True])`` -> True + # (and not 2); GH4074 + if self.weights.dtype == bool: + sum_of_weights = self._reduce( + mask, + duck_array_ops.astype(self.weights, dtype=int), + dim=dim, + skipna=False, + ) + else: + sum_of_weights = self._reduce(mask, self.weights, dim=dim, skipna=False) + + # 0-weights are not valid + valid_weights = sum_of_weights != 0.0 + + return sum_of_weights.where(valid_weights) + + def _sum_of_squares( + self, + da: T_DataArray, + dim: Dims = None, + skipna: bool | None = None, + ) -> T_DataArray: + """Reduce a DataArray by a weighted ``sum_of_squares`` along some dimension(s).""" + + demeaned = da - da.weighted(self.weights).mean(dim=dim) + + return self._reduce((demeaned**2), self.weights, dim=dim, skipna=skipna) + + def _weighted_sum( + self, + da: T_DataArray, + dim: Dims = None, + skipna: bool | None = None, + ) -> T_DataArray: + """Reduce a DataArray by a weighted ``sum`` along some dimension(s).""" + + return self._reduce(da, self.weights, dim=dim, skipna=skipna) + + def _weighted_mean( + self, + da: T_DataArray, + dim: Dims = None, + skipna: bool | None = None, + ) -> T_DataArray: + """Reduce a DataArray by a weighted ``mean`` along some dimension(s).""" + + weighted_sum = self._weighted_sum(da, dim=dim, skipna=skipna) + + sum_of_weights = self._sum_of_weights(da, dim=dim) + + return weighted_sum / sum_of_weights + + def _weighted_var( + self, + da: T_DataArray, + dim: Dims = None, + skipna: bool | None = None, + ) -> T_DataArray: + """Reduce a DataArray by a weighted ``var`` along some dimension(s).""" + + sum_of_squares = self._sum_of_squares(da, dim=dim, skipna=skipna) + + sum_of_weights = self._sum_of_weights(da, dim=dim) + + return sum_of_squares / sum_of_weights + + def _weighted_std( + self, + da: T_DataArray, + dim: Dims = None, + skipna: bool | None = None, + ) -> T_DataArray: + """Reduce a DataArray by a weighted ``std`` along some dimension(s).""" + + return cast("T_DataArray", np.sqrt(self._weighted_var(da, dim, skipna))) + + def _weighted_quantile( + self, + da: T_DataArray, + q: ArrayLike, + dim: Dims = None, + skipna: bool | None = None, + ) -> T_DataArray: + """Apply a weighted ``quantile`` to a DataArray along some dimension(s).""" + + def _get_h(n: float, q: np.ndarray, method: QUANTILE_METHODS) -> np.ndarray: + """Return the interpolation parameter.""" + # Note that options are not yet exposed in the public API. + h: np.ndarray + if method == "linear": + h = (n - 1) * q + 1 + elif method == "interpolated_inverted_cdf": + h = n * q + elif method == "hazen": + h = n * q + 0.5 + elif method == "weibull": + h = (n + 1) * q + elif method == "median_unbiased": + h = (n + 1 / 3) * q + 1 / 3 + elif method == "normal_unbiased": + h = (n + 1 / 4) * q + 3 / 8 + else: + raise ValueError(f"Invalid method: {method}.") + return h.clip(1, n) + + def _weighted_quantile_1d( + data: np.ndarray, + weights: np.ndarray, + q: np.ndarray, + skipna: bool, + method: QUANTILE_METHODS = "linear", + ) -> np.ndarray: + # This algorithm has been adapted from: + # https://aakinshin.net/posts/weighted-quantiles/#reference-implementation + is_nan = np.isnan(data) + if skipna: + # Remove nans from data and weights + not_nan = ~is_nan + data = data[not_nan] + weights = weights[not_nan] + elif is_nan.any(): + # Return nan if data contains any nan + return np.full(q.size, np.nan) + + # Filter out data (and weights) associated with zero weights, which also flattens them + nonzero_weights = weights != 0 + data = data[nonzero_weights] + weights = weights[nonzero_weights] + n = data.size + + if n == 0: + # Possibly empty after nan or zero weight filtering above + return np.full(q.size, np.nan) + + # Kish's effective sample size + nw = weights.sum() ** 2 / (weights**2).sum() + + # Sort data and weights + sorter = np.argsort(data) + data = data[sorter] + weights = weights[sorter] + + # Normalize and sum the weights + weights = weights / weights.sum() + weights_cum = np.append(0, weights.cumsum()) + + # Vectorize the computation by transposing q with respect to weights + q = np.atleast_2d(q).T + + # Get the interpolation parameter for each q + h = _get_h(nw, q, method) + + # Find the samples contributing to the quantile computation (at *positions* between (h-1)/nw and h/nw) + u = np.maximum((h - 1) / nw, np.minimum(h / nw, weights_cum)) + + # Compute their relative weight + v = u * nw - h + 1 + w = np.diff(v) + + # Apply the weights + return (data * w).sum(axis=1) + + if skipna is None and da.dtype.kind in "cfO": + skipna = True + + q = np.atleast_1d(np.asarray(q, dtype=np.float64)) + + if q.ndim > 1: + raise ValueError("q must be a scalar or 1d") + + if np.any((q < 0) | (q > 1)): + raise ValueError("q values must be between 0 and 1") + + if dim is None: + dim = da.dims + + if utils.is_scalar(dim): + dim = [dim] + + # To satisfy mypy + dim = cast(Sequence, dim) + + # need to align *and* broadcast + # - `_weighted_quantile_1d` requires arrays with the same shape + # - broadcast does an outer join, which can introduce NaN to weights + # - therefore we first need to do align(..., join="inner") + + # TODO: use broadcast(..., join="inner") once available + # see https://github.com/pydata/xarray/issues/6304 + + da, weights = align(da, self.weights, join="inner") + da, weights = broadcast(da, weights) + + result = apply_ufunc( + _weighted_quantile_1d, + da, + weights, + input_core_dims=[dim, dim], + output_core_dims=[["quantile"]], + output_dtypes=[np.float64], + dask_gufunc_kwargs=dict(output_sizes={"quantile": len(q)}), + dask="parallelized", + vectorize=True, + kwargs={"q": q, "skipna": skipna}, + ) + + result = result.transpose("quantile", ...) + result = result.assign_coords(quantile=q).squeeze() + + return result + + def _implementation(self, func, dim, **kwargs): + raise NotImplementedError("Use `Dataset.weighted` or `DataArray.weighted`") + + @_deprecate_positional_args("v2023.10.0") + def sum_of_weights( + self, + dim: Dims = None, + *, + keep_attrs: bool | None = None, + ) -> T_Xarray: + return self._implementation( + self._sum_of_weights, dim=dim, keep_attrs=keep_attrs + ) + + @_deprecate_positional_args("v2023.10.0") + def sum_of_squares( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + keep_attrs: bool | None = None, + ) -> T_Xarray: + return self._implementation( + self._sum_of_squares, dim=dim, skipna=skipna, keep_attrs=keep_attrs + ) + + @_deprecate_positional_args("v2023.10.0") + def sum( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + keep_attrs: bool | None = None, + ) -> T_Xarray: + return self._implementation( + self._weighted_sum, dim=dim, skipna=skipna, keep_attrs=keep_attrs + ) + + @_deprecate_positional_args("v2023.10.0") + def mean( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + keep_attrs: bool | None = None, + ) -> T_Xarray: + return self._implementation( + self._weighted_mean, dim=dim, skipna=skipna, keep_attrs=keep_attrs + ) + + @_deprecate_positional_args("v2023.10.0") + def var( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + keep_attrs: bool | None = None, + ) -> T_Xarray: + return self._implementation( + self._weighted_var, dim=dim, skipna=skipna, keep_attrs=keep_attrs + ) + + @_deprecate_positional_args("v2023.10.0") + def std( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + keep_attrs: bool | None = None, + ) -> T_Xarray: + return self._implementation( + self._weighted_std, dim=dim, skipna=skipna, keep_attrs=keep_attrs + ) + + def quantile( + self, + q: ArrayLike, + *, + dim: Dims = None, + keep_attrs: bool | None = None, + skipna: bool = True, + ) -> T_Xarray: + return self._implementation( + self._weighted_quantile, q=q, dim=dim, skipna=skipna, keep_attrs=keep_attrs + ) + + def __repr__(self) -> str: + """provide a nice str repr of our Weighted object""" + + klass = self.__class__.__name__ + weight_dims = ", ".join(map(str, self.weights.dims)) + return f"{klass} with weights along dimensions: {weight_dims}" + + +class DataArrayWeighted(Weighted["DataArray"]): + def _implementation(self, func, dim, **kwargs) -> DataArray: + self._check_dim(dim) + + dataset = self.obj._to_temp_dataset() + dataset = dataset.map(func, dim=dim, **kwargs) + return self.obj._from_temp_dataset(dataset) + + +class DatasetWeighted(Weighted["Dataset"]): + def _implementation(self, func, dim, **kwargs) -> Dataset: + self._check_dim(dim) + + return self.obj.map(func, dim=dim, **kwargs) + + +def _inject_docstring(cls, cls_name): + cls.sum_of_weights.__doc__ = _SUM_OF_WEIGHTS_DOCSTRING.format(cls=cls_name) + + cls.sum.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format( + cls=cls_name, fcn="sum", on_zero="0" + ) + + cls.mean.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format( + cls=cls_name, fcn="mean", on_zero="NaN" + ) + + cls.sum_of_squares.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format( + cls=cls_name, fcn="sum_of_squares", on_zero="0" + ) + + cls.var.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format( + cls=cls_name, fcn="var", on_zero="NaN" + ) + + cls.std.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format( + cls=cls_name, fcn="std", on_zero="NaN" + ) + + cls.quantile.__doc__ = _WEIGHTED_QUANTILE_DOCSTRING_TEMPLATE.format(cls=cls_name) + + +_inject_docstring(DataArrayWeighted, "DataArray") +_inject_docstring(DatasetWeighted, "Dataset") diff --git a/test/fixtures/whole_applications/xarray/xarray/datatree_/.flake8 b/test/fixtures/whole_applications/xarray/xarray/datatree_/.flake8 new file mode 100644 index 0000000..f1e3f92 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/datatree_/.flake8 @@ -0,0 +1,15 @@ +[flake8] +ignore = + # whitespace before ':' - doesn't work well with black + E203 + # module level import not at top of file + E402 + # line too long - let black worry about that + E501 + # do not assign a lambda expression, use a def + E731 + # line break before binary operator + W503 +exclude= + .eggs + doc diff --git a/test/fixtures/whole_applications/xarray/xarray/datatree_/.git_archival.txt b/test/fixtures/whole_applications/xarray/xarray/datatree_/.git_archival.txt new file mode 100644 index 0000000..408f6f6 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/datatree_/.git_archival.txt @@ -0,0 +1,4 @@ +node: bef04067dd87f9f0c1a3ae7840299e0bbdd595a8 +node-date: 2024-06-13T07:05:11-06:00 +describe-name: %(describe:tags=true) +ref-names: tag: v2024.06.0 diff --git a/test/fixtures/whole_applications/xarray/xarray/datatree_/.github/dependabot.yml b/test/fixtures/whole_applications/xarray/xarray/datatree_/.github/dependabot.yml new file mode 100644 index 0000000..d1d1190 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/datatree_/.github/dependabot.yml @@ -0,0 +1,11 @@ +version: 2 +updates: + - package-ecosystem: pip + directory: "/" + schedule: + interval: daily + - package-ecosystem: "github-actions" + directory: "/" + schedule: + # Check for updates to GitHub Actions every weekday + interval: "daily" diff --git a/test/fixtures/whole_applications/xarray/xarray/datatree_/.github/pull_request_template.md b/test/fixtures/whole_applications/xarray/xarray/datatree_/.github/pull_request_template.md new file mode 100644 index 0000000..8270498 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/datatree_/.github/pull_request_template.md @@ -0,0 +1,7 @@ + + +- [ ] Closes #xxxx +- [ ] Tests added +- [ ] Passes `pre-commit run --all-files` +- [ ] New functions/methods are listed in `api.rst` +- [ ] Changes are summarized in `docs/source/whats-new.rst` diff --git a/test/fixtures/whole_applications/xarray/xarray/datatree_/.github/workflows/main.yaml b/test/fixtures/whole_applications/xarray/xarray/datatree_/.github/workflows/main.yaml new file mode 100644 index 0000000..37034fc --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/datatree_/.github/workflows/main.yaml @@ -0,0 +1,97 @@ +name: CI + +on: + push: + branches: + - main + pull_request: + branches: + - main + schedule: + - cron: "0 0 * * *" + +jobs: + + test: + name: ${{ matrix.python-version }}-build + runs-on: ubuntu-latest + defaults: + run: + shell: bash -l {0} + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12"] + steps: + - uses: actions/checkout@v4 + + - name: Create conda environment + uses: mamba-org/provision-with-micromamba@main + with: + cache-downloads: true + micromamba-version: 'latest' + environment-file: ci/environment.yml + extra-specs: | + python=${{ matrix.python-version }} + + - name: Conda info + run: conda info + + - name: Install datatree + run: | + python -m pip install -e . --no-deps --force-reinstall + + - name: Conda list + run: conda list + + - name: Running Tests + run: | + python -m pytest --cov=./ --cov-report=xml --verbose + + - name: Upload code coverage to Codecov + uses: codecov/codecov-action@v3.1.4 + with: + file: ./coverage.xml + flags: unittests + env_vars: OS,PYTHON + name: codecov-umbrella + fail_ci_if_error: false + + + test-upstream: + name: ${{ matrix.python-version }}-dev-build + runs-on: ubuntu-latest + defaults: + run: + shell: bash -l {0} + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12"] + steps: + - uses: actions/checkout@v4 + + - name: Create conda environment + uses: mamba-org/provision-with-micromamba@main + with: + cache-downloads: true + micromamba-version: 'latest' + environment-file: ci/environment.yml + extra-specs: | + python=${{ matrix.python-version }} + + - name: Conda info + run: conda info + + - name: Install dev reqs + run: | + python -m pip install --no-deps --upgrade \ + git+https://github.com/pydata/xarray \ + git+https://github.com/Unidata/netcdf4-python + + python -m pip install -e . --no-deps --force-reinstall + + - name: Conda list + run: conda list + + - name: Running Tests + run: | + python -m pytest --verbose diff --git a/test/fixtures/whole_applications/xarray/xarray/datatree_/.github/workflows/pypipublish.yaml b/test/fixtures/whole_applications/xarray/xarray/datatree_/.github/workflows/pypipublish.yaml new file mode 100644 index 0000000..7dc36d8 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/datatree_/.github/workflows/pypipublish.yaml @@ -0,0 +1,84 @@ +name: Build distribution +on: + release: + types: + - published + push: + branches: + - main + pull_request: + branches: + - main + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + build-artifacts: + runs-on: ubuntu-latest + if: github.repository == 'xarray-contrib/datatree' + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - uses: actions/setup-python@v5 + name: Install Python + with: + python-version: 3.9 + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install build + + - name: Build tarball and wheels + run: | + git clean -xdf + git restore -SW . + python -m build --sdist --wheel . + + + - uses: actions/upload-artifact@v4 + with: + name: releases + path: dist + + test-built-dist: + needs: build-artifacts + runs-on: ubuntu-latest + steps: + - uses: actions/setup-python@v5 + name: Install Python + with: + python-version: '3.10' + - uses: actions/download-artifact@v4 + with: + name: releases + path: dist + - name: List contents of built dist + run: | + ls -ltrh + ls -ltrh dist + + - name: Verify the built dist/wheel is valid + run: | + python -m pip install --upgrade pip + python -m pip install dist/xarray_datatree*.whl + python -c "import datatree; print(datatree.__version__)" + + upload-to-pypi: + needs: test-built-dist + if: github.event_name == 'release' + runs-on: ubuntu-latest + steps: + - uses: actions/download-artifact@v4 + with: + name: releases + path: dist + - name: Publish package to PyPI + uses: pypa/gh-action-pypi-publish@v1.8.11 + with: + user: ${{ secrets.PYPI_USERNAME }} + password: ${{ secrets.PYPI_PASSWORD }} + verbose: true diff --git a/test/fixtures/whole_applications/xarray/xarray/datatree_/.gitignore b/test/fixtures/whole_applications/xarray/xarray/datatree_/.gitignore new file mode 100644 index 0000000..88af994 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/datatree_/.gitignore @@ -0,0 +1,136 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ +docs/source/generated + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# version +_version.py + +# Ignore vscode specific settings +.vscode/ diff --git a/test/fixtures/whole_applications/xarray/xarray/datatree_/.pre-commit-config.yaml b/test/fixtures/whole_applications/xarray/xarray/datatree_/.pre-commit-config.yaml new file mode 100644 index 0000000..ea73c38 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/datatree_/.pre-commit-config.yaml @@ -0,0 +1,58 @@ +# https://pre-commit.com/ +ci: + autoupdate_schedule: monthly +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + # isort should run before black as black sometimes tweaks the isort output + - repo: https://github.com/PyCQA/isort + rev: 5.13.2 + hooks: + - id: isort + # https://github.com/python/black#version-control-integration + - repo: https://github.com/psf/black + rev: 23.12.1 + hooks: + - id: black + - repo: https://github.com/keewis/blackdoc + rev: v0.3.9 + hooks: + - id: blackdoc + - repo: https://github.com/PyCQA/flake8 + rev: 6.1.0 + hooks: + - id: flake8 + # - repo: https://github.com/Carreau/velin + # rev: 0.0.8 + # hooks: + # - id: velin + # args: ["--write", "--compact"] + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.8.0 + hooks: + - id: mypy + # Copied from setup.cfg + exclude: "properties|asv_bench|docs" + additional_dependencies: [ + # Type stubs + types-python-dateutil, + types-pkg_resources, + types-PyYAML, + types-pytz, + # Dependencies that are typed + numpy, + typing-extensions>=4.1.0, + ] + # run this occasionally, ref discussion https://github.com/pydata/xarray/pull/3194 + # - repo: https://github.com/asottile/pyupgrade + # rev: v1.22.1 + # hooks: + # - id: pyupgrade + # args: + # - "--py3-only" + # # remove on f-strings in Py3.7 + # - "--keep-percent-format" diff --git a/test/fixtures/whole_applications/xarray/xarray/datatree_/LICENSE b/test/fixtures/whole_applications/xarray/xarray/datatree_/LICENSE new file mode 100644 index 0000000..d68e723 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/datatree_/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright (c) 2022 onwards, datatree developers + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/test/fixtures/whole_applications/xarray/xarray/datatree_/README.md b/test/fixtures/whole_applications/xarray/xarray/datatree_/README.md new file mode 100644 index 0000000..e41a13b --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/datatree_/README.md @@ -0,0 +1,95 @@ +# datatree + +| CI | [![GitHub Workflow Status][github-ci-badge]][github-ci-link] [![Code Coverage Status][codecov-badge]][codecov-link] [![pre-commit.ci status][pre-commit.ci-badge]][pre-commit.ci-link] | +| :---------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| **Docs** | [![Documentation Status][rtd-badge]][rtd-link] | +| **Package** | [![Conda][conda-badge]][conda-link] [![PyPI][pypi-badge]][pypi-link] | +| **License** | [![License][license-badge]][repo-link] | + + +**Datatree is a prototype implementation of a tree-like hierarchical data structure for xarray.** + +Datatree was born after the xarray team recognised a [need for a new hierarchical data structure](https://github.com/pydata/xarray/issues/4118), +that was more flexible than a single `xarray.Dataset` object. +The initial motivation was to represent netCDF files / Zarr stores with multiple nested groups in a single in-memory object, +but `datatree.DataTree` objects have many other uses. + +### DEPRECATION NOTICE + +Datatree is in the process of being merged upstream into xarray (as of [v0.0.14](https://github.com/xarray-contrib/datatree/releases/tag/v0.0.14), see xarray issue [#8572](https://github.com/pydata/xarray/issues/8572)). We are aiming to preserve the record of contributions to this repository during the migration process. However whilst we will hapily accept new PRs to this repository, this repo will be deprecated and any PRs since [v0.0.14](https://github.com/xarray-contrib/datatree/releases/tag/v0.0.14) might be later copied across to xarray without full git attribution. + +Hopefully for users the disruption will be minimal - and just mean that in some future version of xarray you only need to do `from xarray import DataTree` rather than `from datatree import DataTree`. Once the migration is complete this repository will be archived. + +### Installation +You can install datatree via pip: +```shell +pip install xarray-datatree +``` + +or via conda-forge +```shell +conda install -c conda-forge xarray-datatree +``` + +### Why Datatree? + +You might want to use datatree for: + +- Organising many related datasets, e.g. results of the same experiment with different parameters, or simulations of the same system using different models, +- Analysing similar data at multiple resolutions simultaneously, such as when doing a convergence study, +- Comparing heterogenous but related data, such as experimental and theoretical data, +- I/O with nested data formats such as netCDF / Zarr groups. + +[**Talk slides on Datatree from AMS-python 2023**](https://speakerdeck.com/tomnicholas/xarray-datatree-hierarchical-data-structures-for-multi-model-science) + +### Features + +The approach used here is based on benbovy's [`DatasetNode` example](https://gist.github.com/benbovy/92e7c76220af1aaa4b3a0b65374e233a) - the basic idea is that each tree node wraps a up to a single `xarray.Dataset`. The differences are that this effort: +- Uses a node structure inspired by [anytree](https://github.com/xarray-contrib/datatree/issues/7) for the tree, +- Implements path-like getting and setting, +- Has functions for mapping user-supplied functions over every node in the tree, +- Automatically dispatches *some* of `xarray.Dataset`'s API over every node in the tree (such as `.isel`), +- Has a bunch of tests, +- Has a printable representation that currently looks like this: +drawing + +### Get Started + +You can create a `DataTree` object in 3 ways: +1) Load from a netCDF file (or Zarr store) that has groups via `open_datatree()`. +2) Using the init method of `DataTree`, which creates an individual node. + You can then specify the nodes' relationships to one other, either by setting `.parent` and `.children` attributes, + or through `__get/setitem__` access, e.g. `dt['path/to/node'] = DataTree()`. +3) Create a tree from a dictionary of paths to datasets using `DataTree.from_dict()`. + +### Development Roadmap + +Datatree currently lives in a separate repository to the main xarray package. +This allows the datatree developers to make changes to it, experiment, and improve it faster. + +Eventually we plan to fully integrate datatree upstream into xarray's main codebase, at which point the [github.com/xarray-contrib/datatree](https://github.com/xarray-contrib/datatree>) repository will be archived. +This should not cause much disruption to code that depends on datatree - you will likely only have to change the import line (i.e. from ``from datatree import DataTree`` to ``from xarray import DataTree``). + +However, until this full integration occurs, datatree's API should not be considered to have the same [level of stability as xarray's](https://docs.xarray.dev/en/stable/contributing.html#backwards-compatibility). + +### User Feedback + +We really really really want to hear your opinions on datatree! +At this point in development, user feedback is critical to help us create something that will suit everyone's needs. +Please raise any thoughts, issues, suggestions or bugs, no matter how small or large, on the [github issue tracker](https://github.com/xarray-contrib/datatree/issues). + + +[github-ci-badge]: https://img.shields.io/github/actions/workflow/status/xarray-contrib/datatree/main.yaml?branch=main&label=CI&logo=github +[github-ci-link]: https://github.com/xarray-contrib/datatree/actions?query=workflow%3ACI +[codecov-badge]: https://img.shields.io/codecov/c/github/xarray-contrib/datatree.svg?logo=codecov +[codecov-link]: https://codecov.io/gh/xarray-contrib/datatree +[rtd-badge]: https://img.shields.io/readthedocs/xarray-datatree/latest.svg +[rtd-link]: https://xarray-datatree.readthedocs.io/en/latest/?badge=latest +[pypi-badge]: https://img.shields.io/pypi/v/xarray-datatree?logo=pypi +[pypi-link]: https://pypi.org/project/xarray-datatree +[conda-badge]: https://img.shields.io/conda/vn/conda-forge/xarray-datatree?logo=anaconda +[conda-link]: https://anaconda.org/conda-forge/xarray-datatree +[license-badge]: https://img.shields.io/github/license/xarray-contrib/datatree +[repo-link]: https://github.com/xarray-contrib/datatree +[pre-commit.ci-badge]: https://results.pre-commit.ci/badge/github/xarray-contrib/datatree/main.svg +[pre-commit.ci-link]: https://results.pre-commit.ci/latest/github/xarray-contrib/datatree/main diff --git a/test/fixtures/whole_applications/xarray/xarray/datatree_/ci/doc.yml b/test/fixtures/whole_applications/xarray/xarray/datatree_/ci/doc.yml new file mode 100644 index 0000000..f3b95f7 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/datatree_/ci/doc.yml @@ -0,0 +1,25 @@ +name: datatree-doc +channels: + - conda-forge +dependencies: + - pip + - python>=3.9 + - netcdf4 + - scipy + - sphinx>=4.2.0 + - sphinx-copybutton + - sphinx-panels + - sphinx-autosummary-accessors + - sphinx-book-theme >= 0.0.38 + - nbsphinx + - sphinxcontrib-srclinks + - pickleshare + - pydata-sphinx-theme>=0.4.3 + - ipython + - h5netcdf + - zarr + - xarray + - pip: + - -e .. + - sphinxext-rediraffe + - sphinxext-opengraph diff --git a/test/fixtures/whole_applications/xarray/xarray/datatree_/ci/environment.yml b/test/fixtures/whole_applications/xarray/xarray/datatree_/ci/environment.yml new file mode 100644 index 0000000..fc0c6d9 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/datatree_/ci/environment.yml @@ -0,0 +1,16 @@ +name: datatree-test +channels: + - conda-forge + - nodefaults +dependencies: + - python>=3.9 + - netcdf4 + - pytest + - flake8 + - black + - codecov + - pytest-cov + - h5netcdf + - zarr + - pip: + - xarray>=2022.05.0.dev0 diff --git a/test/fixtures/whole_applications/xarray/xarray/datatree_/codecov.yml b/test/fixtures/whole_applications/xarray/xarray/datatree_/codecov.yml new file mode 100644 index 0000000..44fd739 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/datatree_/codecov.yml @@ -0,0 +1,21 @@ +codecov: + require_ci_to_pass: false + max_report_age: off + +comment: false + +ignore: + - 'datatree/tests/*' + - 'setup.py' + - 'conftest.py' + +coverage: + precision: 2 + round: down + status: + project: + default: + target: 95 + informational: true + patch: off + changes: false diff --git a/test/fixtures/whole_applications/xarray/xarray/datatree_/conftest.py b/test/fixtures/whole_applications/xarray/xarray/datatree_/conftest.py new file mode 100644 index 0000000..7ef1917 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/datatree_/conftest.py @@ -0,0 +1,3 @@ +import pytest + +pytest.register_assert_rewrite("datatree.testing") diff --git a/test/fixtures/whole_applications/xarray/xarray/datatree_/datatree/__init__.py b/test/fixtures/whole_applications/xarray/xarray/datatree_/datatree/__init__.py new file mode 100644 index 0000000..51c5f1b --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/datatree_/datatree/__init__.py @@ -0,0 +1,7 @@ +# import public API +from xarray.core.treenode import InvalidTreeError, NotFoundInTreeError + +__all__ = ( + "InvalidTreeError", + "NotFoundInTreeError", +) diff --git a/test/fixtures/whole_applications/xarray/xarray/datatree_/datatree/py.typed b/test/fixtures/whole_applications/xarray/xarray/datatree_/datatree/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/Makefile b/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/Makefile new file mode 100644 index 0000000..6e9b405 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/Makefile @@ -0,0 +1,183 @@ +# Makefile for Sphinx documentation +# + +# You can set these variables from the command line. +SPHINXOPTS = +SPHINXBUILD = sphinx-build +PAPER = +BUILDDIR = _build + +# User-friendly check for sphinx-build +ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) +$(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) +endif + +# Internal variables. +PAPEROPT_a4 = -D latex_paper_size=a4 +PAPEROPT_letter = -D latex_paper_size=letter +ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) source +# the i18n builder cannot share the environment and doctrees with the others +I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) source + +.PHONY: help clean html rtdhtml dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext + +help: + @echo "Please use \`make ' where is one of" + @echo " html to make standalone HTML files" + @echo " rtdhtml Build html using same settings used on ReadtheDocs" + @echo " dirhtml to make HTML files named index.html in directories" + @echo " singlehtml to make a single large HTML file" + @echo " pickle to make pickle files" + @echo " json to make JSON files" + @echo " htmlhelp to make HTML files and a HTML help project" + @echo " qthelp to make HTML files and a qthelp project" + @echo " devhelp to make HTML files and a Devhelp project" + @echo " epub to make an epub" + @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" + @echo " latexpdf to make LaTeX files and run them through pdflatex" + @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" + @echo " text to make text files" + @echo " man to make manual pages" + @echo " texinfo to make Texinfo files" + @echo " info to make Texinfo files and run them through makeinfo" + @echo " gettext to make PO message catalogs" + @echo " changes to make an overview of all changed/added/deprecated items" + @echo " xml to make Docutils-native XML files" + @echo " pseudoxml to make pseudoxml-XML files for display purposes" + @echo " linkcheck to check all external links for integrity" + @echo " doctest to run all doctests embedded in the documentation (if enabled)" + +clean: + rm -rf $(BUILDDIR)/* + +html: + $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html + @echo + @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." + +rtdhtml: + $(SPHINXBUILD) -T -j auto -E -W --keep-going -b html -d $(BUILDDIR)/doctrees -D language=en . $(BUILDDIR)/html + @echo + @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." + +dirhtml: + $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml + @echo + @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." + +singlehtml: + $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml + @echo + @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." + +pickle: + $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle + @echo + @echo "Build finished; now you can process the pickle files." + +json: + $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json + @echo + @echo "Build finished; now you can process the JSON files." + +htmlhelp: + $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp + @echo + @echo "Build finished; now you can run HTML Help Workshop with the" \ + ".hhp project file in $(BUILDDIR)/htmlhelp." + +qthelp: + $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp + @echo + @echo "Build finished; now you can run "qcollectiongenerator" with the" \ + ".qhcp project file in $(BUILDDIR)/qthelp, like this:" + @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/complexity.qhcp" + @echo "To view the help file:" + @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/complexity.qhc" + +devhelp: + $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp + @echo + @echo "Build finished." + @echo "To view the help file:" + @echo "# mkdir -p $$HOME/.local/share/devhelp/complexity" + @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/complexity" + @echo "# devhelp" + +epub: + $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub + @echo + @echo "Build finished. The epub file is in $(BUILDDIR)/epub." + +latex: + $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex + @echo + @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." + @echo "Run \`make' in that directory to run these through (pdf)latex" \ + "(use \`make latexpdf' here to do that automatically)." + +latexpdf: + $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex + @echo "Running LaTeX files through pdflatex..." + $(MAKE) -C $(BUILDDIR)/latex all-pdf + @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." + +latexpdfja: + $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex + @echo "Running LaTeX files through platex and dvipdfmx..." + $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja + @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." + +text: + $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text + @echo + @echo "Build finished. The text files are in $(BUILDDIR)/text." + +man: + $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man + @echo + @echo "Build finished. The manual pages are in $(BUILDDIR)/man." + +texinfo: + $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo + @echo + @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." + @echo "Run \`make' in that directory to run these through makeinfo" \ + "(use \`make info' here to do that automatically)." + +info: + $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo + @echo "Running Texinfo files through makeinfo..." + make -C $(BUILDDIR)/texinfo info + @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." + +gettext: + $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale + @echo + @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." + +changes: + $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes + @echo + @echo "The overview file is in $(BUILDDIR)/changes." + +linkcheck: + $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck + @echo + @echo "Link check complete; look for any errors in the above output " \ + "or in $(BUILDDIR)/linkcheck/output.txt." + +doctest: + $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest + @echo "Testing of doctests in the sources finished, look at the " \ + "results in $(BUILDDIR)/doctest/output.txt." + +xml: + $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml + @echo + @echo "Build finished. The XML files are in $(BUILDDIR)/xml." + +pseudoxml: + $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml + @echo + @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." diff --git a/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/README.md b/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/README.md new file mode 100644 index 0000000..ca2bf72 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/README.md @@ -0,0 +1,14 @@ +# README - docs + +## Build the documentation locally + +```bash +cd docs # From project's root +make clean +rm -rf source/generated # remove autodoc artefacts, that are not removed by `make clean` +make html +``` + +## Access the documentation locally + +Open `docs/_build/html/index.html` in a web browser diff --git a/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/make.bat b/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/make.bat new file mode 100644 index 0000000..2df9a8c --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/make.bat @@ -0,0 +1,242 @@ +@ECHO OFF + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set BUILDDIR=_build +set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . +set I18NSPHINXOPTS=%SPHINXOPTS% . +if NOT "%PAPER%" == "" ( + set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% + set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% +) + +if "%1" == "" goto help + +if "%1" == "help" ( + :help + echo.Please use `make ^` where ^ is one of + echo. html to make standalone HTML files + echo. dirhtml to make HTML files named index.html in directories + echo. singlehtml to make a single large HTML file + echo. pickle to make pickle files + echo. json to make JSON files + echo. htmlhelp to make HTML files and a HTML help project + echo. qthelp to make HTML files and a qthelp project + echo. devhelp to make HTML files and a Devhelp project + echo. epub to make an epub + echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter + echo. text to make text files + echo. man to make manual pages + echo. texinfo to make Texinfo files + echo. gettext to make PO message catalogs + echo. changes to make an overview over all changed/added/deprecated items + echo. xml to make Docutils-native XML files + echo. pseudoxml to make pseudoxml-XML files for display purposes + echo. linkcheck to check all external links for integrity + echo. doctest to run all doctests embedded in the documentation if enabled + goto end +) + +if "%1" == "clean" ( + for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i + del /q /s %BUILDDIR%\* + goto end +) + + +%SPHINXBUILD% 2> nul +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "html" ( + %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/html. + goto end +) + +if "%1" == "dirhtml" ( + %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. + goto end +) + +if "%1" == "singlehtml" ( + %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. + goto end +) + +if "%1" == "pickle" ( + %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can process the pickle files. + goto end +) + +if "%1" == "json" ( + %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can process the JSON files. + goto end +) + +if "%1" == "htmlhelp" ( + %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can run HTML Help Workshop with the ^ +.hhp project file in %BUILDDIR%/htmlhelp. + goto end +) + +if "%1" == "qthelp" ( + %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can run "qcollectiongenerator" with the ^ +.qhcp project file in %BUILDDIR%/qthelp, like this: + echo.^> qcollectiongenerator %BUILDDIR%\qthelp\complexity.qhcp + echo.To view the help file: + echo.^> assistant -collectionFile %BUILDDIR%\qthelp\complexity.ghc + goto end +) + +if "%1" == "devhelp" ( + %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. + goto end +) + +if "%1" == "epub" ( + %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The epub file is in %BUILDDIR%/epub. + goto end +) + +if "%1" == "latex" ( + %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. + goto end +) + +if "%1" == "latexpdf" ( + %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex + cd %BUILDDIR%/latex + make all-pdf + cd %BUILDDIR%/.. + echo. + echo.Build finished; the PDF files are in %BUILDDIR%/latex. + goto end +) + +if "%1" == "latexpdfja" ( + %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex + cd %BUILDDIR%/latex + make all-pdf-ja + cd %BUILDDIR%/.. + echo. + echo.Build finished; the PDF files are in %BUILDDIR%/latex. + goto end +) + +if "%1" == "text" ( + %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The text files are in %BUILDDIR%/text. + goto end +) + +if "%1" == "man" ( + %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The manual pages are in %BUILDDIR%/man. + goto end +) + +if "%1" == "texinfo" ( + %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. + goto end +) + +if "%1" == "gettext" ( + %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The message catalogs are in %BUILDDIR%/locale. + goto end +) + +if "%1" == "changes" ( + %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes + if errorlevel 1 exit /b 1 + echo. + echo.The overview file is in %BUILDDIR%/changes. + goto end +) + +if "%1" == "linkcheck" ( + %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck + if errorlevel 1 exit /b 1 + echo. + echo.Link check complete; look for any errors in the above output ^ +or in %BUILDDIR%/linkcheck/output.txt. + goto end +) + +if "%1" == "doctest" ( + %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest + if errorlevel 1 exit /b 1 + echo. + echo.Testing of doctests in the sources finished, look at the ^ +results in %BUILDDIR%/doctest/output.txt. + goto end +) + +if "%1" == "xml" ( + %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The XML files are in %BUILDDIR%/xml. + goto end +) + +if "%1" == "pseudoxml" ( + %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml. + goto end +) + +:end diff --git a/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/source/api.rst b/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/source/api.rst new file mode 100644 index 0000000..d325d24 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/source/api.rst @@ -0,0 +1,362 @@ +.. currentmodule:: datatree + +############# +API reference +############# + +DataTree +======== + +Creating a DataTree +------------------- + +Methods of creating a datatree. + +.. autosummary:: + :toctree: generated/ + + DataTree + DataTree.from_dict + +Tree Attributes +--------------- + +Attributes relating to the recursive tree-like structure of a ``DataTree``. + +.. autosummary:: + :toctree: generated/ + + DataTree.parent + DataTree.children + DataTree.name + DataTree.path + DataTree.root + DataTree.is_root + DataTree.is_leaf + DataTree.leaves + DataTree.level + DataTree.depth + DataTree.width + DataTree.subtree + DataTree.descendants + DataTree.siblings + DataTree.lineage + DataTree.parents + DataTree.ancestors + DataTree.groups + +Data Contents +------------- + +Interface to the data objects (optionally) stored inside a single ``DataTree`` node. +This interface echoes that of ``xarray.Dataset``. + +.. autosummary:: + :toctree: generated/ + + DataTree.dims + DataTree.sizes + DataTree.data_vars + DataTree.coords + DataTree.attrs + DataTree.encoding + DataTree.indexes + DataTree.nbytes + DataTree.ds + DataTree.to_dataset + DataTree.has_data + DataTree.has_attrs + DataTree.is_empty + DataTree.is_hollow + +Dictionary Interface +-------------------- + +``DataTree`` objects also have a dict-like interface mapping keys to either ``xarray.DataArray``s or to child ``DataTree`` nodes. + +.. autosummary:: + :toctree: generated/ + + DataTree.__getitem__ + DataTree.__setitem__ + DataTree.__delitem__ + DataTree.update + DataTree.get + DataTree.items + DataTree.keys + DataTree.values + +Tree Manipulation +----------------- + +For manipulating, traversing, navigating, or mapping over the tree structure. + +.. autosummary:: + :toctree: generated/ + + DataTree.orphan + DataTree.same_tree + DataTree.relative_to + DataTree.iter_lineage + DataTree.find_common_ancestor + DataTree.map_over_subtree + map_over_subtree + DataTree.pipe + DataTree.match + DataTree.filter + +Pathlib-like Interface +---------------------- + +``DataTree`` objects deliberately echo some of the API of `pathlib.PurePath`. + +.. autosummary:: + :toctree: generated/ + + DataTree.name + DataTree.parent + DataTree.parents + DataTree.relative_to + +Missing: + +.. + + ``DataTree.glob`` + ``DataTree.joinpath`` + ``DataTree.with_name`` + ``DataTree.walk`` + ``DataTree.rename`` + ``DataTree.replace`` + +DataTree Contents +----------------- + +Manipulate the contents of all nodes in a tree simultaneously. + +.. autosummary:: + :toctree: generated/ + + DataTree.copy + DataTree.assign_coords + DataTree.merge + DataTree.rename + DataTree.rename_vars + DataTree.rename_dims + DataTree.swap_dims + DataTree.expand_dims + DataTree.drop_vars + DataTree.drop_dims + DataTree.set_coords + DataTree.reset_coords + +DataTree Node Contents +---------------------- + +Manipulate the contents of a single DataTree node. + +.. autosummary:: + :toctree: generated/ + + DataTree.assign + DataTree.drop_nodes + +Comparisons +=========== + +Compare one ``DataTree`` object to another. + +.. autosummary:: + :toctree: generated/ + + DataTree.isomorphic + DataTree.equals + DataTree.identical + +Indexing +======== + +Index into all nodes in the subtree simultaneously. + +.. autosummary:: + :toctree: generated/ + + DataTree.isel + DataTree.sel + DataTree.drop_sel + DataTree.drop_isel + DataTree.head + DataTree.tail + DataTree.thin + DataTree.squeeze + DataTree.interp + DataTree.interp_like + DataTree.reindex + DataTree.reindex_like + DataTree.set_index + DataTree.reset_index + DataTree.reorder_levels + DataTree.query + +.. + + Missing: + ``DataTree.loc`` + + +Missing Value Handling +====================== + +.. autosummary:: + :toctree: generated/ + + DataTree.isnull + DataTree.notnull + DataTree.combine_first + DataTree.dropna + DataTree.fillna + DataTree.ffill + DataTree.bfill + DataTree.interpolate_na + DataTree.where + DataTree.isin + +Computation +=========== + +Apply a computation to the data in all nodes in the subtree simultaneously. + +.. autosummary:: + :toctree: generated/ + + DataTree.map + DataTree.reduce + DataTree.diff + DataTree.quantile + DataTree.differentiate + DataTree.integrate + DataTree.map_blocks + DataTree.polyfit + DataTree.curvefit + +Aggregation +=========== + +Aggregate data in all nodes in the subtree simultaneously. + +.. autosummary:: + :toctree: generated/ + + DataTree.all + DataTree.any + DataTree.argmax + DataTree.argmin + DataTree.idxmax + DataTree.idxmin + DataTree.max + DataTree.min + DataTree.mean + DataTree.median + DataTree.prod + DataTree.sum + DataTree.std + DataTree.var + DataTree.cumsum + DataTree.cumprod + +ndarray methods +=============== + +Methods copied from :py:class:`numpy.ndarray` objects, here applying to the data in all nodes in the subtree. + +.. autosummary:: + :toctree: generated/ + + DataTree.argsort + DataTree.astype + DataTree.clip + DataTree.conj + DataTree.conjugate + DataTree.round + DataTree.rank + +Reshaping and reorganising +========================== + +Reshape or reorganise the data in all nodes in the subtree. + +.. autosummary:: + :toctree: generated/ + + DataTree.transpose + DataTree.stack + DataTree.unstack + DataTree.shift + DataTree.roll + DataTree.pad + DataTree.sortby + DataTree.broadcast_like + +Plotting +======== + +I/O +=== + +Open a datatree from an on-disk store or serialize the tree. + +.. autosummary:: + :toctree: generated/ + + open_datatree + DataTree.to_dict + DataTree.to_netcdf + DataTree.to_zarr + +.. + + Missing: + ``open_mfdatatree`` + +Tutorial +======== + +Testing +======= + +Test that two DataTree objects are similar. + +.. autosummary:: + :toctree: generated/ + + testing.assert_isomorphic + testing.assert_equal + testing.assert_identical + +Exceptions +========== + +Exceptions raised when manipulating trees. + +.. autosummary:: + :toctree: generated/ + + TreeIsomorphismError + InvalidTreeError + NotFoundInTreeError + +Advanced API +============ + +Relatively advanced API for users or developers looking to understand the internals, or extend functionality. + +.. autosummary:: + :toctree: generated/ + + DataTree.variables + register_datatree_accessor + +.. + + Missing: + ``DataTree.set_close`` diff --git a/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/source/conf.py b/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/source/conf.py new file mode 100644 index 0000000..430dbb5 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/source/conf.py @@ -0,0 +1,412 @@ +# -*- coding: utf-8 -*- +# flake8: noqa +# Ignoring F401: imported but unused + +# complexity documentation build configuration file, created by +# sphinx-quickstart on Tue Jul 9 22:26:36 2013. +# +# This file is execfile()d with the current directory set to its containing dir. +# +# Note that not all possible configuration values are present in this +# autogenerated file. +# +# All configuration values have a default; values that are commented out +# serve to show the default. + +import inspect +import os +import sys + +import sphinx_autosummary_accessors # type: ignore + +import datatree # type: ignore + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# sys.path.insert(0, os.path.abspath('.')) + +cwd = os.getcwd() +parent = os.path.dirname(cwd) +sys.path.insert(0, parent) + + +# -- General configuration ----------------------------------------------------- + +# If your documentation needs a minimal Sphinx version, state it here. +# needs_sphinx = '1.0' + +# Add any Sphinx extension module names here, as strings. They can be extensions +# coming with Sphinx (named 'sphinx.ext.*') or your custom ones. +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.viewcode", + "sphinx.ext.linkcode", + "sphinx.ext.autosummary", + "sphinx.ext.intersphinx", + "sphinx.ext.extlinks", + "sphinx.ext.napoleon", + "sphinx_copybutton", + "sphinxext.opengraph", + "sphinx_autosummary_accessors", + "IPython.sphinxext.ipython_console_highlighting", + "IPython.sphinxext.ipython_directive", + "nbsphinx", + "sphinxcontrib.srclinks", +] + +extlinks = { + "issue": ("https://github.com/xarray-contrib/datatree/issues/%s", "GH#%s"), + "pull": ("https://github.com/xarray-contrib/datatree/pull/%s", "GH#%s"), +} +# Add any paths that contain templates here, relative to this directory. +templates_path = ["_templates", sphinx_autosummary_accessors.templates_path] + +# Generate the API documentation when building +autosummary_generate = True + + +# Napoleon configurations + +napoleon_google_docstring = False +napoleon_numpy_docstring = True +napoleon_use_param = False +napoleon_use_rtype = False +napoleon_preprocess_types = True +napoleon_type_aliases = { + # general terms + "sequence": ":term:`sequence`", + "iterable": ":term:`iterable`", + "callable": ":py:func:`callable`", + "dict_like": ":term:`dict-like `", + "dict-like": ":term:`dict-like `", + "path-like": ":term:`path-like `", + "mapping": ":term:`mapping`", + "file-like": ":term:`file-like `", + # special terms + # "same type as caller": "*same type as caller*", # does not work, yet + # "same type as values": "*same type as values*", # does not work, yet + # stdlib type aliases + "MutableMapping": "~collections.abc.MutableMapping", + "sys.stdout": ":obj:`sys.stdout`", + "timedelta": "~datetime.timedelta", + "string": ":class:`string `", + # numpy terms + "array_like": ":term:`array_like`", + "array-like": ":term:`array-like `", + "scalar": ":term:`scalar`", + "array": ":term:`array`", + "hashable": ":term:`hashable `", + # matplotlib terms + "color-like": ":py:func:`color-like `", + "matplotlib colormap name": ":doc:`matplotlib colormap name `", + "matplotlib axes object": ":py:class:`matplotlib axes object `", + "colormap": ":py:class:`colormap `", + # objects without namespace: xarray + "DataArray": "~xarray.DataArray", + "Dataset": "~xarray.Dataset", + "Variable": "~xarray.Variable", + "DatasetGroupBy": "~xarray.core.groupby.DatasetGroupBy", + "DataArrayGroupBy": "~xarray.core.groupby.DataArrayGroupBy", + # objects without namespace: numpy + "ndarray": "~numpy.ndarray", + "MaskedArray": "~numpy.ma.MaskedArray", + "dtype": "~numpy.dtype", + "ComplexWarning": "~numpy.ComplexWarning", + # objects without namespace: pandas + "Index": "~pandas.Index", + "MultiIndex": "~pandas.MultiIndex", + "CategoricalIndex": "~pandas.CategoricalIndex", + "TimedeltaIndex": "~pandas.TimedeltaIndex", + "DatetimeIndex": "~pandas.DatetimeIndex", + "Series": "~pandas.Series", + "DataFrame": "~pandas.DataFrame", + "Categorical": "~pandas.Categorical", + "Path": "~~pathlib.Path", + # objects with abbreviated namespace (from pandas) + "pd.Index": "~pandas.Index", + "pd.NaT": "~pandas.NaT", +} + +# The suffix of source filenames. +source_suffix = ".rst" + +# The encoding of source files. +# source_encoding = 'utf-8-sig' + +# The master toctree document. +master_doc = "index" + +# General information about the project. +project = "Datatree" +copyright = "2021 onwards, Tom Nicholas and its Contributors" +author = "Tom Nicholas" + +html_show_sourcelink = True +srclink_project = "https://github.com/xarray-contrib/datatree" +srclink_branch = "main" +srclink_src_path = "docs/source" + +# The version info for the project you're documenting, acts as replacement for +# |version| and |release|, also used in various other places throughout the +# built documents. +# +# The short X.Y version. +version = datatree.__version__ +# The full version, including alpha/beta/rc tags. +release = datatree.__version__ + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# language = None + +# There are two options for replacing |today|: either, you set today to some +# non-false value, then it is used: +# today = '' +# Else, today_fmt is used as the format for a strftime call. +# today_fmt = '%B %d, %Y' + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +exclude_patterns = ["_build"] + +# The reST default role (used for this markup: `text`) to use for all documents. +# default_role = None + +# If true, '()' will be appended to :func: etc. cross-reference text. +# add_function_parentheses = True + +# If true, the current module name will be prepended to all description +# unit titles (such as .. function::). +# add_module_names = True + +# If true, sectionauthor and moduleauthor directives will be shown in the +# output. They are ignored by default. +# show_authors = False + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = "sphinx" + +# A list of ignored prefixes for module index sorting. +# modindex_common_prefix = [] + +# If true, keep warnings as "system message" paragraphs in the built documents. +# keep_warnings = False + + +# -- Intersphinx links --------------------------------------------------------- + +intersphinx_mapping = { + "python": ("https://docs.python.org/3.8/", None), + "numpy": ("https://numpy.org/doc/stable", None), + "xarray": ("https://xarray.pydata.org/en/stable/", None), +} + +# -- Options for HTML output --------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +html_theme = "sphinx_book_theme" + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +html_theme_options = { + "repository_url": "https://github.com/xarray-contrib/datatree", + "repository_branch": "main", + "path_to_docs": "docs/source", + "use_repository_button": True, + "use_issues_button": True, + "use_edit_page_button": True, +} + +# Add any paths that contain custom themes here, relative to this directory. +# html_theme_path = [] + +# The name for this set of Sphinx documents. If None, it defaults to +# " v documentation". +# html_title = None + +# A shorter title for the navigation bar. Default is the same as html_title. +# html_short_title = None + +# The name of an image file (relative to this directory) to place at the top +# of the sidebar. +# html_logo = None + +# The name of an image file (within the static path) to use as favicon of the +# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 +# pixels large. +# html_favicon = None + +# If not '', a 'Last updated on:' timestamp is inserted at every page bottom, +# using the given strftime format. +# html_last_updated_fmt = '%b %d, %Y' + +# If true, SmartyPants will be used to convert quotes and dashes to +# typographically correct entities. +# html_use_smartypants = True + +# Custom sidebar templates, maps document names to template names. +# html_sidebars = {} + +# Additional templates that should be rendered to pages, maps page names to +# template names. +# html_additional_pages = {} + +# If false, no module index is generated. +# html_domain_indices = True + +# If false, no index is generated. +# html_use_index = True + +# If true, the index is split into individual pages for each letter. +# html_split_index = False + +# If true, links to the reST sources are added to the pages. +# html_show_sourcelink = True + +# If true, "Created using Sphinx" is shown in the HTML footer. Default is True. +# html_show_sphinx = True + +# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. +# html_show_copyright = True + +# If true, an OpenSearch description file will be output, and all pages will +# contain a tag referring to it. The value of this option must be the +# base URL from which the finished HTML is served. +# html_use_opensearch = '' + +# This is the file name suffix for HTML files (e.g. ".xhtml"). +# html_file_suffix = None + +# Output file base name for HTML help builder. +htmlhelp_basename = "datatree_doc" + + +# -- Options for LaTeX output -------------------------------------------------- + +latex_elements: dict = { + # The paper size ('letterpaper' or 'a4paper'). + # 'papersize': 'letterpaper', + # The font size ('10pt', '11pt' or '12pt'). + # 'pointsize': '10pt', + # Additional stuff for the LaTeX preamble. + # 'preamble': '', +} + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, author, documentclass [howto/manual]). +latex_documents = [ + ("index", "datatree.tex", "Datatree Documentation", author, "manual") +] + +# The name of an image file (relative to this directory) to place at the top of +# the title page. +# latex_logo = None + +# For "manual" documents, if this is true, then toplevel headings are parts, +# not chapters. +# latex_use_parts = False + +# If true, show page references after internal links. +# latex_show_pagerefs = False + +# If true, show URL addresses after external links. +# latex_show_urls = False + +# Documents to append as an appendix to all manuals. +# latex_appendices = [] + +# If false, no module index is generated. +# latex_domain_indices = True + + +# -- Options for manual page output -------------------------------------------- + +# One entry per manual page. List of tuples +# (source start file, name, description, authors, manual section). +man_pages = [("index", "datatree", "Datatree Documentation", [author], 1)] + +# If true, show URL addresses after external links. +# man_show_urls = False + + +# -- Options for Texinfo output ------------------------------------------------ + +# Grouping the document tree into Texinfo files. List of tuples +# (source start file, target name, title, author, +# dir menu entry, description, category) +texinfo_documents = [ + ( + "index", + "datatree", + "Datatree Documentation", + author, + "datatree", + "Tree-like hierarchical data structure for xarray.", + "Miscellaneous", + ) +] + +# Documents to append as an appendix to all manuals. +# texinfo_appendices = [] + +# If false, no module index is generated. +# texinfo_domain_indices = True + +# How to display URL addresses: 'footnote', 'no', or 'inline'. +# texinfo_show_urls = 'footnote' + +# If true, do not generate a @detailmenu in the "Top" node's menu. +# texinfo_no_detailmenu = False + + +# based on numpy doc/source/conf.py +def linkcode_resolve(domain, info): + """ + Determine the URL corresponding to Python object + """ + if domain != "py": + return None + + modname = info["module"] + fullname = info["fullname"] + + submod = sys.modules.get(modname) + if submod is None: + return None + + obj = submod + for part in fullname.split("."): + try: + obj = getattr(obj, part) + except AttributeError: + return None + + try: + fn = inspect.getsourcefile(inspect.unwrap(obj)) + except TypeError: + fn = None + if not fn: + return None + + try: + source, lineno = inspect.getsourcelines(obj) + except OSError: + lineno = None + + if lineno: + linespec = f"#L{lineno}-L{lineno + len(source) - 1}" + else: + linespec = "" + + fn = os.path.relpath(fn, start=os.path.dirname(datatree.__file__)) + + if "+" in datatree.__version__: + return f"https://github.com/xarray-contrib/datatree/blob/main/datatree/{fn}{linespec}" + else: + return ( + f"https://github.com/xarray-contrib/datatree/blob/" + f"v{datatree.__version__}/datatree/{fn}{linespec}" + ) diff --git a/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/source/contributing.rst b/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/source/contributing.rst new file mode 100644 index 0000000..b070c07 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/source/contributing.rst @@ -0,0 +1,136 @@ +======================== +Contributing to Datatree +======================== + +Contributions are highly welcomed and appreciated. Every little help counts, +so do not hesitate! + +.. contents:: Contribution links + :depth: 2 + +.. _submitfeedback: + +Feature requests and feedback +----------------------------- + +Do you like Datatree? Share some love on Twitter or in your blog posts! + +We'd also like to hear about your propositions and suggestions. Feel free to +`submit them as issues `_ and: + +* Explain in detail how they should work. +* Keep the scope as narrow as possible. This will make it easier to implement. + +.. _reportbugs: + +Report bugs +----------- + +Report bugs for Datatree in the `issue tracker `_. + +If you are reporting a bug, please include: + +* Your operating system name and version. +* Any details about your local setup that might be helpful in troubleshooting, + specifically the Python interpreter version, installed libraries, and Datatree + version. +* Detailed steps to reproduce the bug. + +If you can write a demonstration test that currently fails but should pass +(xfail), that is a very useful commit to make as well, even if you cannot +fix the bug itself. + +.. _fixbugs: + +Fix bugs +-------- + +Look through the `GitHub issues for bugs `_. + +Talk to developers to find out how you can fix specific bugs. + +Write documentation +------------------- + +Datatree could always use more documentation. What exactly is needed? + +* More complementary documentation. Have you perhaps found something unclear? +* Docstrings. There can never be too many of them. +* Blog posts, articles and such -- they're all very appreciated. + +You can also edit documentation files directly in the GitHub web interface, +without using a local copy. This can be convenient for small fixes. + +To build the documentation locally, you first need to install the following +tools: + +- `Sphinx `__ +- `sphinx_rtd_theme `__ +- `sphinx-autosummary-accessors `__ + +You can then build the documentation with the following commands:: + + $ cd docs + $ make html + +The built documentation should be available in the ``docs/_build/`` folder. + +.. _`pull requests`: +.. _pull-requests: + +Preparing Pull Requests +----------------------- + +#. Fork the + `Datatree GitHub repository `__. It's + fine to use ``Datatree`` as your fork repository name because it will live + under your user. + +#. Clone your fork locally using `git `_ and create a branch:: + + $ git clone git@github.com:{YOUR_GITHUB_USERNAME}/Datatree.git + $ cd Datatree + + # now, to fix a bug or add feature create your own branch off "master": + + $ git checkout -b your-bugfix-feature-branch-name master + +#. Install `pre-commit `_ and its hook on the Datatree repo:: + + $ pip install --user pre-commit + $ pre-commit install + + Afterwards ``pre-commit`` will run whenever you commit. + + https://pre-commit.com/ is a framework for managing and maintaining multi-language pre-commit hooks + to ensure code-style and code formatting is consistent. + +#. Install dependencies into a new conda environment:: + + $ conda env update -f ci/environment.yml + +#. Run all the tests + + Now running tests is as simple as issuing this command:: + + $ conda activate datatree-dev + $ pytest --junitxml=test-reports/junit.xml --cov=./ --verbose + + This command will run tests via the "pytest" tool. + +#. You can now edit your local working copy and run the tests again as necessary. Please follow PEP-8 for naming. + + When committing, ``pre-commit`` will re-format the files if necessary. + +#. Commit and push once your tests pass and you are happy with your change(s):: + + $ git commit -a -m "" + $ git push -u + +#. Finally, submit a pull request through the GitHub website using this data:: + + head-fork: YOUR_GITHUB_USERNAME/Datatree + compare: your-branch-name + + base-fork: TomNicholas/datatree + base: master diff --git a/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/source/data-structures.rst b/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/source/data-structures.rst new file mode 100644 index 0000000..02e4a31 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/source/data-structures.rst @@ -0,0 +1,197 @@ +.. currentmodule:: datatree + +.. _data structures: + +Data Structures +=============== + +.. ipython:: python + :suppress: + + import numpy as np + import pandas as pd + import xarray as xr + import datatree + + np.random.seed(123456) + np.set_printoptions(threshold=10) + + %xmode minimal + +.. note:: + + This page builds on the information given in xarray's main page on + `data structures `_, so it is suggested that you + are familiar with those first. + +DataTree +-------- + +:py:class:`DataTree` is xarray's highest-level data structure, able to organise heterogeneous data which +could not be stored inside a single :py:class:`Dataset` object. This includes representing the recursive structure of multiple +`groups`_ within a netCDF file or `Zarr Store`_. + +.. _groups: https://www.unidata.ucar.edu/software/netcdf/workshops/2011/groups-types/GroupsIntro.html +.. _Zarr Store: https://zarr.readthedocs.io/en/stable/tutorial.html#groups + +Each ``DataTree`` object (or "node") contains the same data that a single ``xarray.Dataset`` would (i.e. ``DataArray`` objects +stored under hashable keys), and so has the same key properties: + +- ``dims``: a dictionary mapping of dimension names to lengths, for the variables in this node, +- ``data_vars``: a dict-like container of DataArrays corresponding to variables in this node, +- ``coords``: another dict-like container of DataArrays, corresponding to coordinate variables in this node, +- ``attrs``: dict to hold arbitary metadata relevant to data in this node. + +A single ``DataTree`` object acts much like a single ``Dataset`` object, and has a similar set of dict-like methods +defined upon it. However, ``DataTree``'s can also contain other ``DataTree`` objects, so they can be thought of as nested dict-like +containers of both ``xarray.DataArray``'s and ``DataTree``'s. + +A single datatree object is known as a "node", and its position relative to other nodes is defined by two more key +properties: + +- ``children``: An ordered dictionary mapping from names to other ``DataTree`` objects, known as its' "child nodes". +- ``parent``: The single ``DataTree`` object whose children this datatree is a member of, known as its' "parent node". + +Each child automatically knows about its parent node, and a node without a parent is known as a "root" node +(represented by the ``parent`` attribute pointing to ``None``). +Nodes can have multiple children, but as each child node has at most one parent, there can only ever be one root node in a given tree. + +The overall structure is technically a `connected acyclic undirected rooted graph`, otherwise known as a +`"Tree" `_. + +.. note:: + + Technically a ``DataTree`` with more than one child node forms an `"Ordered Tree" `_, + because the children are stored in an Ordered Dictionary. However, this distinction only really matters for a few + edge cases involving operations on multiple trees simultaneously, and can safely be ignored by most users. + + +``DataTree`` objects can also optionally have a ``name`` as well as ``attrs``, just like a ``DataArray``. +Again these are not normally used unless explicitly accessed by the user. + + +.. _creating a datatree: + +Creating a DataTree +~~~~~~~~~~~~~~~~~~~ + +One way to create a ``DataTree`` from scratch is to create each node individually, +specifying the nodes' relationship to one another as you create each one. + +The ``DataTree`` constructor takes: + +- ``data``: The data that will be stored in this node, represented by a single ``xarray.Dataset``, or a named ``xarray.DataArray``. +- ``parent``: The parent node (if there is one), given as a ``DataTree`` object. +- ``children``: The various child nodes (if there are any), given as a mapping from string keys to ``DataTree`` objects. +- ``name``: A string to use as the name of this node. + +Let's make a single datatree node with some example data in it: + +.. ipython:: python + + from datatree import DataTree + + ds1 = xr.Dataset({"foo": "orange"}) + dt = DataTree(name="root", data=ds1) # create root node + + dt + +At this point our node is also the root node, as every tree has a root node. + +We can add a second node to this tree either by referring to the first node in the constructor of the second: + +.. ipython:: python + + ds2 = xr.Dataset({"bar": 0}, coords={"y": ("y", [0, 1, 2])}) + # add a child by referring to the parent node + node2 = DataTree(name="a", parent=dt, data=ds2) + +or by dynamically updating the attributes of one node to refer to another: + +.. ipython:: python + + # add a second child by first creating a new node ... + ds3 = xr.Dataset({"zed": np.NaN}) + node3 = DataTree(name="b", data=ds3) + # ... then updating its .parent property + node3.parent = dt + +Our tree now has three nodes within it: + +.. ipython:: python + + dt + +It is at tree construction time that consistency checks are enforced. For instance, if we try to create a `cycle` the constructor will raise an error: + +.. ipython:: python + :okexcept: + + dt.parent = node3 + +Alternatively you can also create a ``DataTree`` object from + +- An ``xarray.Dataset`` using ``Dataset.to_node()`` (not yet implemented), +- A dictionary mapping directory-like paths to either ``DataTree`` nodes or data, using :py:meth:`DataTree.from_dict()`, +- A netCDF or Zarr file on disk with :py:func:`open_datatree()`. See :ref:`reading and writing files `. + + +DataTree Contents +~~~~~~~~~~~~~~~~~ + +Like ``xarray.Dataset``, ``DataTree`` implements the python mapping interface, but with values given by either ``xarray.DataArray`` objects or other ``DataTree`` objects. + +.. ipython:: python + + dt["a"] + dt["foo"] + +Iterating over keys will iterate over both the names of variables and child nodes. + +We can also access all the data in a single node through a dataset-like view + +.. ipython:: python + + dt["a"].ds + +This demonstrates the fact that the data in any one node is equivalent to the contents of a single ``xarray.Dataset`` object. +The ``DataTree.ds`` property returns an immutable view, but we can instead extract the node's data contents as a new (and mutable) +``xarray.Dataset`` object via :py:meth:`DataTree.to_dataset()`: + +.. ipython:: python + + dt["a"].to_dataset() + +Like with ``Dataset``, you can access the data and coordinate variables of a node separately via the ``data_vars`` and ``coords`` attributes: + +.. ipython:: python + + dt["a"].data_vars + dt["a"].coords + + +Dictionary-like methods +~~~~~~~~~~~~~~~~~~~~~~~ + +We can update a datatree in-place using Python's standard dictionary syntax, similar to how we can for Dataset objects. +For example, to create this example datatree from scratch, we could have written: + +# TODO update this example using ``.coords`` and ``.data_vars`` as setters, + +.. ipython:: python + + dt = DataTree(name="root") + dt["foo"] = "orange" + dt["a"] = DataTree(data=xr.Dataset({"bar": 0}, coords={"y": ("y", [0, 1, 2])})) + dt["a/b/zed"] = np.NaN + dt + +To change the variables in a node of a ``DataTree``, you can use all the standard dictionary +methods, including ``values``, ``items``, ``__delitem__``, ``get`` and +:py:meth:`DataTree.update`. +Note that assigning a ``DataArray`` object to a ``DataTree`` variable using ``__setitem__`` or ``update`` will +:ref:`automatically align ` the array(s) to the original node's indexes. + +If you copy a ``DataTree`` using the :py:func:`copy` function or the :py:meth:`DataTree.copy` method it will copy the subtree, +meaning that node and children below it, but no parents above it. +Like for ``Dataset``, this copy is shallow by default, but you can copy all the underlying data arrays by calling ``dt.copy(deep=True)``. diff --git a/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/source/hierarchical-data.rst b/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/source/hierarchical-data.rst new file mode 100644 index 0000000..d4f5884 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/source/hierarchical-data.rst @@ -0,0 +1,639 @@ +.. currentmodule:: datatree + +.. _hierarchical-data: + +Working With Hierarchical Data +============================== + +.. ipython:: python + :suppress: + + import numpy as np + import pandas as pd + import xarray as xr + from datatree import DataTree + + np.random.seed(123456) + np.set_printoptions(threshold=10) + + %xmode minimal + +Why Hierarchical Data? +---------------------- + +Many real-world datasets are composed of multiple differing components, +and it can often be be useful to think of these in terms of a hierarchy of related groups of data. +Examples of data which one might want organise in a grouped or hierarchical manner include: + +- Simulation data at multiple resolutions, +- Observational data about the same system but from multiple different types of sensors, +- Mixed experimental and theoretical data, +- A systematic study recording the same experiment but with different parameters, +- Heterogenous data, such as demographic and metereological data, + +or even any combination of the above. + +Often datasets like this cannot easily fit into a single :py:class:`xarray.Dataset` object, +or are more usefully thought of as groups of related ``xarray.Dataset`` objects. +For this purpose we provide the :py:class:`DataTree` class. + +This page explains in detail how to understand and use the different features of the :py:class:`DataTree` class for your own hierarchical data needs. + +.. _node relationships: + +Node Relationships +------------------ + +.. _creating a family tree: + +Creating a Family Tree +~~~~~~~~~~~~~~~~~~~~~~ + +The three main ways of creating a ``DataTree`` object are described briefly in :ref:`creating a datatree`. +Here we go into more detail about how to create a tree node-by-node, using a famous family tree from the Simpsons cartoon as an example. + +Let's start by defining nodes representing the two siblings, Bart and Lisa Simpson: + +.. ipython:: python + + bart = DataTree(name="Bart") + lisa = DataTree(name="Lisa") + +Each of these node objects knows their own :py:class:`~DataTree.name`, but they currently have no relationship to one another. +We can connect them by creating another node representing a common parent, Homer Simpson: + +.. ipython:: python + + homer = DataTree(name="Homer", children={"Bart": bart, "Lisa": lisa}) + +Here we set the children of Homer in the node's constructor. +We now have a small family tree + +.. ipython:: python + + homer + +where we can see how these individual Simpson family members are related to one another. +The nodes representing Bart and Lisa are now connected - we can confirm their sibling rivalry by examining the :py:class:`~DataTree.siblings` property: + +.. ipython:: python + + list(bart.siblings) + +But oops, we forgot Homer's third daughter, Maggie! Let's add her by updating Homer's :py:class:`~DataTree.children` property to include her: + +.. ipython:: python + + maggie = DataTree(name="Maggie") + homer.children = {"Bart": bart, "Lisa": lisa, "Maggie": maggie} + homer + +Let's check that Maggie knows who her Dad is: + +.. ipython:: python + + maggie.parent.name + +That's good - updating the properties of our nodes does not break the internal consistency of our tree, as changes of parentage are automatically reflected on both nodes. + + These children obviously have another parent, Marge Simpson, but ``DataTree`` nodes can only have a maximum of one parent. + Genealogical `family trees are not even technically trees `_ in the mathematical sense - + the fact that distant relatives can mate makes it a directed acyclic graph. + Trees of ``DataTree`` objects cannot represent this. + +Homer is currently listed as having no parent (the so-called "root node" of this tree), but we can update his :py:class:`~DataTree.parent` property: + +.. ipython:: python + + abe = DataTree(name="Abe") + homer.parent = abe + +Abe is now the "root" of this tree, which we can see by examining the :py:class:`~DataTree.root` property of any node in the tree + +.. ipython:: python + + maggie.root.name + +We can see the whole tree by printing Abe's node or just part of the tree by printing Homer's node: + +.. ipython:: python + + abe + homer + +We can see that Homer is aware of his parentage, and we say that Homer and his children form a "subtree" of the larger Simpson family tree. + +In episode 28, Abe Simpson reveals that he had another son, Herbert "Herb" Simpson. +We can add Herbert to the family tree without displacing Homer by :py:meth:`~DataTree.assign`-ing another child to Abe: + +.. ipython:: python + + herbert = DataTree(name="Herb") + abe.assign({"Herbert": herbert}) + +.. note:: + This example shows a minor subtlety - the returned tree has Homer's brother listed as ``"Herbert"``, + but the original node was named "Herbert". Not only are names overriden when stored as keys like this, + but the new node is a copy, so that the original node that was reference is unchanged (i.e. ``herbert.name == "Herb"`` still). + In other words, nodes are copied into trees, not inserted into them. + This is intentional, and mirrors the behaviour when storing named ``xarray.DataArray`` objects inside datasets. + +Certain manipulations of our tree are forbidden, if they would create an inconsistent result. +In episode 51 of the show Futurama, Philip J. Fry travels back in time and accidentally becomes his own Grandfather. +If we try similar time-travelling hijinks with Homer, we get a :py:class:`InvalidTreeError` raised: + +.. ipython:: python + :okexcept: + + abe.parent = homer + +.. _evolutionary tree: + +Ancestry in an Evolutionary Tree +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Let's use a different example of a tree to discuss more complex relationships between nodes - the phylogenetic tree, or tree of life. + +.. ipython:: python + + vertebrates = DataTree.from_dict( + name="Vertebrae", + d={ + "/Sharks": None, + "/Bony Skeleton/Ray-finned Fish": None, + "/Bony Skeleton/Four Limbs/Amphibians": None, + "/Bony Skeleton/Four Limbs/Amniotic Egg/Hair/Primates": None, + "/Bony Skeleton/Four Limbs/Amniotic Egg/Hair/Rodents & Rabbits": None, + "/Bony Skeleton/Four Limbs/Amniotic Egg/Two Fenestrae/Dinosaurs": None, + "/Bony Skeleton/Four Limbs/Amniotic Egg/Two Fenestrae/Birds": None, + }, + ) + + primates = vertebrates["/Bony Skeleton/Four Limbs/Amniotic Egg/Hair/Primates"] + dinosaurs = vertebrates[ + "/Bony Skeleton/Four Limbs/Amniotic Egg/Two Fenestrae/Dinosaurs" + ] + +We have used the :py:meth:`~DataTree.from_dict` constructor method as an alternate way to quickly create a whole tree, +and :ref:`filesystem paths` (to be explained shortly) to select two nodes of interest. + +.. ipython:: python + + vertebrates + +This tree shows various families of species, grouped by their common features (making it technically a `"Cladogram" `_, +rather than an evolutionary tree). + +Here both the species and the features used to group them are represented by ``DataTree`` node objects - there is no distinction in types of node. +We can however get a list of only the nodes we used to represent species by using the fact that all those nodes have no children - they are "leaf nodes". +We can check if a node is a leaf with :py:meth:`~DataTree.is_leaf`, and get a list of all leaves with the :py:class:`~DataTree.leaves` property: + +.. ipython:: python + + primates.is_leaf + [node.name for node in vertebrates.leaves] + +Pretending that this is a true evolutionary tree for a moment, we can find the features of the evolutionary ancestors (so-called "ancestor" nodes), +the distinguishing feature of the common ancestor of all vertebrate life (the root node), +and even the distinguishing feature of the common ancestor of any two species (the common ancestor of two nodes): + +.. ipython:: python + + [node.name for node in primates.ancestors] + primates.root.name + primates.find_common_ancestor(dinosaurs).name + +We can only find a common ancestor between two nodes that lie in the same tree. +If we try to find the common evolutionary ancestor between primates and an Alien species that has no relationship to Earth's evolutionary tree, +an error will be raised. + +.. ipython:: python + :okexcept: + + alien = DataTree(name="Xenomorph") + primates.find_common_ancestor(alien) + + +.. _navigating trees: + +Navigating Trees +---------------- + +There are various ways to access the different nodes in a tree. + +Properties +~~~~~~~~~~ + +We can navigate trees using the :py:class:`~DataTree.parent` and :py:class:`~DataTree.children` properties of each node, for example: + +.. ipython:: python + + lisa.parent.children["Bart"].name + +but there are also more convenient ways to access nodes. + +Dictionary-like interface +~~~~~~~~~~~~~~~~~~~~~~~~~ + +Children are stored on each node as a key-value mapping from name to child node. +They can be accessed and altered via the :py:class:`~DataTree.__getitem__` and :py:class:`~DataTree.__setitem__` syntax. +In general :py:class:`~DataTree.DataTree` objects support almost the entire set of dict-like methods, +including :py:meth:`~DataTree.keys`, :py:class:`~DataTree.values`, :py:class:`~DataTree.items`, +:py:meth:`~DataTree.__delitem__` and :py:meth:`~DataTree.update`. + +.. ipython:: python + + vertebrates["Bony Skeleton"]["Ray-finned Fish"] + +Note that the dict-like interface combines access to child ``DataTree`` nodes and stored ``DataArrays``, +so if we have a node that contains both children and data, calling :py:meth:`~DataTree.keys` will list both names of child nodes and +names of data variables: + +.. ipython:: python + + dt = DataTree( + data=xr.Dataset({"foo": 0, "bar": 1}), + children={"a": DataTree(), "b": DataTree()}, + ) + print(dt) + list(dt.keys()) + +This also means that the names of variables and of child nodes must be different to one another. + +Attribute-like access +~~~~~~~~~~~~~~~~~~~~~ + +You can also select both variables and child nodes through dot indexing + +.. ipython:: python + + dt.foo + dt.a + +.. _filesystem paths: + +Filesystem-like Paths +~~~~~~~~~~~~~~~~~~~~~ + +Hierarchical trees can be thought of as analogous to file systems. +Each node is like a directory, and each directory can contain both more sub-directories and data. + +.. note:: + + You can even make the filesystem analogy concrete by using :py:func:`~DataTree.open_mfdatatree` or :py:func:`~DataTree.save_mfdatatree` # TODO not yet implemented - see GH issue 51 + +Datatree objects support a syntax inspired by unix-like filesystems, +where the "path" to a node is specified by the keys of each intermediate node in sequence, +separated by forward slashes. +This is an extension of the conventional dictionary ``__getitem__`` syntax to allow navigation across multiple levels of the tree. + +Like with filepaths, paths within the tree can either be relative to the current node, e.g. + +.. ipython:: python + + abe["Homer/Bart"].name + abe["./Homer/Bart"].name # alternative syntax + +or relative to the root node. +A path specified from the root (as opposed to being specified relative to an arbitrary node in the tree) is sometimes also referred to as a +`"fully qualified name" `_, +or as an "absolute path". +The root node is referred to by ``"/"``, so the path from the root node to its grand-child would be ``"/child/grandchild"``, e.g. + +.. ipython:: python + + # absolute path will start from root node + lisa["/Homer/Bart"].name + +Relative paths between nodes also support the ``"../"`` syntax to mean the parent of the current node. +We can use this with ``__setitem__`` to add a missing entry to our evolutionary tree, but add it relative to a more familiar node of interest: + +.. ipython:: python + + primates["../../Two Fenestrae/Crocodiles"] = DataTree() + print(vertebrates) + +Given two nodes in a tree, we can also find their relative path: + +.. ipython:: python + + bart.relative_to(lisa) + +You can use this filepath feature to build a nested tree from a dictionary of filesystem-like paths and corresponding ``xarray.Dataset`` objects in a single step. +If we have a dictionary where each key is a valid path, and each value is either valid data or ``None``, +we can construct a complex tree quickly using the alternative constructor :py:meth:`DataTree.from_dict()`: + +.. ipython:: python + + d = { + "/": xr.Dataset({"foo": "orange"}), + "/a": xr.Dataset({"bar": 0}, coords={"y": ("y", [0, 1, 2])}), + "/a/b": xr.Dataset({"zed": np.NaN}), + "a/c/d": None, + } + dt = DataTree.from_dict(d) + dt + +.. note:: + + Notice that using the path-like syntax will also create any intermediate empty nodes necessary to reach the end of the specified path + (i.e. the node labelled `"c"` in this case.) + This is to help avoid lots of redundant entries when creating deeply-nested trees using :py:meth:`DataTree.from_dict`. + +.. _iterating over trees: + +Iterating over trees +~~~~~~~~~~~~~~~~~~~~ + +You can iterate over every node in a tree using the subtree :py:class:`~DataTree.subtree` property. +This returns an iterable of nodes, which yields them in depth-first order. + +.. ipython:: python + + for node in vertebrates.subtree: + print(node.path) + +A very useful pattern is to use :py:class:`~DataTree.subtree` conjunction with the :py:class:`~DataTree.path` property to manipulate the nodes however you wish, +then rebuild a new tree using :py:meth:`DataTree.from_dict()`. + +For example, we could keep only the nodes containing data by looping over all nodes, +checking if they contain any data using :py:class:`~DataTree.has_data`, +then rebuilding a new tree using only the paths of those nodes: + +.. ipython:: python + + non_empty_nodes = {node.path: node.ds for node in dt.subtree if node.has_data} + DataTree.from_dict(non_empty_nodes) + +You can see this tree is similar to the ``dt`` object above, except that it is missing the empty nodes ``a/c`` and ``a/c/d``. + +(If you want to keep the name of the root node, you will need to add the ``name`` kwarg to :py:class:`from_dict`, i.e. ``DataTree.from_dict(non_empty_nodes, name=dt.root.name)``.) + +.. _manipulating trees: + +Manipulating Trees +------------------ + +Subsetting Tree Nodes +~~~~~~~~~~~~~~~~~~~~~ + +We can subset our tree to select only nodes of interest in various ways. + +Similarly to on a real filesystem, matching nodes by common patterns in their paths is often useful. +We can use :py:meth:`DataTree.match` for this: + +.. ipython:: python + + dt = DataTree.from_dict( + { + "/a/A": None, + "/a/B": None, + "/b/A": None, + "/b/B": None, + } + ) + result = dt.match("*/B") + result + +We can also subset trees by the contents of the nodes. +:py:meth:`DataTree.filter` retains only the nodes of a tree that meet a certain condition. +For example, we could recreate the Simpson's family tree with the ages of each individual, then filter for only the adults: +First lets recreate the tree but with an `age` data variable in every node: + +.. ipython:: python + + simpsons = DataTree.from_dict( + d={ + "/": xr.Dataset({"age": 83}), + "/Herbert": xr.Dataset({"age": 40}), + "/Homer": xr.Dataset({"age": 39}), + "/Homer/Bart": xr.Dataset({"age": 10}), + "/Homer/Lisa": xr.Dataset({"age": 8}), + "/Homer/Maggie": xr.Dataset({"age": 1}), + }, + name="Abe", + ) + simpsons + +Now let's filter out the minors: + +.. ipython:: python + + simpsons.filter(lambda node: node["age"] > 18) + +The result is a new tree, containing only the nodes matching the condition. + +(Yes, under the hood :py:meth:`~DataTree.filter` is just syntactic sugar for the pattern we showed you in :ref:`iterating over trees` !) + +.. _Tree Contents: + +Tree Contents +------------- + +Hollow Trees +~~~~~~~~~~~~ + +A concept that can sometimes be useful is that of a "Hollow Tree", which means a tree with data stored only at the leaf nodes. +This is useful because certain useful tree manipulation operations only make sense for hollow trees. + +You can check if a tree is a hollow tree by using the :py:class:`~DataTree.is_hollow` property. +We can see that the Simpson's family is not hollow because the data variable ``"age"`` is present at some nodes which +have children (i.e. Abe and Homer). + +.. ipython:: python + + simpsons.is_hollow + +.. _tree computation: + +Computation +----------- + +`DataTree` objects are also useful for performing computations, not just for organizing data. + +Operations and Methods on Trees +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To show how applying operations across a whole tree at once can be useful, +let's first create a example scientific dataset. + +.. ipython:: python + + def time_stamps(n_samples, T): + """Create an array of evenly-spaced time stamps""" + return xr.DataArray( + data=np.linspace(0, 2 * np.pi * T, n_samples), dims=["time"] + ) + + + def signal_generator(t, f, A, phase): + """Generate an example electrical-like waveform""" + return A * np.sin(f * t.data + phase) + + + time_stamps1 = time_stamps(n_samples=15, T=1.5) + time_stamps2 = time_stamps(n_samples=10, T=1.0) + + voltages = DataTree.from_dict( + { + "/oscilloscope1": xr.Dataset( + { + "potential": ( + "time", + signal_generator(time_stamps1, f=2, A=1.2, phase=0.5), + ), + "current": ( + "time", + signal_generator(time_stamps1, f=2, A=1.2, phase=1), + ), + }, + coords={"time": time_stamps1}, + ), + "/oscilloscope2": xr.Dataset( + { + "potential": ( + "time", + signal_generator(time_stamps2, f=1.6, A=1.6, phase=0.2), + ), + "current": ( + "time", + signal_generator(time_stamps2, f=1.6, A=1.6, phase=0.7), + ), + }, + coords={"time": time_stamps2}, + ), + } + ) + voltages + +Most xarray computation methods also exist as methods on datatree objects, +so you can for example take the mean value of these two timeseries at once: + +.. ipython:: python + + voltages.mean(dim="time") + +This works by mapping the standard :py:meth:`xarray.Dataset.mean()` method over the dataset stored in each node of the +tree one-by-one. + +The arguments passed to the method are used for every node, so the values of the arguments you pass might be valid for one node and invalid for another + +.. ipython:: python + :okexcept: + + voltages.isel(time=12) + +Notice that the error raised helpfully indicates which node of the tree the operation failed on. + +Arithmetic Methods on Trees +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Arithmetic methods are also implemented, so you can e.g. add a scalar to every dataset in the tree at once. +For example, we can advance the timeline of the Simpsons by a decade just by + +.. ipython:: python + + simpsons + 10 + +See that the same change (fast-forwarding by adding 10 years to the age of each character) has been applied to every node. + +Mapping Custom Functions Over Trees +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +You can map custom computation over each node in a tree using :py:meth:`DataTree.map_over_subtree`. +You can map any function, so long as it takes `xarray.Dataset` objects as one (or more) of the input arguments, +and returns one (or more) xarray datasets. + +.. note:: + + Functions passed to :py:func:`map_over_subtree` cannot alter nodes in-place. + Instead they must return new `xarray.Dataset` objects. + +For example, we can define a function to calculate the Root Mean Square of a timeseries + +.. ipython:: python + + def rms(signal): + return np.sqrt(np.mean(signal**2)) + +Then calculate the RMS value of these signals: + +.. ipython:: python + + voltages.map_over_subtree(rms) + +.. _multiple trees: + +We can also use the :py:func:`map_over_subtree` decorator to promote a function which accepts datasets into one which +accepts datatrees. + +Operating on Multiple Trees +--------------------------- + +The examples so far have involved mapping functions or methods over the nodes of a single tree, +but we can generalize this to mapping functions over multiple trees at once. + +Comparing Trees for Isomorphism +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +For it to make sense to map a single non-unary function over the nodes of multiple trees at once, +each tree needs to have the same structure. Specifically two trees can only be considered similar, or "isomorphic", +if they have the same number of nodes, and each corresponding node has the same number of children. +We can check if any two trees are isomorphic using the :py:meth:`DataTree.isomorphic` method. + +.. ipython:: python + :okexcept: + + dt1 = DataTree.from_dict({"a": None, "a/b": None}) + dt2 = DataTree.from_dict({"a": None}) + dt1.isomorphic(dt2) + + dt3 = DataTree.from_dict({"a": None, "b": None}) + dt1.isomorphic(dt3) + + dt4 = DataTree.from_dict({"A": None, "A/B": xr.Dataset({"foo": 1})}) + dt1.isomorphic(dt4) + +If the trees are not isomorphic a :py:class:`~TreeIsomorphismError` will be raised. +Notice that corresponding tree nodes do not need to have the same name or contain the same data in order to be considered isomorphic. + +Arithmetic Between Multiple Trees +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Arithmetic operations like multiplication are binary operations, so as long as we have two isomorphic trees, +we can do arithmetic between them. + +.. ipython:: python + + currents = DataTree.from_dict( + { + "/oscilloscope1": xr.Dataset( + { + "current": ( + "time", + signal_generator(time_stamps1, f=2, A=1.2, phase=1), + ), + }, + coords={"time": time_stamps1}, + ), + "/oscilloscope2": xr.Dataset( + { + "current": ( + "time", + signal_generator(time_stamps2, f=1.6, A=1.6, phase=0.7), + ), + }, + coords={"time": time_stamps2}, + ), + } + ) + currents + + currents.isomorphic(voltages) + +We could use this feature to quickly calculate the electrical power in our signal, P=IV. + +.. ipython:: python + + power = currents * voltages + power diff --git a/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/source/index.rst b/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/source/index.rst new file mode 100644 index 0000000..a88a574 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/source/index.rst @@ -0,0 +1,61 @@ +.. currentmodule:: datatree + +Datatree +======== + +**Datatree is a prototype implementation of a tree-like hierarchical data structure for xarray.** + +Why Datatree? +~~~~~~~~~~~~~ + +Datatree was born after the xarray team recognised a `need for a new hierarchical data structure `_, +that was more flexible than a single :py:class:`xarray.Dataset` object. +The initial motivation was to represent netCDF files / Zarr stores with multiple nested groups in a single in-memory object, +but :py:class:`~datatree.DataTree` objects have many other uses. + +You might want to use datatree for: + +- Organising many related datasets, e.g. results of the same experiment with different parameters, or simulations of the same system using different models, +- Analysing similar data at multiple resolutions simultaneously, such as when doing a convergence study, +- Comparing heterogenous but related data, such as experimental and theoretical data, +- I/O with nested data formats such as netCDF / Zarr groups. + +Development Roadmap +~~~~~~~~~~~~~~~~~~~ + +Datatree currently lives in a separate repository to the main xarray package. +This allows the datatree developers to make changes to it, experiment, and improve it faster. + +Eventually we plan to fully integrate datatree upstream into xarray's main codebase, at which point the `github.com/xarray-contrib/datatree `_ repository will be archived. +This should not cause much disruption to code that depends on datatree - you will likely only have to change the import line (i.e. from ``from datatree import DataTree`` to ``from xarray import DataTree``). + +However, until this full integration occurs, datatree's API should not be considered to have the same `level of stability as xarray's `_. + +User Feedback +~~~~~~~~~~~~~ + +We really really really want to hear your opinions on datatree! +At this point in development, user feedback is critical to help us create something that will suit everyone's needs. +Please raise any thoughts, issues, suggestions or bugs, no matter how small or large, on the `github issue tracker `_. + +.. toctree:: + :maxdepth: 2 + :caption: Documentation Contents + + Installation + Quick Overview + Tutorial + Data Model + Hierarchical Data + Reading and Writing Files + API Reference + Terminology + Contributing Guide + What's New + GitHub repository + +Feedback +-------- + +If you encounter any errors, problems with **Datatree**, or have any suggestions, please open an issue +on `GitHub `_. diff --git a/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/source/installation.rst b/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/source/installation.rst new file mode 100644 index 0000000..b268274 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/source/installation.rst @@ -0,0 +1,38 @@ +.. currentmodule:: datatree + +============ +Installation +============ + +Datatree can be installed in three ways: + +Using the `conda `__ package manager that comes with the +Anaconda/Miniconda distribution: + +.. code:: bash + + $ conda install xarray-datatree --channel conda-forge + +Using the `pip `__ package manager: + +.. code:: bash + + $ python -m pip install xarray-datatree + +To install a development version from source: + +.. code:: bash + + $ git clone https://github.com/xarray-contrib/datatree + $ cd datatree + $ python -m pip install -e . + + +You will just need xarray as a required dependency, with netcdf4, zarr, and h5netcdf as optional dependencies to allow file I/O. + +.. note:: + + Datatree is very much still in the early stages of development. There may be functions that are present but whose + internals are not yet implemented, or significant changes to the API in future. + That said, if you try it out and find some behaviour that looks like a bug to you, please report it on the + `issue tracker `_! diff --git a/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/source/io.rst b/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/source/io.rst new file mode 100644 index 0000000..2f2dabf --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/source/io.rst @@ -0,0 +1,54 @@ +.. currentmodule:: datatree + +.. _io: + +Reading and Writing Files +========================= + +.. note:: + + This page builds on the information given in xarray's main page on + `reading and writing files `_, + so it is suggested that you are familiar with those first. + + +netCDF +------ + +Groups +~~~~~~ + +Whilst netCDF groups can only be loaded individually as Dataset objects, a whole file of many nested groups can be loaded +as a single :py:class:`DataTree` object. +To open a whole netCDF file as a tree of groups use the :py:func:`open_datatree` function. +To save a DataTree object as a netCDF file containing many groups, use the :py:meth:`DataTree.to_netcdf` method. + + +.. _netcdf.group.warning: + +.. warning:: + ``DataTree`` objects do not follow the exact same data model as netCDF files, which means that perfect round-tripping + is not always possible. + + In particular in the netCDF data model dimensions are entities that can exist regardless of whether any variable possesses them. + This is in contrast to `xarray's data model `_ + (and hence :ref:`datatree's data model `) in which the dimensions of a (Dataset/Tree) + object are simply the set of dimensions present across all variables in that dataset. + + This means that if a netCDF file contains dimensions but no variables which possess those dimensions, + these dimensions will not be present when that file is opened as a DataTree object. + Saving this DataTree object to file will therefore not preserve these "unused" dimensions. + +Zarr +---- + +Groups +~~~~~~ + +Nested groups in zarr stores can be represented by loading the store as a :py:class:`DataTree` object, similarly to netCDF. +To open a whole zarr store as a tree of groups use the :py:func:`open_datatree` function. +To save a DataTree object as a zarr store containing many groups, use the :py:meth:`DataTree.to_zarr()` method. + +.. note:: + Note that perfect round-tripping should always be possible with a zarr store (:ref:`unlike for netCDF files `), + as zarr does not support "unused" dimensions. diff --git a/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/source/quick-overview.rst b/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/source/quick-overview.rst new file mode 100644 index 0000000..4743b08 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/source/quick-overview.rst @@ -0,0 +1,84 @@ +.. currentmodule:: datatree + +############## +Quick overview +############## + +DataTrees +--------- + +:py:class:`DataTree` is a tree-like container of :py:class:`xarray.DataArray` objects, organised into multiple mutually alignable groups. +You can think of it like a (recursive) ``dict`` of :py:class:`xarray.Dataset` objects. + +Let's first make some example xarray datasets (following on from xarray's +`quick overview `_ page): + +.. ipython:: python + + import numpy as np + import xarray as xr + + data = xr.DataArray(np.random.randn(2, 3), dims=("x", "y"), coords={"x": [10, 20]}) + ds = xr.Dataset(dict(foo=data, bar=("x", [1, 2]), baz=np.pi)) + ds + + ds2 = ds.interp(coords={"x": [10, 12, 14, 16, 18, 20]}) + ds2 + + ds3 = xr.Dataset( + dict(people=["alice", "bob"], heights=("people", [1.57, 1.82])), + coords={"species": "human"}, + ) + ds3 + +Now we'll put this data into a multi-group tree: + +.. ipython:: python + + from datatree import DataTree + + dt = DataTree.from_dict({"simulation/coarse": ds, "simulation/fine": ds2, "/": ds3}) + dt + +This creates a datatree with various groups. We have one root group, containing information about individual people. +(This root group can be named, but here is unnamed, so is referred to with ``"/"``, same as the root of a unix-like filesystem.) +The root group then has one subgroup ``simulation``, which contains no data itself but does contain another two subgroups, +named ``fine`` and ``coarse``. + +The (sub-)sub-groups ``fine`` and ``coarse`` contain two very similar datasets. +They both have an ``"x"`` dimension, but the dimension is of different lengths in each group, which makes the data in each group unalignable. +In the root group we placed some completely unrelated information, showing how we can use a tree to store heterogenous data. + +The constraints on each group are therefore the same as the constraint on dataarrays within a single dataset. + +We created the sub-groups using a filesystem-like syntax, and accessing groups works the same way. +We can access individual dataarrays in a similar fashion + +.. ipython:: python + + dt["simulation/coarse/foo"] + +and we can also pull out the data in a particular group as a ``Dataset`` object using ``.ds``: + +.. ipython:: python + + dt["simulation/coarse"].ds + +Operations map over subtrees, so we can take a mean over the ``x`` dimension of both the ``fine`` and ``coarse`` groups just by + +.. ipython:: python + + avg = dt["simulation"].mean(dim="x") + avg + +Here the ``"x"`` dimension used is always the one local to that sub-group. + +You can do almost everything you can do with ``Dataset`` objects with ``DataTree`` objects +(including indexing and arithmetic), as operations will be mapped over every sub-group in the tree. +This allows you to work with multiple groups of non-alignable variables at once. + +.. note:: + + If all of your variables are mutually alignable + (i.e. they live on the same grid, such that every common dimension name maps to the same length), + then you probably don't need :py:class:`DataTree`, and should consider just sticking with ``xarray.Dataset``. diff --git a/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/source/terminology.rst b/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/source/terminology.rst new file mode 100644 index 0000000..e481a01 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/source/terminology.rst @@ -0,0 +1,34 @@ +.. currentmodule:: datatree + +.. _terminology: + +This page extends `xarray's page on terminology `_. + +Terminology +=========== + +.. glossary:: + + DataTree + A tree-like collection of ``Dataset`` objects. A *tree* is made up of one or more *nodes*, + each of which can store the same information as a single ``Dataset`` (accessed via `.ds`). + This data is stored in the same way as in a ``Dataset``, i.e. in the form of data variables + (see **Variable** in the `corresponding xarray terminology page `_), + dimensions, coordinates, and attributes. + + The nodes in a tree are linked to one another, and each node is it's own instance of ``DataTree`` object. + Each node can have zero or more *children* (stored in a dictionary-like manner under their corresponding *names*), + and those child nodes can themselves have children. + If a node is a child of another node that other node is said to be its *parent*. Nodes can have a maximum of one parent, + and if a node has no parent it is said to be the *root* node of that *tree*. + + Subtree + A section of a *tree*, consisting of a *node* along with all the child nodes below it + (and the child nodes below them, i.e. all so-called *descendant* nodes). + Excludes the parent node and all nodes above. + + Group + Another word for a subtree, reflecting how the hierarchical structure of a ``DataTree`` allows for grouping related data together. + Analogous to a single + `netCDF group `_ or + `Zarr group `_. diff --git a/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/source/tutorial.rst b/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/source/tutorial.rst new file mode 100644 index 0000000..6e33bd3 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/source/tutorial.rst @@ -0,0 +1,7 @@ +.. currentmodule:: datatree + +======== +Tutorial +======== + +Coming soon! diff --git a/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/source/whats-new.rst b/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/source/whats-new.rst new file mode 100644 index 0000000..2f6e4f8 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/datatree_/docs/source/whats-new.rst @@ -0,0 +1,426 @@ +.. currentmodule:: datatree + +What's New +========== + +.. ipython:: python + :suppress: + + import numpy as np + import pandas as pd + import xarray as xray + import xarray + import xarray as xr + import datatree + + np.random.seed(123456) + +.. _whats-new.v0.0.14: + +v0.0.14 (unreleased) +-------------------- + +New Features +~~~~~~~~~~~~ + +Breaking changes +~~~~~~~~~~~~~~~~ + +- Renamed `DataTree.lineage` to `DataTree.parents` to match `pathlib` vocabulary + (:issue:`283`, :pull:`286`) +- Minimum required version of xarray is now 2023.12.0, i.e. the latest version. + This is required to prevent recent changes to xarray's internals from breaking datatree. + (:issue:`293`, :pull:`294`) + By `Tom Nicholas `_. +- Change default write mode of :py:meth:`DataTree.to_zarr` to ``'w-'`` to match ``xarray`` + default and prevent accidental directory overwrites. (:issue:`274`, :pull:`275`) + By `Sam Levang `_. + +Deprecations +~~~~~~~~~~~~ + +- Renamed `DataTree.lineage` to `DataTree.parents` to match `pathlib` vocabulary + (:issue:`283`, :pull:`286`). `lineage` is now deprecated and use of `parents` is encouraged. + By `Etienne Schalk `_. + +Bug fixes +~~~~~~~~~ +- Keep attributes on nodes containing no data in :py:func:`map_over_subtree`. (:issue:`278`, :pull:`279`) + By `Sam Levang `_. + +Documentation +~~~~~~~~~~~~~ +- Use ``napoleon`` instead of ``numpydoc`` to align with xarray documentation + (:issue:`284`, :pull:`298`). + By `Etienne Schalk `_. + +Internal Changes +~~~~~~~~~~~~~~~~ + +.. _whats-new.v0.0.13: + +v0.0.13 (27/10/2023) +-------------------- + +New Features +~~~~~~~~~~~~ + +- New :py:meth:`DataTree.match` method for glob-like pattern matching of node paths. (:pull:`267`) + By `Tom Nicholas `_. +- New :py:meth:`DataTree.is_hollow` property for checking if data is only contained at the leaf nodes. (:pull:`272`) + By `Tom Nicholas `_. +- Indicate which node caused the problem if error encountered while applying user function using :py:func:`map_over_subtree` + (:issue:`190`, :pull:`264`). Only works when using python 3.11 or later. + By `Tom Nicholas `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + +- Nodes containing only attributes but no data are now ignored by :py:func:`map_over_subtree` (:issue:`262`, :pull:`263`) + By `Tom Nicholas `_. +- Disallow altering of given dataset inside function called by :py:func:`map_over_subtree` (:pull:`269`, reverts part of :pull:`194`). + By `Tom Nicholas `_. + +Bug fixes +~~~~~~~~~ + +- Fix unittests on i386. (:pull:`249`) + By `Antonio Valentino `_. +- Ensure nodepath class is compatible with python 3.12 (:pull:`260`) + By `Max Grover `_. + +Documentation +~~~~~~~~~~~~~ + +- Added new sections to page on ``Working with Hierarchical Data`` (:pull:`180`) + By `Tom Nicholas `_. + +Internal Changes +~~~~~~~~~~~~~~~~ + +* No longer use the deprecated `distutils` package. + +.. _whats-new.v0.0.12: + +v0.0.12 (03/07/2023) +-------------------- + +New Features +~~~~~~~~~~~~ + +- Added a :py:func:`DataTree.level`, :py:func:`DataTree.depth`, and :py:func:`DataTree.width` property (:pull:`208`). + By `Tom Nicholas `_. +- Allow dot-style (or "attribute-like") access to child nodes and variables, with ipython autocomplete. (:issue:`189`, :pull:`98`) + By `Tom Nicholas `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + +Deprecations +~~~~~~~~~~~~ + +- Dropped support for python 3.8 (:issue:`212`, :pull:`214`) + By `Tom Nicholas `_. + +Bug fixes +~~~~~~~~~ + +- Allow for altering of given dataset inside function called by :py:func:`map_over_subtree` (:issue:`188`, :pull:`194`). + By `Tom Nicholas `_. +- copy subtrees without creating ancestor nodes (:pull:`201`) + By `Justus Magin `_. + +Documentation +~~~~~~~~~~~~~ + +Internal Changes +~~~~~~~~~~~~~~~~ + +.. _whats-new.v0.0.11: + +v0.0.11 (01/09/2023) +-------------------- + +Big update with entirely new pages in the docs, +new methods (``.drop_nodes``, ``.filter``, ``.leaves``, ``.descendants``), and bug fixes! + +New Features +~~~~~~~~~~~~ + +- Added a :py:meth:`DataTree.drop_nodes` method (:issue:`161`, :pull:`175`). + By `Tom Nicholas `_. +- New, more specific exception types for tree-related errors (:pull:`169`). + By `Tom Nicholas `_. +- Added a new :py:meth:`DataTree.descendants` property (:pull:`170`). + By `Tom Nicholas `_. +- Added a :py:meth:`DataTree.leaves` property (:pull:`177`). + By `Tom Nicholas `_. +- Added a :py:meth:`DataTree.filter` method (:pull:`184`). + By `Tom Nicholas `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + +- :py:meth:`DataTree.copy` copy method now only copies the subtree, not the parent nodes (:pull:`171`). + By `Tom Nicholas `_. +- Grafting a subtree onto another tree now leaves name of original subtree object unchanged (:issue:`116`, :pull:`172`, :pull:`178`). + By `Tom Nicholas `_. +- Changed the :py:meth:`DataTree.assign` method to just work on the local node (:pull:`181`). + By `Tom Nicholas `_. + +Deprecations +~~~~~~~~~~~~ + +Bug fixes +~~~~~~~~~ + +- Fix bug with :py:meth:`DataTree.relative_to` method (:issue:`133`, :pull:`160`). + By `Tom Nicholas `_. +- Fix links to API docs in all documentation (:pull:`183`). + By `Tom Nicholas `_. + +Documentation +~~~~~~~~~~~~~ + +- Changed docs theme to match xarray's main documentation. (:pull:`173`) + By `Tom Nicholas `_. +- Added ``Terminology`` page. (:pull:`174`) + By `Tom Nicholas `_. +- Added page on ``Working with Hierarchical Data`` (:pull:`179`) + By `Tom Nicholas `_. +- Added context content to ``Index`` page (:pull:`182`) + By `Tom Nicholas `_. +- Updated the README (:pull:`187`) + By `Tom Nicholas `_. + +Internal Changes +~~~~~~~~~~~~~~~~ + + +.. _whats-new.v0.0.10: + +v0.0.10 (12/07/2022) +-------------------- + +Adds accessors and a `.pipe()` method. + +New Features +~~~~~~~~~~~~ + +- Add the ability to register accessors on ``DataTree`` objects, by using ``register_datatree_accessor``. (:pull:`144`) + By `Tom Nicholas `_. +- Allow method chaining with a new :py:meth:`DataTree.pipe` method (:issue:`151`, :pull:`156`). + By `Justus Magin `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + +Deprecations +~~~~~~~~~~~~ + +Bug fixes +~~~~~~~~~ + +- Allow ``Datatree`` objects as values in :py:meth:`DataTree.from_dict` (:pull:`159`). + By `Justus Magin `_. + +Documentation +~~~~~~~~~~~~~ + +- Added ``Reading and Writing Files`` page. (:pull:`158`) + By `Tom Nicholas `_. + +Internal Changes +~~~~~~~~~~~~~~~~ + +- Avoid reading from same file twice with fsspec3 (:pull:`130`) + By `William Roberts `_. + + +.. _whats-new.v0.0.9: + +v0.0.9 (07/14/2022) +------------------- + +New Features +~~~~~~~~~~~~ + +Breaking changes +~~~~~~~~~~~~~~~~ + +Deprecations +~~~~~~~~~~~~ + +Bug fixes +~~~~~~~~~ + +Documentation +~~~~~~~~~~~~~ +- Switch docs theme (:pull:`123`). + By `JuliusBusecke `_. + +Internal Changes +~~~~~~~~~~~~~~~~ + + +.. _whats-new.v0.0.7: + +v0.0.7 (07/11/2022) +------------------- + +New Features +~~~~~~~~~~~~ + +- Improve the HTML repr by adding tree-style lines connecting groups and sub-groups (:pull:`109`). + By `Benjamin Woods `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + +- The ``DataTree.ds`` attribute now returns a view onto an immutable Dataset-like object, instead of an actual instance + of ``xarray.Dataset``. This make break existing ``isinstance`` checks or ``assert`` comparisons. (:pull:`99`) + By `Tom Nicholas `_. + +Deprecations +~~~~~~~~~~~~ + +Bug fixes +~~~~~~~~~ + +- Modifying the contents of a ``DataTree`` object via the ``DataTree.ds`` attribute is now forbidden, which prevents + any possibility of the contents of a ``DataTree`` object and its ``.ds`` attribute diverging. (:issue:`38`, :pull:`99`) + By `Tom Nicholas `_. +- Fixed a bug so that names of children now always match keys under which parents store them (:pull:`99`). + By `Tom Nicholas `_. + +Documentation +~~~~~~~~~~~~~ + +- Added ``Data Structures`` page describing the internal structure of a ``DataTree`` object, and its relation to + ``xarray.Dataset`` objects. (:pull:`103`) + By `Tom Nicholas `_. +- API page updated with all the methods that are copied from ``xarray.Dataset``. (:pull:`41`) + By `Tom Nicholas `_. + +Internal Changes +~~~~~~~~~~~~~~~~ + +- Refactored ``DataTree`` class to store a set of ``xarray.Variable`` objects instead of a single ``xarray.Dataset``. + This approach means that the ``DataTree`` class now effectively copies and extends the internal structure of + ``xarray.Dataset``. (:pull:`41`) + By `Tom Nicholas `_. +- Refactored to use intermediate ``NamedNode`` class, separating implementation of methods requiring a ``name`` + attribute from those not requiring it. + By `Tom Nicholas `_. +- Made ``testing.test_datatree.create_test_datatree`` into a pytest fixture (:pull:`107`). + By `Benjamin Woods `_. + + + +.. _whats-new.v0.0.6: + +v0.0.6 (06/03/2022) +------------------- + +Various small bug fixes, in preparation for more significant changes in the next version. + +Bug fixes +~~~~~~~~~ + +- Fixed bug with checking that assigning parent or new children did not create a loop in the tree (:pull:`105`) + By `Tom Nicholas `_. +- Do not call ``__exit__`` on Zarr store when opening (:pull:`90`) + By `Matt McCormick `_. +- Fix netCDF encoding for compression (:pull:`95`) + By `Joe Hamman `_. +- Added validity checking for node names (:pull:`106`) + By `Tom Nicholas `_. + +.. _whats-new.v0.0.5: + +v0.0.5 (05/05/2022) +------------------- + +- Major refactor of internals, moving from the ``DataTree.children`` attribute being a ``Tuple[DataTree]`` to being a + ``OrderedDict[str, DataTree]``. This was necessary in order to integrate better with xarray's dictionary-like API, + solve several issues, simplify the code internally, remove dependencies, and enable new features. (:pull:`76`) + By `Tom Nicholas `_. + +New Features +~~~~~~~~~~~~ + +- Syntax for accessing nodes now supports file-like paths, including parent nodes via ``"../"``, relative paths, the + root node via ``"/"``, and the current node via ``"."``. (Internally it actually uses ``pathlib`` now.) + By `Tom Nicholas `_. +- New path-like API methods, such as ``.relative_to``, ``.find_common_ancestor``, and ``.same_tree``. +- Some new dictionary-like methods, such as ``DataTree.get`` and ``DataTree.update``. (:pull:`76`) + By `Tom Nicholas `_. +- New HTML repr, which will automatically display in a jupyter notebook. (:pull:`78`) + By `Tom Nicholas `_. +- New delitem method so you can delete nodes. (:pull:`88`) + By `Tom Nicholas `_. +- New ``to_dict`` method. (:pull:`82`) + By `Tom Nicholas `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + +- Node names are now optional, which means that the root of the tree can be unnamed. This has knock-on effects for + a lot of the API. +- The ``__init__`` signature for ``DataTree`` has changed, so that ``name`` is now an optional kwarg. +- Files will now be loaded as a slightly different tree, because the root group no longer needs to be given a default + name. +- Removed tag-like access to nodes. +- Removes the option to delete all data in a node by assigning None to the node (in favour of deleting data by replacing + the node's ``.ds`` attribute with an empty Dataset), or to create a new empty node in the same way (in favour of + assigning an empty DataTree object instead). +- Removes the ability to create a new node by assigning a ``Dataset`` object to ``DataTree.__setitem__``. +- Several other minor API changes such as ``.pathstr`` -> ``.path``, and ``from_dict``'s dictionary argument now being + required. (:pull:`76`) + By `Tom Nicholas `_. + +Deprecations +~~~~~~~~~~~~ + +- No longer depends on the anytree library (:pull:`76`) + By `Tom Nicholas `_. + +Bug fixes +~~~~~~~~~ + +- Fixed indentation issue with the string repr (:pull:`86`) + By `Tom Nicholas `_. + +Documentation +~~~~~~~~~~~~~ + +- Quick-overview page updated to match change in path syntax (:pull:`76`) + By `Tom Nicholas `_. + +Internal Changes +~~~~~~~~~~~~~~~~ + +- Basically every file was changed in some way to accommodate (:pull:`76`). +- No longer need the utility functions for string manipulation that were defined in ``utils.py``. +- A considerable amount of code copied over from the internals of anytree (e.g. in ``render.py`` and ``iterators.py``). + The Apache license for anytree has now been bundled with datatree. (:pull:`76`). + By `Tom Nicholas `_. + +.. _whats-new.v0.0.4: + +v0.0.4 (31/03/2022) +------------------- + +- Ensure you get the pretty tree-like string representation by default in ipython (:pull:`73`). + By `Tom Nicholas `_. +- Now available on conda-forge (as xarray-datatree)! (:pull:`71`) + By `Anderson Banihirwe `_. +- Allow for python 3.8 (:pull:`70`). + By `Don Setiawan `_. + +.. _whats-new.v0.0.3: + +v0.0.3 (30/03/2022) +------------------- + +- First released version available on both pypi (as xarray-datatree)! diff --git a/test/fixtures/whole_applications/xarray/xarray/datatree_/readthedocs.yml b/test/fixtures/whole_applications/xarray/xarray/datatree_/readthedocs.yml new file mode 100644 index 0000000..9b04939 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/datatree_/readthedocs.yml @@ -0,0 +1,7 @@ +version: 2 +conda: + environment: ci/doc.yml +build: + os: 'ubuntu-20.04' + tools: + python: 'mambaforge-4.10' diff --git a/test/fixtures/whole_applications/xarray/xarray/indexes/__init__.py b/test/fixtures/whole_applications/xarray/xarray/indexes/__init__.py new file mode 100644 index 0000000..b1bf7a1 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/indexes/__init__.py @@ -0,0 +1,8 @@ +"""Xarray index objects for label-based selection and alignment of Dataset / +DataArray objects. + +""" + +from xarray.core.indexes import Index, PandasIndex, PandasMultiIndex + +__all__ = ["Index", "PandasIndex", "PandasMultiIndex"] diff --git a/test/fixtures/whole_applications/xarray/xarray/namedarray/__init__.py b/test/fixtures/whole_applications/xarray/xarray/namedarray/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/fixtures/whole_applications/xarray/xarray/namedarray/_aggregations.py b/test/fixtures/whole_applications/xarray/xarray/namedarray/_aggregations.py new file mode 100644 index 0000000..9f58aeb --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/namedarray/_aggregations.py @@ -0,0 +1,950 @@ +"""Mixin classes with reduction operations.""" + +# This file was generated using xarray.util.generate_aggregations. Do not edit manually. + +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any, Callable + +from xarray.core import duck_array_ops +from xarray.core.types import Dims, Self + + +class NamedArrayAggregations: + __slots__ = () + + def reduce( + self, + func: Callable[..., Any], + dim: Dims = None, + *, + axis: int | Sequence[int] | None = None, + keepdims: bool = False, + **kwargs: Any, + ) -> Self: + raise NotImplementedError() + + def count( + self, + dim: Dims = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this NamedArray's data by applying ``count`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``count``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``count`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : NamedArray + New NamedArray with ``count`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + pandas.DataFrame.count + dask.dataframe.DataFrame.count + Dataset.count + DataArray.count + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x", + ... np.array([1, 2, 3, 0, 2, np.nan]), + ... ) + >>> na + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) + + >>> na.count() + Size: 8B + array(5) + """ + return self.reduce( + duck_array_ops.count, + dim=dim, + **kwargs, + ) + + def all( + self, + dim: Dims = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this NamedArray's data by applying ``all`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``all``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``all`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : NamedArray + New NamedArray with ``all`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.all + dask.array.all + Dataset.all + DataArray.all + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x", + ... np.array([True, True, True, True, True, False], dtype=bool), + ... ) + >>> na + Size: 6B + array([ True, True, True, True, True, False]) + + >>> na.all() + Size: 1B + array(False) + """ + return self.reduce( + duck_array_ops.array_all, + dim=dim, + **kwargs, + ) + + def any( + self, + dim: Dims = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this NamedArray's data by applying ``any`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``any``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``any`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : NamedArray + New NamedArray with ``any`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.any + dask.array.any + Dataset.any + DataArray.any + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x", + ... np.array([True, True, True, True, True, False], dtype=bool), + ... ) + >>> na + Size: 6B + array([ True, True, True, True, True, False]) + + >>> na.any() + Size: 1B + array(True) + """ + return self.reduce( + duck_array_ops.array_any, + dim=dim, + **kwargs, + ) + + def max( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this NamedArray's data by applying ``max`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``max``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``max`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : NamedArray + New NamedArray with ``max`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.max + dask.array.max + Dataset.max + DataArray.max + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x", + ... np.array([1, 2, 3, 0, 2, np.nan]), + ... ) + >>> na + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) + + >>> na.max() + Size: 8B + array(3.) + + Use ``skipna`` to control whether NaNs are ignored. + + >>> na.max(skipna=False) + Size: 8B + array(nan) + """ + return self.reduce( + duck_array_ops.max, + dim=dim, + skipna=skipna, + **kwargs, + ) + + def min( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this NamedArray's data by applying ``min`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``min``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``min`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : NamedArray + New NamedArray with ``min`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.min + dask.array.min + Dataset.min + DataArray.min + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x", + ... np.array([1, 2, 3, 0, 2, np.nan]), + ... ) + >>> na + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) + + >>> na.min() + Size: 8B + array(0.) + + Use ``skipna`` to control whether NaNs are ignored. + + >>> na.min(skipna=False) + Size: 8B + array(nan) + """ + return self.reduce( + duck_array_ops.min, + dim=dim, + skipna=skipna, + **kwargs, + ) + + def mean( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this NamedArray's data by applying ``mean`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``mean``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``mean`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : NamedArray + New NamedArray with ``mean`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.mean + dask.array.mean + Dataset.mean + DataArray.mean + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x", + ... np.array([1, 2, 3, 0, 2, np.nan]), + ... ) + >>> na + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) + + >>> na.mean() + Size: 8B + array(1.6) + + Use ``skipna`` to control whether NaNs are ignored. + + >>> na.mean(skipna=False) + Size: 8B + array(nan) + """ + return self.reduce( + duck_array_ops.mean, + dim=dim, + skipna=skipna, + **kwargs, + ) + + def prod( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + min_count: int | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this NamedArray's data by applying ``prod`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``prod``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + min_count : int or None, optional + The required number of valid values to perform the operation. If + fewer than min_count non-NA values are present the result will be + NA. Only used if skipna is set to True or defaults to True for the + array's dtype. Changed in version 0.17.0: if specified on an integer + array and skipna=True, the result will be a float array. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``prod`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : NamedArray + New NamedArray with ``prod`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.prod + dask.array.prod + Dataset.prod + DataArray.prod + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x", + ... np.array([1, 2, 3, 0, 2, np.nan]), + ... ) + >>> na + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) + + >>> na.prod() + Size: 8B + array(0.) + + Use ``skipna`` to control whether NaNs are ignored. + + >>> na.prod(skipna=False) + Size: 8B + array(nan) + + Specify ``min_count`` for finer control over when NaNs are ignored. + + >>> na.prod(skipna=True, min_count=2) + Size: 8B + array(0.) + """ + return self.reduce( + duck_array_ops.prod, + dim=dim, + skipna=skipna, + min_count=min_count, + **kwargs, + ) + + def sum( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + min_count: int | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this NamedArray's data by applying ``sum`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``sum``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + min_count : int or None, optional + The required number of valid values to perform the operation. If + fewer than min_count non-NA values are present the result will be + NA. Only used if skipna is set to True or defaults to True for the + array's dtype. Changed in version 0.17.0: if specified on an integer + array and skipna=True, the result will be a float array. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``sum`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : NamedArray + New NamedArray with ``sum`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.sum + dask.array.sum + Dataset.sum + DataArray.sum + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x", + ... np.array([1, 2, 3, 0, 2, np.nan]), + ... ) + >>> na + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) + + >>> na.sum() + Size: 8B + array(8.) + + Use ``skipna`` to control whether NaNs are ignored. + + >>> na.sum(skipna=False) + Size: 8B + array(nan) + + Specify ``min_count`` for finer control over when NaNs are ignored. + + >>> na.sum(skipna=True, min_count=2) + Size: 8B + array(8.) + """ + return self.reduce( + duck_array_ops.sum, + dim=dim, + skipna=skipna, + min_count=min_count, + **kwargs, + ) + + def std( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + ddof: int = 0, + **kwargs: Any, + ) -> Self: + """ + Reduce this NamedArray's data by applying ``std`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``std``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + ddof : int, default: 0 + “Delta Degrees of Freedom”: the divisor used in the calculation is ``N - ddof``, + where ``N`` represents the number of elements. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``std`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : NamedArray + New NamedArray with ``std`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.std + dask.array.std + Dataset.std + DataArray.std + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x", + ... np.array([1, 2, 3, 0, 2, np.nan]), + ... ) + >>> na + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) + + >>> na.std() + Size: 8B + array(1.0198039) + + Use ``skipna`` to control whether NaNs are ignored. + + >>> na.std(skipna=False) + Size: 8B + array(nan) + + Specify ``ddof=1`` for an unbiased estimate. + + >>> na.std(skipna=True, ddof=1) + Size: 8B + array(1.14017543) + """ + return self.reduce( + duck_array_ops.std, + dim=dim, + skipna=skipna, + ddof=ddof, + **kwargs, + ) + + def var( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + ddof: int = 0, + **kwargs: Any, + ) -> Self: + """ + Reduce this NamedArray's data by applying ``var`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``var``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + ddof : int, default: 0 + “Delta Degrees of Freedom”: the divisor used in the calculation is ``N - ddof``, + where ``N`` represents the number of elements. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``var`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : NamedArray + New NamedArray with ``var`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.var + dask.array.var + Dataset.var + DataArray.var + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x", + ... np.array([1, 2, 3, 0, 2, np.nan]), + ... ) + >>> na + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) + + >>> na.var() + Size: 8B + array(1.04) + + Use ``skipna`` to control whether NaNs are ignored. + + >>> na.var(skipna=False) + Size: 8B + array(nan) + + Specify ``ddof=1`` for an unbiased estimate. + + >>> na.var(skipna=True, ddof=1) + Size: 8B + array(1.3) + """ + return self.reduce( + duck_array_ops.var, + dim=dim, + skipna=skipna, + ddof=ddof, + **kwargs, + ) + + def median( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this NamedArray's data by applying ``median`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``median``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``median`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : NamedArray + New NamedArray with ``median`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.median + dask.array.median + Dataset.median + DataArray.median + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x", + ... np.array([1, 2, 3, 0, 2, np.nan]), + ... ) + >>> na + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) + + >>> na.median() + Size: 8B + array(2.) + + Use ``skipna`` to control whether NaNs are ignored. + + >>> na.median(skipna=False) + Size: 8B + array(nan) + """ + return self.reduce( + duck_array_ops.median, + dim=dim, + skipna=skipna, + **kwargs, + ) + + def cumsum( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this NamedArray's data by applying ``cumsum`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``cumsum``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``cumsum`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : NamedArray + New NamedArray with ``cumsum`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.cumsum + dask.array.cumsum + Dataset.cumsum + DataArray.cumsum + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x", + ... np.array([1, 2, 3, 0, 2, np.nan]), + ... ) + >>> na + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) + + >>> na.cumsum() + Size: 48B + array([1., 3., 6., 6., 8., 8.]) + + Use ``skipna`` to control whether NaNs are ignored. + + >>> na.cumsum(skipna=False) + Size: 48B + array([ 1., 3., 6., 6., 8., nan]) + """ + return self.reduce( + duck_array_ops.cumsum, + dim=dim, + skipna=skipna, + **kwargs, + ) + + def cumprod( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this NamedArray's data by applying ``cumprod`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``cumprod``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``cumprod`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : NamedArray + New NamedArray with ``cumprod`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.cumprod + dask.array.cumprod + Dataset.cumprod + DataArray.cumprod + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x", + ... np.array([1, 2, 3, 0, 2, np.nan]), + ... ) + >>> na + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) + + >>> na.cumprod() + Size: 48B + array([1., 2., 6., 0., 0., 0.]) + + Use ``skipna`` to control whether NaNs are ignored. + + >>> na.cumprod(skipna=False) + Size: 48B + array([ 1., 2., 6., 0., 0., nan]) + """ + return self.reduce( + duck_array_ops.cumprod, + dim=dim, + skipna=skipna, + **kwargs, + ) diff --git a/test/fixtures/whole_applications/xarray/xarray/namedarray/_array_api.py b/test/fixtures/whole_applications/xarray/xarray/namedarray/_array_api.py new file mode 100644 index 0000000..acbfc8a --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/namedarray/_array_api.py @@ -0,0 +1,219 @@ +from __future__ import annotations + +from types import ModuleType +from typing import Any + +import numpy as np + +from xarray.namedarray._typing import ( + Default, + _arrayapi, + _Axes, + _Axis, + _default, + _Dim, + _DType, + _ScalarType, + _ShapeType, + _SupportsImag, + _SupportsReal, +) +from xarray.namedarray.core import NamedArray + + +def _get_data_namespace(x: NamedArray[Any, Any]) -> ModuleType: + if isinstance(x._data, _arrayapi): + return x._data.__array_namespace__() + + return np + + +# %% Creation Functions + + +def astype( + x: NamedArray[_ShapeType, Any], dtype: _DType, /, *, copy: bool = True +) -> NamedArray[_ShapeType, _DType]: + """ + Copies an array to a specified data type irrespective of Type Promotion Rules rules. + + Parameters + ---------- + x : NamedArray + Array to cast. + dtype : _DType + Desired data type. + copy : bool, optional + Specifies whether to copy an array when the specified dtype matches the data + type of the input array x. + If True, a newly allocated array must always be returned. + If False and the specified dtype matches the data type of the input array, + the input array must be returned; otherwise, a newly allocated array must be + returned. Default: True. + + Returns + ------- + out : NamedArray + An array having the specified data type. The returned array must have the + same shape as x. + + Examples + -------- + >>> narr = NamedArray(("x",), np.asarray([1.5, 2.5])) + >>> narr + Size: 16B + array([1.5, 2.5]) + >>> astype(narr, np.dtype(np.int32)) + Size: 8B + array([1, 2], dtype=int32) + """ + if isinstance(x._data, _arrayapi): + xp = x._data.__array_namespace__() + return x._new(data=xp.astype(x._data, dtype, copy=copy)) + + # np.astype doesn't exist yet: + return x._new(data=x._data.astype(dtype, copy=copy)) # type: ignore[attr-defined] + + +# %% Elementwise Functions + + +def imag( + x: NamedArray[_ShapeType, np.dtype[_SupportsImag[_ScalarType]]], / # type: ignore[type-var] +) -> NamedArray[_ShapeType, np.dtype[_ScalarType]]: + """ + Returns the imaginary component of a complex number for each element x_i of the + input array x. + + Parameters + ---------- + x : NamedArray + Input array. Should have a complex floating-point data type. + + Returns + ------- + out : NamedArray + An array containing the element-wise results. The returned array must have a + floating-point data type with the same floating-point precision as x + (e.g., if x is complex64, the returned array must have the floating-point + data type float32). + + Examples + -------- + >>> narr = NamedArray(("x",), np.asarray([1.0 + 2j, 2 + 4j])) + >>> imag(narr) + Size: 16B + array([2., 4.]) + """ + xp = _get_data_namespace(x) + out = x._new(data=xp.imag(x._data)) + return out + + +def real( + x: NamedArray[_ShapeType, np.dtype[_SupportsReal[_ScalarType]]], / # type: ignore[type-var] +) -> NamedArray[_ShapeType, np.dtype[_ScalarType]]: + """ + Returns the real component of a complex number for each element x_i of the + input array x. + + Parameters + ---------- + x : NamedArray + Input array. Should have a complex floating-point data type. + + Returns + ------- + out : NamedArray + An array containing the element-wise results. The returned array must have a + floating-point data type with the same floating-point precision as x + (e.g., if x is complex64, the returned array must have the floating-point + data type float32). + + Examples + -------- + >>> narr = NamedArray(("x",), np.asarray([1.0 + 2j, 2 + 4j])) + >>> real(narr) + Size: 16B + array([1., 2.]) + """ + xp = _get_data_namespace(x) + out = x._new(data=xp.real(x._data)) + return out + + +# %% Manipulation functions +def expand_dims( + x: NamedArray[Any, _DType], + /, + *, + dim: _Dim | Default = _default, + axis: _Axis = 0, +) -> NamedArray[Any, _DType]: + """ + Expands the shape of an array by inserting a new dimension of size one at the + position specified by dims. + + Parameters + ---------- + x : + Array to expand. + dim : + Dimension name. New dimension will be stored in the axis position. + axis : + (Not recommended) Axis position (zero-based). Default is 0. + + Returns + ------- + out : + An expanded output array having the same data type as x. + + Examples + -------- + >>> x = NamedArray(("x", "y"), np.asarray([[1.0, 2.0], [3.0, 4.0]])) + >>> expand_dims(x) + Size: 32B + array([[[1., 2.], + [3., 4.]]]) + >>> expand_dims(x, dim="z") + Size: 32B + array([[[1., 2.], + [3., 4.]]]) + """ + xp = _get_data_namespace(x) + dims = x.dims + if dim is _default: + dim = f"dim_{len(dims)}" + d = list(dims) + d.insert(axis, dim) + out = x._new(dims=tuple(d), data=xp.expand_dims(x._data, axis=axis)) + return out + + +def permute_dims(x: NamedArray[Any, _DType], axes: _Axes) -> NamedArray[Any, _DType]: + """ + Permutes the dimensions of an array. + + Parameters + ---------- + x : + Array to permute. + axes : + Permutation of the dimensions of x. + + Returns + ------- + out : + An array with permuted dimensions. The returned array must have the same + data type as x. + + """ + + dims = x.dims + new_dims = tuple(dims[i] for i in axes) + if isinstance(x._data, _arrayapi): + xp = _get_data_namespace(x) + out = x._new(dims=new_dims, data=xp.permute_dims(x._data, axes)) + else: + out = x._new(dims=new_dims, data=x._data.transpose(axes)) # type: ignore[attr-defined] + return out diff --git a/test/fixtures/whole_applications/xarray/xarray/namedarray/_typing.py b/test/fixtures/whole_applications/xarray/xarray/namedarray/_typing.py new file mode 100644 index 0000000..b715973 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/namedarray/_typing.py @@ -0,0 +1,319 @@ +from __future__ import annotations + +import sys +from collections.abc import Hashable, Iterable, Mapping, Sequence +from enum import Enum +from types import ModuleType +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Final, + Literal, + Protocol, + SupportsIndex, + TypeVar, + Union, + overload, + runtime_checkable, +) + +import numpy as np + +try: + if sys.version_info >= (3, 11): + from typing import TypeAlias + else: + from typing_extensions import TypeAlias +except ImportError: + if TYPE_CHECKING: + raise + else: + Self: Any = None + + +# Singleton type, as per https://github.com/python/typing/pull/240 +class Default(Enum): + token: Final = 0 + + +_default = Default.token + +# https://stackoverflow.com/questions/74633074/how-to-type-hint-a-generic-numpy-array +_T = TypeVar("_T") +_T_co = TypeVar("_T_co", covariant=True) + +_dtype = np.dtype +_DType = TypeVar("_DType", bound=np.dtype[Any]) +_DType_co = TypeVar("_DType_co", covariant=True, bound=np.dtype[Any]) +# A subset of `npt.DTypeLike` that can be parametrized w.r.t. `np.generic` + +_ScalarType = TypeVar("_ScalarType", bound=np.generic) +_ScalarType_co = TypeVar("_ScalarType_co", bound=np.generic, covariant=True) + + +# A protocol for anything with the dtype attribute +@runtime_checkable +class _SupportsDType(Protocol[_DType_co]): + @property + def dtype(self) -> _DType_co: ... + + +_DTypeLike = Union[ + np.dtype[_ScalarType], + type[_ScalarType], + _SupportsDType[np.dtype[_ScalarType]], +] + +# For unknown shapes Dask uses np.nan, array_api uses None: +_IntOrUnknown = int +_Shape = tuple[_IntOrUnknown, ...] +_ShapeLike = Union[SupportsIndex, Sequence[SupportsIndex]] +_ShapeType = TypeVar("_ShapeType", bound=Any) +_ShapeType_co = TypeVar("_ShapeType_co", bound=Any, covariant=True) + +_Axis = int +_Axes = tuple[_Axis, ...] +_AxisLike = Union[_Axis, _Axes] + +_Chunks = tuple[_Shape, ...] +_NormalizedChunks = tuple[tuple[int, ...], ...] +# FYI in some cases we don't allow `None`, which this doesn't take account of. +T_ChunkDim: TypeAlias = Union[int, Literal["auto"], None, tuple[int, ...]] +# We allow the tuple form of this (though arguably we could transition to named dims only) +T_Chunks: TypeAlias = Union[T_ChunkDim, Mapping[Any, T_ChunkDim]] + +_Dim = Hashable +_Dims = tuple[_Dim, ...] + +_DimsLike = Union[str, Iterable[_Dim]] + +# https://data-apis.org/array-api/latest/API_specification/indexing.html +# TODO: np.array_api was bugged and didn't allow (None,), but should! +# https://github.com/numpy/numpy/pull/25022 +# https://github.com/data-apis/array-api/pull/674 +_IndexKey = Union[int, slice, "ellipsis"] +_IndexKeys = tuple[Union[_IndexKey], ...] # tuple[Union[_IndexKey, None], ...] +_IndexKeyLike = Union[_IndexKey, _IndexKeys] + +_AttrsLike = Union[Mapping[Any, Any], None] + + +class _SupportsReal(Protocol[_T_co]): + @property + def real(self) -> _T_co: ... + + +class _SupportsImag(Protocol[_T_co]): + @property + def imag(self) -> _T_co: ... + + +@runtime_checkable +class _array(Protocol[_ShapeType_co, _DType_co]): + """ + Minimal duck array named array uses. + + Corresponds to np.ndarray. + """ + + @property + def shape(self) -> _Shape: ... + + @property + def dtype(self) -> _DType_co: ... + + +@runtime_checkable +class _arrayfunction( + _array[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType_co] +): + """ + Duck array supporting NEP 18. + + Corresponds to np.ndarray. + """ + + @overload + def __getitem__( + self, key: _arrayfunction[Any, Any] | tuple[_arrayfunction[Any, Any], ...], / + ) -> _arrayfunction[Any, _DType_co]: ... + + @overload + def __getitem__(self, key: _IndexKeyLike, /) -> Any: ... + + def __getitem__( + self, + key: ( + _IndexKeyLike + | _arrayfunction[Any, Any] + | tuple[_arrayfunction[Any, Any], ...] + ), + /, + ) -> _arrayfunction[Any, _DType_co] | Any: ... + + @overload + def __array__(self, dtype: None = ..., /) -> np.ndarray[Any, _DType_co]: ... + + @overload + def __array__(self, dtype: _DType, /) -> np.ndarray[Any, _DType]: ... + + def __array__( + self, dtype: _DType | None = ..., / + ) -> np.ndarray[Any, _DType] | np.ndarray[Any, _DType_co]: ... + + # TODO: Should return the same subclass but with a new dtype generic. + # https://github.com/python/typing/issues/548 + def __array_ufunc__( + self, + ufunc: Any, + method: Any, + *inputs: Any, + **kwargs: Any, + ) -> Any: ... + + # TODO: Should return the same subclass but with a new dtype generic. + # https://github.com/python/typing/issues/548 + def __array_function__( + self, + func: Callable[..., Any], + types: Iterable[type], + args: Iterable[Any], + kwargs: Mapping[str, Any], + ) -> Any: ... + + @property + def imag(self) -> _arrayfunction[_ShapeType_co, Any]: ... + + @property + def real(self) -> _arrayfunction[_ShapeType_co, Any]: ... + + +@runtime_checkable +class _arrayapi(_array[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType_co]): + """ + Duck array supporting NEP 47. + + Corresponds to np.ndarray. + """ + + def __getitem__( + self, + key: ( + _IndexKeyLike | Any + ), # TODO: Any should be _arrayapi[Any, _dtype[np.integer]] + /, + ) -> _arrayapi[Any, Any]: ... + + def __array_namespace__(self) -> ModuleType: ... + + +# NamedArray can most likely use both __array_function__ and __array_namespace__: +_arrayfunction_or_api = (_arrayfunction, _arrayapi) + +duckarray = Union[ + _arrayfunction[_ShapeType_co, _DType_co], _arrayapi[_ShapeType_co, _DType_co] +] + +# Corresponds to np.typing.NDArray: +DuckArray = _arrayfunction[Any, np.dtype[_ScalarType_co]] + + +@runtime_checkable +class _chunkedarray( + _array[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType_co] +): + """ + Minimal chunked duck array. + + Corresponds to np.ndarray. + """ + + @property + def chunks(self) -> _Chunks: ... + + +@runtime_checkable +class _chunkedarrayfunction( + _arrayfunction[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType_co] +): + """ + Chunked duck array supporting NEP 18. + + Corresponds to np.ndarray. + """ + + @property + def chunks(self) -> _Chunks: ... + + +@runtime_checkable +class _chunkedarrayapi( + _arrayapi[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType_co] +): + """ + Chunked duck array supporting NEP 47. + + Corresponds to np.ndarray. + """ + + @property + def chunks(self) -> _Chunks: ... + + +# NamedArray can most likely use both __array_function__ and __array_namespace__: +_chunkedarrayfunction_or_api = (_chunkedarrayfunction, _chunkedarrayapi) +chunkedduckarray = Union[ + _chunkedarrayfunction[_ShapeType_co, _DType_co], + _chunkedarrayapi[_ShapeType_co, _DType_co], +] + + +@runtime_checkable +class _sparsearray( + _array[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType_co] +): + """ + Minimal sparse duck array. + + Corresponds to np.ndarray. + """ + + def todense(self) -> np.ndarray[Any, _DType_co]: ... + + +@runtime_checkable +class _sparsearrayfunction( + _arrayfunction[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType_co] +): + """ + Sparse duck array supporting NEP 18. + + Corresponds to np.ndarray. + """ + + def todense(self) -> np.ndarray[Any, _DType_co]: ... + + +@runtime_checkable +class _sparsearrayapi( + _arrayapi[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType_co] +): + """ + Sparse duck array supporting NEP 47. + + Corresponds to np.ndarray. + """ + + def todense(self) -> np.ndarray[Any, _DType_co]: ... + + +# NamedArray can most likely use both __array_function__ and __array_namespace__: +_sparsearrayfunction_or_api = (_sparsearrayfunction, _sparsearrayapi) +sparseduckarray = Union[ + _sparsearrayfunction[_ShapeType_co, _DType_co], + _sparsearrayapi[_ShapeType_co, _DType_co], +] + +ErrorOptions = Literal["raise", "ignore"] +ErrorOptionsWithWarn = Literal["raise", "warn", "ignore"] diff --git a/test/fixtures/whole_applications/xarray/xarray/namedarray/core.py b/test/fixtures/whole_applications/xarray/xarray/namedarray/core.py new file mode 100644 index 0000000..960ab9d --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/namedarray/core.py @@ -0,0 +1,1160 @@ +from __future__ import annotations + +import copy +import math +import sys +import warnings +from collections.abc import Hashable, Iterable, Mapping, Sequence +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generic, + Literal, + TypeVar, + cast, + overload, +) + +import numpy as np + +# TODO: get rid of this after migrating this class to array API +from xarray.core import dtypes, formatting, formatting_html +from xarray.core.indexing import ( + ExplicitlyIndexed, + ImplicitToExplicitIndexingAdapter, + OuterIndexer, +) +from xarray.namedarray._aggregations import NamedArrayAggregations +from xarray.namedarray._typing import ( + ErrorOptionsWithWarn, + _arrayapi, + _arrayfunction_or_api, + _chunkedarray, + _default, + _dtype, + _DType_co, + _ScalarType_co, + _ShapeType_co, + _sparsearrayfunction_or_api, + _SupportsImag, + _SupportsReal, +) +from xarray.namedarray.parallelcompat import guess_chunkmanager +from xarray.namedarray.pycompat import to_numpy +from xarray.namedarray.utils import ( + either_dict_or_kwargs, + infix_dims, + is_dict_like, + is_duck_dask_array, + to_0d_object_array, +) + +if TYPE_CHECKING: + from numpy.typing import ArrayLike, NDArray + + from xarray.core.types import Dims + from xarray.namedarray._typing import ( + Default, + _AttrsLike, + _Chunks, + _Dim, + _Dims, + _DimsLike, + _DType, + _IntOrUnknown, + _ScalarType, + _Shape, + _ShapeType, + duckarray, + ) + from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint + + try: + from dask.typing import ( + Graph, + NestedKeys, + PostComputeCallable, + PostPersistCallable, + SchedulerGetCallable, + ) + except ImportError: + Graph: Any # type: ignore[no-redef] + NestedKeys: Any # type: ignore[no-redef] + SchedulerGetCallable: Any # type: ignore[no-redef] + PostComputeCallable: Any # type: ignore[no-redef] + PostPersistCallable: Any # type: ignore[no-redef] + + if sys.version_info >= (3, 11): + from typing import Self + else: + from typing_extensions import Self + + T_NamedArray = TypeVar("T_NamedArray", bound="_NamedArray[Any]") + T_NamedArrayInteger = TypeVar( + "T_NamedArrayInteger", bound="_NamedArray[np.integer[Any]]" + ) + + +@overload +def _new( + x: NamedArray[Any, _DType_co], + dims: _DimsLike | Default = ..., + data: duckarray[_ShapeType, _DType] = ..., + attrs: _AttrsLike | Default = ..., +) -> NamedArray[_ShapeType, _DType]: ... + + +@overload +def _new( + x: NamedArray[_ShapeType_co, _DType_co], + dims: _DimsLike | Default = ..., + data: Default = ..., + attrs: _AttrsLike | Default = ..., +) -> NamedArray[_ShapeType_co, _DType_co]: ... + + +def _new( + x: NamedArray[Any, _DType_co], + dims: _DimsLike | Default = _default, + data: duckarray[_ShapeType, _DType] | Default = _default, + attrs: _AttrsLike | Default = _default, +) -> NamedArray[_ShapeType, _DType] | NamedArray[Any, _DType_co]: + """ + Create a new array with new typing information. + + Parameters + ---------- + x : NamedArray + Array to create a new array from + dims : Iterable of Hashable, optional + Name(s) of the dimension(s). + Will copy the dims from x by default. + data : duckarray, optional + The actual data that populates the array. Should match the + shape specified by `dims`. + Will copy the data from x by default. + attrs : dict, optional + A dictionary containing any additional information or + attributes you want to store with the array. + Will copy the attrs from x by default. + """ + dims_ = copy.copy(x._dims) if dims is _default else dims + + attrs_: Mapping[Any, Any] | None + if attrs is _default: + attrs_ = None if x._attrs is None else x._attrs.copy() + else: + attrs_ = attrs + + if data is _default: + return type(x)(dims_, copy.copy(x._data), attrs_) + else: + cls_ = cast("type[NamedArray[_ShapeType, _DType]]", type(x)) + return cls_(dims_, data, attrs_) + + +@overload +def from_array( + dims: _DimsLike, + data: duckarray[_ShapeType, _DType], + attrs: _AttrsLike = ..., +) -> NamedArray[_ShapeType, _DType]: ... + + +@overload +def from_array( + dims: _DimsLike, + data: ArrayLike, + attrs: _AttrsLike = ..., +) -> NamedArray[Any, Any]: ... + + +def from_array( + dims: _DimsLike, + data: duckarray[_ShapeType, _DType] | ArrayLike, + attrs: _AttrsLike = None, +) -> NamedArray[_ShapeType, _DType] | NamedArray[Any, Any]: + """ + Create a Named array from an array-like object. + + Parameters + ---------- + dims : str or iterable of str + Name(s) of the dimension(s). + data : T_DuckArray or ArrayLike + The actual data that populates the array. Should match the + shape specified by `dims`. + attrs : dict, optional + A dictionary containing any additional information or + attributes you want to store with the array. + Default is None, meaning no attributes will be stored. + """ + if isinstance(data, NamedArray): + raise TypeError( + "Array is already a Named array. Use 'data.data' to retrieve the data array" + ) + + # TODO: dask.array.ma.MaskedArray also exists, better way? + if isinstance(data, np.ma.MaskedArray): + mask = np.ma.getmaskarray(data) # type: ignore[no-untyped-call] + if mask.any(): + # TODO: requires refactoring/vendoring xarray.core.dtypes and + # xarray.core.duck_array_ops + raise NotImplementedError("MaskedArray is not supported yet") + + return NamedArray(dims, data, attrs) + + if isinstance(data, _arrayfunction_or_api): + return NamedArray(dims, data, attrs) + + if isinstance(data, tuple): + return NamedArray(dims, to_0d_object_array(data), attrs) + + # validate whether the data is valid data types. + return NamedArray(dims, np.asarray(data), attrs) + + +class NamedArray(NamedArrayAggregations, Generic[_ShapeType_co, _DType_co]): + """ + A wrapper around duck arrays with named dimensions + and attributes which describe a single Array. + Numeric operations on this object implement array broadcasting and + dimension alignment based on dimension names, + rather than axis order. + + + Parameters + ---------- + dims : str or iterable of hashable + Name(s) of the dimension(s). + data : array-like or duck-array + The actual data that populates the array. Should match the + shape specified by `dims`. + attrs : dict, optional + A dictionary containing any additional information or + attributes you want to store with the array. + Default is None, meaning no attributes will be stored. + + Raises + ------ + ValueError + If the `dims` length does not match the number of data dimensions (ndim). + + + Examples + -------- + >>> data = np.array([1.5, 2, 3], dtype=float) + >>> narr = NamedArray(("x",), data, {"units": "m"}) # TODO: Better name than narr? + """ + + __slots__ = ("_data", "_dims", "_attrs") + + _data: duckarray[Any, _DType_co] + _dims: _Dims + _attrs: dict[Any, Any] | None + + def __init__( + self, + dims: _DimsLike, + data: duckarray[Any, _DType_co], + attrs: _AttrsLike = None, + ): + self._data = data + self._dims = self._parse_dimensions(dims) + self._attrs = dict(attrs) if attrs else None + + def __init_subclass__(cls, **kwargs: Any) -> None: + if NamedArray in cls.__bases__ and (cls._new == NamedArray._new): + # Type hinting does not work for subclasses unless _new is + # overridden with the correct class. + raise TypeError( + "Subclasses of `NamedArray` must override the `_new` method." + ) + super().__init_subclass__(**kwargs) + + @overload + def _new( + self, + dims: _DimsLike | Default = ..., + data: duckarray[_ShapeType, _DType] = ..., + attrs: _AttrsLike | Default = ..., + ) -> NamedArray[_ShapeType, _DType]: ... + + @overload + def _new( + self, + dims: _DimsLike | Default = ..., + data: Default = ..., + attrs: _AttrsLike | Default = ..., + ) -> NamedArray[_ShapeType_co, _DType_co]: ... + + def _new( + self, + dims: _DimsLike | Default = _default, + data: duckarray[Any, _DType] | Default = _default, + attrs: _AttrsLike | Default = _default, + ) -> NamedArray[_ShapeType, _DType] | NamedArray[_ShapeType_co, _DType_co]: + """ + Create a new array with new typing information. + + _new has to be reimplemented each time NamedArray is subclassed, + otherwise type hints will not be correct. The same is likely true + for methods that relied on _new. + + Parameters + ---------- + dims : Iterable of Hashable, optional + Name(s) of the dimension(s). + Will copy the dims from x by default. + data : duckarray, optional + The actual data that populates the array. Should match the + shape specified by `dims`. + Will copy the data from x by default. + attrs : dict, optional + A dictionary containing any additional information or + attributes you want to store with the array. + Will copy the attrs from x by default. + """ + return _new(self, dims, data, attrs) + + def _replace( + self, + dims: _DimsLike | Default = _default, + data: duckarray[_ShapeType_co, _DType_co] | Default = _default, + attrs: _AttrsLike | Default = _default, + ) -> Self: + """ + Create a new array with the same typing information. + + The types for each argument cannot change, + use self._new if that is a risk. + + Parameters + ---------- + dims : Iterable of Hashable, optional + Name(s) of the dimension(s). + Will copy the dims from x by default. + data : duckarray, optional + The actual data that populates the array. Should match the + shape specified by `dims`. + Will copy the data from x by default. + attrs : dict, optional + A dictionary containing any additional information or + attributes you want to store with the array. + Will copy the attrs from x by default. + """ + return cast("Self", self._new(dims, data, attrs)) + + def _copy( + self, + deep: bool = True, + data: duckarray[_ShapeType_co, _DType_co] | None = None, + memo: dict[int, Any] | None = None, + ) -> Self: + if data is None: + ndata = self._data + if deep: + ndata = copy.deepcopy(ndata, memo=memo) + else: + ndata = data + self._check_shape(ndata) + + attrs = ( + copy.deepcopy(self._attrs, memo=memo) if deep else copy.copy(self._attrs) + ) + + return self._replace(data=ndata, attrs=attrs) + + def __copy__(self) -> Self: + return self._copy(deep=False) + + def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Self: + return self._copy(deep=True, memo=memo) + + def copy( + self, + deep: bool = True, + data: duckarray[_ShapeType_co, _DType_co] | None = None, + ) -> Self: + """Returns a copy of this object. + + If `deep=True`, the data array is loaded into memory and copied onto + the new object. Dimensions, attributes and encodings are always copied. + + Use `data` to create a new object with the same structure as + original but entirely new data. + + Parameters + ---------- + deep : bool, default: True + Whether the data array is loaded into memory and copied onto + the new object. Default is True. + data : array_like, optional + Data to use in the new object. Must have same shape as original. + When `data` is used, `deep` is ignored. + + Returns + ------- + object : NamedArray + New object with dimensions, attributes, and optionally + data copied from original. + + + """ + return self._copy(deep=deep, data=data) + + @property + def ndim(self) -> int: + """ + Number of array dimensions. + + See Also + -------- + numpy.ndarray.ndim + """ + return len(self.shape) + + @property + def size(self) -> _IntOrUnknown: + """ + Number of elements in the array. + + Equal to ``np.prod(a.shape)``, i.e., the product of the array’s dimensions. + + See Also + -------- + numpy.ndarray.size + """ + return math.prod(self.shape) + + def __len__(self) -> _IntOrUnknown: + try: + return self.shape[0] + except Exception as exc: + raise TypeError("len() of unsized object") from exc + + @property + def dtype(self) -> _DType_co: + """ + Data-type of the array’s elements. + + See Also + -------- + ndarray.dtype + numpy.dtype + """ + return self._data.dtype + + @property + def shape(self) -> _Shape: + """ + Get the shape of the array. + + Returns + ------- + shape : tuple of ints + Tuple of array dimensions. + + See Also + -------- + numpy.ndarray.shape + """ + return self._data.shape + + @property + def nbytes(self) -> _IntOrUnknown: + """ + Total bytes consumed by the elements of the data array. + + If the underlying data array does not include ``nbytes``, estimates + the bytes consumed based on the ``size`` and ``dtype``. + """ + from xarray.namedarray._array_api import _get_data_namespace + + if hasattr(self._data, "nbytes"): + return self._data.nbytes # type: ignore[no-any-return] + + if hasattr(self.dtype, "itemsize"): + itemsize = self.dtype.itemsize + elif isinstance(self._data, _arrayapi): + xp = _get_data_namespace(self) + + if xp.isdtype(self.dtype, "bool"): + itemsize = 1 + elif xp.isdtype(self.dtype, "integral"): + itemsize = xp.iinfo(self.dtype).bits // 8 + else: + itemsize = xp.finfo(self.dtype).bits // 8 + else: + raise TypeError( + "cannot compute the number of bytes (no array API nor nbytes / itemsize)" + ) + + return self.size * itemsize + + @property + def dims(self) -> _Dims: + """Tuple of dimension names with which this NamedArray is associated.""" + return self._dims + + @dims.setter + def dims(self, value: _DimsLike) -> None: + self._dims = self._parse_dimensions(value) + + def _parse_dimensions(self, dims: _DimsLike) -> _Dims: + dims = (dims,) if isinstance(dims, str) else tuple(dims) + if len(dims) != self.ndim: + raise ValueError( + f"dimensions {dims} must have the same length as the " + f"number of data dimensions, ndim={self.ndim}" + ) + if len(set(dims)) < len(dims): + repeated_dims = {d for d in dims if dims.count(d) > 1} + warnings.warn( + f"Duplicate dimension names present: dimensions {repeated_dims} appear more than once in dims={dims}. " + "We do not yet support duplicate dimension names, but we do allow initial construction of the object. " + "We recommend you rename the dims immediately to become distinct, as most xarray functionality is likely to fail silently if you do not. " + "To rename the dimensions you will need to set the ``.dims`` attribute of each variable, ``e.g. var.dims=('x0', 'x1')``.", + UserWarning, + ) + return dims + + @property + def attrs(self) -> dict[Any, Any]: + """Dictionary of local attributes on this NamedArray.""" + if self._attrs is None: + self._attrs = {} + return self._attrs + + @attrs.setter + def attrs(self, value: Mapping[Any, Any]) -> None: + self._attrs = dict(value) if value else None + + def _check_shape(self, new_data: duckarray[Any, _DType_co]) -> None: + if new_data.shape != self.shape: + raise ValueError( + f"replacement data must match the {self.__class__.__name__}'s shape. " + f"replacement data has shape {new_data.shape}; {self.__class__.__name__} has shape {self.shape}" + ) + + @property + def data(self) -> duckarray[Any, _DType_co]: + """ + The NamedArray's data as an array. The underlying array type + (e.g. dask, sparse, pint) is preserved. + + """ + + return self._data + + @data.setter + def data(self, data: duckarray[Any, _DType_co]) -> None: + self._check_shape(data) + self._data = data + + @property + def imag( + self: NamedArray[_ShapeType, np.dtype[_SupportsImag[_ScalarType]]], # type: ignore[type-var] + ) -> NamedArray[_ShapeType, _dtype[_ScalarType]]: + """ + The imaginary part of the array. + + See Also + -------- + numpy.ndarray.imag + """ + if isinstance(self._data, _arrayapi): + from xarray.namedarray._array_api import imag + + return imag(self) + + return self._new(data=self._data.imag) + + @property + def real( + self: NamedArray[_ShapeType, np.dtype[_SupportsReal[_ScalarType]]], # type: ignore[type-var] + ) -> NamedArray[_ShapeType, _dtype[_ScalarType]]: + """ + The real part of the array. + + See Also + -------- + numpy.ndarray.real + """ + if isinstance(self._data, _arrayapi): + from xarray.namedarray._array_api import real + + return real(self) + return self._new(data=self._data.real) + + def __dask_tokenize__(self) -> object: + # Use v.data, instead of v._data, in order to cope with the wrappers + # around NetCDF and the like + from dask.base import normalize_token + + return normalize_token((type(self), self._dims, self.data, self._attrs or None)) + + def __dask_graph__(self) -> Graph | None: + if is_duck_dask_array(self._data): + return self._data.__dask_graph__() + else: + # TODO: Should this method just raise instead? + # raise NotImplementedError("Method requires self.data to be a dask array") + return None + + def __dask_keys__(self) -> NestedKeys: + if is_duck_dask_array(self._data): + return self._data.__dask_keys__() + else: + raise AttributeError("Method requires self.data to be a dask array.") + + def __dask_layers__(self) -> Sequence[str]: + if is_duck_dask_array(self._data): + return self._data.__dask_layers__() + else: + raise AttributeError("Method requires self.data to be a dask array.") + + @property + def __dask_optimize__( + self, + ) -> Callable[..., dict[Any, Any]]: + if is_duck_dask_array(self._data): + return self._data.__dask_optimize__ # type: ignore[no-any-return] + else: + raise AttributeError("Method requires self.data to be a dask array.") + + @property + def __dask_scheduler__(self) -> SchedulerGetCallable: + if is_duck_dask_array(self._data): + return self._data.__dask_scheduler__ + else: + raise AttributeError("Method requires self.data to be a dask array.") + + def __dask_postcompute__( + self, + ) -> tuple[PostComputeCallable, tuple[Any, ...]]: + if is_duck_dask_array(self._data): + array_func, array_args = self._data.__dask_postcompute__() # type: ignore[no-untyped-call] + return self._dask_finalize, (array_func,) + array_args + else: + raise AttributeError("Method requires self.data to be a dask array.") + + def __dask_postpersist__( + self, + ) -> tuple[ + Callable[ + [Graph, PostPersistCallable[Any], Any, Any], + Self, + ], + tuple[Any, ...], + ]: + if is_duck_dask_array(self._data): + a: tuple[PostPersistCallable[Any], tuple[Any, ...]] + a = self._data.__dask_postpersist__() # type: ignore[no-untyped-call] + array_func, array_args = a + + return self._dask_finalize, (array_func,) + array_args + else: + raise AttributeError("Method requires self.data to be a dask array.") + + def _dask_finalize( + self, + results: Graph, + array_func: PostPersistCallable[Any], + *args: Any, + **kwargs: Any, + ) -> Self: + data = array_func(results, *args, **kwargs) + return type(self)(self._dims, data, attrs=self._attrs) + + @overload + def get_axis_num(self, dim: Iterable[Hashable]) -> tuple[int, ...]: ... + + @overload + def get_axis_num(self, dim: Hashable) -> int: ... + + def get_axis_num(self, dim: Hashable | Iterable[Hashable]) -> int | tuple[int, ...]: + """Return axis number(s) corresponding to dimension(s) in this array. + + Parameters + ---------- + dim : str or iterable of str + Dimension name(s) for which to lookup axes. + + Returns + ------- + int or tuple of int + Axis number or numbers corresponding to the given dimensions. + """ + if not isinstance(dim, str) and isinstance(dim, Iterable): + return tuple(self._get_axis_num(d) for d in dim) + else: + return self._get_axis_num(dim) + + def _get_axis_num(self: Any, dim: Hashable) -> int: + _raise_if_any_duplicate_dimensions(self.dims) + try: + return self.dims.index(dim) # type: ignore[no-any-return] + except ValueError: + raise ValueError(f"{dim!r} not found in array dimensions {self.dims!r}") + + @property + def chunks(self) -> _Chunks | None: + """ + Tuple of block lengths for this NamedArray's data, in order of dimensions, or None if + the underlying data is not a dask array. + + See Also + -------- + NamedArray.chunk + NamedArray.chunksizes + xarray.unify_chunks + """ + data = self._data + if isinstance(data, _chunkedarray): + return data.chunks + else: + return None + + @property + def chunksizes( + self, + ) -> Mapping[_Dim, _Shape]: + """ + Mapping from dimension names to block lengths for this namedArray's data, or None if + the underlying data is not a dask array. + Cannot be modified directly, but can be modified by calling .chunk(). + + Differs from NamedArray.chunks because it returns a mapping of dimensions to chunk shapes + instead of a tuple of chunk shapes. + + See Also + -------- + NamedArray.chunk + NamedArray.chunks + xarray.unify_chunks + """ + data = self._data + if isinstance(data, _chunkedarray): + return dict(zip(self.dims, data.chunks)) + else: + return {} + + @property + def sizes(self) -> dict[_Dim, _IntOrUnknown]: + """Ordered mapping from dimension names to lengths.""" + return dict(zip(self.dims, self.shape)) + + def chunk( + self, + chunks: int | Literal["auto"] | Mapping[Any, None | int | tuple[int, ...]] = {}, + chunked_array_type: str | ChunkManagerEntrypoint[Any] | None = None, + from_array_kwargs: Any = None, + **chunks_kwargs: Any, + ) -> Self: + """Coerce this array's data into a dask array with the given chunks. + + If this variable is a non-dask array, it will be converted to dask + array. If it's a dask array, it will be rechunked to the given chunk + sizes. + + If neither chunks is not provided for one or more dimensions, chunk + sizes along that dimension will not be updated; non-dask arrays will be + converted into dask arrays with a single block. + + Parameters + ---------- + chunks : int, tuple or dict, optional + Chunk sizes along each dimension, e.g., ``5``, ``(5, 5)`` or + ``{'x': 5, 'y': 5}``. + chunked_array_type: str, optional + Which chunked array type to coerce this datasets' arrays to. + Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEntrypoint` system. + Experimental API that should not be relied upon. + from_array_kwargs: dict, optional + Additional keyword arguments passed on to the `ChunkManagerEntrypoint.from_array` method used to create + chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg. + For example, with dask as the default chunked array type, this method would pass additional kwargs + to :py:func:`dask.array.from_array`. Experimental API that should not be relied upon. + **chunks_kwargs : {dim: chunks, ...}, optional + The keyword arguments form of ``chunks``. + One of chunks or chunks_kwargs must be provided. + + Returns + ------- + chunked : xarray.Variable + + See Also + -------- + Variable.chunks + Variable.chunksizes + xarray.unify_chunks + dask.array.from_array + """ + + if from_array_kwargs is None: + from_array_kwargs = {} + + if chunks is None: + warnings.warn( + "None value for 'chunks' is deprecated. " + "It will raise an error in the future. Use instead '{}'", + category=FutureWarning, + ) + chunks = {} + + if isinstance(chunks, (float, str, int, tuple, list)): + # TODO we shouldn't assume here that other chunkmanagers can handle these types + # TODO should we call normalize_chunks here? + pass # dask.array.from_array can handle these directly + else: + chunks = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk") + + if is_dict_like(chunks): + chunks = {self.get_axis_num(dim): chunk for dim, chunk in chunks.items()} + + chunkmanager = guess_chunkmanager(chunked_array_type) + + data_old = self._data + if chunkmanager.is_chunked_array(data_old): + data_chunked = chunkmanager.rechunk(data_old, chunks) # type: ignore[arg-type] + else: + if not isinstance(data_old, ExplicitlyIndexed): + ndata = data_old + else: + # Unambiguously handle array storage backends (like NetCDF4 and h5py) + # that can't handle general array indexing. For example, in netCDF4 you + # can do "outer" indexing along two dimensions independent, which works + # differently from how NumPy handles it. + # da.from_array works by using lazy indexing with a tuple of slices. + # Using OuterIndexer is a pragmatic choice: dask does not yet handle + # different indexing types in an explicit way: + # https://github.com/dask/dask/issues/2883 + ndata = ImplicitToExplicitIndexingAdapter(data_old, OuterIndexer) # type: ignore[assignment] + + if is_dict_like(chunks): + chunks = tuple(chunks.get(n, s) for n, s in enumerate(ndata.shape)) # type: ignore[assignment] + + data_chunked = chunkmanager.from_array(ndata, chunks, **from_array_kwargs) # type: ignore[arg-type] + + return self._replace(data=data_chunked) + + def to_numpy(self) -> np.ndarray[Any, Any]: + """Coerces wrapped data to numpy and returns a numpy.ndarray""" + # TODO an entrypoint so array libraries can choose coercion method? + return to_numpy(self._data) + + def as_numpy(self) -> Self: + """Coerces wrapped data into a numpy array, returning a Variable.""" + return self._replace(data=self.to_numpy()) + + def reduce( + self, + func: Callable[..., Any], + dim: Dims = None, + axis: int | Sequence[int] | None = None, + keepdims: bool = False, + **kwargs: Any, + ) -> NamedArray[Any, Any]: + """Reduce this array by applying `func` along some dimension(s). + + Parameters + ---------- + func : callable + Function which can be called in the form + `func(x, axis=axis, **kwargs)` to return the result of reducing an + np.ndarray over an integer valued axis. + dim : "...", str, Iterable of Hashable or None, optional + Dimension(s) over which to apply `func`. By default `func` is + applied over all dimensions. + axis : int or Sequence of int, optional + Axis(es) over which to apply `func`. Only one of the 'dim' + and 'axis' arguments can be supplied. If neither are supplied, then + the reduction is calculated over the flattened array (by calling + `func(x)` without an axis argument). + keepdims : bool, default: False + If True, the dimensions which are reduced are left in the result + as dimensions of size one + **kwargs : dict + Additional keyword arguments passed on to `func`. + + Returns + ------- + reduced : Array + Array with summarized data and the indicated dimension(s) + removed. + """ + if dim == ...: + dim = None + if dim is not None and axis is not None: + raise ValueError("cannot supply both 'axis' and 'dim' arguments") + + if dim is not None: + axis = self.get_axis_num(dim) + + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", r"Mean of empty slice", category=RuntimeWarning + ) + if axis is not None: + if isinstance(axis, tuple) and len(axis) == 1: + # unpack axis for the benefit of functions + # like np.argmin which can't handle tuple arguments + axis = axis[0] + data = func(self.data, axis=axis, **kwargs) + else: + data = func(self.data, **kwargs) + + if getattr(data, "shape", ()) == self.shape: + dims = self.dims + else: + removed_axes: Iterable[int] + if axis is None: + removed_axes = range(self.ndim) + else: + removed_axes = np.atleast_1d(axis) % self.ndim + if keepdims: + # Insert np.newaxis for removed dims + slices = tuple( + np.newaxis if i in removed_axes else slice(None, None) + for i in range(self.ndim) + ) + if getattr(data, "shape", None) is None: + # Reduce has produced a scalar value, not an array-like + data = np.asanyarray(data)[slices] + else: + data = data[slices] + dims = self.dims + else: + dims = tuple( + adim for n, adim in enumerate(self.dims) if n not in removed_axes + ) + + # Return NamedArray to handle IndexVariable when data is nD + return from_array(dims, data, attrs=self._attrs) + + def _nonzero(self: T_NamedArrayInteger) -> tuple[T_NamedArrayInteger, ...]: + """Equivalent numpy's nonzero but returns a tuple of NamedArrays.""" + # TODO: we should replace dask's native nonzero + # after https://github.com/dask/dask/issues/1076 is implemented. + # TODO: cast to ndarray and back to T_DuckArray is a workaround + nonzeros = np.nonzero(cast("NDArray[np.integer[Any]]", self.data)) + _attrs = self.attrs + return tuple( + cast("T_NamedArrayInteger", self._new((dim,), nz, _attrs)) + for nz, dim in zip(nonzeros, self.dims) + ) + + def __repr__(self) -> str: + return formatting.array_repr(self) + + def _repr_html_(self) -> str: + return formatting_html.array_repr(self) + + def _as_sparse( + self, + sparse_format: Literal["coo"] | Default = _default, + fill_value: ArrayLike | Default = _default, + ) -> NamedArray[Any, _DType_co]: + """ + Use sparse-array as backend. + """ + import sparse + + from xarray.namedarray._array_api import astype + + # TODO: what to do if dask-backended? + if fill_value is _default: + dtype, fill_value = dtypes.maybe_promote(self.dtype) + else: + dtype = dtypes.result_type(self.dtype, fill_value) + + if sparse_format is _default: + sparse_format = "coo" + try: + as_sparse = getattr(sparse, f"as_{sparse_format.lower()}") + except AttributeError as exc: + raise ValueError(f"{sparse_format} is not a valid sparse format") from exc + + data = as_sparse(astype(self, dtype).data, fill_value=fill_value) + return self._new(data=data) + + def _to_dense(self) -> NamedArray[Any, _DType_co]: + """ + Change backend from sparse to np.array. + """ + if isinstance(self._data, _sparsearrayfunction_or_api): + data_dense: np.ndarray[Any, _DType_co] = self._data.todense() + return self._new(data=data_dense) + else: + raise TypeError("self.data is not a sparse array") + + def permute_dims( + self, + *dim: Iterable[_Dim] | ellipsis, + missing_dims: ErrorOptionsWithWarn = "raise", + ) -> NamedArray[Any, _DType_co]: + """Return a new object with transposed dimensions. + + Parameters + ---------- + *dim : Hashable, optional + By default, reverse the order of the dimensions. Otherwise, reorder the + dimensions to this order. + missing_dims : {"raise", "warn", "ignore"}, default: "raise" + What to do if dimensions that should be selected from are not present in the + NamedArray: + - "raise": raise an exception + - "warn": raise a warning, and ignore the missing dimensions + - "ignore": ignore the missing dimensions + + Returns + ------- + NamedArray + The returned NamedArray has permuted dimensions and data with the + same attributes as the original. + + + See Also + -------- + numpy.transpose + """ + + from xarray.namedarray._array_api import permute_dims + + if not dim: + dims = self.dims[::-1] + else: + dims = tuple(infix_dims(dim, self.dims, missing_dims)) # type: ignore[arg-type] + + if len(dims) < 2 or dims == self.dims: + # no need to transpose if only one dimension + # or dims are in same order + return self.copy(deep=False) + + axes_result = self.get_axis_num(dims) + axes = (axes_result,) if isinstance(axes_result, int) else axes_result + + return permute_dims(self, axes) + + @property + def T(self) -> NamedArray[Any, _DType_co]: + """Return a new object with transposed dimensions.""" + if self.ndim != 2: + raise ValueError( + f"x.T requires x to have 2 dimensions, got {self.ndim}. Use x.permute_dims() to permute dimensions." + ) + + return self.permute_dims() + + def broadcast_to( + self, dim: Mapping[_Dim, int] | None = None, **dim_kwargs: Any + ) -> NamedArray[Any, _DType_co]: + """ + Broadcast the NamedArray to a new shape. New dimensions are not allowed. + + This method allows for the expansion of the array's dimensions to a specified shape. + It handles both positional and keyword arguments for specifying the dimensions to broadcast. + An error is raised if new dimensions are attempted to be added. + + Parameters + ---------- + dim : dict, str, sequence of str, optional + Dimensions to broadcast the array to. If a dict, keys are dimension names and values are the new sizes. + If a string or sequence of strings, existing dimensions are matched with a size of 1. + + **dim_kwargs : Any + Additional dimensions specified as keyword arguments. Each keyword argument specifies the name of an existing dimension and its size. + + Returns + ------- + NamedArray + A new NamedArray with the broadcasted dimensions. + + Examples + -------- + >>> data = np.asarray([[1.0, 2.0], [3.0, 4.0]]) + >>> array = xr.NamedArray(("x", "y"), data) + >>> array.sizes + {'x': 2, 'y': 2} + + >>> broadcasted = array.broadcast_to(x=2, y=2) + >>> broadcasted.sizes + {'x': 2, 'y': 2} + """ + + from xarray.core import duck_array_ops + + combined_dims = either_dict_or_kwargs(dim, dim_kwargs, "broadcast_to") + + # Check that no new dimensions are added + if new_dims := set(combined_dims) - set(self.dims): + raise ValueError( + f"Cannot add new dimensions: {new_dims}. Only existing dimensions are allowed. " + "Use `expand_dims` method to add new dimensions." + ) + + # Create a dictionary of the current dimensions and their sizes + current_shape = self.sizes + + # Update the current shape with the new dimensions, keeping the order of the original dimensions + broadcast_shape = {d: current_shape.get(d, 1) for d in self.dims} + broadcast_shape |= combined_dims + + # Ensure the dimensions are in the correct order + ordered_dims = list(broadcast_shape.keys()) + ordered_shape = tuple(broadcast_shape[d] for d in ordered_dims) + data = duck_array_ops.broadcast_to(self._data, ordered_shape) # type: ignore[no-untyped-call] # TODO: use array-api-compat function + return self._new(data=data, dims=ordered_dims) + + def expand_dims( + self, + dim: _Dim | Default = _default, + ) -> NamedArray[Any, _DType_co]: + """ + Expand the dimensions of the NamedArray. + + This method adds new dimensions to the object. The new dimensions are added at the beginning of the array. + + Parameters + ---------- + dim : Hashable, optional + Dimension name to expand the array to. This dimension will be added at the beginning of the array. + + Returns + ------- + NamedArray + A new NamedArray with expanded dimensions. + + + Examples + -------- + + >>> data = np.asarray([[1.0, 2.0], [3.0, 4.0]]) + >>> array = xr.NamedArray(("x", "y"), data) + + + # expand dimensions by specifying a new dimension name + >>> expanded = array.expand_dims(dim="z") + >>> expanded.dims + ('z', 'x', 'y') + + """ + + from xarray.namedarray._array_api import expand_dims + + return expand_dims(self, dim=dim) + + +_NamedArray = NamedArray[Any, np.dtype[_ScalarType_co]] + + +def _raise_if_any_duplicate_dimensions( + dims: _Dims, err_context: str = "This function" +) -> None: + if len(set(dims)) < len(dims): + repeated_dims = {d for d in dims if dims.count(d) > 1} + raise ValueError( + f"{err_context} cannot handle duplicate dimensions, but dimensions {repeated_dims} appear more than once on this object's dims: {dims}" + ) diff --git a/test/fixtures/whole_applications/xarray/xarray/namedarray/daskmanager.py b/test/fixtures/whole_applications/xarray/xarray/namedarray/daskmanager.py new file mode 100644 index 0000000..14744d2 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/namedarray/daskmanager.py @@ -0,0 +1,253 @@ +from __future__ import annotations + +from collections.abc import Iterable, Sequence +from typing import TYPE_CHECKING, Any, Callable + +import numpy as np +from packaging.version import Version + +from xarray.core.indexing import ImplicitToExplicitIndexingAdapter +from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint, T_ChunkedArray +from xarray.namedarray.utils import is_duck_dask_array, module_available + +if TYPE_CHECKING: + from xarray.namedarray._typing import ( + T_Chunks, + _DType_co, + _NormalizedChunks, + duckarray, + ) + + try: + from dask.array import Array as DaskArray + except ImportError: + DaskArray = np.ndarray[Any, Any] # type: ignore[assignment, misc] + + +dask_available = module_available("dask") + + +class DaskManager(ChunkManagerEntrypoint["DaskArray"]): # type: ignore[type-var] + array_cls: type[DaskArray] + available: bool = dask_available + + def __init__(self) -> None: + # TODO can we replace this with a class attribute instead? + + from dask.array import Array + + self.array_cls = Array + + def is_chunked_array(self, data: duckarray[Any, Any]) -> bool: + return is_duck_dask_array(data) + + def chunks(self, data: Any) -> _NormalizedChunks: + return data.chunks # type: ignore[no-any-return] + + def normalize_chunks( + self, + chunks: T_Chunks | _NormalizedChunks, + shape: tuple[int, ...] | None = None, + limit: int | None = None, + dtype: _DType_co | None = None, + previous_chunks: _NormalizedChunks | None = None, + ) -> Any: + """Called by open_dataset""" + from dask.array.core import normalize_chunks + + return normalize_chunks( + chunks, + shape=shape, + limit=limit, + dtype=dtype, + previous_chunks=previous_chunks, + ) # type: ignore[no-untyped-call] + + def from_array( + self, data: Any, chunks: T_Chunks | _NormalizedChunks, **kwargs: Any + ) -> DaskArray | Any: + import dask.array as da + + if isinstance(data, ImplicitToExplicitIndexingAdapter): + # lazily loaded backend array classes should use NumPy array operations. + kwargs["meta"] = np.ndarray + + return da.from_array( + data, + chunks, + **kwargs, + ) # type: ignore[no-untyped-call] + + def compute( + self, *data: Any, **kwargs: Any + ) -> tuple[np.ndarray[Any, _DType_co], ...]: + from dask.array import compute + + return compute(*data, **kwargs) # type: ignore[no-untyped-call, no-any-return] + + @property + def array_api(self) -> Any: + from dask import array as da + + return da + + def reduction( # type: ignore[override] + self, + arr: T_ChunkedArray, + func: Callable[..., Any], + combine_func: Callable[..., Any] | None = None, + aggregate_func: Callable[..., Any] | None = None, + axis: int | Sequence[int] | None = None, + dtype: _DType_co | None = None, + keepdims: bool = False, + ) -> DaskArray | Any: + from dask.array import reduction + + return reduction( + arr, + chunk=func, + combine=combine_func, + aggregate=aggregate_func, + axis=axis, + dtype=dtype, + keepdims=keepdims, + ) # type: ignore[no-untyped-call] + + def scan( # type: ignore[override] + self, + func: Callable[..., Any], + binop: Callable[..., Any], + ident: float, + arr: T_ChunkedArray, + axis: int | None = None, + dtype: _DType_co | None = None, + **kwargs: Any, + ) -> DaskArray | Any: + from dask.array.reductions import cumreduction + + return cumreduction( + func, + binop, + ident, + arr, + axis=axis, + dtype=dtype, + **kwargs, + ) # type: ignore[no-untyped-call] + + def apply_gufunc( + self, + func: Callable[..., Any], + signature: str, + *args: Any, + axes: Sequence[tuple[int, ...]] | None = None, + axis: int | None = None, + keepdims: bool = False, + output_dtypes: Sequence[_DType_co] | None = None, + output_sizes: dict[str, int] | None = None, + vectorize: bool | None = None, + allow_rechunk: bool = False, + meta: tuple[np.ndarray[Any, _DType_co], ...] | None = None, + **kwargs: Any, + ) -> Any: + from dask.array.gufunc import apply_gufunc + + return apply_gufunc( + func, + signature, + *args, + axes=axes, + axis=axis, + keepdims=keepdims, + output_dtypes=output_dtypes, + output_sizes=output_sizes, + vectorize=vectorize, + allow_rechunk=allow_rechunk, + meta=meta, + **kwargs, + ) # type: ignore[no-untyped-call] + + def map_blocks( + self, + func: Callable[..., Any], + *args: Any, + dtype: _DType_co | None = None, + chunks: tuple[int, ...] | None = None, + drop_axis: int | Sequence[int] | None = None, + new_axis: int | Sequence[int] | None = None, + **kwargs: Any, + ) -> Any: + import dask + from dask.array import map_blocks + + if drop_axis is None and Version(dask.__version__) < Version("2022.9.1"): + # See https://github.com/pydata/xarray/pull/7019#discussion_r1196729489 + # TODO remove once dask minimum version >= 2022.9.1 + drop_axis = [] + + # pass through name, meta, token as kwargs + return map_blocks( + func, + *args, + dtype=dtype, + chunks=chunks, + drop_axis=drop_axis, + new_axis=new_axis, + **kwargs, + ) # type: ignore[no-untyped-call] + + def blockwise( + self, + func: Callable[..., Any], + out_ind: Iterable[Any], + *args: Any, + # can't type this as mypy assumes args are all same type, but dask blockwise args alternate types + name: str | None = None, + token: Any | None = None, + dtype: _DType_co | None = None, + adjust_chunks: dict[Any, Callable[..., Any]] | None = None, + new_axes: dict[Any, int] | None = None, + align_arrays: bool = True, + concatenate: bool | None = None, + meta: tuple[np.ndarray[Any, _DType_co], ...] | None = None, + **kwargs: Any, + ) -> DaskArray | Any: + from dask.array import blockwise + + return blockwise( + func, + out_ind, + *args, + name=name, + token=token, + dtype=dtype, + adjust_chunks=adjust_chunks, + new_axes=new_axes, + align_arrays=align_arrays, + concatenate=concatenate, + meta=meta, + **kwargs, + ) # type: ignore[no-untyped-call] + + def unify_chunks( + self, + *args: Any, # can't type this as mypy assumes args are all same type, but dask unify_chunks args alternate types + **kwargs: Any, + ) -> tuple[dict[str, _NormalizedChunks], list[DaskArray]]: + from dask.array.core import unify_chunks + + return unify_chunks(*args, **kwargs) # type: ignore[no-any-return, no-untyped-call] + + def store( + self, + sources: Any | Sequence[Any], + targets: Any, + **kwargs: Any, + ) -> Any: + from dask.array import store + + return store( + sources=sources, + targets=targets, + **kwargs, + ) diff --git a/test/fixtures/whole_applications/xarray/xarray/namedarray/dtypes.py b/test/fixtures/whole_applications/xarray/xarray/namedarray/dtypes.py new file mode 100644 index 0000000..7a83bd1 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/namedarray/dtypes.py @@ -0,0 +1,199 @@ +from __future__ import annotations + +import functools +import sys +from typing import Any, Literal + +if sys.version_info >= (3, 10): + from typing import TypeGuard +else: + from typing_extensions import TypeGuard + +import numpy as np + +from xarray.namedarray import utils + +# Use as a sentinel value to indicate a dtype appropriate NA value. +NA = utils.ReprObject("") + + +@functools.total_ordering +class AlwaysGreaterThan: + def __gt__(self, other: Any) -> Literal[True]: + return True + + def __eq__(self, other: Any) -> bool: + return isinstance(other, type(self)) + + +@functools.total_ordering +class AlwaysLessThan: + def __lt__(self, other: Any) -> Literal[True]: + return True + + def __eq__(self, other: Any) -> bool: + return isinstance(other, type(self)) + + +# Equivalence to np.inf (-np.inf) for object-type +INF = AlwaysGreaterThan() +NINF = AlwaysLessThan() + + +# Pairs of types that, if both found, should be promoted to object dtype +# instead of following NumPy's own type-promotion rules. These type promotion +# rules match pandas instead. For reference, see the NumPy type hierarchy: +# https://numpy.org/doc/stable/reference/arrays.scalars.html +PROMOTE_TO_OBJECT: tuple[tuple[type[np.generic], type[np.generic]], ...] = ( + (np.number, np.character), # numpy promotes to character + (np.bool_, np.character), # numpy promotes to character + (np.bytes_, np.str_), # numpy promotes to unicode +) + + +def maybe_promote(dtype: np.dtype[np.generic]) -> tuple[np.dtype[np.generic], Any]: + """Simpler equivalent of pandas.core.common._maybe_promote + + Parameters + ---------- + dtype : np.dtype + + Returns + ------- + dtype : Promoted dtype that can hold missing values. + fill_value : Valid missing value for the promoted dtype. + """ + # N.B. these casting rules should match pandas + dtype_: np.typing.DTypeLike + fill_value: Any + if np.issubdtype(dtype, np.floating): + dtype_ = dtype + fill_value = np.nan + elif np.issubdtype(dtype, np.timedelta64): + # See https://github.com/numpy/numpy/issues/10685 + # np.timedelta64 is a subclass of np.integer + # Check np.timedelta64 before np.integer + fill_value = np.timedelta64("NaT") + dtype_ = dtype + elif np.issubdtype(dtype, np.integer): + dtype_ = np.float32 if dtype.itemsize <= 2 else np.float64 + fill_value = np.nan + elif np.issubdtype(dtype, np.complexfloating): + dtype_ = dtype + fill_value = np.nan + np.nan * 1j + elif np.issubdtype(dtype, np.datetime64): + dtype_ = dtype + fill_value = np.datetime64("NaT") + else: + dtype_ = object + fill_value = np.nan + + dtype_out = np.dtype(dtype_) + fill_value = dtype_out.type(fill_value) + return dtype_out, fill_value + + +NAT_TYPES = {np.datetime64("NaT").dtype, np.timedelta64("NaT").dtype} + + +def get_fill_value(dtype: np.dtype[np.generic]) -> Any: + """Return an appropriate fill value for this dtype. + + Parameters + ---------- + dtype : np.dtype + + Returns + ------- + fill_value : Missing value corresponding to this dtype. + """ + _, fill_value = maybe_promote(dtype) + return fill_value + + +def get_pos_infinity( + dtype: np.dtype[np.generic], max_for_int: bool = False +) -> float | complex | AlwaysGreaterThan: + """Return an appropriate positive infinity for this dtype. + + Parameters + ---------- + dtype : np.dtype + max_for_int : bool + Return np.iinfo(dtype).max instead of np.inf + + Returns + ------- + fill_value : positive infinity value corresponding to this dtype. + """ + if issubclass(dtype.type, np.floating): + return np.inf + + if issubclass(dtype.type, np.integer): + return np.iinfo(dtype.type).max if max_for_int else np.inf + if issubclass(dtype.type, np.complexfloating): + return np.inf + 1j * np.inf + + return INF + + +def get_neg_infinity( + dtype: np.dtype[np.generic], min_for_int: bool = False +) -> float | complex | AlwaysLessThan: + """Return an appropriate positive infinity for this dtype. + + Parameters + ---------- + dtype : np.dtype + min_for_int : bool + Return np.iinfo(dtype).min instead of -np.inf + + Returns + ------- + fill_value : positive infinity value corresponding to this dtype. + """ + if issubclass(dtype.type, np.floating): + return -np.inf + + if issubclass(dtype.type, np.integer): + return np.iinfo(dtype.type).min if min_for_int else -np.inf + if issubclass(dtype.type, np.complexfloating): + return -np.inf - 1j * np.inf + + return NINF + + +def is_datetime_like( + dtype: np.dtype[np.generic], +) -> TypeGuard[np.datetime64 | np.timedelta64]: + """Check if a dtype is a subclass of the numpy datetime types""" + return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64) + + +def result_type( + *arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike, +) -> np.dtype[np.generic]: + """Like np.result_type, but with type promotion rules matching pandas. + + Examples of changed behavior: + number + string -> object (not string) + bytes + unicode -> object (not unicode) + + Parameters + ---------- + *arrays_and_dtypes : list of arrays and dtypes + The dtype is extracted from both numpy and dask arrays. + + Returns + ------- + numpy.dtype for the result. + """ + types = {np.result_type(t).type for t in arrays_and_dtypes} + + for left, right in PROMOTE_TO_OBJECT: + if any(issubclass(t, left) for t in types) and any( + issubclass(t, right) for t in types + ): + return np.dtype(object) + + return np.result_type(*arrays_and_dtypes) diff --git a/test/fixtures/whole_applications/xarray/xarray/namedarray/parallelcompat.py b/test/fixtures/whole_applications/xarray/xarray/namedarray/parallelcompat.py new file mode 100644 index 0000000..dd555fe --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/namedarray/parallelcompat.py @@ -0,0 +1,708 @@ +""" +The code in this module is an experiment in going from N=1 to N=2 parallel computing frameworks in xarray. +It could later be used as the basis for a public interface allowing any N frameworks to interoperate with xarray, +but for now it is just a private experiment. +""" + +from __future__ import annotations + +import functools +import sys +from abc import ABC, abstractmethod +from collections.abc import Iterable, Sequence +from importlib.metadata import EntryPoint, entry_points +from typing import TYPE_CHECKING, Any, Callable, Generic, Protocol, TypeVar + +import numpy as np + +from xarray.core.utils import emit_user_level_warning +from xarray.namedarray.pycompat import is_chunked_array + +if TYPE_CHECKING: + from xarray.namedarray._typing import ( + _Chunks, + _DType, + _DType_co, + _NormalizedChunks, + _ShapeType, + duckarray, + ) + + +class ChunkedArrayMixinProtocol(Protocol): + def rechunk(self, chunks: Any, **kwargs: Any) -> Any: ... + + @property + def dtype(self) -> np.dtype[Any]: ... + + @property + def chunks(self) -> _NormalizedChunks: ... + + def compute( + self, *data: Any, **kwargs: Any + ) -> tuple[np.ndarray[Any, _DType_co], ...]: ... + + +T_ChunkedArray = TypeVar("T_ChunkedArray", bound=ChunkedArrayMixinProtocol) + + +@functools.lru_cache(maxsize=1) +def list_chunkmanagers() -> dict[str, ChunkManagerEntrypoint[Any]]: + """ + Return a dictionary of available chunk managers and their ChunkManagerEntrypoint subclass objects. + + Returns + ------- + chunkmanagers : dict + Dictionary whose values are registered ChunkManagerEntrypoint subclass instances, and whose values + are the strings under which they are registered. + + Notes + ----- + # New selection mechanism introduced with Python 3.10. See GH6514. + """ + if sys.version_info >= (3, 10): + entrypoints = entry_points(group="xarray.chunkmanagers") + else: + entrypoints = entry_points().get("xarray.chunkmanagers", ()) + + return load_chunkmanagers(entrypoints) + + +def load_chunkmanagers( + entrypoints: Sequence[EntryPoint], +) -> dict[str, ChunkManagerEntrypoint[Any]]: + """Load entrypoints and instantiate chunkmanagers only once.""" + + loaded_entrypoints = {} + for entrypoint in entrypoints: + try: + loaded_entrypoints[entrypoint.name] = entrypoint.load() + except ModuleNotFoundError as e: + emit_user_level_warning( + f"Failed to load chunk manager entrypoint {entrypoint.name} due to {e}. Skipping.", + ) + pass + + available_chunkmanagers = { + name: chunkmanager() + for name, chunkmanager in loaded_entrypoints.items() + if chunkmanager.available + } + return available_chunkmanagers + + +def guess_chunkmanager( + manager: str | ChunkManagerEntrypoint[Any] | None, +) -> ChunkManagerEntrypoint[Any]: + """ + Get namespace of chunk-handling methods, guessing from what's available. + + If the name of a specific ChunkManager is given (e.g. "dask"), then use that. + Else use whatever is installed, defaulting to dask if there are multiple options. + """ + + chunkmanagers = list_chunkmanagers() + + if manager is None: + if len(chunkmanagers) == 1: + # use the only option available + manager = next(iter(chunkmanagers.keys())) + else: + # default to trying to use dask + manager = "dask" + + if isinstance(manager, str): + if manager not in chunkmanagers: + raise ValueError( + f"unrecognized chunk manager {manager} - must be one of: {list(chunkmanagers)}" + ) + + return chunkmanagers[manager] + elif isinstance(manager, ChunkManagerEntrypoint): + # already a valid ChunkManager so just pass through + return manager + else: + raise TypeError( + f"manager must be a string or instance of ChunkManagerEntrypoint, but received type {type(manager)}" + ) + + +def get_chunked_array_type(*args: Any) -> ChunkManagerEntrypoint[Any]: + """ + Detects which parallel backend should be used for given set of arrays. + + Also checks that all arrays are of same chunking type (i.e. not a mix of cubed and dask). + """ + + # TODO this list is probably redundant with something inside xarray.apply_ufunc + ALLOWED_NON_CHUNKED_TYPES = {int, float, np.ndarray} + + chunked_arrays = [ + a + for a in args + if is_chunked_array(a) and type(a) not in ALLOWED_NON_CHUNKED_TYPES + ] + + # Asserts all arrays are the same type (or numpy etc.) + chunked_array_types = {type(a) for a in chunked_arrays} + if len(chunked_array_types) > 1: + raise TypeError( + f"Mixing chunked array types is not supported, but received multiple types: {chunked_array_types}" + ) + elif len(chunked_array_types) == 0: + raise TypeError("Expected a chunked array but none were found") + + # iterate over defined chunk managers, seeing if each recognises this array type + chunked_arr = chunked_arrays[0] + chunkmanagers = list_chunkmanagers() + selected = [ + chunkmanager + for chunkmanager in chunkmanagers.values() + if chunkmanager.is_chunked_array(chunked_arr) + ] + if not selected: + raise TypeError( + f"Could not find a Chunk Manager which recognises type {type(chunked_arr)}" + ) + elif len(selected) >= 2: + raise TypeError(f"Multiple ChunkManagers recognise type {type(chunked_arr)}") + else: + return selected[0] + + +class ChunkManagerEntrypoint(ABC, Generic[T_ChunkedArray]): + """ + Interface between a particular parallel computing framework and xarray. + + This abstract base class must be subclassed by libraries implementing chunked array types, and + registered via the ``chunkmanagers`` entrypoint. + + Abstract methods on this class must be implemented, whereas non-abstract methods are only required in order to + enable a subset of xarray functionality, and by default will raise a ``NotImplementedError`` if called. + + Attributes + ---------- + array_cls + Type of the array class this parallel computing framework provides. + + Parallel frameworks need to provide an array class that supports the array API standard. + This attribute is used for array instance type checking at runtime. + """ + + array_cls: type[T_ChunkedArray] + available: bool = True + + @abstractmethod + def __init__(self) -> None: + """Used to set the array_cls attribute at import time.""" + raise NotImplementedError() + + def is_chunked_array(self, data: duckarray[Any, Any]) -> bool: + """ + Check if the given object is an instance of this type of chunked array. + + Compares against the type stored in the array_cls attribute by default. + + Parameters + ---------- + data : Any + + Returns + ------- + is_chunked : bool + + See Also + -------- + dask.is_dask_collection + """ + return isinstance(data, self.array_cls) + + @abstractmethod + def chunks(self, data: T_ChunkedArray) -> _NormalizedChunks: + """ + Return the current chunks of the given array. + + Returns chunks explicitly as a tuple of tuple of ints. + + Used internally by xarray objects' .chunks and .chunksizes properties. + + Parameters + ---------- + data : chunked array + + Returns + ------- + chunks : tuple[tuple[int, ...], ...] + + See Also + -------- + dask.array.Array.chunks + cubed.Array.chunks + """ + raise NotImplementedError() + + @abstractmethod + def normalize_chunks( + self, + chunks: _Chunks | _NormalizedChunks, + shape: _ShapeType | None = None, + limit: int | None = None, + dtype: _DType | None = None, + previous_chunks: _NormalizedChunks | None = None, + ) -> _NormalizedChunks: + """ + Normalize given chunking pattern into an explicit tuple of tuples representation. + + Exposed primarily because different chunking backends may want to make different decisions about how to + automatically chunk along dimensions not given explicitly in the input chunks. + + Called internally by xarray.open_dataset. + + Parameters + ---------- + chunks : tuple, int, dict, or string + The chunks to be normalized. + shape : Tuple[int] + The shape of the array + limit : int (optional) + The maximum block size to target in bytes, + if freedom is given to choose + dtype : np.dtype + previous_chunks : Tuple[Tuple[int]], optional + Chunks from a previous array that we should use for inspiration when + rechunking dimensions automatically. + + See Also + -------- + dask.array.core.normalize_chunks + """ + raise NotImplementedError() + + @abstractmethod + def from_array( + self, data: duckarray[Any, Any], chunks: _Chunks, **kwargs: Any + ) -> T_ChunkedArray: + """ + Create a chunked array from a non-chunked numpy-like array. + + Generally input should have a ``.shape``, ``.ndim``, ``.dtype`` and support numpy-style slicing. + + Called when the .chunk method is called on an xarray object that is not already chunked. + Also called within open_dataset (when chunks is not None) to create a chunked array from + an xarray lazily indexed array. + + Parameters + ---------- + data : array_like + chunks : int, tuple + How to chunk the array. + + See Also + -------- + dask.array.from_array + cubed.from_array + """ + raise NotImplementedError() + + def rechunk( + self, + data: T_ChunkedArray, + chunks: _NormalizedChunks | tuple[int, ...] | _Chunks, + **kwargs: Any, + ) -> Any: + """ + Changes the chunking pattern of the given array. + + Called when the .chunk method is called on an xarray object that is already chunked. + + Parameters + ---------- + data : dask array + Array to be rechunked. + chunks : int, tuple, dict or str, optional + The new block dimensions to create. -1 indicates the full size of the + corresponding dimension. Default is "auto" which automatically + determines chunk sizes. + + Returns + ------- + chunked array + + See Also + -------- + dask.array.Array.rechunk + cubed.Array.rechunk + """ + return data.rechunk(chunks, **kwargs) + + @abstractmethod + def compute( + self, *data: T_ChunkedArray | Any, **kwargs: Any + ) -> tuple[np.ndarray[Any, _DType_co], ...]: + """ + Computes one or more chunked arrays, returning them as eager numpy arrays. + + Called anytime something needs to computed, including multiple arrays at once. + Used by `.compute`, `.persist`, `.values`. + + Parameters + ---------- + *data : object + Any number of objects. If an object is an instance of the chunked array type, it is computed + and the in-memory result returned as a numpy array. All other types should be passed through unchanged. + + Returns + ------- + objs + The input, but with all chunked arrays now computed. + + See Also + -------- + dask.compute + cubed.compute + """ + raise NotImplementedError() + + @property + def array_api(self) -> Any: + """ + Return the array_api namespace following the python array API standard. + + See https://data-apis.org/array-api/latest/ . Currently used to access the array API function + ``full_like``, which is called within the xarray constructors ``xarray.full_like``, ``xarray.ones_like``, + ``xarray.zeros_like``, etc. + + See Also + -------- + dask.array + cubed.array_api + """ + raise NotImplementedError() + + def reduction( + self, + arr: T_ChunkedArray, + func: Callable[..., Any], + combine_func: Callable[..., Any] | None = None, + aggregate_func: Callable[..., Any] | None = None, + axis: int | Sequence[int] | None = None, + dtype: _DType_co | None = None, + keepdims: bool = False, + ) -> T_ChunkedArray: + """ + A general version of array reductions along one or more axes. + + Used inside some reductions like nanfirst, which is used by ``groupby.first``. + + Parameters + ---------- + arr : chunked array + Data to be reduced along one or more axes. + func : Callable(x_chunk, axis, keepdims) + First function to be executed when resolving the dask graph. + This function is applied in parallel to all original chunks of x. + See below for function parameters. + combine_func : Callable(x_chunk, axis, keepdims), optional + Function used for intermediate recursive aggregation (see + split_every below). If omitted, it defaults to aggregate_func. + aggregate_func : Callable(x_chunk, axis, keepdims) + Last function to be executed, producing the final output. It is always invoked, even when the reduced + Array counts a single chunk along the reduced axes. + axis : int or sequence of ints, optional + Axis or axes to aggregate upon. If omitted, aggregate along all axes. + dtype : np.dtype + data type of output. This argument was previously optional, but + leaving as ``None`` will now raise an exception. + keepdims : boolean, optional + Whether the reduction function should preserve the reduced axes, + leaving them at size ``output_size``, or remove them. + + Returns + ------- + chunked array + + See Also + -------- + dask.array.reduction + cubed.core.reduction + """ + raise NotImplementedError() + + def scan( + self, + func: Callable[..., Any], + binop: Callable[..., Any], + ident: float, + arr: T_ChunkedArray, + axis: int | None = None, + dtype: _DType_co | None = None, + **kwargs: Any, + ) -> T_ChunkedArray: + """ + General version of a 1D scan, also known as a cumulative array reduction. + + Used in ``ffill`` and ``bfill`` in xarray. + + Parameters + ---------- + func: callable + Cumulative function like np.cumsum or np.cumprod + binop: callable + Associated binary operator like ``np.cumsum->add`` or ``np.cumprod->mul`` + ident: Number + Associated identity like ``np.cumsum->0`` or ``np.cumprod->1`` + arr: dask Array + axis: int, optional + dtype: dtype + + Returns + ------- + Chunked array + + See also + -------- + dask.array.cumreduction + """ + raise NotImplementedError() + + @abstractmethod + def apply_gufunc( + self, + func: Callable[..., Any], + signature: str, + *args: Any, + axes: Sequence[tuple[int, ...]] | None = None, + keepdims: bool = False, + output_dtypes: Sequence[_DType_co] | None = None, + vectorize: bool | None = None, + **kwargs: Any, + ) -> Any: + """ + Apply a generalized ufunc or similar python function to arrays. + + ``signature`` determines if the function consumes or produces core + dimensions. The remaining dimensions in given input arrays (``*args``) + are considered loop dimensions and are required to broadcast + naturally against each other. + + In other terms, this function is like ``np.vectorize``, but for + the blocks of chunked arrays. If the function itself shall also + be vectorized use ``vectorize=True`` for convenience. + + Called inside ``xarray.apply_ufunc``, which is called internally for most xarray operations. + Therefore this method must be implemented for the vast majority of xarray computations to be supported. + + Parameters + ---------- + func : callable + Function to call like ``func(*args, **kwargs)`` on input arrays + (``*args``) that returns an array or tuple of arrays. If multiple + arguments with non-matching dimensions are supplied, this function is + expected to vectorize (broadcast) over axes of positional arguments in + the style of NumPy universal functions [1]_ (if this is not the case, + set ``vectorize=True``). If this function returns multiple outputs, + ``output_core_dims`` has to be set as well. + signature: string + Specifies what core dimensions are consumed and produced by ``func``. + According to the specification of numpy.gufunc signature [2]_ + *args : numeric + Input arrays or scalars to the callable function. + axes: List of tuples, optional, keyword only + A list of tuples with indices of axes a generalized ufunc should operate on. + For instance, for a signature of ``"(i,j),(j,k)->(i,k)"`` appropriate for + matrix multiplication, the base elements are two-dimensional matrices + and these are taken to be stored in the two last axes of each argument. The + corresponding axes keyword would be ``[(-2, -1), (-2, -1), (-2, -1)]``. + For simplicity, for generalized ufuncs that operate on 1-dimensional arrays + (vectors), a single integer is accepted instead of a single-element tuple, + and for generalized ufuncs for which all outputs are scalars, the output + tuples can be omitted. + keepdims: bool, optional, keyword only + If this is set to True, axes which are reduced over will be left in the result as + a dimension with size one, so that the result will broadcast correctly against the + inputs. This option can only be used for generalized ufuncs that operate on inputs + that all have the same number of core dimensions and with outputs that have no core + dimensions , i.e., with signatures like ``"(i),(i)->()"`` or ``"(m,m)->()"``. + If used, the location of the dimensions in the output can be controlled with axes + and axis. + output_dtypes : Optional, dtype or list of dtypes, keyword only + Valid numpy dtype specification or list thereof. + If not given, a call of ``func`` with a small set of data + is performed in order to try to automatically determine the + output dtypes. + vectorize: bool, keyword only + If set to ``True``, ``np.vectorize`` is applied to ``func`` for + convenience. Defaults to ``False``. + **kwargs : dict + Extra keyword arguments to pass to `func` + + Returns + ------- + Single chunked array or tuple of chunked arrays + + See Also + -------- + dask.array.gufunc.apply_gufunc + cubed.apply_gufunc + + References + ---------- + .. [1] https://docs.scipy.org/doc/numpy/reference/ufuncs.html + .. [2] https://docs.scipy.org/doc/numpy/reference/c-api/generalized-ufuncs.html + """ + raise NotImplementedError() + + def map_blocks( + self, + func: Callable[..., Any], + *args: Any, + dtype: _DType_co | None = None, + chunks: tuple[int, ...] | None = None, + drop_axis: int | Sequence[int] | None = None, + new_axis: int | Sequence[int] | None = None, + **kwargs: Any, + ) -> Any: + """ + Map a function across all blocks of a chunked array. + + Called in elementwise operations, but notably not (currently) called within xarray.map_blocks. + + Parameters + ---------- + func : callable + Function to apply to every block in the array. + If ``func`` accepts ``block_info=`` or ``block_id=`` + as keyword arguments, these will be passed dictionaries + containing information about input and output chunks/arrays + during computation. See examples for details. + args : dask arrays or other objects + dtype : np.dtype, optional + The ``dtype`` of the output array. It is recommended to provide this. + If not provided, will be inferred by applying the function to a small + set of fake data. + chunks : tuple, optional + Chunk shape of resulting blocks if the function does not preserve + shape. If not provided, the resulting array is assumed to have the same + block structure as the first input array. + drop_axis : number or iterable, optional + Dimensions lost by the function. + new_axis : number or iterable, optional + New dimensions created by the function. Note that these are applied + after ``drop_axis`` (if present). + **kwargs : + Other keyword arguments to pass to function. Values must be constants + (not dask.arrays) + + See Also + -------- + dask.array.map_blocks + cubed.map_blocks + """ + raise NotImplementedError() + + def blockwise( + self, + func: Callable[..., Any], + out_ind: Iterable[Any], + *args: Any, # can't type this as mypy assumes args are all same type, but dask blockwise args alternate types + adjust_chunks: dict[Any, Callable[..., Any]] | None = None, + new_axes: dict[Any, int] | None = None, + align_arrays: bool = True, + **kwargs: Any, + ) -> Any: + """ + Tensor operation: Generalized inner and outer products. + + A broad class of blocked algorithms and patterns can be specified with a + concise multi-index notation. The ``blockwise`` function applies an in-memory + function across multiple blocks of multiple inputs in a variety of ways. + Many chunked array operations are special cases of blockwise including + elementwise, broadcasting, reductions, tensordot, and transpose. + + Currently only called explicitly in xarray when performing multidimensional interpolation. + + Parameters + ---------- + func : callable + Function to apply to individual tuples of blocks + out_ind : iterable + Block pattern of the output, something like 'ijk' or (1, 2, 3) + *args : sequence of Array, index pairs + You may also pass literal arguments, accompanied by None index + e.g. (x, 'ij', y, 'jk', z, 'i', some_literal, None) + **kwargs : dict + Extra keyword arguments to pass to function + adjust_chunks : dict + Dictionary mapping index to function to be applied to chunk sizes + new_axes : dict, keyword only + New indexes and their dimension lengths + align_arrays: bool + Whether or not to align chunks along equally sized dimensions when + multiple arrays are provided. This allows for larger chunks in some + arrays to be broken into smaller ones that match chunk sizes in other + arrays such that they are compatible for block function mapping. If + this is false, then an error will be thrown if arrays do not already + have the same number of blocks in each dimension. + + See Also + -------- + dask.array.blockwise + cubed.core.blockwise + """ + raise NotImplementedError() + + def unify_chunks( + self, + *args: Any, # can't type this as mypy assumes args are all same type, but dask unify_chunks args alternate types + **kwargs: Any, + ) -> tuple[dict[str, _NormalizedChunks], list[T_ChunkedArray]]: + """ + Unify chunks across a sequence of arrays. + + Called by xarray.unify_chunks. + + Parameters + ---------- + *args: sequence of Array, index pairs + Sequence like (x, 'ij', y, 'jk', z, 'i') + + See Also + -------- + dask.array.core.unify_chunks + cubed.core.unify_chunks + """ + raise NotImplementedError() + + def store( + self, + sources: T_ChunkedArray | Sequence[T_ChunkedArray], + targets: Any, + **kwargs: dict[str, Any], + ) -> Any: + """ + Store chunked arrays in array-like objects, overwriting data in target. + + This stores chunked arrays into object that supports numpy-style setitem + indexing (e.g. a Zarr Store). Allows storing values chunk by chunk so that it does not have to + fill up memory. For best performance you likely want to align the block size of + the storage target with the block size of your array. + + Used when writing to any registered xarray I/O backend. + + Parameters + ---------- + sources: Array or collection of Arrays + targets: array-like or collection of array-likes + These should support setitem syntax ``target[10:20] = ...``. + If sources is a single item, targets must be a single item; if sources is a + collection of arrays, targets must be a matching collection. + kwargs: + Parameters passed to compute/persist (only used if compute=True) + + See Also + -------- + dask.array.store + cubed.store + """ + raise NotImplementedError() diff --git a/test/fixtures/whole_applications/xarray/xarray/namedarray/pycompat.py b/test/fixtures/whole_applications/xarray/xarray/namedarray/pycompat.py new file mode 100644 index 0000000..3ce33d4 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/namedarray/pycompat.py @@ -0,0 +1,138 @@ +from __future__ import annotations + +from importlib import import_module +from types import ModuleType +from typing import TYPE_CHECKING, Any, Literal + +import numpy as np +from packaging.version import Version + +from xarray.core.utils import is_scalar +from xarray.namedarray.utils import is_duck_array, is_duck_dask_array + +integer_types = (int, np.integer) + +if TYPE_CHECKING: + ModType = Literal["dask", "pint", "cupy", "sparse", "cubed", "numbagg"] + DuckArrayTypes = tuple[type[Any], ...] # TODO: improve this? maybe Generic + from xarray.namedarray._typing import _DType, _ShapeType, duckarray + + +class DuckArrayModule: + """ + Solely for internal isinstance and version checks. + + Motivated by having to only import pint when required (as pint currently imports xarray) + https://github.com/pydata/xarray/pull/5561#discussion_r664815718 + """ + + module: ModuleType | None + version: Version + type: DuckArrayTypes + available: bool + + def __init__(self, mod: ModType) -> None: + duck_array_module: ModuleType | None + duck_array_version: Version + duck_array_type: DuckArrayTypes + try: + duck_array_module = import_module(mod) + duck_array_version = Version(duck_array_module.__version__) + + if mod == "dask": + duck_array_type = (import_module("dask.array").Array,) + elif mod == "pint": + duck_array_type = (duck_array_module.Quantity,) + elif mod == "cupy": + duck_array_type = (duck_array_module.ndarray,) + elif mod == "sparse": + duck_array_type = (duck_array_module.SparseArray,) + elif mod == "cubed": + duck_array_type = (duck_array_module.Array,) + # Not a duck array module, but using this system regardless, to get lazy imports + elif mod == "numbagg": + duck_array_type = () + else: + raise NotImplementedError + + except (ImportError, AttributeError): # pragma: no cover + duck_array_module = None + duck_array_version = Version("0.0.0") + duck_array_type = () + + self.module = duck_array_module + self.version = duck_array_version + self.type = duck_array_type + self.available = duck_array_module is not None + + +_cached_duck_array_modules: dict[ModType, DuckArrayModule] = {} + + +def _get_cached_duck_array_module(mod: ModType) -> DuckArrayModule: + if mod not in _cached_duck_array_modules: + duckmod = DuckArrayModule(mod) + _cached_duck_array_modules[mod] = duckmod + return duckmod + else: + return _cached_duck_array_modules[mod] + + +def array_type(mod: ModType) -> DuckArrayTypes: + """Quick wrapper to get the array class of the module.""" + return _get_cached_duck_array_module(mod).type + + +def mod_version(mod: ModType) -> Version: + """Quick wrapper to get the version of the module.""" + return _get_cached_duck_array_module(mod).version + + +def is_chunked_array(x: duckarray[Any, Any]) -> bool: + return is_duck_dask_array(x) or (is_duck_array(x) and hasattr(x, "chunks")) + + +def is_0d_dask_array(x: duckarray[Any, Any]) -> bool: + return is_duck_dask_array(x) and is_scalar(x) + + +def to_numpy( + data: duckarray[Any, Any], **kwargs: dict[str, Any] +) -> np.ndarray[Any, np.dtype[Any]]: + from xarray.core.indexing import ExplicitlyIndexed + from xarray.namedarray.parallelcompat import get_chunked_array_type + + if isinstance(data, ExplicitlyIndexed): + data = data.get_duck_array() # type: ignore[no-untyped-call] + + # TODO first attempt to call .to_numpy() once some libraries implement it + if is_chunked_array(data): + chunkmanager = get_chunked_array_type(data) + data, *_ = chunkmanager.compute(data, **kwargs) + if isinstance(data, array_type("cupy")): + data = data.get() + # pint has to be imported dynamically as pint imports xarray + if isinstance(data, array_type("pint")): + data = data.magnitude + if isinstance(data, array_type("sparse")): + data = data.todense() + data = np.asarray(data) + + return data + + +def to_duck_array(data: Any, **kwargs: dict[str, Any]) -> duckarray[_ShapeType, _DType]: + from xarray.core.indexing import ExplicitlyIndexed + from xarray.namedarray.parallelcompat import get_chunked_array_type + + if is_chunked_array(data): + chunkmanager = get_chunked_array_type(data) + loaded_data, *_ = chunkmanager.compute(data, **kwargs) # type: ignore[var-annotated] + return loaded_data + + if isinstance(data, ExplicitlyIndexed): + return data.get_duck_array() # type: ignore[no-untyped-call, no-any-return] + elif is_duck_array(data): + return data + else: + return np.asarray(data) # type: ignore[return-value] diff --git a/test/fixtures/whole_applications/xarray/xarray/namedarray/utils.py b/test/fixtures/whole_applications/xarray/xarray/namedarray/utils.py new file mode 100644 index 0000000..b82a80b --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/namedarray/utils.py @@ -0,0 +1,224 @@ +from __future__ import annotations + +import importlib +import sys +import warnings +from collections.abc import Hashable, Iterable, Iterator, Mapping +from functools import lru_cache +from typing import TYPE_CHECKING, Any, TypeVar, cast + +import numpy as np +from packaging.version import Version + +from xarray.namedarray._typing import ErrorOptionsWithWarn, _DimsLike + +if TYPE_CHECKING: + if sys.version_info >= (3, 10): + from typing import TypeGuard + else: + from typing_extensions import TypeGuard + + from numpy.typing import NDArray + + try: + from dask.array.core import Array as DaskArray + from dask.typing import DaskCollection + except ImportError: + DaskArray = NDArray # type: ignore + DaskCollection: Any = NDArray # type: ignore + + from xarray.namedarray._typing import _Dim, duckarray + + +K = TypeVar("K") +V = TypeVar("V") +T = TypeVar("T") + + +@lru_cache +def module_available(module: str, minversion: str | None = None) -> bool: + """Checks whether a module is installed without importing it. + + Use this for a lightweight check and lazy imports. + + Parameters + ---------- + module : str + Name of the module. + minversion : str, optional + Minimum version of the module + + Returns + ------- + available : bool + Whether the module is installed. + """ + if importlib.util.find_spec(module) is None: + return False + + if minversion is not None: + version = importlib.metadata.version(module) + + return Version(version) >= Version(minversion) + + return True + + +def is_dask_collection(x: object) -> TypeGuard[DaskCollection]: + if module_available("dask"): + from dask.base import is_dask_collection + + # use is_dask_collection function instead of dask.typing.DaskCollection + # see https://github.com/pydata/xarray/pull/8241#discussion_r1476276023 + return is_dask_collection(x) + return False + + +def is_duck_array(value: Any) -> TypeGuard[duckarray[Any, Any]]: + # TODO: replace is_duck_array with runtime checks via _arrayfunction_or_api protocol on + # python 3.12 and higher (see https://github.com/pydata/xarray/issues/8696#issuecomment-1924588981) + if isinstance(value, np.ndarray): + return True + return ( + hasattr(value, "ndim") + and hasattr(value, "shape") + and hasattr(value, "dtype") + and ( + (hasattr(value, "__array_function__") and hasattr(value, "__array_ufunc__")) + or hasattr(value, "__array_namespace__") + ) + ) + + +def is_duck_dask_array(x: duckarray[Any, Any]) -> TypeGuard[DaskArray]: + return is_duck_array(x) and is_dask_collection(x) + + +def to_0d_object_array( + value: object, +) -> NDArray[np.object_]: + """Given a value, wrap it in a 0-D numpy.ndarray with dtype=object.""" + result = np.empty((), dtype=object) + result[()] = value + return result + + +def is_dict_like(value: Any) -> TypeGuard[Mapping[Any, Any]]: + return hasattr(value, "keys") and hasattr(value, "__getitem__") + + +def drop_missing_dims( + supplied_dims: Iterable[_Dim], + dims: Iterable[_Dim], + missing_dims: ErrorOptionsWithWarn, +) -> _DimsLike: + """Depending on the setting of missing_dims, drop any dimensions from supplied_dims that + are not present in dims. + + Parameters + ---------- + supplied_dims : Iterable of Hashable + dims : Iterable of Hashable + missing_dims : {"raise", "warn", "ignore"} + """ + + if missing_dims == "raise": + supplied_dims_set = {val for val in supplied_dims if val is not ...} + if invalid := supplied_dims_set - set(dims): + raise ValueError( + f"Dimensions {invalid} do not exist. Expected one or more of {dims}" + ) + + return supplied_dims + + elif missing_dims == "warn": + if invalid := set(supplied_dims) - set(dims): + warnings.warn( + f"Dimensions {invalid} do not exist. Expected one or more of {dims}" + ) + + return [val for val in supplied_dims if val in dims or val is ...] + + elif missing_dims == "ignore": + return [val for val in supplied_dims if val in dims or val is ...] + + else: + raise ValueError( + f"Unrecognised option {missing_dims} for missing_dims argument" + ) + + +def infix_dims( + dims_supplied: Iterable[_Dim], + dims_all: Iterable[_Dim], + missing_dims: ErrorOptionsWithWarn = "raise", +) -> Iterator[_Dim]: + """ + Resolves a supplied list containing an ellipsis representing other items, to + a generator with the 'realized' list of all items + """ + if ... in dims_supplied: + dims_all_list = list(dims_all) + if len(set(dims_all)) != len(dims_all_list): + raise ValueError("Cannot use ellipsis with repeated dims") + if list(dims_supplied).count(...) > 1: + raise ValueError("More than one ellipsis supplied") + other_dims = [d for d in dims_all if d not in dims_supplied] + existing_dims = drop_missing_dims(dims_supplied, dims_all, missing_dims) + for d in existing_dims: + if d is ...: + yield from other_dims + else: + yield d + else: + existing_dims = drop_missing_dims(dims_supplied, dims_all, missing_dims) + if set(existing_dims) ^ set(dims_all): + raise ValueError( + f"{dims_supplied} must be a permuted list of {dims_all}, unless `...` is included" + ) + yield from existing_dims + + +def either_dict_or_kwargs( + pos_kwargs: Mapping[Any, T] | None, + kw_kwargs: Mapping[str, T], + func_name: str, +) -> Mapping[Hashable, T]: + if pos_kwargs is None or pos_kwargs == {}: + # Need an explicit cast to appease mypy due to invariance; see + # https://github.com/python/mypy/issues/6228 + return cast(Mapping[Hashable, T], kw_kwargs) + + if not is_dict_like(pos_kwargs): + raise ValueError(f"the first argument to .{func_name} must be a dictionary") + if kw_kwargs: + raise ValueError( + f"cannot specify both keyword and positional arguments to .{func_name}" + ) + return pos_kwargs + + +class ReprObject: + """Object that prints as the given value, for use with sentinel values.""" + + __slots__ = ("_value",) + + _value: str + + def __init__(self, value: str): + self._value = value + + def __repr__(self) -> str: + return self._value + + def __eq__(self, other: ReprObject | Any) -> bool: + # TODO: What type can other be? ArrayLike? + return self._value == other._value if isinstance(other, ReprObject) else False + + def __hash__(self) -> int: + return hash((type(self), self._value)) + + def __dask_tokenize__(self) -> object: + from dask.base import normalize_token + + return normalize_token((type(self), self._value)) diff --git a/test/fixtures/whole_applications/xarray/xarray/plot/__init__.py b/test/fixtures/whole_applications/xarray/xarray/plot/__init__.py new file mode 100644 index 0000000..ae7a001 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/plot/__init__.py @@ -0,0 +1,36 @@ +""" +Use this module directly: + import xarray.plot as xplt + +Or use the methods on a DataArray or Dataset: + DataArray.plot._____ + Dataset.plot._____ +""" + +from xarray.plot.dataarray_plot import ( + contour, + contourf, + hist, + imshow, + line, + pcolormesh, + plot, + step, + surface, +) +from xarray.plot.dataset_plot import scatter +from xarray.plot.facetgrid import FacetGrid + +__all__ = [ + "plot", + "line", + "step", + "contour", + "contourf", + "hist", + "imshow", + "pcolormesh", + "FacetGrid", + "scatter", + "surface", +] diff --git a/test/fixtures/whole_applications/xarray/xarray/plot/accessor.py b/test/fixtures/whole_applications/xarray/xarray/plot/accessor.py new file mode 100644 index 0000000..9db4ae4 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/plot/accessor.py @@ -0,0 +1,1272 @@ +from __future__ import annotations + +import functools +from collections.abc import Hashable, Iterable +from typing import TYPE_CHECKING, Any, Literal, NoReturn, overload + +import numpy as np + +# Accessor methods have the same name as plotting methods, so we need a different namespace +from xarray.plot import dataarray_plot, dataset_plot + +if TYPE_CHECKING: + from matplotlib.axes import Axes + from matplotlib.collections import LineCollection, PathCollection, QuadMesh + from matplotlib.colors import Normalize + from matplotlib.container import BarContainer + from matplotlib.contour import QuadContourSet + from matplotlib.image import AxesImage + from matplotlib.patches import Polygon + from matplotlib.quiver import Quiver + from mpl_toolkits.mplot3d.art3d import Line3D, Poly3DCollection + from numpy.typing import ArrayLike + + from xarray.core.dataarray import DataArray + from xarray.core.dataset import Dataset + from xarray.core.types import AspectOptions, HueStyleOptions, ScaleOptions + from xarray.plot.facetgrid import FacetGrid + + +class DataArrayPlotAccessor: + """ + Enables use of xarray.plot functions as attributes on a DataArray. + For example, DataArray.plot.imshow + """ + + _da: DataArray + + __slots__ = ("_da",) + __doc__ = dataarray_plot.plot.__doc__ + + def __init__(self, darray: DataArray) -> None: + self._da = darray + + # Should return Any such that the user does not run into problems + # with the many possible return values + @functools.wraps(dataarray_plot.plot, assigned=("__doc__", "__annotations__")) + def __call__(self, **kwargs) -> Any: + return dataarray_plot.plot(self._da, **kwargs) + + @functools.wraps(dataarray_plot.hist) + def hist( + self, *args, **kwargs + ) -> tuple[np.ndarray, np.ndarray, BarContainer | Polygon]: + return dataarray_plot.hist(self._da, *args, **kwargs) + + @overload + def line( # type: ignore[misc,unused-ignore] # None is hashable :( + self, + *args: Any, + row: None = None, # no wrap -> primitive + col: None = None, # no wrap -> primitive + figsize: Iterable[float] | None = None, + aspect: AspectOptions = None, + size: float | None = None, + ax: Axes | None = None, + hue: Hashable | None = None, + x: Hashable | None = None, + y: Hashable | None = None, + xincrease: bool | None = None, + yincrease: bool | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, + add_legend: bool = True, + _labels: bool = True, + **kwargs: Any, + ) -> list[Line3D]: ... + + @overload + def line( + self, + *args: Any, + row: Hashable, # wrap -> FacetGrid + col: Hashable | None = None, + figsize: Iterable[float] | None = None, + aspect: AspectOptions = None, + size: float | None = None, + ax: Axes | None = None, + hue: Hashable | None = None, + x: Hashable | None = None, + y: Hashable | None = None, + xincrease: bool | None = None, + yincrease: bool | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, + add_legend: bool = True, + _labels: bool = True, + **kwargs: Any, + ) -> FacetGrid[DataArray]: ... + + @overload + def line( + self, + *args: Any, + row: Hashable | None = None, + col: Hashable, # wrap -> FacetGrid + figsize: Iterable[float] | None = None, + aspect: AspectOptions = None, + size: float | None = None, + ax: Axes | None = None, + hue: Hashable | None = None, + x: Hashable | None = None, + y: Hashable | None = None, + xincrease: bool | None = None, + yincrease: bool | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, + add_legend: bool = True, + _labels: bool = True, + **kwargs: Any, + ) -> FacetGrid[DataArray]: ... + + @functools.wraps(dataarray_plot.line, assigned=("__doc__",)) + def line(self, *args, **kwargs) -> list[Line3D] | FacetGrid[DataArray]: + return dataarray_plot.line(self._da, *args, **kwargs) + + @overload + def step( # type: ignore[misc,unused-ignore] # None is hashable :( + self, + *args: Any, + where: Literal["pre", "post", "mid"] = "pre", + drawstyle: str | None = None, + ds: str | None = None, + row: None = None, # no wrap -> primitive + col: None = None, # no wrap -> primitive + **kwargs: Any, + ) -> list[Line3D]: ... + + @overload + def step( + self, + *args: Any, + where: Literal["pre", "post", "mid"] = "pre", + drawstyle: str | None = None, + ds: str | None = None, + row: Hashable, # wrap -> FacetGrid + col: Hashable | None = None, + **kwargs: Any, + ) -> FacetGrid[DataArray]: ... + + @overload + def step( + self, + *args: Any, + where: Literal["pre", "post", "mid"] = "pre", + drawstyle: str | None = None, + ds: str | None = None, + row: Hashable | None = None, + col: Hashable, # wrap -> FacetGrid + **kwargs: Any, + ) -> FacetGrid[DataArray]: ... + + @functools.wraps(dataarray_plot.step, assigned=("__doc__",)) + def step(self, *args, **kwargs) -> list[Line3D] | FacetGrid[DataArray]: + return dataarray_plot.step(self._da, *args, **kwargs) + + @overload + def scatter( # type: ignore[misc,unused-ignore] # None is hashable :( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + z: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: None = None, # no wrap -> primitive + col: None = None, # no wrap -> primitive + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | Iterable[bool] = True, + add_title: bool = True, + subplot_kws: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, + cmap=None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + extend=None, + levels=None, + **kwargs, + ) -> PathCollection: ... + + @overload + def scatter( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + z: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: Hashable | None = None, + col: Hashable, # wrap -> FacetGrid + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | Iterable[bool] = True, + add_title: bool = True, + subplot_kws: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, + cmap=None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + extend=None, + levels=None, + **kwargs, + ) -> FacetGrid[DataArray]: ... + + @overload + def scatter( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + z: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: Hashable, # wrap -> FacetGrid + col: Hashable | None = None, + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | Iterable[bool] = True, + add_title: bool = True, + subplot_kws: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, + cmap=None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + extend=None, + levels=None, + **kwargs, + ) -> FacetGrid[DataArray]: ... + + @functools.wraps(dataarray_plot.scatter, assigned=("__doc__",)) + def scatter(self, *args, **kwargs) -> PathCollection | FacetGrid[DataArray]: + return dataarray_plot.scatter(self._da, *args, **kwargs) + + @overload + def imshow( # type: ignore[misc,unused-ignore] # None is hashable :( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: None = None, # no wrap -> primitive + col: None = None, # no wrap -> primitive + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap=None, + center=None, + robust: bool = False, + extend=None, + levels=None, + infer_intervals=None, + colors=None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, + norm: Normalize | None = None, + **kwargs: Any, + ) -> AxesImage: ... + + @overload + def imshow( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: Hashable | None = None, + col: Hashable, # wrap -> FacetGrid + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap=None, + center=None, + robust: bool = False, + extend=None, + levels=None, + infer_intervals=None, + colors=None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, + norm: Normalize | None = None, + **kwargs: Any, + ) -> FacetGrid[DataArray]: ... + + @overload + def imshow( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: Hashable, # wrap -> FacetGrid + col: Hashable | None = None, + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap=None, + center=None, + robust: bool = False, + extend=None, + levels=None, + infer_intervals=None, + colors=None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, + norm: Normalize | None = None, + **kwargs: Any, + ) -> FacetGrid[DataArray]: ... + + @functools.wraps(dataarray_plot.imshow, assigned=("__doc__",)) + def imshow(self, *args, **kwargs) -> AxesImage | FacetGrid[DataArray]: + return dataarray_plot.imshow(self._da, *args, **kwargs) + + @overload + def contour( # type: ignore[misc,unused-ignore] # None is hashable :( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: None = None, # no wrap -> primitive + col: None = None, # no wrap -> primitive + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap=None, + center=None, + robust: bool = False, + extend=None, + levels=None, + infer_intervals=None, + colors=None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, + norm: Normalize | None = None, + **kwargs: Any, + ) -> QuadContourSet: ... + + @overload + def contour( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: Hashable | None = None, + col: Hashable, # wrap -> FacetGrid + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap=None, + center=None, + robust: bool = False, + extend=None, + levels=None, + infer_intervals=None, + colors=None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, + norm: Normalize | None = None, + **kwargs: Any, + ) -> FacetGrid[DataArray]: ... + + @overload + def contour( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: Hashable, # wrap -> FacetGrid + col: Hashable | None = None, + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap=None, + center=None, + robust: bool = False, + extend=None, + levels=None, + infer_intervals=None, + colors=None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, + norm: Normalize | None = None, + **kwargs: Any, + ) -> FacetGrid[DataArray]: ... + + @functools.wraps(dataarray_plot.contour, assigned=("__doc__",)) + def contour(self, *args, **kwargs) -> QuadContourSet | FacetGrid[DataArray]: + return dataarray_plot.contour(self._da, *args, **kwargs) + + @overload + def contourf( # type: ignore[misc,unused-ignore] # None is hashable :( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: None = None, # no wrap -> primitive + col: None = None, # no wrap -> primitive + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap=None, + center=None, + robust: bool = False, + extend=None, + levels=None, + infer_intervals=None, + colors=None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, + norm: Normalize | None = None, + **kwargs: Any, + ) -> QuadContourSet: ... + + @overload + def contourf( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: Hashable | None = None, + col: Hashable, # wrap -> FacetGrid + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap=None, + center=None, + robust: bool = False, + extend=None, + levels=None, + infer_intervals=None, + colors=None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, + norm: Normalize | None = None, + **kwargs: Any, + ) -> FacetGrid[DataArray]: ... + + @overload + def contourf( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: Hashable, # wrap -> FacetGrid + col: Hashable | None = None, + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap=None, + center=None, + robust: bool = False, + extend=None, + levels=None, + infer_intervals=None, + colors=None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, + norm: Normalize | None = None, + **kwargs: Any, + ) -> FacetGrid: ... + + @functools.wraps(dataarray_plot.contourf, assigned=("__doc__",)) + def contourf(self, *args, **kwargs) -> QuadContourSet | FacetGrid[DataArray]: + return dataarray_plot.contourf(self._da, *args, **kwargs) + + @overload + def pcolormesh( # type: ignore[misc,unused-ignore] # None is hashable :( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: None = None, # no wrap -> primitive + col: None = None, # no wrap -> primitive + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap=None, + center=None, + robust: bool = False, + extend=None, + levels=None, + infer_intervals=None, + colors=None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, + norm: Normalize | None = None, + **kwargs: Any, + ) -> QuadMesh: ... + + @overload + def pcolormesh( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: Hashable | None = None, + col: Hashable, # wrap -> FacetGrid + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap=None, + center=None, + robust: bool = False, + extend=None, + levels=None, + infer_intervals=None, + colors=None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, + norm: Normalize | None = None, + **kwargs: Any, + ) -> FacetGrid[DataArray]: ... + + @overload + def pcolormesh( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: Hashable, # wrap -> FacetGrid + col: Hashable | None = None, + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap=None, + center=None, + robust: bool = False, + extend=None, + levels=None, + infer_intervals=None, + colors=None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, + norm: Normalize | None = None, + **kwargs: Any, + ) -> FacetGrid[DataArray]: ... + + @functools.wraps(dataarray_plot.pcolormesh, assigned=("__doc__",)) + def pcolormesh(self, *args, **kwargs) -> QuadMesh | FacetGrid[DataArray]: + return dataarray_plot.pcolormesh(self._da, *args, **kwargs) + + @overload + def surface( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: None = None, # no wrap -> primitive + col: None = None, # no wrap -> primitive + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap=None, + center=None, + robust: bool = False, + extend=None, + levels=None, + infer_intervals=None, + colors=None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, + norm: Normalize | None = None, + **kwargs: Any, + ) -> Poly3DCollection: ... + + @overload + def surface( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: Hashable | None = None, + col: Hashable, # wrap -> FacetGrid + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap=None, + center=None, + robust: bool = False, + extend=None, + levels=None, + infer_intervals=None, + colors=None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, + norm: Normalize | None = None, + **kwargs: Any, + ) -> FacetGrid: ... + + @overload + def surface( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: Hashable, # wrap -> FacetGrid + col: Hashable | None = None, + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap=None, + center=None, + robust: bool = False, + extend=None, + levels=None, + infer_intervals=None, + colors=None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, + norm: Normalize | None = None, + **kwargs: Any, + ) -> FacetGrid: ... + + @functools.wraps(dataarray_plot.surface, assigned=("__doc__",)) + def surface(self, *args, **kwargs) -> Poly3DCollection: + return dataarray_plot.surface(self._da, *args, **kwargs) + + +class DatasetPlotAccessor: + """ + Enables use of xarray.plot functions as attributes on a Dataset. + For example, Dataset.plot.scatter + """ + + _ds: Dataset + __slots__ = ("_ds",) + + def __init__(self, dataset: Dataset) -> None: + self._ds = dataset + + def __call__(self, *args, **kwargs) -> NoReturn: + raise ValueError( + "Dataset.plot cannot be called directly. Use " + "an explicit plot method, e.g. ds.plot.scatter(...)" + ) + + @overload + def scatter( # type: ignore[misc,unused-ignore] # None is hashable :( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + z: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: None = None, # no wrap -> primitive + col: None = None, # no wrap -> primitive + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | Iterable[bool] = True, + add_title: bool = True, + subplot_kws: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, + cmap=None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + extend=None, + levels=None, + **kwargs: Any, + ) -> PathCollection: ... + + @overload + def scatter( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + z: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: Hashable | None = None, + col: Hashable, # wrap -> FacetGrid + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | Iterable[bool] = True, + add_title: bool = True, + subplot_kws: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, + cmap=None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + extend=None, + levels=None, + **kwargs: Any, + ) -> FacetGrid[Dataset]: ... + + @overload + def scatter( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + z: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: Hashable, # wrap -> FacetGrid + col: Hashable | None = None, + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | Iterable[bool] = True, + add_title: bool = True, + subplot_kws: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, + cmap=None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + extend=None, + levels=None, + **kwargs: Any, + ) -> FacetGrid[Dataset]: ... + + @functools.wraps(dataset_plot.scatter, assigned=("__doc__",)) + def scatter(self, *args, **kwargs) -> PathCollection | FacetGrid[Dataset]: + return dataset_plot.scatter(self._ds, *args, **kwargs) + + @overload + def quiver( # type: ignore[misc,unused-ignore] # None is hashable :( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + u: Hashable | None = None, + v: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + col: None = None, # no wrap -> primitive + row: None = None, # no wrap -> primitive + ax: Axes | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + col_wrap: int | None = None, + sharex: bool = True, + sharey: bool = True, + aspect: AspectOptions = None, + subplot_kws: dict[str, Any] | None = None, + add_guide: bool | None = None, + cbar_kwargs: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + infer_intervals=None, + center=None, + levels=None, + robust: bool | None = None, + colors=None, + extend=None, + cmap=None, + **kwargs: Any, + ) -> Quiver: ... + + @overload + def quiver( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + u: Hashable | None = None, + v: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + col: Hashable, # wrap -> FacetGrid + row: Hashable | None = None, + ax: Axes | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + col_wrap: int | None = None, + sharex: bool = True, + sharey: bool = True, + aspect: AspectOptions = None, + subplot_kws: dict[str, Any] | None = None, + add_guide: bool | None = None, + cbar_kwargs: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + infer_intervals=None, + center=None, + levels=None, + robust: bool | None = None, + colors=None, + extend=None, + cmap=None, + **kwargs: Any, + ) -> FacetGrid[Dataset]: ... + + @overload + def quiver( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + u: Hashable | None = None, + v: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + col: Hashable | None = None, + row: Hashable, # wrap -> FacetGrid + ax: Axes | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + col_wrap: int | None = None, + sharex: bool = True, + sharey: bool = True, + aspect: AspectOptions = None, + subplot_kws: dict[str, Any] | None = None, + add_guide: bool | None = None, + cbar_kwargs: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + infer_intervals=None, + center=None, + levels=None, + robust: bool | None = None, + colors=None, + extend=None, + cmap=None, + **kwargs: Any, + ) -> FacetGrid[Dataset]: ... + + @functools.wraps(dataset_plot.quiver, assigned=("__doc__",)) + def quiver(self, *args, **kwargs) -> Quiver | FacetGrid[Dataset]: + return dataset_plot.quiver(self._ds, *args, **kwargs) + + @overload + def streamplot( # type: ignore[misc,unused-ignore] # None is hashable :( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + u: Hashable | None = None, + v: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + col: None = None, # no wrap -> primitive + row: None = None, # no wrap -> primitive + ax: Axes | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + col_wrap: int | None = None, + sharex: bool = True, + sharey: bool = True, + aspect: AspectOptions = None, + subplot_kws: dict[str, Any] | None = None, + add_guide: bool | None = None, + cbar_kwargs: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + infer_intervals=None, + center=None, + levels=None, + robust: bool | None = None, + colors=None, + extend=None, + cmap=None, + **kwargs: Any, + ) -> LineCollection: ... + + @overload + def streamplot( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + u: Hashable | None = None, + v: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + col: Hashable, # wrap -> FacetGrid + row: Hashable | None = None, + ax: Axes | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + col_wrap: int | None = None, + sharex: bool = True, + sharey: bool = True, + aspect: AspectOptions = None, + subplot_kws: dict[str, Any] | None = None, + add_guide: bool | None = None, + cbar_kwargs: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + infer_intervals=None, + center=None, + levels=None, + robust: bool | None = None, + colors=None, + extend=None, + cmap=None, + **kwargs: Any, + ) -> FacetGrid[Dataset]: ... + + @overload + def streamplot( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + u: Hashable | None = None, + v: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + col: Hashable | None = None, + row: Hashable, # wrap -> FacetGrid + ax: Axes | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + col_wrap: int | None = None, + sharex: bool = True, + sharey: bool = True, + aspect: AspectOptions = None, + subplot_kws: dict[str, Any] | None = None, + add_guide: bool | None = None, + cbar_kwargs: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + infer_intervals=None, + center=None, + levels=None, + robust: bool | None = None, + colors=None, + extend=None, + cmap=None, + **kwargs: Any, + ) -> FacetGrid[Dataset]: ... + + @functools.wraps(dataset_plot.streamplot, assigned=("__doc__",)) + def streamplot(self, *args, **kwargs) -> LineCollection | FacetGrid[Dataset]: + return dataset_plot.streamplot(self._ds, *args, **kwargs) diff --git a/test/fixtures/whole_applications/xarray/xarray/plot/dataarray_plot.py b/test/fixtures/whole_applications/xarray/xarray/plot/dataarray_plot.py new file mode 100644 index 0000000..ed752d3 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/plot/dataarray_plot.py @@ -0,0 +1,2455 @@ +from __future__ import annotations + +import functools +import warnings +from collections.abc import Hashable, Iterable, MutableMapping +from typing import TYPE_CHECKING, Any, Callable, Literal, Union, cast, overload + +import numpy as np +import pandas as pd + +from xarray.core.alignment import broadcast +from xarray.core.concat import concat +from xarray.plot.facetgrid import _easy_facetgrid +from xarray.plot.utils import ( + _LINEWIDTH_RANGE, + _MARKERSIZE_RANGE, + _add_colorbar, + _add_legend, + _assert_valid_xy, + _determine_guide, + _ensure_plottable, + _guess_coords_to_plot, + _infer_interval_breaks, + _infer_xy_labels, + _Normalize, + _process_cmap_cbar_kwargs, + _rescale_imshow_rgb, + _resolve_intervals_1dplot, + _resolve_intervals_2dplot, + _set_concise_date, + _update_axes, + get_axis, + label_from_attrs, +) + +if TYPE_CHECKING: + from matplotlib.axes import Axes + from matplotlib.collections import PathCollection, QuadMesh + from matplotlib.colors import Colormap, Normalize + from matplotlib.container import BarContainer + from matplotlib.contour import QuadContourSet + from matplotlib.image import AxesImage + from matplotlib.patches import Polygon + from mpl_toolkits.mplot3d.art3d import Line3D, Poly3DCollection + from numpy.typing import ArrayLike + + from xarray.core.dataarray import DataArray + from xarray.core.types import ( + AspectOptions, + ExtendOptions, + HueStyleOptions, + ScaleOptions, + T_DataArray, + ) + from xarray.plot.facetgrid import FacetGrid + +_styles: dict[str, Any] = { + # Add a white border to make it easier seeing overlapping markers: + "scatter.edgecolors": "w", +} + + +def _infer_line_data( + darray: DataArray, x: Hashable | None, y: Hashable | None, hue: Hashable | None +) -> tuple[DataArray, DataArray, DataArray | None, str]: + ndims = len(darray.dims) + + if x is not None and y is not None: + raise ValueError("Cannot specify both x and y kwargs for line plots.") + + if x is not None: + _assert_valid_xy(darray, x, "x") + + if y is not None: + _assert_valid_xy(darray, y, "y") + + if ndims == 1: + huename = None + hueplt = None + huelabel = "" + + if x is not None: + xplt = darray[x] + yplt = darray + + elif y is not None: + xplt = darray + yplt = darray[y] + + else: # Both x & y are None + dim = darray.dims[0] + xplt = darray[dim] + yplt = darray + + else: + if x is None and y is None and hue is None: + raise ValueError("For 2D inputs, please specify either hue, x or y.") + + if y is None: + if hue is not None: + _assert_valid_xy(darray, hue, "hue") + xname, huename = _infer_xy_labels(darray=darray, x=x, y=hue) + xplt = darray[xname] + if xplt.ndim > 1: + if huename in darray.dims: + otherindex = 1 if darray.dims.index(huename) == 0 else 0 + otherdim = darray.dims[otherindex] + yplt = darray.transpose(otherdim, huename, transpose_coords=False) + xplt = xplt.transpose(otherdim, huename, transpose_coords=False) + else: + raise ValueError( + "For 2D inputs, hue must be a dimension" + " i.e. one of " + repr(darray.dims) + ) + + else: + (xdim,) = darray[xname].dims + (huedim,) = darray[huename].dims + yplt = darray.transpose(xdim, huedim) + + else: + yname, huename = _infer_xy_labels(darray=darray, x=y, y=hue) + yplt = darray[yname] + if yplt.ndim > 1: + if huename in darray.dims: + otherindex = 1 if darray.dims.index(huename) == 0 else 0 + otherdim = darray.dims[otherindex] + xplt = darray.transpose(otherdim, huename, transpose_coords=False) + yplt = yplt.transpose(otherdim, huename, transpose_coords=False) + else: + raise ValueError( + "For 2D inputs, hue must be a dimension" + " i.e. one of " + repr(darray.dims) + ) + + else: + (ydim,) = darray[yname].dims + (huedim,) = darray[huename].dims + xplt = darray.transpose(ydim, huedim) + + huelabel = label_from_attrs(darray[huename]) + hueplt = darray[huename] + + return xplt, yplt, hueplt, huelabel + + +def _prepare_plot1d_data( + darray: T_DataArray, + coords_to_plot: MutableMapping[str, Hashable], + plotfunc_name: str | None = None, + _is_facetgrid: bool = False, +) -> dict[str, T_DataArray]: + """ + Prepare data for usage with plt.scatter. + + Parameters + ---------- + darray : T_DataArray + Base DataArray. + coords_to_plot : MutableMapping[str, Hashable] + Coords that will be plotted. + plotfunc_name : str | None + Name of the plotting function that will be used. + + Returns + ------- + plts : dict[str, T_DataArray] + Dict of DataArrays that will be sent to matplotlib. + + Examples + -------- + >>> # Make sure int coords are plotted: + >>> a = xr.DataArray( + ... data=[1, 2], + ... coords={1: ("x", [0, 1], {"units": "s"})}, + ... dims=("x",), + ... name="a", + ... ) + >>> plts = xr.plot.dataarray_plot._prepare_plot1d_data( + ... a, coords_to_plot={"x": 1, "z": None, "hue": None, "size": None} + ... ) + >>> # Check which coords to plot: + >>> print({k: v.name for k, v in plts.items()}) + {'y': 'a', 'x': 1} + """ + # If there are more than 1 dimension in the array than stack all the + # dimensions so the plotter can plot anything: + if darray.ndim > 1: + # When stacking dims the lines will continue connecting. For floats + # this can be solved by adding a nan element in between the flattening + # points: + dims_T = [] + if np.issubdtype(darray.dtype, np.floating): + for v in ["z", "x"]: + dim = coords_to_plot.get(v, None) + if (dim is not None) and (dim in darray.dims): + darray_nan = np.nan * darray.isel({dim: -1}) + darray = concat([darray, darray_nan], dim=dim) + dims_T.append(coords_to_plot[v]) + + # Lines should never connect to the same coordinate when stacked, + # transpose to avoid this as much as possible: + darray = darray.transpose(..., *dims_T) + + # Array is now ready to be stacked: + darray = darray.stack(_stacked_dim=darray.dims) + + # Broadcast together all the chosen variables: + plts = dict(y=darray) + plts.update( + {k: darray.coords[v] for k, v in coords_to_plot.items() if v is not None} + ) + plts = dict(zip(plts.keys(), broadcast(*(plts.values())))) + + return plts + + +# return type is Any due to the many different possibilities +def plot( + darray: DataArray, + *, + row: Hashable | None = None, + col: Hashable | None = None, + col_wrap: int | None = None, + ax: Axes | None = None, + hue: Hashable | None = None, + subplot_kws: dict[str, Any] | None = None, + **kwargs: Any, +) -> Any: + """ + Default plot of DataArray using :py:mod:`matplotlib:matplotlib.pyplot`. + + Calls xarray plotting function based on the dimensions of + the squeezed DataArray. + + =============== =========================== + Dimensions Plotting function + =============== =========================== + 1 :py:func:`xarray.plot.line` + 2 :py:func:`xarray.plot.pcolormesh` + Anything else :py:func:`xarray.plot.hist` + =============== =========================== + + Parameters + ---------- + darray : DataArray + row : Hashable or None, optional + If passed, make row faceted plots on this dimension name. + col : Hashable or None, optional + If passed, make column faceted plots on this dimension name. + col_wrap : int or None, optional + Use together with ``col`` to wrap faceted plots. + ax : matplotlib axes object, optional + Axes on which to plot. By default, use the current axes. + Mutually exclusive with ``size``, ``figsize`` and facets. + hue : Hashable or None, optional + If passed, make faceted line plots with hue on this dimension name. + subplot_kws : dict, optional + Dictionary of keyword arguments for Matplotlib subplots + (see :py:meth:`matplotlib:matplotlib.figure.Figure.add_subplot`). + **kwargs : optional + Additional keyword arguments for Matplotlib. + + See Also + -------- + xarray.DataArray.squeeze + """ + darray = darray.squeeze( + d for d, s in darray.sizes.items() if s == 1 and d not in (row, col, hue) + ).compute() + + plot_dims = set(darray.dims) + plot_dims.discard(row) + plot_dims.discard(col) + plot_dims.discard(hue) + + ndims = len(plot_dims) + + plotfunc: Callable + + if ndims == 0 or darray.size == 0: + raise TypeError("No numeric data to plot.") + if ndims in (1, 2): + if row or col: + kwargs["subplot_kws"] = subplot_kws + kwargs["row"] = row + kwargs["col"] = col + kwargs["col_wrap"] = col_wrap + if ndims == 1: + plotfunc = line + kwargs["hue"] = hue + elif ndims == 2: + if hue: + plotfunc = line + kwargs["hue"] = hue + else: + plotfunc = pcolormesh + kwargs["subplot_kws"] = subplot_kws + else: + if row or col or hue: + raise ValueError( + "Only 1d and 2d plots are supported for facets in xarray. " + "See the package `Seaborn` for more options." + ) + plotfunc = hist + + kwargs["ax"] = ax + + return plotfunc(darray, **kwargs) + + +@overload +def line( # type: ignore[misc,unused-ignore] # None is hashable :( + darray: DataArray, + *args: Any, + row: None = None, # no wrap -> primitive + col: None = None, # no wrap -> primitive + figsize: Iterable[float] | None = None, + aspect: AspectOptions = None, + size: float | None = None, + ax: Axes | None = None, + hue: Hashable | None = None, + x: Hashable | None = None, + y: Hashable | None = None, + xincrease: bool | None = None, + yincrease: bool | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, + add_legend: bool = True, + _labels: bool = True, + **kwargs: Any, +) -> list[Line3D]: ... + + +@overload +def line( + darray: T_DataArray, + *args: Any, + row: Hashable, # wrap -> FacetGrid + col: Hashable | None = None, + figsize: Iterable[float] | None = None, + aspect: AspectOptions = None, + size: float | None = None, + ax: Axes | None = None, + hue: Hashable | None = None, + x: Hashable | None = None, + y: Hashable | None = None, + xincrease: bool | None = None, + yincrease: bool | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, + add_legend: bool = True, + _labels: bool = True, + **kwargs: Any, +) -> FacetGrid[T_DataArray]: ... + + +@overload +def line( + darray: T_DataArray, + *args: Any, + row: Hashable | None = None, + col: Hashable, # wrap -> FacetGrid + figsize: Iterable[float] | None = None, + aspect: AspectOptions = None, + size: float | None = None, + ax: Axes | None = None, + hue: Hashable | None = None, + x: Hashable | None = None, + y: Hashable | None = None, + xincrease: bool | None = None, + yincrease: bool | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, + add_legend: bool = True, + _labels: bool = True, + **kwargs: Any, +) -> FacetGrid[T_DataArray]: ... + + +# This function signature should not change so that it can use +# matplotlib format strings +def line( + darray: T_DataArray, + *args: Any, + row: Hashable | None = None, + col: Hashable | None = None, + figsize: Iterable[float] | None = None, + aspect: AspectOptions = None, + size: float | None = None, + ax: Axes | None = None, + hue: Hashable | None = None, + x: Hashable | None = None, + y: Hashable | None = None, + xincrease: bool | None = None, + yincrease: bool | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, + add_legend: bool = True, + _labels: bool = True, + **kwargs: Any, +) -> list[Line3D] | FacetGrid[T_DataArray]: + """ + Line plot of DataArray values. + + Wraps :py:func:`matplotlib:matplotlib.pyplot.plot`. + + Parameters + ---------- + darray : DataArray + Either 1D or 2D. If 2D, one of ``hue``, ``x`` or ``y`` must be provided. + row : Hashable, optional + If passed, make row faceted plots on this dimension name. + col : Hashable, optional + If passed, make column faceted plots on this dimension name. + figsize : tuple, optional + A tuple (width, height) of the figure in inches. + Mutually exclusive with ``size`` and ``ax``. + aspect : "auto", "equal", scalar or None, optional + Aspect ratio of plot, so that ``aspect * size`` gives the *width* in + inches. Only used if a ``size`` is provided. + size : scalar, optional + If provided, create a new figure for the plot with the given size: + *height* (in inches) of each plot. See also: ``aspect``. + ax : matplotlib axes object, optional + Axes on which to plot. By default, the current is used. + Mutually exclusive with ``size`` and ``figsize``. + hue : Hashable, optional + Dimension or coordinate for which you want multiple lines plotted. + If plotting against a 2D coordinate, ``hue`` must be a dimension. + x, y : Hashable, optional + Dimension, coordinate or multi-index level for *x*, *y* axis. + Only one of these may be specified. + The other will be used for values from the DataArray on which this + plot method is called. + xincrease : bool or None, optional + Should the values on the *x* axis be increasing from left to right? + if ``None``, use the default for the Matplotlib function. + yincrease : bool or None, optional + Should the values on the *y* axis be increasing from top to bottom? + if ``None``, use the default for the Matplotlib function. + xscale, yscale : {'linear', 'symlog', 'log', 'logit'}, optional + Specifies scaling for the *x*- and *y*-axis, respectively. + xticks, yticks : array-like, optional + Specify tick locations for *x*- and *y*-axis. + xlim, ylim : tuple[float, float], optional + Specify *x*- and *y*-axis limits. + add_legend : bool, default: True + Add legend with *y* axis coordinates (2D inputs only). + *args, **kwargs : optional + Additional arguments to :py:func:`matplotlib:matplotlib.pyplot.plot`. + + Returns + ------- + primitive : list of Line3D or FacetGrid + When either col or row is given, returns a FacetGrid, otherwise + a list of matplotlib Line3D objects. + """ + # Handle facetgrids first + if row or col: + allargs = locals().copy() + allargs.update(allargs.pop("kwargs")) + allargs.pop("darray") + return _easy_facetgrid(darray, line, kind="line", **allargs) + + ndims = len(darray.dims) + if ndims == 0 or darray.size == 0: + # TypeError to be consistent with pandas + raise TypeError("No numeric data to plot.") + if ndims > 2: + raise ValueError( + "Line plots are for 1- or 2-dimensional DataArrays. " + f"Passed DataArray has {ndims} " + "dimensions" + ) + + # The allargs dict passed to _easy_facetgrid above contains args + if args == (): + args = kwargs.pop("args", ()) + else: + assert "args" not in kwargs + + ax = get_axis(figsize, size, aspect, ax) + xplt, yplt, hueplt, hue_label = _infer_line_data(darray, x, y, hue) + + # Remove pd.Intervals if contained in xplt.values and/or yplt.values. + xplt_val, yplt_val, x_suffix, y_suffix, kwargs = _resolve_intervals_1dplot( + xplt.to_numpy(), yplt.to_numpy(), kwargs + ) + xlabel = label_from_attrs(xplt, extra=x_suffix) + ylabel = label_from_attrs(yplt, extra=y_suffix) + + _ensure_plottable(xplt_val, yplt_val) + + primitive = ax.plot(xplt_val, yplt_val, *args, **kwargs) + + if _labels: + if xlabel is not None: + ax.set_xlabel(xlabel) + + if ylabel is not None: + ax.set_ylabel(ylabel) + + ax.set_title(darray._title_for_slice()) + + if darray.ndim == 2 and add_legend: + assert hueplt is not None + ax.legend(handles=primitive, labels=list(hueplt.to_numpy()), title=hue_label) + + if np.issubdtype(xplt.dtype, np.datetime64): + _set_concise_date(ax, axis="x") + + _update_axes(ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim) + + return primitive + + +@overload +def step( # type: ignore[misc,unused-ignore] # None is hashable :( + darray: DataArray, + *args: Any, + where: Literal["pre", "post", "mid"] = "pre", + drawstyle: str | None = None, + ds: str | None = None, + row: None = None, # no wrap -> primitive + col: None = None, # no wrap -> primitive + **kwargs: Any, +) -> list[Line3D]: ... + + +@overload +def step( + darray: DataArray, + *args: Any, + where: Literal["pre", "post", "mid"] = "pre", + drawstyle: str | None = None, + ds: str | None = None, + row: Hashable, # wrap -> FacetGrid + col: Hashable | None = None, + **kwargs: Any, +) -> FacetGrid[DataArray]: ... + + +@overload +def step( + darray: DataArray, + *args: Any, + where: Literal["pre", "post", "mid"] = "pre", + drawstyle: str | None = None, + ds: str | None = None, + row: Hashable | None = None, + col: Hashable, # wrap -> FacetGrid + **kwargs: Any, +) -> FacetGrid[DataArray]: ... + + +def step( + darray: DataArray, + *args: Any, + where: Literal["pre", "post", "mid"] = "pre", + drawstyle: str | None = None, + ds: str | None = None, + row: Hashable | None = None, + col: Hashable | None = None, + **kwargs: Any, +) -> list[Line3D] | FacetGrid[DataArray]: + """ + Step plot of DataArray values. + + Similar to :py:func:`matplotlib:matplotlib.pyplot.step`. + + Parameters + ---------- + where : {'pre', 'post', 'mid'}, default: 'pre' + Define where the steps should be placed: + + - ``'pre'``: The y value is continued constantly to the left from + every *x* position, i.e. the interval ``(x[i-1], x[i]]`` has the + value ``y[i]``. + - ``'post'``: The y value is continued constantly to the right from + every *x* position, i.e. the interval ``[x[i], x[i+1])`` has the + value ``y[i]``. + - ``'mid'``: Steps occur half-way between the *x* positions. + + Note that this parameter is ignored if one coordinate consists of + :py:class:`pandas.Interval` values, e.g. as a result of + :py:func:`xarray.Dataset.groupby_bins`. In this case, the actual + boundaries of the interval are used. + drawstyle, ds : str or None, optional + Additional drawstyle. Only use one of drawstyle and ds. + row : Hashable, optional + If passed, make row faceted plots on this dimension name. + col : Hashable, optional + If passed, make column faceted plots on this dimension name. + *args, **kwargs : optional + Additional arguments for :py:func:`xarray.plot.line`. + + Returns + ------- + primitive : list of Line3D or FacetGrid + When either col or row is given, returns a FacetGrid, otherwise + a list of matplotlib Line3D objects. + """ + if where not in {"pre", "post", "mid"}: + raise ValueError("'where' argument to step must be 'pre', 'post' or 'mid'") + + if ds is not None: + if drawstyle is None: + drawstyle = ds + else: + raise TypeError("ds and drawstyle are mutually exclusive") + if drawstyle is None: + drawstyle = "" + drawstyle = "steps-" + where + drawstyle + + return line(darray, *args, drawstyle=drawstyle, col=col, row=row, **kwargs) + + +def hist( + darray: DataArray, + *args: Any, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + xincrease: bool | None = None, + yincrease: bool | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, + **kwargs: Any, +) -> tuple[np.ndarray, np.ndarray, BarContainer | Polygon]: + """ + Histogram of DataArray. + + Wraps :py:func:`matplotlib:matplotlib.pyplot.hist`. + + Plots *N*-dimensional arrays by first flattening the array. + + Parameters + ---------- + darray : DataArray + Can have any number of dimensions. + figsize : Iterable of float, optional + A tuple (width, height) of the figure in inches. + Mutually exclusive with ``size`` and ``ax``. + aspect : "auto", "equal", scalar or None, optional + Aspect ratio of plot, so that ``aspect * size`` gives the *width* in + inches. Only used if a ``size`` is provided. + size : scalar, optional + If provided, create a new figure for the plot with the given size: + *height* (in inches) of each plot. See also: ``aspect``. + ax : matplotlib axes object, optional + Axes on which to plot. By default, use the current axes. + Mutually exclusive with ``size`` and ``figsize``. + xincrease : bool or None, optional + Should the values on the *x* axis be increasing from left to right? + if ``None``, use the default for the Matplotlib function. + yincrease : bool or None, optional + Should the values on the *y* axis be increasing from top to bottom? + if ``None``, use the default for the Matplotlib function. + xscale, yscale : {'linear', 'symlog', 'log', 'logit'}, optional + Specifies scaling for the *x*- and *y*-axis, respectively. + xticks, yticks : array-like, optional + Specify tick locations for *x*- and *y*-axis. + xlim, ylim : tuple[float, float], optional + Specify *x*- and *y*-axis limits. + **kwargs : optional + Additional keyword arguments to :py:func:`matplotlib:matplotlib.pyplot.hist`. + + """ + assert len(args) == 0 + + if darray.ndim == 0 or darray.size == 0: + # TypeError to be consistent with pandas + raise TypeError("No numeric data to plot.") + + ax = get_axis(figsize, size, aspect, ax) + + no_nan = np.ravel(darray.to_numpy()) + no_nan = no_nan[pd.notnull(no_nan)] + + n, bins, patches = cast( + tuple[np.ndarray, np.ndarray, Union["BarContainer", "Polygon"]], + ax.hist(no_nan, **kwargs), + ) + + ax.set_title(darray._title_for_slice()) + ax.set_xlabel(label_from_attrs(darray)) + + _update_axes(ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim) + + return n, bins, patches + + +def _plot1d(plotfunc): + """Decorator for common 1d plotting logic.""" + commondoc = """ + Parameters + ---------- + darray : DataArray + Must be 2 dimensional, unless creating faceted plots. + x : Hashable or None, optional + Coordinate for x axis. If None use darray.dims[1]. + y : Hashable or None, optional + Coordinate for y axis. If None use darray.dims[0]. + z : Hashable or None, optional + If specified plot 3D and use this coordinate for *z* axis. + hue : Hashable or None, optional + Dimension or coordinate for which you want multiple lines plotted. + markersize: Hashable or None, optional + scatter only. Variable by which to vary size of scattered points. + linewidth: Hashable or None, optional + Variable by which to vary linewidth. + row : Hashable, optional + If passed, make row faceted plots on this dimension name. + col : Hashable, optional + If passed, make column faceted plots on this dimension name. + col_wrap : int, optional + Use together with ``col`` to wrap faceted plots + ax : matplotlib axes object, optional + If None, uses the current axis. Not applicable when using facets. + figsize : Iterable[float] or None, optional + A tuple (width, height) of the figure in inches. + Mutually exclusive with ``size`` and ``ax``. + size : scalar, optional + If provided, create a new figure for the plot with the given size. + Height (in inches) of each plot. See also: ``aspect``. + aspect : "auto", "equal", scalar or None, optional + Aspect ratio of plot, so that ``aspect * size`` gives the width in + inches. Only used if a ``size`` is provided. + xincrease : bool or None, default: True + Should the values on the x axes be increasing from left to right? + if None, use the default for the matplotlib function. + yincrease : bool or None, default: True + Should the values on the y axes be increasing from top to bottom? + if None, use the default for the matplotlib function. + add_legend : bool or None, optional + If True use xarray metadata to add a legend. + add_colorbar : bool or None, optional + If True add a colorbar. + add_labels : bool or None, optional + If True use xarray metadata to label axes + add_title : bool or None, optional + If True use xarray metadata to add a title + subplot_kws : dict, optional + Dictionary of keyword arguments for matplotlib subplots. Only applies + to FacetGrid plotting. + xscale : {'linear', 'symlog', 'log', 'logit'} or None, optional + Specifies scaling for the x-axes. + yscale : {'linear', 'symlog', 'log', 'logit'} or None, optional + Specifies scaling for the y-axes. + xticks : ArrayLike or None, optional + Specify tick locations for x-axes. + yticks : ArrayLike or None, optional + Specify tick locations for y-axes. + xlim : tuple[float, float] or None, optional + Specify x-axes limits. + ylim : tuple[float, float] or None, optional + Specify y-axes limits. + cmap : matplotlib colormap name or colormap, optional + The mapping from data values to color space. Either a + Matplotlib colormap name or object. If not provided, this will + be either ``'viridis'`` (if the function infers a sequential + dataset) or ``'RdBu_r'`` (if the function infers a diverging + dataset). + See :doc:`Choosing Colormaps in Matplotlib ` + for more information. + + If *seaborn* is installed, ``cmap`` may also be a + `seaborn color palette `_. + Note: if ``cmap`` is a seaborn color palette, + ``levels`` must also be specified. + vmin : float or None, optional + Lower value to anchor the colormap, otherwise it is inferred from the + data and other keyword arguments. When a diverging dataset is inferred, + setting `vmin` or `vmax` will fix the other by symmetry around + ``center``. Setting both values prevents use of a diverging colormap. + If discrete levels are provided as an explicit list, both of these + values are ignored. + vmax : float or None, optional + Upper value to anchor the colormap, otherwise it is inferred from the + data and other keyword arguments. When a diverging dataset is inferred, + setting `vmin` or `vmax` will fix the other by symmetry around + ``center``. Setting both values prevents use of a diverging colormap. + If discrete levels are provided as an explicit list, both of these + values are ignored. + norm : matplotlib.colors.Normalize, optional + If ``norm`` has ``vmin`` or ``vmax`` specified, the corresponding + kwarg must be ``None``. + extend : {'neither', 'both', 'min', 'max'}, optional + How to draw arrows extending the colorbar beyond its limits. If not + provided, ``extend`` is inferred from ``vmin``, ``vmax`` and the data limits. + levels : int or array-like, optional + Split the colormap (``cmap``) into discrete color intervals. If an integer + is provided, "nice" levels are chosen based on the data range: this can + imply that the final number of levels is not exactly the expected one. + Setting ``vmin`` and/or ``vmax`` with ``levels=N`` is equivalent to + setting ``levels=np.linspace(vmin, vmax, N)``. + **kwargs : optional + Additional arguments to wrapped matplotlib function + + Returns + ------- + artist : + The same type of primitive artist that the wrapped matplotlib + function returns + """ + + # Build on the original docstring + plotfunc.__doc__ = f"{plotfunc.__doc__}\n{commondoc}" + + @functools.wraps( + plotfunc, assigned=("__module__", "__name__", "__qualname__", "__doc__") + ) + def newplotfunc( + darray: DataArray, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + z: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + row: Hashable | None = None, + col: Hashable | None = None, + col_wrap: int | None = None, + ax: Axes | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | Iterable[bool] = True, + add_title: bool = True, + subplot_kws: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, + cmap: str | Colormap | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + **kwargs, + ) -> Any: + # All 1d plots in xarray share this function signature. + # Method signature below should be consistent. + + import matplotlib.pyplot as plt + + if subplot_kws is None: + subplot_kws = dict() + + # Handle facetgrids first + if row or col: + if z is not None: + subplot_kws.update(projection="3d") + + allargs = locals().copy() + allargs.update(allargs.pop("kwargs")) + allargs.pop("darray") + allargs.pop("plt") + allargs["plotfunc"] = globals()[plotfunc.__name__] + + return _easy_facetgrid(darray, kind="plot1d", **allargs) + + if darray.ndim == 0 or darray.size == 0: + # TypeError to be consistent with pandas + raise TypeError("No numeric data to plot.") + + # The allargs dict passed to _easy_facetgrid above contains args + if args == (): + args = kwargs.pop("args", ()) + + if args: + assert "args" not in kwargs + # TODO: Deprecated since 2022.10: + msg = "Using positional arguments is deprecated for plot methods, use keyword arguments instead." + assert x is None + x = args[0] + if len(args) > 1: + assert y is None + y = args[1] + if len(args) > 2: + assert z is None + z = args[2] + if len(args) > 3: + assert hue is None + hue = args[3] + if len(args) > 4: + raise ValueError(msg) + else: + warnings.warn(msg, DeprecationWarning, stacklevel=2) + del args + + if hue_style is not None: + # TODO: Not used since 2022.10. Deprecated since 2023.07. + warnings.warn( + ( + "hue_style is no longer used for plot1d plots " + "and the argument will eventually be removed. " + "Convert numbers to string for a discrete hue " + "and use add_legend or add_colorbar to control which guide to display." + ), + DeprecationWarning, + stacklevel=2, + ) + + _is_facetgrid = kwargs.pop("_is_facetgrid", False) + + if plotfunc.__name__ == "scatter": + size_ = kwargs.pop("_size", markersize) + size_r = _MARKERSIZE_RANGE + + # Remove any nulls, .where(m, drop=True) doesn't work when m is + # a dask array, so load the array to memory. + # It will have to be loaded to memory at some point anyway: + darray = darray.load() + darray = darray.where(darray.notnull(), drop=True) + else: + size_ = kwargs.pop("_size", linewidth) + size_r = _LINEWIDTH_RANGE + + # Get data to plot: + coords_to_plot: MutableMapping[str, Hashable | None] = dict( + x=x, z=z, hue=hue, size=size_ + ) + if not _is_facetgrid: + # Guess what coords to use if some of the values in coords_to_plot are None: + coords_to_plot = _guess_coords_to_plot(darray, coords_to_plot, kwargs) + plts = _prepare_plot1d_data(darray, coords_to_plot, plotfunc.__name__) + xplt = plts.pop("x", None) + yplt = plts.pop("y", None) + zplt = plts.pop("z", None) + kwargs.update(zplt=zplt) + hueplt = plts.pop("hue", None) + sizeplt = plts.pop("size", None) + + # Handle size and hue: + hueplt_norm = _Normalize(data=hueplt) + kwargs.update(hueplt=hueplt_norm.values) + sizeplt_norm = _Normalize( + data=sizeplt, width=size_r, _is_facetgrid=_is_facetgrid + ) + kwargs.update(sizeplt=sizeplt_norm.values) + cmap_params_subset = kwargs.pop("cmap_params_subset", {}) + cbar_kwargs = kwargs.pop("cbar_kwargs", {}) + + if hueplt_norm.data is not None: + if not hueplt_norm.data_is_numeric: + # Map hue values back to its original value: + cbar_kwargs.update(format=hueplt_norm.format, ticks=hueplt_norm.ticks) + levels = kwargs.get("levels", hueplt_norm.levels) + + cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( + plotfunc, + cast("DataArray", hueplt_norm.values).data, + **locals(), + ) + + # subset that can be passed to scatter, hist2d + if not cmap_params_subset: + ckw = {vv: cmap_params[vv] for vv in ("vmin", "vmax", "norm", "cmap")} + cmap_params_subset.update(**ckw) + + with plt.rc_context(_styles): + if z is not None: + import mpl_toolkits + + if ax is None: + subplot_kws.update(projection="3d") + ax = get_axis(figsize, size, aspect, ax, **subplot_kws) + assert isinstance(ax, mpl_toolkits.mplot3d.axes3d.Axes3D) + + # Using 30, 30 minimizes rotation of the plot. Making it easier to + # build on your intuition from 2D plots: + ax.view_init(azim=30, elev=30, vertical_axis="y") + else: + ax = get_axis(figsize, size, aspect, ax, **subplot_kws) + + primitive = plotfunc( + xplt, + yplt, + ax=ax, + add_labels=add_labels, + **cmap_params_subset, + **kwargs, + ) + + if np.any(np.asarray(add_labels)) and add_title: + ax.set_title(darray._title_for_slice()) + + add_colorbar_, add_legend_ = _determine_guide( + hueplt_norm, + sizeplt_norm, + add_colorbar, + add_legend, + plotfunc_name=plotfunc.__name__, + ) + + if add_colorbar_: + if "label" not in cbar_kwargs: + cbar_kwargs["label"] = label_from_attrs(hueplt_norm.data) + + _add_colorbar( + primitive, ax, kwargs.get("cbar_ax", None), cbar_kwargs, cmap_params + ) + + if add_legend_: + if plotfunc.__name__ in ["scatter", "line"]: + _add_legend( + ( + hueplt_norm + if add_legend or not add_colorbar_ + else _Normalize(None) + ), + sizeplt_norm, + primitive, + legend_ax=ax, + plotfunc=plotfunc.__name__, + ) + else: + hueplt_norm_values: list[np.ndarray | None] + if hueplt_norm.data is not None: + hueplt_norm_values = list(hueplt_norm.data.to_numpy()) + else: + hueplt_norm_values = [hueplt_norm.data] + + if plotfunc.__name__ == "hist": + ax.legend( + handles=primitive[-1], + labels=hueplt_norm_values, + title=label_from_attrs(hueplt_norm.data), + ) + else: + ax.legend( + handles=primitive, + labels=hueplt_norm_values, + title=label_from_attrs(hueplt_norm.data), + ) + + _update_axes( + ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim + ) + + return primitive + + # we want to actually expose the signature of newplotfunc + # and not the copied **kwargs from the plotfunc which + # functools.wraps adds, so delete the wrapped attr + del newplotfunc.__wrapped__ + + return newplotfunc + + +def _add_labels( + add_labels: bool | Iterable[bool], + darrays: Iterable[DataArray | None], + suffixes: Iterable[str], + ax: Axes, +) -> None: + """Set x, y, z labels.""" + add_labels = [add_labels] * 3 if isinstance(add_labels, bool) else add_labels + axes: tuple[Literal["x", "y", "z"], ...] = ("x", "y", "z") + for axis, add_label, darray, suffix in zip(axes, add_labels, darrays, suffixes): + if darray is None: + continue + + if add_label: + label = label_from_attrs(darray, extra=suffix) + if label is not None: + getattr(ax, f"set_{axis}label")(label) + + if np.issubdtype(darray.dtype, np.datetime64): + _set_concise_date(ax, axis=axis) + + +@overload +def scatter( # type: ignore[misc,unused-ignore] # None is hashable :( + darray: DataArray, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + z: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: None = None, # no wrap -> primitive + col: None = None, # no wrap -> primitive + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | Iterable[bool] = True, + add_title: bool = True, + subplot_kws: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + cmap: str | Colormap | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + **kwargs, +) -> PathCollection: ... + + +@overload +def scatter( + darray: T_DataArray, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + z: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: Hashable | None = None, + col: Hashable, # wrap -> FacetGrid + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | Iterable[bool] = True, + add_title: bool = True, + subplot_kws: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + cmap: str | Colormap | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + **kwargs, +) -> FacetGrid[T_DataArray]: ... + + +@overload +def scatter( + darray: T_DataArray, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + z: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: Hashable, # wrap -> FacetGrid + col: Hashable | None = None, + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | Iterable[bool] = True, + add_title: bool = True, + subplot_kws: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + cmap: str | Colormap | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + **kwargs, +) -> FacetGrid[T_DataArray]: ... + + +@_plot1d +def scatter( + xplt: DataArray | None, + yplt: DataArray | None, + ax: Axes, + add_labels: bool | Iterable[bool] = True, + **kwargs, +) -> PathCollection: + """Scatter variables against each other. + + Wraps :py:func:`matplotlib:matplotlib.pyplot.scatter`. + """ + if "u" in kwargs or "v" in kwargs: + raise ValueError("u, v are not allowed in scatter plots.") + + zplt: DataArray | None = kwargs.pop("zplt", None) + hueplt: DataArray | None = kwargs.pop("hueplt", None) + sizeplt: DataArray | None = kwargs.pop("sizeplt", None) + + if hueplt is not None: + kwargs.update(c=hueplt.to_numpy().ravel()) + + if sizeplt is not None: + kwargs.update(s=sizeplt.to_numpy().ravel()) + + plts_or_none = (xplt, yplt, zplt) + _add_labels(add_labels, plts_or_none, ("", "", ""), ax) + + xplt_np = None if xplt is None else xplt.to_numpy().ravel() + yplt_np = None if yplt is None else yplt.to_numpy().ravel() + zplt_np = None if zplt is None else zplt.to_numpy().ravel() + plts_np = tuple(p for p in (xplt_np, yplt_np, zplt_np) if p is not None) + + if len(plts_np) == 3: + import mpl_toolkits + + assert isinstance(ax, mpl_toolkits.mplot3d.axes3d.Axes3D) + return ax.scatter(xplt_np, yplt_np, zplt_np, **kwargs) + + if len(plts_np) == 2: + return ax.scatter(plts_np[0], plts_np[1], **kwargs) + + raise ValueError("At least two variables required for a scatter plot.") + + +def _plot2d(plotfunc): + """Decorator for common 2d plotting logic.""" + commondoc = """ + Parameters + ---------- + darray : DataArray + Must be two-dimensional, unless creating faceted plots. + x : Hashable or None, optional + Coordinate for *x* axis. If ``None``, use ``darray.dims[1]``. + y : Hashable or None, optional + Coordinate for *y* axis. If ``None``, use ``darray.dims[0]``. + figsize : Iterable or float or None, optional + A tuple (width, height) of the figure in inches. + Mutually exclusive with ``size`` and ``ax``. + size : scalar, optional + If provided, create a new figure for the plot with the given size: + *height* (in inches) of each plot. See also: ``aspect``. + aspect : "auto", "equal", scalar or None, optional + Aspect ratio of plot, so that ``aspect * size`` gives the *width* in + inches. Only used if a ``size`` is provided. + ax : matplotlib axes object, optional + Axes on which to plot. By default, use the current axes. + Mutually exclusive with ``size`` and ``figsize``. + row : Hashable or None, optional + If passed, make row faceted plots on this dimension name. + col : Hashable or None, optional + If passed, make column faceted plots on this dimension name. + col_wrap : int, optional + Use together with ``col`` to wrap faceted plots. + xincrease : None, True, or False, optional + Should the values on the *x* axis be increasing from left to right? + If ``None``, use the default for the Matplotlib function. + yincrease : None, True, or False, optional + Should the values on the *y* axis be increasing from top to bottom? + If ``None``, use the default for the Matplotlib function. + add_colorbar : bool, optional + Add colorbar to axes. + add_labels : bool, optional + Use xarray metadata to label axes. + vmin : float or None, optional + Lower value to anchor the colormap, otherwise it is inferred from the + data and other keyword arguments. When a diverging dataset is inferred, + setting `vmin` or `vmax` will fix the other by symmetry around + ``center``. Setting both values prevents use of a diverging colormap. + If discrete levels are provided as an explicit list, both of these + values are ignored. + vmax : float or None, optional + Upper value to anchor the colormap, otherwise it is inferred from the + data and other keyword arguments. When a diverging dataset is inferred, + setting `vmin` or `vmax` will fix the other by symmetry around + ``center``. Setting both values prevents use of a diverging colormap. + If discrete levels are provided as an explicit list, both of these + values are ignored. + cmap : matplotlib colormap name or colormap, optional + The mapping from data values to color space. If not provided, this + will be either be ``'viridis'`` (if the function infers a sequential + dataset) or ``'RdBu_r'`` (if the function infers a diverging dataset). + See :doc:`Choosing Colormaps in Matplotlib ` + for more information. + + If *seaborn* is installed, ``cmap`` may also be a + `seaborn color palette `_. + Note: if ``cmap`` is a seaborn color palette and the plot type + is not ``'contour'`` or ``'contourf'``, ``levels`` must also be specified. + center : float or False, optional + The value at which to center the colormap. Passing this value implies + use of a diverging colormap. Setting it to ``False`` prevents use of a + diverging colormap. + robust : bool, optional + If ``True`` and ``vmin`` or ``vmax`` are absent, the colormap range is + computed with 2nd and 98th percentiles instead of the extreme values. + extend : {'neither', 'both', 'min', 'max'}, optional + How to draw arrows extending the colorbar beyond its limits. If not + provided, ``extend`` is inferred from ``vmin``, ``vmax`` and the data limits. + levels : int or array-like, optional + Split the colormap (``cmap``) into discrete color intervals. If an integer + is provided, "nice" levels are chosen based on the data range: this can + imply that the final number of levels is not exactly the expected one. + Setting ``vmin`` and/or ``vmax`` with ``levels=N`` is equivalent to + setting ``levels=np.linspace(vmin, vmax, N)``. + infer_intervals : bool, optional + Only applies to pcolormesh. If ``True``, the coordinate intervals are + passed to pcolormesh. If ``False``, the original coordinates are used + (this can be useful for certain map projections). The default is to + always infer intervals, unless the mesh is irregular and plotted on + a map projection. + colors : str or array-like of color-like, optional + A single color or a sequence of colors. If the plot type is not ``'contour'`` + or ``'contourf'``, the ``levels`` argument is required. + subplot_kws : dict, optional + Dictionary of keyword arguments for Matplotlib subplots. Only used + for 2D and faceted plots. + (see :py:meth:`matplotlib:matplotlib.figure.Figure.add_subplot`). + cbar_ax : matplotlib axes object, optional + Axes in which to draw the colorbar. + cbar_kwargs : dict, optional + Dictionary of keyword arguments to pass to the colorbar + (see :meth:`matplotlib:matplotlib.figure.Figure.colorbar`). + xscale : {'linear', 'symlog', 'log', 'logit'} or None, optional + Specifies scaling for the x-axes. + yscale : {'linear', 'symlog', 'log', 'logit'} or None, optional + Specifies scaling for the y-axes. + xticks : ArrayLike or None, optional + Specify tick locations for x-axes. + yticks : ArrayLike or None, optional + Specify tick locations for y-axes. + xlim : tuple[float, float] or None, optional + Specify x-axes limits. + ylim : tuple[float, float] or None, optional + Specify y-axes limits. + norm : matplotlib.colors.Normalize, optional + If ``norm`` has ``vmin`` or ``vmax`` specified, the corresponding + kwarg must be ``None``. + **kwargs : optional + Additional keyword arguments to wrapped Matplotlib function. + + Returns + ------- + artist : + The same type of primitive artist that the wrapped Matplotlib + function returns. + """ + + # Build on the original docstring + plotfunc.__doc__ = f"{plotfunc.__doc__}\n{commondoc}" + + @functools.wraps( + plotfunc, assigned=("__module__", "__name__", "__qualname__", "__doc__") + ) + def newplotfunc( + darray: DataArray, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: Hashable | None = None, + col: Hashable | None = None, + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap: str | Colormap | None = None, + center: float | Literal[False] | None = None, + robust: bool = False, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + infer_intervals=None, + colors: str | ArrayLike | None = None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, + norm: Normalize | None = None, + **kwargs: Any, + ) -> Any: + # All 2d plots in xarray share this function signature. + + if args: + # TODO: Deprecated since 2022.10: + msg = "Using positional arguments is deprecated for plot methods, use keyword arguments instead." + assert x is None + x = args[0] + if len(args) > 1: + assert y is None + y = args[1] + if len(args) > 2: + raise ValueError(msg) + else: + warnings.warn(msg, DeprecationWarning, stacklevel=2) + del args + + # Decide on a default for the colorbar before facetgrids + if add_colorbar is None: + add_colorbar = True + if plotfunc.__name__ == "contour" or ( + plotfunc.__name__ == "surface" and cmap is None + ): + add_colorbar = False + imshow_rgb = plotfunc.__name__ == "imshow" and darray.ndim == ( + 3 + (row is not None) + (col is not None) + ) + if imshow_rgb: + # Don't add a colorbar when showing an image with explicit colors + add_colorbar = False + # Matplotlib does not support normalising RGB data, so do it here. + # See eg. https://github.com/matplotlib/matplotlib/pull/10220 + if robust or vmax is not None or vmin is not None: + darray = _rescale_imshow_rgb(darray.as_numpy(), vmin, vmax, robust) + vmin, vmax, robust = None, None, False + + if subplot_kws is None: + subplot_kws = dict() + + if plotfunc.__name__ == "surface" and not kwargs.get("_is_facetgrid", False): + if ax is None: + # TODO: Importing Axes3D is no longer necessary in matplotlib >= 3.2. + # Remove when minimum requirement of matplotlib is 3.2: + from mpl_toolkits.mplot3d import Axes3D # noqa: F401 + + # delete so it does not end up in locals() + del Axes3D + + # Need to create a "3d" Axes instance for surface plots + subplot_kws["projection"] = "3d" + + # In facet grids, shared axis labels don't make sense for surface plots + sharex = False + sharey = False + + # Handle facetgrids first + if row or col: + allargs = locals().copy() + del allargs["darray"] + del allargs["imshow_rgb"] + allargs.update(allargs.pop("kwargs")) + # Need the decorated plotting function + allargs["plotfunc"] = globals()[plotfunc.__name__] + return _easy_facetgrid(darray, kind="dataarray", **allargs) + + if darray.ndim == 0 or darray.size == 0: + # TypeError to be consistent with pandas + raise TypeError("No numeric data to plot.") + + if ( + plotfunc.__name__ == "surface" + and not kwargs.get("_is_facetgrid", False) + and ax is not None + ): + import mpl_toolkits + + if not isinstance(ax, mpl_toolkits.mplot3d.Axes3D): + raise ValueError( + "If ax is passed to surface(), it must be created with " + 'projection="3d"' + ) + + rgb = kwargs.pop("rgb", None) + if rgb is not None and plotfunc.__name__ != "imshow": + raise ValueError('The "rgb" keyword is only valid for imshow()') + elif rgb is not None and not imshow_rgb: + raise ValueError( + 'The "rgb" keyword is only valid for imshow()' + "with a three-dimensional array (per facet)" + ) + + xlab, ylab = _infer_xy_labels( + darray=darray, x=x, y=y, imshow=imshow_rgb, rgb=rgb + ) + + xval = darray[xlab] + yval = darray[ylab] + + if xval.ndim > 1 or yval.ndim > 1 or plotfunc.__name__ == "surface": + # Passing 2d coordinate values, need to ensure they are transposed the same + # way as darray. + # Also surface plots always need 2d coordinates + xval = xval.broadcast_like(darray) + yval = yval.broadcast_like(darray) + dims = darray.dims + else: + dims = (yval.dims[0], xval.dims[0]) + + # May need to transpose for correct x, y labels + # xlab may be the name of a coord, we have to check for dim names + if imshow_rgb: + # For RGB[A] images, matplotlib requires the color dimension + # to be last. In Xarray the order should be unimportant, so + # we transpose to (y, x, color) to make this work. + yx_dims = (ylab, xlab) + dims = yx_dims + tuple(d for d in darray.dims if d not in yx_dims) + + if dims != darray.dims: + darray = darray.transpose(*dims, transpose_coords=True) + + # better to pass the ndarrays directly to plotting functions + xvalnp = xval.to_numpy() + yvalnp = yval.to_numpy() + + # Pass the data as a masked ndarray too + zval = darray.to_masked_array(copy=False) + + # Replace pd.Intervals if contained in xval or yval. + xplt, xlab_extra = _resolve_intervals_2dplot(xvalnp, plotfunc.__name__) + yplt, ylab_extra = _resolve_intervals_2dplot(yvalnp, plotfunc.__name__) + + _ensure_plottable(xplt, yplt, zval) + + cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( + plotfunc, + zval.data, + **locals(), + _is_facetgrid=kwargs.pop("_is_facetgrid", False), + ) + + if "contour" in plotfunc.__name__: + # extend is a keyword argument only for contour and contourf, but + # passing it to the colorbar is sufficient for imshow and + # pcolormesh + kwargs["extend"] = cmap_params["extend"] + kwargs["levels"] = cmap_params["levels"] + # if colors == a single color, matplotlib draws dashed negative + # contours. we lose this feature if we pass cmap and not colors + if isinstance(colors, str): + cmap_params["cmap"] = None + kwargs["colors"] = colors + + if "pcolormesh" == plotfunc.__name__: + kwargs["infer_intervals"] = infer_intervals + kwargs["xscale"] = xscale + kwargs["yscale"] = yscale + + if "imshow" == plotfunc.__name__ and isinstance(aspect, str): + # forbid usage of mpl strings + raise ValueError("plt.imshow's `aspect` kwarg is not available in xarray") + + ax = get_axis(figsize, size, aspect, ax, **subplot_kws) + + primitive = plotfunc( + xplt, + yplt, + zval, + ax=ax, + cmap=cmap_params["cmap"], + vmin=cmap_params["vmin"], + vmax=cmap_params["vmax"], + norm=cmap_params["norm"], + **kwargs, + ) + + # Label the plot with metadata + if add_labels: + ax.set_xlabel(label_from_attrs(darray[xlab], xlab_extra)) + ax.set_ylabel(label_from_attrs(darray[ylab], ylab_extra)) + ax.set_title(darray._title_for_slice()) + if plotfunc.__name__ == "surface": + import mpl_toolkits + + assert isinstance(ax, mpl_toolkits.mplot3d.axes3d.Axes3D) + ax.set_zlabel(label_from_attrs(darray)) + + if add_colorbar: + if add_labels and "label" not in cbar_kwargs: + cbar_kwargs["label"] = label_from_attrs(darray) + cbar = _add_colorbar(primitive, ax, cbar_ax, cbar_kwargs, cmap_params) + elif cbar_ax is not None or cbar_kwargs: + # inform the user about keywords which aren't used + raise ValueError( + "cbar_ax and cbar_kwargs can't be used with add_colorbar=False." + ) + + # origin kwarg overrides yincrease + if "origin" in kwargs: + yincrease = None + + _update_axes( + ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim + ) + + if np.issubdtype(xplt.dtype, np.datetime64): + _set_concise_date(ax, "x") + + return primitive + + # we want to actually expose the signature of newplotfunc + # and not the copied **kwargs from the plotfunc which + # functools.wraps adds, so delete the wrapped attr + del newplotfunc.__wrapped__ + + return newplotfunc + + +@overload +def imshow( # type: ignore[misc,unused-ignore] # None is hashable :( + darray: DataArray, + x: Hashable | None = None, + y: Hashable | None = None, + *, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: None = None, # no wrap -> primitive + col: None = None, # no wrap -> primitive + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap: str | Colormap | None = None, + center: float | Literal[False] | None = None, + robust: bool = False, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + infer_intervals=None, + colors: str | ArrayLike | None = None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, +) -> AxesImage: ... + + +@overload +def imshow( + darray: T_DataArray, + x: Hashable | None = None, + y: Hashable | None = None, + *, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: Hashable | None = None, + col: Hashable, # wrap -> FacetGrid + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap: str | Colormap | None = None, + center: float | Literal[False] | None = None, + robust: bool = False, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + infer_intervals=None, + colors: str | ArrayLike | None = None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, +) -> FacetGrid[T_DataArray]: ... + + +@overload +def imshow( + darray: T_DataArray, + x: Hashable | None = None, + y: Hashable | None = None, + *, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: Hashable, # wrap -> FacetGrid + col: Hashable | None = None, + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap: str | Colormap | None = None, + center: float | Literal[False] | None = None, + robust: bool = False, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + infer_intervals=None, + colors: str | ArrayLike | None = None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, +) -> FacetGrid[T_DataArray]: ... + + +@_plot2d +def imshow( + x: np.ndarray, y: np.ndarray, z: np.ma.core.MaskedArray, ax: Axes, **kwargs: Any +) -> AxesImage: + """ + Image plot of 2D DataArray. + + Wraps :py:func:`matplotlib:matplotlib.pyplot.imshow`. + + While other plot methods require the DataArray to be strictly + two-dimensional, ``imshow`` also accepts a 3D array where some + dimension can be interpreted as RGB or RGBA color channels and + allows this dimension to be specified via the kwarg ``rgb=``. + + Unlike :py:func:`matplotlib:matplotlib.pyplot.imshow`, which ignores ``vmin``/``vmax`` + for RGB(A) data, + xarray *will* use ``vmin`` and ``vmax`` for RGB(A) data + by applying a single scaling factor and offset to all bands. + Passing ``robust=True`` infers ``vmin`` and ``vmax`` + :ref:`in the usual way `. + Additionally the y-axis is not inverted by default, you can + restore the matplotlib behavior by setting `yincrease=False`. + + .. note:: + This function needs uniformly spaced coordinates to + properly label the axes. Call :py:meth:`DataArray.plot` to check. + + The pixels are centered on the coordinates. For example, if the coordinate + value is 3.2, then the pixels for those coordinates will be centered on 3.2. + """ + + if x.ndim != 1 or y.ndim != 1: + raise ValueError( + "imshow requires 1D coordinates, try using pcolormesh or contour(f)" + ) + + def _center_pixels(x): + """Center the pixels on the coordinates.""" + if np.issubdtype(x.dtype, str): + # When using strings as inputs imshow converts it to + # integers. Choose extent values which puts the indices in + # in the center of the pixels: + return 0 - 0.5, len(x) - 0.5 + + try: + # Center the pixels assuming uniform spacing: + xstep = 0.5 * (x[1] - x[0]) + except IndexError: + # Arbitrary default value, similar to matplotlib behaviour: + xstep = 0.1 + + return x[0] - xstep, x[-1] + xstep + + # Center the pixels: + left, right = _center_pixels(x) + top, bottom = _center_pixels(y) + + defaults: dict[str, Any] = {"origin": "upper", "interpolation": "nearest"} + + if not hasattr(ax, "projection"): + # not for cartopy geoaxes + defaults["aspect"] = "auto" + + # Allow user to override these defaults + defaults.update(kwargs) + + if defaults["origin"] == "upper": + defaults["extent"] = [left, right, bottom, top] + else: + defaults["extent"] = [left, right, top, bottom] + + if z.ndim == 3: + # matplotlib imshow uses black for missing data, but Xarray makes + # missing data transparent. We therefore add an alpha channel if + # there isn't one, and set it to transparent where data is masked. + if z.shape[-1] == 3: + safe_dtype = np.promote_types(z.dtype, np.uint8) + alpha = np.ma.ones(z.shape[:2] + (1,), dtype=safe_dtype) + if np.issubdtype(z.dtype, np.integer): + alpha[:] = 255 + z = np.ma.concatenate((z, alpha), axis=2) + else: + z = z.copy() + z[np.any(z.mask, axis=-1), -1] = 0 + + primitive = ax.imshow(z, **defaults) + + # If x or y are strings the ticklabels have been replaced with + # integer indices. Replace them back to strings: + for axis, v in [("x", x), ("y", y)]: + if np.issubdtype(v.dtype, str): + getattr(ax, f"set_{axis}ticks")(np.arange(len(v))) + getattr(ax, f"set_{axis}ticklabels")(v) + + return primitive + + +@overload +def contour( # type: ignore[misc,unused-ignore] # None is hashable :( + darray: DataArray, + x: Hashable | None = None, + y: Hashable | None = None, + *, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: None = None, # no wrap -> primitive + col: None = None, # no wrap -> primitive + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap: str | Colormap | None = None, + center: float | Literal[False] | None = None, + robust: bool = False, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + infer_intervals=None, + colors: str | ArrayLike | None = None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, +) -> QuadContourSet: ... + + +@overload +def contour( + darray: T_DataArray, + x: Hashable | None = None, + y: Hashable | None = None, + *, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: Hashable | None = None, + col: Hashable, # wrap -> FacetGrid + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap: str | Colormap | None = None, + center: float | Literal[False] | None = None, + robust: bool = False, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + infer_intervals=None, + colors: str | ArrayLike | None = None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, +) -> FacetGrid[T_DataArray]: ... + + +@overload +def contour( + darray: T_DataArray, + x: Hashable | None = None, + y: Hashable | None = None, + *, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: Hashable, # wrap -> FacetGrid + col: Hashable | None = None, + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap: str | Colormap | None = None, + center: float | Literal[False] | None = None, + robust: bool = False, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + infer_intervals=None, + colors: str | ArrayLike | None = None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, +) -> FacetGrid[T_DataArray]: ... + + +@_plot2d +def contour( + x: np.ndarray, y: np.ndarray, z: np.ndarray, ax: Axes, **kwargs: Any +) -> QuadContourSet: + """ + Contour plot of 2D DataArray. + + Wraps :py:func:`matplotlib:matplotlib.pyplot.contour`. + """ + primitive = ax.contour(x, y, z, **kwargs) + return primitive + + +@overload +def contourf( # type: ignore[misc,unused-ignore] # None is hashable :( + darray: DataArray, + x: Hashable | None = None, + y: Hashable | None = None, + *, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: None = None, # no wrap -> primitive + col: None = None, # no wrap -> primitive + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap: str | Colormap | None = None, + center: float | Literal[False] | None = None, + robust: bool = False, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + infer_intervals=None, + colors: str | ArrayLike | None = None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, +) -> QuadContourSet: ... + + +@overload +def contourf( + darray: T_DataArray, + x: Hashable | None = None, + y: Hashable | None = None, + *, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: Hashable | None = None, + col: Hashable, # wrap -> FacetGrid + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap: str | Colormap | None = None, + center: float | Literal[False] | None = None, + robust: bool = False, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + infer_intervals=None, + colors: str | ArrayLike | None = None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, +) -> FacetGrid[T_DataArray]: ... + + +@overload +def contourf( + darray: T_DataArray, + x: Hashable | None = None, + y: Hashable | None = None, + *, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: Hashable, # wrap -> FacetGrid + col: Hashable | None = None, + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap: str | Colormap | None = None, + center: float | Literal[False] | None = None, + robust: bool = False, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + infer_intervals=None, + colors: str | ArrayLike | None = None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, +) -> FacetGrid[T_DataArray]: ... + + +@_plot2d +def contourf( + x: np.ndarray, y: np.ndarray, z: np.ndarray, ax: Axes, **kwargs: Any +) -> QuadContourSet: + """ + Filled contour plot of 2D DataArray. + + Wraps :py:func:`matplotlib:matplotlib.pyplot.contourf`. + """ + primitive = ax.contourf(x, y, z, **kwargs) + return primitive + + +@overload +def pcolormesh( # type: ignore[misc,unused-ignore] # None is hashable :( + darray: DataArray, + x: Hashable | None = None, + y: Hashable | None = None, + *, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: None = None, # no wrap -> primitive + col: None = None, # no wrap -> primitive + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap: str | Colormap | None = None, + center: float | Literal[False] | None = None, + robust: bool = False, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + infer_intervals=None, + colors: str | ArrayLike | None = None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, +) -> QuadMesh: ... + + +@overload +def pcolormesh( + darray: T_DataArray, + x: Hashable | None = None, + y: Hashable | None = None, + *, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: Hashable | None = None, + col: Hashable, # wrap -> FacetGrid + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap: str | Colormap | None = None, + center: float | Literal[False] | None = None, + robust: bool = False, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + infer_intervals=None, + colors: str | ArrayLike | None = None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, +) -> FacetGrid[T_DataArray]: ... + + +@overload +def pcolormesh( + darray: T_DataArray, + x: Hashable | None = None, + y: Hashable | None = None, + *, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: Hashable, # wrap -> FacetGrid + col: Hashable | None = None, + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap: str | Colormap | None = None, + center: float | Literal[False] | None = None, + robust: bool = False, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + infer_intervals=None, + colors: str | ArrayLike | None = None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, +) -> FacetGrid[T_DataArray]: ... + + +@_plot2d +def pcolormesh( + x: np.ndarray, + y: np.ndarray, + z: np.ndarray, + ax: Axes, + xscale: ScaleOptions | None = None, + yscale: ScaleOptions | None = None, + infer_intervals=None, + **kwargs: Any, +) -> QuadMesh: + """ + Pseudocolor plot of 2D DataArray. + + Wraps :py:func:`matplotlib:matplotlib.pyplot.pcolormesh`. + """ + + # decide on a default for infer_intervals (GH781) + x = np.asarray(x) + if infer_intervals is None: + if hasattr(ax, "projection"): + if len(x.shape) == 1: + infer_intervals = True + else: + infer_intervals = False + else: + infer_intervals = True + + if any(np.issubdtype(k.dtype, str) for k in (x, y)): + # do not infer intervals if any axis contains str ticks, see #6775 + infer_intervals = False + + if infer_intervals and ( + (np.shape(x)[0] == np.shape(z)[1]) + or ((x.ndim > 1) and (np.shape(x)[1] == np.shape(z)[1])) + ): + if x.ndim == 1: + x = _infer_interval_breaks(x, check_monotonic=True, scale=xscale) + else: + # we have to infer the intervals on both axes + x = _infer_interval_breaks(x, axis=1, scale=xscale) + x = _infer_interval_breaks(x, axis=0, scale=xscale) + + if infer_intervals and (np.shape(y)[0] == np.shape(z)[0]): + if y.ndim == 1: + y = _infer_interval_breaks(y, check_monotonic=True, scale=yscale) + else: + # we have to infer the intervals on both axes + y = _infer_interval_breaks(y, axis=1, scale=yscale) + y = _infer_interval_breaks(y, axis=0, scale=yscale) + + ax.grid(False) + primitive = ax.pcolormesh(x, y, z, **kwargs) + + # by default, pcolormesh picks "round" values for bounds + # this results in ugly looking plots with lots of surrounding whitespace + if not hasattr(ax, "projection") and x.ndim == 1 and y.ndim == 1: + # not a cartopy geoaxis + ax.set_xlim(x[0], x[-1]) + ax.set_ylim(y[0], y[-1]) + + return primitive + + +@overload +def surface( + darray: DataArray, + x: Hashable | None = None, + y: Hashable | None = None, + *, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: None = None, # no wrap -> primitive + col: None = None, # no wrap -> primitive + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap: str | Colormap | None = None, + center: float | Literal[False] | None = None, + robust: bool = False, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + infer_intervals=None, + colors: str | ArrayLike | None = None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, +) -> Poly3DCollection: ... + + +@overload +def surface( + darray: T_DataArray, + x: Hashable | None = None, + y: Hashable | None = None, + *, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: Hashable | None = None, + col: Hashable, # wrap -> FacetGrid + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap: str | Colormap | None = None, + center: float | Literal[False] | None = None, + robust: bool = False, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + infer_intervals=None, + colors: str | ArrayLike | None = None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, +) -> FacetGrid[T_DataArray]: ... + + +@overload +def surface( + darray: T_DataArray, + x: Hashable | None = None, + y: Hashable | None = None, + *, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: Hashable, # wrap -> FacetGrid + col: Hashable | None = None, + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap: str | Colormap | None = None, + center: float | Literal[False] | None = None, + robust: bool = False, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + infer_intervals=None, + colors: str | ArrayLike | None = None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, +) -> FacetGrid[T_DataArray]: ... + + +@_plot2d +def surface( + x: np.ndarray, y: np.ndarray, z: np.ndarray, ax: Axes, **kwargs: Any +) -> Poly3DCollection: + """ + Surface plot of 2D DataArray. + + Wraps :py:meth:`matplotlib:mpl_toolkits.mplot3d.axes3d.Axes3D.plot_surface`. + """ + import mpl_toolkits + + assert isinstance(ax, mpl_toolkits.mplot3d.axes3d.Axes3D) + primitive = ax.plot_surface(x, y, z, **kwargs) + return primitive diff --git a/test/fixtures/whole_applications/xarray/xarray/plot/dataset_plot.py b/test/fixtures/whole_applications/xarray/xarray/plot/dataset_plot.py new file mode 100644 index 0000000..edc2bf4 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/plot/dataset_plot.py @@ -0,0 +1,913 @@ +from __future__ import annotations + +import functools +import inspect +import warnings +from collections.abc import Hashable, Iterable +from typing import TYPE_CHECKING, Any, Callable, TypeVar, overload + +from xarray.core.alignment import broadcast +from xarray.plot import dataarray_plot +from xarray.plot.facetgrid import _easy_facetgrid +from xarray.plot.utils import ( + _add_colorbar, + _get_nice_quiver_magnitude, + _infer_meta_data, + _process_cmap_cbar_kwargs, + get_axis, +) + +if TYPE_CHECKING: + from matplotlib.axes import Axes + from matplotlib.collections import LineCollection, PathCollection + from matplotlib.colors import Colormap, Normalize + from matplotlib.quiver import Quiver + from numpy.typing import ArrayLike + + from xarray.core.dataarray import DataArray + from xarray.core.dataset import Dataset + from xarray.core.types import ( + AspectOptions, + ExtendOptions, + HueStyleOptions, + ScaleOptions, + ) + from xarray.plot.facetgrid import FacetGrid + + +def _dsplot(plotfunc): + commondoc = """ + Parameters + ---------- + + ds : Dataset + x : Hashable or None, optional + Variable name for x-axis. + y : Hashable or None, optional + Variable name for y-axis. + u : Hashable or None, optional + Variable name for the *u* velocity (in *x* direction). + quiver/streamplot plots only. + v : Hashable or None, optional + Variable name for the *v* velocity (in *y* direction). + quiver/streamplot plots only. + hue: Hashable or None, optional + Variable by which to color scatter points or arrows. + hue_style: {'continuous', 'discrete'} or None, optional + How to use the ``hue`` variable: + + - ``'continuous'`` -- continuous color scale + (default for numeric ``hue`` variables) + - ``'discrete'`` -- a color for each unique value, using the default color cycle + (default for non-numeric ``hue`` variables) + + row : Hashable or None, optional + If passed, make row faceted plots on this dimension name. + col : Hashable or None, optional + If passed, make column faceted plots on this dimension name. + col_wrap : int, optional + Use together with ``col`` to wrap faceted plots. + ax : matplotlib axes object or None, optional + If ``None``, use the current axes. Not applicable when using facets. + figsize : Iterable[float] or None, optional + A tuple (width, height) of the figure in inches. + Mutually exclusive with ``size`` and ``ax``. + size : scalar, optional + If provided, create a new figure for the plot with the given size. + Height (in inches) of each plot. See also: ``aspect``. + aspect : "auto", "equal", scalar or None, optional + Aspect ratio of plot, so that ``aspect * size`` gives the width in + inches. Only used if a ``size`` is provided. + sharex : bool or None, optional + If True all subplots share the same x-axis. + sharey : bool or None, optional + If True all subplots share the same y-axis. + add_guide: bool or None, optional + Add a guide that depends on ``hue_style``: + + - ``'continuous'`` -- build a colorbar + - ``'discrete'`` -- build a legend + + subplot_kws : dict or None, optional + Dictionary of keyword arguments for Matplotlib subplots + (see :py:meth:`matplotlib:matplotlib.figure.Figure.add_subplot`). + Only applies to FacetGrid plotting. + cbar_kwargs : dict, optional + Dictionary of keyword arguments to pass to the colorbar + (see :meth:`matplotlib:matplotlib.figure.Figure.colorbar`). + cbar_ax : matplotlib axes object, optional + Axes in which to draw the colorbar. + cmap : matplotlib colormap name or colormap, optional + The mapping from data values to color space. Either a + Matplotlib colormap name or object. If not provided, this will + be either ``'viridis'`` (if the function infers a sequential + dataset) or ``'RdBu_r'`` (if the function infers a diverging + dataset). + See :doc:`Choosing Colormaps in Matplotlib ` + for more information. + + If *seaborn* is installed, ``cmap`` may also be a + `seaborn color palette `_. + Note: if ``cmap`` is a seaborn color palette, + ``levels`` must also be specified. + vmin : float or None, optional + Lower value to anchor the colormap, otherwise it is inferred from the + data and other keyword arguments. When a diverging dataset is inferred, + setting `vmin` or `vmax` will fix the other by symmetry around + ``center``. Setting both values prevents use of a diverging colormap. + If discrete levels are provided as an explicit list, both of these + values are ignored. + vmax : float or None, optional + Upper value to anchor the colormap, otherwise it is inferred from the + data and other keyword arguments. When a diverging dataset is inferred, + setting `vmin` or `vmax` will fix the other by symmetry around + ``center``. Setting both values prevents use of a diverging colormap. + If discrete levels are provided as an explicit list, both of these + values are ignored. + norm : matplotlib.colors.Normalize, optional + If ``norm`` has ``vmin`` or ``vmax`` specified, the corresponding + kwarg must be ``None``. + infer_intervals: bool | None + If True the intervals are inferred. + center : float, optional + The value at which to center the colormap. Passing this value implies + use of a diverging colormap. Setting it to ``False`` prevents use of a + diverging colormap. + robust : bool, optional + If ``True`` and ``vmin`` or ``vmax`` are absent, the colormap range is + computed with 2nd and 98th percentiles instead of the extreme values. + colors : str or array-like of color-like, optional + A single color or a list of colors. The ``levels`` argument + is required. + extend : {'neither', 'both', 'min', 'max'}, optional + How to draw arrows extending the colorbar beyond its limits. If not + provided, ``extend`` is inferred from ``vmin``, ``vmax`` and the data limits. + levels : int or array-like, optional + Split the colormap (``cmap``) into discrete color intervals. If an integer + is provided, "nice" levels are chosen based on the data range: this can + imply that the final number of levels is not exactly the expected one. + Setting ``vmin`` and/or ``vmax`` with ``levels=N`` is equivalent to + setting ``levels=np.linspace(vmin, vmax, N)``. + **kwargs : optional + Additional keyword arguments to wrapped Matplotlib function. + """ + + # Build on the original docstring + plotfunc.__doc__ = f"{plotfunc.__doc__}\n{commondoc}" + + @functools.wraps( + plotfunc, assigned=("__module__", "__name__", "__qualname__", "__doc__") + ) + def newplotfunc( + ds: Dataset, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + u: Hashable | None = None, + v: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + row: Hashable | None = None, + col: Hashable | None = None, + col_wrap: int | None = None, + ax: Axes | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + sharex: bool = True, + sharey: bool = True, + add_guide: bool | None = None, + subplot_kws: dict[str, Any] | None = None, + cbar_kwargs: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cmap: str | Colormap | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + infer_intervals: bool | None = None, + center: float | None = None, + robust: bool | None = None, + colors: str | ArrayLike | None = None, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + **kwargs: Any, + ) -> Any: + if args: + # TODO: Deprecated since 2022.10: + msg = "Using positional arguments is deprecated for plot methods, use keyword arguments instead." + assert x is None + x = args[0] + if len(args) > 1: + assert y is None + y = args[1] + if len(args) > 2: + assert u is None + u = args[2] + if len(args) > 3: + assert v is None + v = args[3] + if len(args) > 4: + assert hue is None + hue = args[4] + if len(args) > 5: + raise ValueError(msg) + else: + warnings.warn(msg, DeprecationWarning, stacklevel=2) + del args + + _is_facetgrid = kwargs.pop("_is_facetgrid", False) + if _is_facetgrid: # facetgrid call + meta_data = kwargs.pop("meta_data") + else: + meta_data = _infer_meta_data( + ds, x, y, hue, hue_style, add_guide, funcname=plotfunc.__name__ + ) + + hue_style = meta_data["hue_style"] + + # handle facetgrids first + if col or row: + allargs = locals().copy() + allargs["plotfunc"] = globals()[plotfunc.__name__] + allargs["data"] = ds + # remove kwargs to avoid passing the information twice + for arg in ["meta_data", "kwargs", "ds"]: + del allargs[arg] + + return _easy_facetgrid(kind="dataset", **allargs, **kwargs) + + figsize = kwargs.pop("figsize", None) + ax = get_axis(figsize, size, aspect, ax) + + if hue_style == "continuous" and hue is not None: + if _is_facetgrid: + cbar_kwargs = meta_data["cbar_kwargs"] + cmap_params = meta_data["cmap_params"] + else: + cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( + plotfunc, ds[hue].values, **locals() + ) + + # subset that can be passed to scatter, hist2d + cmap_params_subset = { + vv: cmap_params[vv] for vv in ["vmin", "vmax", "norm", "cmap"] + } + + else: + cmap_params_subset = {} + + if (u is not None or v is not None) and plotfunc.__name__ not in ( + "quiver", + "streamplot", + ): + raise ValueError("u, v are only allowed for quiver or streamplot plots.") + + primitive = plotfunc( + ds=ds, + x=x, + y=y, + ax=ax, + u=u, + v=v, + hue=hue, + hue_style=hue_style, + cmap_params=cmap_params_subset, + **kwargs, + ) + + if _is_facetgrid: # if this was called from Facetgrid.map_dataset, + return primitive # finish here. Else, make labels + + if meta_data.get("xlabel", None): + ax.set_xlabel(meta_data.get("xlabel")) + if meta_data.get("ylabel", None): + ax.set_ylabel(meta_data.get("ylabel")) + + if meta_data["add_legend"]: + ax.legend(handles=primitive, title=meta_data.get("hue_label", None)) + if meta_data["add_colorbar"]: + cbar_kwargs = {} if cbar_kwargs is None else cbar_kwargs + if "label" not in cbar_kwargs: + cbar_kwargs["label"] = meta_data.get("hue_label", None) + _add_colorbar(primitive, ax, cbar_ax, cbar_kwargs, cmap_params) + + if meta_data["add_quiverkey"]: + magnitude = _get_nice_quiver_magnitude(ds[u], ds[v]) + units = ds[u].attrs.get("units", "") + ax.quiverkey( + primitive, + X=0.85, + Y=0.9, + U=magnitude, + label=f"{magnitude}\n{units}", + labelpos="E", + coordinates="figure", + ) + + if plotfunc.__name__ in ("quiver", "streamplot"): + title = ds[u]._title_for_slice() + else: + title = ds[x]._title_for_slice() + ax.set_title(title) + + return primitive + + # we want to actually expose the signature of newplotfunc + # and not the copied **kwargs from the plotfunc which + # functools.wraps adds, so delete the wrapped attr + del newplotfunc.__wrapped__ + + return newplotfunc + + +@overload +def quiver( # type: ignore[misc,unused-ignore] # None is hashable :( + ds: Dataset, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + u: Hashable | None = None, + v: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + col: None = None, # no wrap -> primitive + row: None = None, # no wrap -> primitive + ax: Axes | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + col_wrap: int | None = None, + sharex: bool = True, + sharey: bool = True, + aspect: AspectOptions = None, + subplot_kws: dict[str, Any] | None = None, + add_guide: bool | None = None, + cbar_kwargs: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + infer_intervals: bool | None = None, + center: float | None = None, + levels: ArrayLike | None = None, + robust: bool | None = None, + colors: str | ArrayLike | None = None, + extend: ExtendOptions = None, + cmap: str | Colormap | None = None, + **kwargs: Any, +) -> Quiver: ... + + +@overload +def quiver( + ds: Dataset, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + u: Hashable | None = None, + v: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + col: Hashable, # wrap -> FacetGrid + row: Hashable | None = None, + ax: Axes | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + col_wrap: int | None = None, + sharex: bool = True, + sharey: bool = True, + aspect: AspectOptions = None, + subplot_kws: dict[str, Any] | None = None, + add_guide: bool | None = None, + cbar_kwargs: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + infer_intervals: bool | None = None, + center: float | None = None, + levels: ArrayLike | None = None, + robust: bool | None = None, + colors: str | ArrayLike | None = None, + extend: ExtendOptions = None, + cmap: str | Colormap | None = None, + **kwargs: Any, +) -> FacetGrid[Dataset]: ... + + +@overload +def quiver( + ds: Dataset, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + u: Hashable | None = None, + v: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + col: Hashable | None = None, + row: Hashable, # wrap -> FacetGrid + ax: Axes | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + col_wrap: int | None = None, + sharex: bool = True, + sharey: bool = True, + aspect: AspectOptions = None, + subplot_kws: dict[str, Any] | None = None, + add_guide: bool | None = None, + cbar_kwargs: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + infer_intervals: bool | None = None, + center: float | None = None, + levels: ArrayLike | None = None, + robust: bool | None = None, + colors: str | ArrayLike | None = None, + extend: ExtendOptions = None, + cmap: str | Colormap | None = None, + **kwargs: Any, +) -> FacetGrid[Dataset]: ... + + +@_dsplot +def quiver( + ds: Dataset, + x: Hashable, + y: Hashable, + ax: Axes, + u: Hashable, + v: Hashable, + **kwargs: Any, +) -> Quiver: + """Quiver plot of Dataset variables. + + Wraps :py:func:`matplotlib:matplotlib.pyplot.quiver`. + """ + import matplotlib as mpl + + if x is None or y is None or u is None or v is None: + raise ValueError("Must specify x, y, u, v for quiver plots.") + + dx, dy, du, dv = broadcast(ds[x], ds[y], ds[u], ds[v]) + + args = [dx.values, dy.values, du.values, dv.values] + hue = kwargs.pop("hue") + cmap_params = kwargs.pop("cmap_params") + + if hue: + args.append(ds[hue].values) + + # TODO: Fix this by always returning a norm with vmin, vmax in cmap_params + if not cmap_params["norm"]: + cmap_params["norm"] = mpl.colors.Normalize( + cmap_params.pop("vmin"), cmap_params.pop("vmax") + ) + + kwargs.pop("hue_style") + kwargs.setdefault("pivot", "middle") + hdl = ax.quiver(*args, **kwargs, **cmap_params) + return hdl + + +@overload +def streamplot( # type: ignore[misc,unused-ignore] # None is hashable :( + ds: Dataset, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + u: Hashable | None = None, + v: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + col: None = None, # no wrap -> primitive + row: None = None, # no wrap -> primitive + ax: Axes | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + col_wrap: int | None = None, + sharex: bool = True, + sharey: bool = True, + aspect: AspectOptions = None, + subplot_kws: dict[str, Any] | None = None, + add_guide: bool | None = None, + cbar_kwargs: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + infer_intervals: bool | None = None, + center: float | None = None, + levels: ArrayLike | None = None, + robust: bool | None = None, + colors: str | ArrayLike | None = None, + extend: ExtendOptions = None, + cmap: str | Colormap | None = None, + **kwargs: Any, +) -> LineCollection: ... + + +@overload +def streamplot( + ds: Dataset, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + u: Hashable | None = None, + v: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + col: Hashable, # wrap -> FacetGrid + row: Hashable | None = None, + ax: Axes | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + col_wrap: int | None = None, + sharex: bool = True, + sharey: bool = True, + aspect: AspectOptions = None, + subplot_kws: dict[str, Any] | None = None, + add_guide: bool | None = None, + cbar_kwargs: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + infer_intervals: bool | None = None, + center: float | None = None, + levels: ArrayLike | None = None, + robust: bool | None = None, + colors: str | ArrayLike | None = None, + extend: ExtendOptions = None, + cmap: str | Colormap | None = None, + **kwargs: Any, +) -> FacetGrid[Dataset]: ... + + +@overload +def streamplot( + ds: Dataset, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + u: Hashable | None = None, + v: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + col: Hashable | None = None, + row: Hashable, # wrap -> FacetGrid + ax: Axes | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + col_wrap: int | None = None, + sharex: bool = True, + sharey: bool = True, + aspect: AspectOptions = None, + subplot_kws: dict[str, Any] | None = None, + add_guide: bool | None = None, + cbar_kwargs: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + infer_intervals: bool | None = None, + center: float | None = None, + levels: ArrayLike | None = None, + robust: bool | None = None, + colors: str | ArrayLike | None = None, + extend: ExtendOptions = None, + cmap: str | Colormap | None = None, + **kwargs: Any, +) -> FacetGrid[Dataset]: ... + + +@_dsplot +def streamplot( + ds: Dataset, + x: Hashable, + y: Hashable, + ax: Axes, + u: Hashable, + v: Hashable, + **kwargs: Any, +) -> LineCollection: + """Plot streamlines of Dataset variables. + + Wraps :py:func:`matplotlib:matplotlib.pyplot.streamplot`. + """ + import matplotlib as mpl + + if x is None or y is None or u is None or v is None: + raise ValueError("Must specify x, y, u, v for streamplot plots.") + + # Matplotlib's streamplot has strong restrictions on what x and y can be, so need to + # get arrays transposed the 'right' way around. 'x' cannot vary within 'rows', so + # the dimension of x must be the second dimension. 'y' cannot vary with 'columns' so + # the dimension of y must be the first dimension. If x and y are both 2d, assume the + # user has got them right already. + xdim = ds[x].dims[0] if len(ds[x].dims) == 1 else None + ydim = ds[y].dims[0] if len(ds[y].dims) == 1 else None + if xdim is not None and ydim is None: + ydims = set(ds[y].dims) - {xdim} + if len(ydims) == 1: + ydim = next(iter(ydims)) + if ydim is not None and xdim is None: + xdims = set(ds[x].dims) - {ydim} + if len(xdims) == 1: + xdim = next(iter(xdims)) + + dx, dy, du, dv = broadcast(ds[x], ds[y], ds[u], ds[v]) + + if xdim is not None and ydim is not None: + # Need to ensure the arrays are transposed correctly + dx = dx.transpose(ydim, xdim) + dy = dy.transpose(ydim, xdim) + du = du.transpose(ydim, xdim) + dv = dv.transpose(ydim, xdim) + + hue = kwargs.pop("hue") + cmap_params = kwargs.pop("cmap_params") + + if hue: + kwargs["color"] = ds[hue].values + + # TODO: Fix this by always returning a norm with vmin, vmax in cmap_params + if not cmap_params["norm"]: + cmap_params["norm"] = mpl.colors.Normalize( + cmap_params.pop("vmin"), cmap_params.pop("vmax") + ) + + kwargs.pop("hue_style") + hdl = ax.streamplot( + dx.values, dy.values, du.values, dv.values, **kwargs, **cmap_params + ) + + # Return .lines so colorbar creation works properly + return hdl.lines + + +F = TypeVar("F", bound=Callable) + + +def _update_doc_to_dataset(dataarray_plotfunc: Callable) -> Callable[[F], F]: + """ + Add a common docstring by re-using the DataArray one. + + TODO: Reduce code duplication. + + * The goal is to reduce code duplication by moving all Dataset + specific plots to the DataArray side and use this thin wrapper to + handle the conversion between Dataset and DataArray. + * Improve docstring handling, maybe reword the DataArray versions to + explain Datasets better. + + Parameters + ---------- + dataarray_plotfunc : Callable + Function that returns a finished plot primitive. + """ + + # Build on the original docstring + da_doc = dataarray_plotfunc.__doc__ + if da_doc is None: + raise NotImplementedError("DataArray plot method requires a docstring") + + da_str = """ + Parameters + ---------- + darray : DataArray + """ + ds_str = """ + + The `y` DataArray will be used as base, any other variables are added as coords. + + Parameters + ---------- + ds : Dataset + """ + # TODO: improve this? + if da_str in da_doc: + ds_doc = da_doc.replace(da_str, ds_str).replace("darray", "ds") + else: + ds_doc = da_doc + + @functools.wraps(dataarray_plotfunc) + def wrapper(dataset_plotfunc: F) -> F: + dataset_plotfunc.__doc__ = ds_doc + return dataset_plotfunc + + return wrapper + + +def _normalize_args( + plotmethod: str, args: tuple[Any, ...], kwargs: dict[str, Any] +) -> dict[str, Any]: + from xarray.core.dataarray import DataArray + + # Determine positional arguments keyword by inspecting the + # signature of the plotmethod: + locals_ = dict( + inspect.signature(getattr(DataArray().plot, plotmethod)) + .bind(*args, **kwargs) + .arguments.items() + ) + locals_.update(locals_.pop("kwargs", {})) + + return locals_ + + +def _temp_dataarray(ds: Dataset, y: Hashable, locals_: dict[str, Any]) -> DataArray: + """Create a temporary datarray with extra coords.""" + from xarray.core.dataarray import DataArray + + # Base coords: + coords = dict(ds.coords) + + # Add extra coords to the DataArray from valid kwargs, if using all + # kwargs there is a risk that we add unnecessary dataarrays as + # coords straining RAM further for example: + # ds.both and extend="both" would add ds.both to the coords: + valid_coord_kwargs = {"x", "z", "markersize", "hue", "row", "col", "u", "v"} + coord_kwargs = locals_.keys() & valid_coord_kwargs + for k in coord_kwargs: + key = locals_[k] + if ds.data_vars.get(key) is not None: + coords[key] = ds[key] + + # The dataarray has to include all the dims. Broadcast to that shape + # and add the additional coords: + _y = ds[y].broadcast_like(ds) + + return DataArray(_y, coords=coords) + + +@overload +def scatter( # type: ignore[misc,unused-ignore] # None is hashable :( + ds: Dataset, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + z: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: None = None, # no wrap -> primitive + col: None = None, # no wrap -> primitive + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | Iterable[bool] = True, + add_title: bool = True, + subplot_kws: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + cmap: str | Colormap | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + **kwargs: Any, +) -> PathCollection: ... + + +@overload +def scatter( + ds: Dataset, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + z: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: Hashable | None = None, + col: Hashable, # wrap -> FacetGrid + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | Iterable[bool] = True, + add_title: bool = True, + subplot_kws: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + cmap: str | Colormap | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + **kwargs: Any, +) -> FacetGrid[DataArray]: ... + + +@overload +def scatter( + ds: Dataset, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + z: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: Hashable, # wrap -> FacetGrid + col: Hashable | None = None, + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | Iterable[bool] = True, + add_title: bool = True, + subplot_kws: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + cmap: str | Colormap | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + **kwargs: Any, +) -> FacetGrid[DataArray]: ... + + +@_update_doc_to_dataset(dataarray_plot.scatter) +def scatter( + ds: Dataset, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + z: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: Hashable | None = None, + col: Hashable | None = None, + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | Iterable[bool] = True, + add_title: bool = True, + subplot_kws: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + cmap: str | Colormap | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + **kwargs: Any, +) -> PathCollection | FacetGrid[DataArray]: + """Scatter plot Dataset data variables against each other.""" + locals_ = locals() + del locals_["ds"] + locals_.update(locals_.pop("kwargs", {})) + da = _temp_dataarray(ds, y, locals_) + + return da.plot.scatter(*locals_.pop("args", ()), **locals_) diff --git a/test/fixtures/whole_applications/xarray/xarray/plot/facetgrid.py b/test/fixtures/whole_applications/xarray/xarray/plot/facetgrid.py new file mode 100644 index 0000000..faf809a --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/plot/facetgrid.py @@ -0,0 +1,1074 @@ +from __future__ import annotations + +import functools +import itertools +import warnings +from collections.abc import Hashable, Iterable, MutableMapping +from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, TypeVar, cast + +import numpy as np + +from xarray.core.formatting import format_item +from xarray.core.types import HueStyleOptions, T_DataArrayOrSet +from xarray.plot.utils import ( + _LINEWIDTH_RANGE, + _MARKERSIZE_RANGE, + _add_legend, + _determine_guide, + _get_nice_quiver_magnitude, + _guess_coords_to_plot, + _infer_xy_labels, + _Normalize, + _parse_size, + _process_cmap_cbar_kwargs, + label_from_attrs, +) + +if TYPE_CHECKING: + from matplotlib.axes import Axes + from matplotlib.cm import ScalarMappable + from matplotlib.colorbar import Colorbar + from matplotlib.figure import Figure + from matplotlib.legend import Legend + from matplotlib.quiver import QuiverKey + from matplotlib.text import Annotation + + from xarray.core.dataarray import DataArray + + +# Overrides axes.labelsize, xtick.major.size, ytick.major.size +# from mpl.rcParams +_FONTSIZE = "small" +# For major ticks on x, y axes +_NTICKS = 5 + + +def _nicetitle(coord, value, maxchar, template): + """ + Put coord, value in template and truncate at maxchar + """ + prettyvalue = format_item(value, quote_strings=False) + title = template.format(coord=coord, value=prettyvalue) + + if len(title) > maxchar: + title = title[: (maxchar - 3)] + "..." + + return title + + +T_FacetGrid = TypeVar("T_FacetGrid", bound="FacetGrid") + + +class FacetGrid(Generic[T_DataArrayOrSet]): + """ + Initialize the Matplotlib figure and FacetGrid object. + + The :class:`FacetGrid` is an object that links a xarray DataArray to + a Matplotlib figure with a particular structure. + + In particular, :class:`FacetGrid` is used to draw plots with multiple + axes, where each axes shows the same relationship conditioned on + different levels of some dimension. It's possible to condition on up to + two variables by assigning variables to the rows and columns of the + grid. + + The general approach to plotting here is called "small multiples", + where the same kind of plot is repeated multiple times, and the + specific use of small multiples to display the same relationship + conditioned on one or more other variables is often called a "trellis + plot". + + The basic workflow is to initialize the :class:`FacetGrid` object with + the DataArray and the variable names that are used to structure the grid. + Then plotting functions can be applied to each subset by calling + :meth:`FacetGrid.map_dataarray` or :meth:`FacetGrid.map`. + + Attributes + ---------- + axs : ndarray of matplotlib.axes.Axes + Array containing axes in corresponding position, as returned from + :py:func:`matplotlib.pyplot.subplots`. + col_labels : list of matplotlib.text.Annotation + Column titles. + row_labels : list of matplotlib.text.Annotation + Row titles. + fig : matplotlib.figure.Figure + The figure containing all the axes. + name_dicts : ndarray of dict + Array containing dictionaries mapping coordinate names to values. ``None`` is + used as a sentinel value for axes that should remain empty, i.e., + sometimes the rightmost grid positions in the bottom row. + """ + + data: T_DataArrayOrSet + name_dicts: np.ndarray + fig: Figure + axs: np.ndarray + row_names: list[np.ndarray] + col_names: list[np.ndarray] + figlegend: Legend | None + quiverkey: QuiverKey | None + cbar: Colorbar | None + _single_group: bool | Hashable + _nrow: int + _row_var: Hashable | None + _ncol: int + _col_var: Hashable | None + _col_wrap: int | None + row_labels: list[Annotation | None] + col_labels: list[Annotation | None] + _x_var: None + _y_var: None + _cmap_extend: Any | None + _mappables: list[ScalarMappable] + _finalized: bool + + def __init__( + self, + data: T_DataArrayOrSet, + col: Hashable | None = None, + row: Hashable | None = None, + col_wrap: int | None = None, + sharex: bool = True, + sharey: bool = True, + figsize: Iterable[float] | None = None, + aspect: float = 1, + size: float = 3, + subplot_kws: dict[str, Any] | None = None, + ) -> None: + """ + Parameters + ---------- + data : DataArray or Dataset + DataArray or Dataset to be plotted. + row, col : str + Dimension names that define subsets of the data, which will be drawn + on separate facets in the grid. + col_wrap : int, optional + "Wrap" the grid the for the column variable after this number of columns, + adding rows if ``col_wrap`` is less than the number of facets. + sharex : bool, optional + If true, the facets will share *x* axes. + sharey : bool, optional + If true, the facets will share *y* axes. + figsize : Iterable of float or None, optional + A tuple (width, height) of the figure in inches. + If set, overrides ``size`` and ``aspect``. + aspect : scalar, default: 1 + Aspect ratio of each facet, so that ``aspect * size`` gives the + width of each facet in inches. + size : scalar, default: 3 + Height (in inches) of each facet. See also: ``aspect``. + subplot_kws : dict, optional + Dictionary of keyword arguments for Matplotlib subplots + (:py:func:`matplotlib.pyplot.subplots`). + + """ + + import matplotlib.pyplot as plt + + # Handle corner case of nonunique coordinates + rep_col = col is not None and not data[col].to_index().is_unique + rep_row = row is not None and not data[row].to_index().is_unique + if rep_col or rep_row: + raise ValueError( + "Coordinates used for faceting cannot " + "contain repeated (nonunique) values." + ) + + # single_group is the grouping variable, if there is exactly one + single_group: bool | Hashable + if col and row: + single_group = False + nrow = len(data[row]) + ncol = len(data[col]) + nfacet = nrow * ncol + if col_wrap is not None: + warnings.warn("Ignoring col_wrap since both col and row were passed") + elif row and not col: + single_group = row + elif not row and col: + single_group = col + else: + raise ValueError("Pass a coordinate name as an argument for row or col") + + # Compute grid shape + if single_group: + nfacet = len(data[single_group]) + if col: + # idea - could add heuristic for nice shapes like 3x4 + ncol = nfacet + if row: + ncol = 1 + if col_wrap is not None: + # Overrides previous settings + ncol = col_wrap + nrow = int(np.ceil(nfacet / ncol)) + + # Set the subplot kwargs + subplot_kws = {} if subplot_kws is None else subplot_kws + + if figsize is None: + # Calculate the base figure size with extra horizontal space for a + # colorbar + cbar_space = 1 + figsize = (ncol * size * aspect + cbar_space, nrow * size) + + fig, axs = plt.subplots( + nrow, + ncol, + sharex=sharex, + sharey=sharey, + squeeze=False, + figsize=figsize, + subplot_kw=subplot_kws, + ) + + # Set up the lists of names for the row and column facet variables + col_names = list(data[col].to_numpy()) if col else [] + row_names = list(data[row].to_numpy()) if row else [] + + if single_group: + full: list[dict[Hashable, Any] | None] = [ + {single_group: x} for x in data[single_group].to_numpy() + ] + empty: list[dict[Hashable, Any] | None] = [ + None for x in range(nrow * ncol - len(full)) + ] + name_dict_list = full + empty + else: + rowcols = itertools.product(row_names, col_names) + name_dict_list = [{row: r, col: c} for r, c in rowcols] + + name_dicts = np.array(name_dict_list).reshape(nrow, ncol) + + # Set up the class attributes + # --------------------------- + + # First the public API + self.data = data + self.name_dicts = name_dicts + self.fig = fig + self.axs = axs + self.row_names = row_names + self.col_names = col_names + + # guides + self.figlegend = None + self.quiverkey = None + self.cbar = None + + # Next the private variables + self._single_group = single_group + self._nrow = nrow + self._row_var = row + self._ncol = ncol + self._col_var = col + self._col_wrap = col_wrap + self.row_labels = [None] * nrow + self.col_labels = [None] * ncol + self._x_var = None + self._y_var = None + self._cmap_extend = None + self._mappables = [] + self._finalized = False + + @property + def axes(self) -> np.ndarray: + warnings.warn( + ( + "self.axes is deprecated since 2022.11 in order to align with " + "matplotlibs plt.subplots, use self.axs instead." + ), + DeprecationWarning, + stacklevel=2, + ) + return self.axs + + @axes.setter + def axes(self, axs: np.ndarray) -> None: + warnings.warn( + ( + "self.axes is deprecated since 2022.11 in order to align with " + "matplotlibs plt.subplots, use self.axs instead." + ), + DeprecationWarning, + stacklevel=2, + ) + self.axs = axs + + @property + def _left_axes(self) -> np.ndarray: + return self.axs[:, 0] + + @property + def _bottom_axes(self) -> np.ndarray: + return self.axs[-1, :] + + def map_dataarray( + self: T_FacetGrid, + func: Callable, + x: Hashable | None, + y: Hashable | None, + **kwargs: Any, + ) -> T_FacetGrid: + """ + Apply a plotting function to a 2d facet's subset of the data. + + This is more convenient and less general than ``FacetGrid.map`` + + Parameters + ---------- + func : callable + A plotting function with the same signature as a 2d xarray + plotting method such as `xarray.plot.imshow` + x, y : string + Names of the coordinates to plot on x, y axes + **kwargs + additional keyword arguments to func + + Returns + ------- + self : FacetGrid object + + """ + + if kwargs.get("cbar_ax", None) is not None: + raise ValueError("cbar_ax not supported by FacetGrid.") + + cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( + func, self.data.to_numpy(), **kwargs + ) + + self._cmap_extend = cmap_params.get("extend") + + # Order is important + func_kwargs = { + k: v + for k, v in kwargs.items() + if k not in {"cmap", "colors", "cbar_kwargs", "levels"} + } + func_kwargs.update(cmap_params) + func_kwargs["add_colorbar"] = False + if func.__name__ != "surface": + func_kwargs["add_labels"] = False + + # Get x, y labels for the first subplot + x, y = _infer_xy_labels( + darray=self.data.loc[self.name_dicts.flat[0]], + x=x, + y=y, + imshow=func.__name__ == "imshow", + rgb=kwargs.get("rgb", None), + ) + + for d, ax in zip(self.name_dicts.flat, self.axs.flat): + # None is the sentinel value + if d is not None: + subset = self.data.loc[d] + mappable = func( + subset, x=x, y=y, ax=ax, **func_kwargs, _is_facetgrid=True + ) + self._mappables.append(mappable) + + self._finalize_grid(x, y) + + if kwargs.get("add_colorbar", True): + self.add_colorbar(**cbar_kwargs) + + return self + + def map_plot1d( + self: T_FacetGrid, + func: Callable, + x: Hashable | None, + y: Hashable | None, + *, + z: Hashable | None = None, + hue: Hashable | None = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + **kwargs: Any, + ) -> T_FacetGrid: + """ + Apply a plotting function to a 1d facet's subset of the data. + + This is more convenient and less general than ``FacetGrid.map`` + + Parameters + ---------- + func : + A plotting function with the same signature as a 1d xarray + plotting method such as `xarray.plot.scatter` + x, y : + Names of the coordinates to plot on x, y axes + **kwargs + additional keyword arguments to func + + Returns + ------- + self : FacetGrid object + + """ + # Copy data to allow converting categoricals to integers and storing + # them in self.data. It is not possible to copy in the init + # unfortunately as there are tests that relies on self.data being + # mutable (test_names_appear_somewhere()). Maybe something to deprecate + # not sure how much that is used outside these tests. + self.data = self.data.copy() + + if kwargs.get("cbar_ax", None) is not None: + raise ValueError("cbar_ax not supported by FacetGrid.") + + if func.__name__ == "scatter": + size_ = kwargs.pop("_size", markersize) + size_r = _MARKERSIZE_RANGE + else: + size_ = kwargs.pop("_size", linewidth) + size_r = _LINEWIDTH_RANGE + + # Guess what coords to use if some of the values in coords_to_plot are None: + coords_to_plot: MutableMapping[str, Hashable | None] = dict( + x=x, z=z, hue=hue, size=size_ + ) + coords_to_plot = _guess_coords_to_plot(self.data, coords_to_plot, kwargs) + + # Handle hues: + hue = coords_to_plot["hue"] + hueplt = self.data.coords[hue] if hue else None # TODO: _infer_line_data2 ? + hueplt_norm = _Normalize(hueplt) + self._hue_var = hueplt + cbar_kwargs = kwargs.pop("cbar_kwargs", {}) + if hueplt_norm.data is not None: + if not hueplt_norm.data_is_numeric: + # TODO: Ticks seems a little too hardcoded, since it will always + # show all the values. But maybe it's ok, since plotting hundreds + # of categorical data isn't that meaningful anyway. + cbar_kwargs.update(format=hueplt_norm.format, ticks=hueplt_norm.ticks) + kwargs.update(levels=hueplt_norm.levels) + + cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( + func, + cast("DataArray", hueplt_norm.values).data, + cbar_kwargs=cbar_kwargs, + **kwargs, + ) + self._cmap_extend = cmap_params.get("extend") + else: + cmap_params = {} + + # Handle sizes: + size_ = coords_to_plot["size"] + sizeplt = self.data.coords[size_] if size_ else None + sizeplt_norm = _Normalize(data=sizeplt, width=size_r) + if sizeplt_norm.data is not None: + self.data[size_] = sizeplt_norm.values + + # Add kwargs that are sent to the plotting function, # order is important ??? + func_kwargs = { + k: v + for k, v in kwargs.items() + if k not in {"cmap", "colors", "cbar_kwargs", "levels"} + } + func_kwargs.update(cmap_params) + # Annotations will be handled later, skip those parts in the plotfunc: + func_kwargs["add_colorbar"] = False + func_kwargs["add_legend"] = False + func_kwargs["add_title"] = False + + add_labels_ = np.zeros(self.axs.shape + (3,), dtype=bool) + if kwargs.get("z") is not None: + # 3d plots looks better with all labels. 3d plots can't sharex either so it + # is easy to get lost while rotating the plots: + add_labels_[:] = True + else: + # Subplots should have labels on the left and bottom edges only: + add_labels_[-1, :, 0] = True # x + add_labels_[:, 0, 1] = True # y + # add_labels_[:, :, 2] = True # z + + # Set up the lists of names for the row and column facet variables: + if self._single_group: + full = tuple( + {self._single_group: x} + for x in range(0, self.data[self._single_group].size) + ) + empty = tuple(None for x in range(self._nrow * self._ncol - len(full))) + name_d = full + empty + else: + rowcols = itertools.product( + range(0, self.data[self._row_var].size), + range(0, self.data[self._col_var].size), + ) + name_d = tuple({self._row_var: r, self._col_var: c} for r, c in rowcols) + name_dicts = np.array(name_d).reshape(self._nrow, self._ncol) + + # Plot the data for each subplot: + for add_lbls, d, ax in zip( + add_labels_.reshape((self.axs.size, -1)), name_dicts.flat, self.axs.flat + ): + func_kwargs["add_labels"] = add_lbls + # None is the sentinel value + if d is not None: + subset = self.data.isel(d) + mappable = func( + subset, + x=x, + y=y, + ax=ax, + hue=hue, + _size=size_, + **func_kwargs, + _is_facetgrid=True, + ) + self._mappables.append(mappable) + + # Add titles and some touch ups: + self._finalize_grid() + self._set_lims() + + add_colorbar, add_legend = _determine_guide( + hueplt_norm, + sizeplt_norm, + kwargs.get("add_colorbar", None), + kwargs.get("add_legend", None), + # kwargs.get("add_guide", None), + # kwargs.get("hue_style", None), + ) + + if add_legend: + use_legend_elements = False if func.__name__ == "hist" else True + if use_legend_elements: + self.add_legend( + use_legend_elements=use_legend_elements, + hueplt_norm=hueplt_norm if not add_colorbar else _Normalize(None), + sizeplt_norm=sizeplt_norm, + primitive=self._mappables, + legend_ax=self.fig, + plotfunc=func.__name__, + ) + else: + self.add_legend(use_legend_elements=use_legend_elements) + + if add_colorbar: + # Colorbar is after legend so it correctly fits the plot: + if "label" not in cbar_kwargs: + cbar_kwargs["label"] = label_from_attrs(hueplt_norm.data) + + self.add_colorbar(**cbar_kwargs) + + return self + + def map_dataarray_line( + self: T_FacetGrid, + func: Callable, + x: Hashable | None, + y: Hashable | None, + hue: Hashable | None, + add_legend: bool = True, + _labels=None, + **kwargs: Any, + ) -> T_FacetGrid: + from xarray.plot.dataarray_plot import _infer_line_data + + for d, ax in zip(self.name_dicts.flat, self.axs.flat): + # None is the sentinel value + if d is not None: + subset = self.data.loc[d] + mappable = func( + subset, + x=x, + y=y, + ax=ax, + hue=hue, + add_legend=False, + _labels=False, + **kwargs, + ) + self._mappables.append(mappable) + + xplt, yplt, hueplt, huelabel = _infer_line_data( + darray=self.data.loc[self.name_dicts.flat[0]], x=x, y=y, hue=hue + ) + xlabel = label_from_attrs(xplt) + ylabel = label_from_attrs(yplt) + + self._hue_var = hueplt + self._finalize_grid(xlabel, ylabel) + + if add_legend and hueplt is not None and huelabel is not None: + self.add_legend(label=huelabel) + + return self + + def map_dataset( + self: T_FacetGrid, + func: Callable, + x: Hashable | None = None, + y: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + add_guide: bool | None = None, + **kwargs: Any, + ) -> T_FacetGrid: + from xarray.plot.dataset_plot import _infer_meta_data + + kwargs["add_guide"] = False + + if kwargs.get("markersize", None): + kwargs["size_mapping"] = _parse_size( + self.data[kwargs["markersize"]], kwargs.pop("size_norm", None) + ) + + meta_data = _infer_meta_data( + self.data, x, y, hue, hue_style, add_guide, funcname=func.__name__ + ) + kwargs["meta_data"] = meta_data + + if hue and meta_data["hue_style"] == "continuous": + cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( + func, self.data[hue].to_numpy(), **kwargs + ) + kwargs["meta_data"]["cmap_params"] = cmap_params + kwargs["meta_data"]["cbar_kwargs"] = cbar_kwargs + + kwargs["_is_facetgrid"] = True + + if func.__name__ == "quiver" and "scale" not in kwargs: + raise ValueError("Please provide scale.") + # TODO: come up with an algorithm for reasonable scale choice + + for d, ax in zip(self.name_dicts.flat, self.axs.flat): + # None is the sentinel value + if d is not None: + subset = self.data.loc[d] + maybe_mappable = func( + ds=subset, x=x, y=y, hue=hue, hue_style=hue_style, ax=ax, **kwargs + ) + # TODO: this is needed to get legends to work. + # but maybe_mappable is a list in that case :/ + self._mappables.append(maybe_mappable) + + self._finalize_grid(meta_data["xlabel"], meta_data["ylabel"]) + + if hue: + hue_label = meta_data.pop("hue_label", None) + self._hue_label = hue_label + if meta_data["add_legend"]: + self._hue_var = meta_data["hue"] + self.add_legend(label=hue_label) + elif meta_data["add_colorbar"]: + self.add_colorbar(label=hue_label, **cbar_kwargs) + + if meta_data["add_quiverkey"]: + self.add_quiverkey(kwargs["u"], kwargs["v"]) + + return self + + def _finalize_grid(self, *axlabels: Hashable) -> None: + """Finalize the annotations and layout.""" + if not self._finalized: + self.set_axis_labels(*axlabels) + self.set_titles() + self.fig.tight_layout() + + for ax, namedict in zip(self.axs.flat, self.name_dicts.flat): + if namedict is None: + ax.set_visible(False) + + self._finalized = True + + def _adjust_fig_for_guide(self, guide) -> None: + # Draw the plot to set the bounding boxes correctly + if hasattr(self.fig.canvas, "get_renderer"): + renderer = self.fig.canvas.get_renderer() + else: + raise RuntimeError("MPL backend has no renderer") + self.fig.draw(renderer) + + # Calculate and set the new width of the figure so the legend fits + guide_width = guide.get_window_extent(renderer).width / self.fig.dpi + figure_width = self.fig.get_figwidth() + total_width = figure_width + guide_width + self.fig.set_figwidth(total_width) + + # Draw the plot again to get the new transformations + self.fig.draw(renderer) + + # Now calculate how much space we need on the right side + guide_width = guide.get_window_extent(renderer).width / self.fig.dpi + space_needed = guide_width / total_width + 0.02 + # margin = .01 + # _space_needed = margin + space_needed + right = 1 - space_needed + + # Place the subplot axes to give space for the legend + self.fig.subplots_adjust(right=right) + + def add_legend( + self, + *, + label: str | None = None, + use_legend_elements: bool = False, + **kwargs: Any, + ) -> None: + if use_legend_elements: + self.figlegend = _add_legend(**kwargs) + else: + self.figlegend = self.fig.legend( + handles=self._mappables[-1], + labels=list(self._hue_var.to_numpy()), + title=label if label is not None else label_from_attrs(self._hue_var), + loc=kwargs.pop("loc", "center right"), + **kwargs, + ) + self._adjust_fig_for_guide(self.figlegend) + + def add_colorbar(self, **kwargs: Any) -> None: + """Draw a colorbar.""" + kwargs = kwargs.copy() + if self._cmap_extend is not None: + kwargs.setdefault("extend", self._cmap_extend) + # dont pass extend as kwarg if it is in the mappable + if hasattr(self._mappables[-1], "extend"): + kwargs.pop("extend", None) + if "label" not in kwargs: + from xarray import DataArray + + assert isinstance(self.data, DataArray) + kwargs.setdefault("label", label_from_attrs(self.data)) + self.cbar = self.fig.colorbar( + self._mappables[-1], ax=list(self.axs.flat), **kwargs + ) + + def add_quiverkey(self, u: Hashable, v: Hashable, **kwargs: Any) -> None: + kwargs = kwargs.copy() + + magnitude = _get_nice_quiver_magnitude(self.data[u], self.data[v]) + units = self.data[u].attrs.get("units", "") + self.quiverkey = self.axs.flat[-1].quiverkey( + self._mappables[-1], + X=0.8, + Y=0.9, + U=magnitude, + label=f"{magnitude}\n{units}", + labelpos="E", + coordinates="figure", + ) + + # TODO: does not work because self.quiverkey.get_window_extent(renderer) = 0 + # https://github.com/matplotlib/matplotlib/issues/18530 + # self._adjust_fig_for_guide(self.quiverkey.text) + + def _get_largest_lims(self) -> dict[str, tuple[float, float]]: + """ + Get largest limits in the facetgrid. + + Returns + ------- + lims_largest : dict[str, tuple[float, float]] + Dictionary with the largest limits along each axis. + + Examples + -------- + >>> ds = xr.tutorial.scatter_example_dataset(seed=42) + >>> fg = ds.plot.scatter(x="A", y="B", hue="y", row="x", col="w") + >>> round(fg._get_largest_lims()["x"][0], 3) + -0.334 + """ + lims_largest: dict[str, tuple[float, float]] = dict( + x=(np.inf, -np.inf), y=(np.inf, -np.inf), z=(np.inf, -np.inf) + ) + for axis in ("x", "y", "z"): + # Find the plot with the largest xlim values: + lower, upper = lims_largest[axis] + for ax in self.axs.flat: + get_lim: None | Callable[[], tuple[float, float]] = getattr( + ax, f"get_{axis}lim", None + ) + if get_lim: + lower_new, upper_new = get_lim() + lower, upper = (min(lower, lower_new), max(upper, upper_new)) + lims_largest[axis] = (lower, upper) + + return lims_largest + + def _set_lims( + self, + x: tuple[float, float] | None = None, + y: tuple[float, float] | None = None, + z: tuple[float, float] | None = None, + ) -> None: + """ + Set the same limits for all the subplots in the facetgrid. + + Parameters + ---------- + x : tuple[float, float] or None, optional + x axis limits. + y : tuple[float, float] or None, optional + y axis limits. + z : tuple[float, float] or None, optional + z axis limits. + + Examples + -------- + >>> ds = xr.tutorial.scatter_example_dataset(seed=42) + >>> fg = ds.plot.scatter(x="A", y="B", hue="y", row="x", col="w") + >>> fg._set_lims(x=(-0.3, 0.3), y=(0, 2), z=(0, 4)) + >>> fg.axs[0, 0].get_xlim(), fg.axs[0, 0].get_ylim() + ((-0.3, 0.3), (0.0, 2.0)) + """ + lims_largest = self._get_largest_lims() + + # Set limits: + for ax in self.axs.flat: + for (axis, data_limit), parameter_limit in zip( + lims_largest.items(), (x, y, z) + ): + set_lim = getattr(ax, f"set_{axis}lim", None) + if set_lim: + set_lim(data_limit if parameter_limit is None else parameter_limit) + + def set_axis_labels(self, *axlabels: Hashable) -> None: + """Set axis labels on the left column and bottom row of the grid.""" + from xarray.core.dataarray import DataArray + + for var, axis in zip(axlabels, ["x", "y", "z"]): + if var is not None: + if isinstance(var, DataArray): + getattr(self, f"set_{axis}labels")(label_from_attrs(var)) + else: + getattr(self, f"set_{axis}labels")(str(var)) + + def _set_labels( + self, axis: str, axes: Iterable, label: str | None = None, **kwargs + ) -> None: + if label is None: + label = label_from_attrs(self.data[getattr(self, f"_{axis}_var")]) + for ax in axes: + getattr(ax, f"set_{axis}label")(label, **kwargs) + + def set_xlabels(self, label: None | str = None, **kwargs: Any) -> None: + """Label the x axis on the bottom row of the grid.""" + self._set_labels("x", self._bottom_axes, label, **kwargs) + + def set_ylabels(self, label: None | str = None, **kwargs: Any) -> None: + """Label the y axis on the left column of the grid.""" + self._set_labels("y", self._left_axes, label, **kwargs) + + def set_zlabels(self, label: None | str = None, **kwargs: Any) -> None: + """Label the z axis.""" + self._set_labels("z", self._left_axes, label, **kwargs) + + def set_titles( + self, + template: str = "{coord} = {value}", + maxchar: int = 30, + size=None, + **kwargs, + ) -> None: + """ + Draw titles either above each facet or on the grid margins. + + Parameters + ---------- + template : str, default: "{coord} = {value}" + Template for plot titles containing {coord} and {value} + maxchar : int, default: 30 + Truncate titles at maxchar + **kwargs : keyword args + additional arguments to matplotlib.text + + Returns + ------- + self: FacetGrid object + + """ + import matplotlib as mpl + + if size is None: + size = mpl.rcParams["axes.labelsize"] + + nicetitle = functools.partial(_nicetitle, maxchar=maxchar, template=template) + + if self._single_group: + for d, ax in zip(self.name_dicts.flat, self.axs.flat): + # Only label the ones with data + if d is not None: + coord, value = list(d.items()).pop() + title = nicetitle(coord, value, maxchar=maxchar) + ax.set_title(title, size=size, **kwargs) + else: + # The row titles on the right edge of the grid + for index, (ax, row_name, handle) in enumerate( + zip(self.axs[:, -1], self.row_names, self.row_labels) + ): + title = nicetitle(coord=self._row_var, value=row_name, maxchar=maxchar) + if not handle: + self.row_labels[index] = ax.annotate( + title, + xy=(1.02, 0.5), + xycoords="axes fraction", + rotation=270, + ha="left", + va="center", + **kwargs, + ) + else: + handle.set_text(title) + handle.update(kwargs) + + # The column titles on the top row + for index, (ax, col_name, handle) in enumerate( + zip(self.axs[0, :], self.col_names, self.col_labels) + ): + title = nicetitle(coord=self._col_var, value=col_name, maxchar=maxchar) + if not handle: + self.col_labels[index] = ax.set_title(title, size=size, **kwargs) + else: + handle.set_text(title) + handle.update(kwargs) + + def set_ticks( + self, + max_xticks: int = _NTICKS, + max_yticks: int = _NTICKS, + fontsize: str | int = _FONTSIZE, + ) -> None: + """ + Set and control tick behavior. + + Parameters + ---------- + max_xticks, max_yticks : int, optional + Maximum number of labeled ticks to plot on x, y axes + fontsize : string or int + Font size as used by matplotlib text + + Returns + ------- + self : FacetGrid object + + """ + from matplotlib.ticker import MaxNLocator + + # Both are necessary + x_major_locator = MaxNLocator(nbins=max_xticks) + y_major_locator = MaxNLocator(nbins=max_yticks) + + for ax in self.axs.flat: + ax.xaxis.set_major_locator(x_major_locator) + ax.yaxis.set_major_locator(y_major_locator) + for tick in itertools.chain( + ax.xaxis.get_major_ticks(), ax.yaxis.get_major_ticks() + ): + tick.label1.set_fontsize(fontsize) + + def map( + self: T_FacetGrid, func: Callable, *args: Hashable, **kwargs: Any + ) -> T_FacetGrid: + """ + Apply a plotting function to each facet's subset of the data. + + Parameters + ---------- + func : callable + A plotting function that takes data and keyword arguments. It + must plot to the currently active matplotlib Axes and take a + `color` keyword argument. If faceting on the `hue` dimension, + it must also take a `label` keyword argument. + *args : Hashable + Column names in self.data that identify variables with data to + plot. The data for each variable is passed to `func` in the + order the variables are specified in the call. + **kwargs : keyword arguments + All keyword arguments are passed to the plotting function. + + Returns + ------- + self : FacetGrid object + + """ + import matplotlib.pyplot as plt + + for ax, namedict in zip(self.axs.flat, self.name_dicts.flat): + if namedict is not None: + data = self.data.loc[namedict] + plt.sca(ax) + innerargs = [data[a].to_numpy() for a in args] + maybe_mappable = func(*innerargs, **kwargs) + # TODO: better way to verify that an artist is mappable? + # https://stackoverflow.com/questions/33023036/is-it-possible-to-detect-if-a-matplotlib-artist-is-a-mappable-suitable-for-use-w#33023522 + if maybe_mappable and hasattr(maybe_mappable, "autoscale_None"): + self._mappables.append(maybe_mappable) + + self._finalize_grid(*args[:2]) + + return self + + +def _easy_facetgrid( + data: T_DataArrayOrSet, + plotfunc: Callable, + kind: Literal["line", "dataarray", "dataset", "plot1d"], + x: Hashable | None = None, + y: Hashable | None = None, + row: Hashable | None = None, + col: Hashable | None = None, + col_wrap: int | None = None, + sharex: bool = True, + sharey: bool = True, + aspect: float | None = None, + size: float | None = None, + subplot_kws: dict[str, Any] | None = None, + ax: Axes | None = None, + figsize: Iterable[float] | None = None, + **kwargs: Any, +) -> FacetGrid[T_DataArrayOrSet]: + """ + Convenience method to call xarray.plot.FacetGrid from 2d plotting methods + + kwargs are the arguments to 2d plotting method + """ + if ax is not None: + raise ValueError("Can't use axes when making faceted plots.") + if aspect is None: + aspect = 1 + if size is None: + size = 3 + elif figsize is not None: + raise ValueError("cannot provide both `figsize` and `size` arguments") + if kwargs.get("z") is not None: + # 3d plots doesn't support sharex, sharey, reset to mpl defaults: + sharex = False + sharey = False + + g = FacetGrid( + data=data, + col=col, + row=row, + col_wrap=col_wrap, + sharex=sharex, + sharey=sharey, + figsize=figsize, + aspect=aspect, + size=size, + subplot_kws=subplot_kws, + ) + + if kind == "line": + return g.map_dataarray_line(plotfunc, x, y, **kwargs) + + if kind == "dataarray": + return g.map_dataarray(plotfunc, x, y, **kwargs) + + if kind == "plot1d": + return g.map_plot1d(plotfunc, x, y, **kwargs) + + if kind == "dataset": + return g.map_dataset(plotfunc, x, y, **kwargs) + + raise ValueError( + f"kind must be one of `line`, `dataarray`, `dataset` or `plot1d`, got {kind}" + ) diff --git a/test/fixtures/whole_applications/xarray/xarray/plot/utils.py b/test/fixtures/whole_applications/xarray/xarray/plot/utils.py new file mode 100644 index 0000000..8789bc2 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/plot/utils.py @@ -0,0 +1,1847 @@ +from __future__ import annotations + +import itertools +import textwrap +import warnings +from collections.abc import Hashable, Iterable, Mapping, MutableMapping, Sequence +from datetime import date, datetime +from inspect import getfullargspec +from typing import TYPE_CHECKING, Any, Callable, Literal, overload + +import numpy as np +import pandas as pd + +from xarray.core.indexes import PandasMultiIndex +from xarray.core.options import OPTIONS +from xarray.core.utils import is_scalar, module_available +from xarray.namedarray.pycompat import DuckArrayModule + +nc_time_axis_available = module_available("nc_time_axis") + + +try: + import cftime +except ImportError: + cftime = None + + +if TYPE_CHECKING: + from matplotlib.axes import Axes + from matplotlib.colors import Normalize + from matplotlib.ticker import FuncFormatter + from numpy.typing import ArrayLike + + from xarray.core.dataarray import DataArray + from xarray.core.dataset import Dataset + from xarray.core.types import AspectOptions, ScaleOptions + + try: + import matplotlib.pyplot as plt + except ImportError: + plt: Any = None # type: ignore + +ROBUST_PERCENTILE = 2.0 + +# copied from seaborn +_MARKERSIZE_RANGE = (18.0, 36.0, 72.0) +_LINEWIDTH_RANGE = (1.5, 1.5, 6.0) + + +def _determine_extend(calc_data, vmin, vmax): + extend_min = calc_data.min() < vmin + extend_max = calc_data.max() > vmax + if extend_min and extend_max: + return "both" + elif extend_min: + return "min" + elif extend_max: + return "max" + else: + return "neither" + + +def _build_discrete_cmap(cmap, levels, extend, filled): + """ + Build a discrete colormap and normalization of the data. + """ + import matplotlib as mpl + + if len(levels) == 1: + levels = [levels[0], levels[0]] + + if not filled: + # non-filled contour plots + extend = "max" + + if extend == "both": + ext_n = 2 + elif extend in ["min", "max"]: + ext_n = 1 + else: + ext_n = 0 + + n_colors = len(levels) + ext_n - 1 + pal = _color_palette(cmap, n_colors) + + new_cmap, cnorm = mpl.colors.from_levels_and_colors(levels, pal, extend=extend) + # copy the old cmap name, for easier testing + new_cmap.name = getattr(cmap, "name", cmap) + + # copy colors to use for bad, under, and over values in case they have been + # set to non-default values + try: + # matplotlib<3.2 only uses bad color for masked values + bad = cmap(np.ma.masked_invalid([np.nan]))[0] + except TypeError: + # cmap was a str or list rather than a color-map object, so there are + # no bad, under or over values to check or copy + pass + else: + under = cmap(-np.inf) + over = cmap(np.inf) + + new_cmap.set_bad(bad) + + # Only update under and over if they were explicitly changed by the user + # (i.e. are different from the lowest or highest values in cmap). Otherwise + # leave unchanged so new_cmap uses its default values (its own lowest and + # highest values). + if under != cmap(0): + new_cmap.set_under(under) + if over != cmap(cmap.N - 1): + new_cmap.set_over(over) + + return new_cmap, cnorm + + +def _color_palette(cmap, n_colors): + import matplotlib.pyplot as plt + from matplotlib.colors import ListedColormap + + colors_i = np.linspace(0, 1.0, n_colors) + if isinstance(cmap, (list, tuple)): + # we have a list of colors + cmap = ListedColormap(cmap, N=n_colors) + pal = cmap(colors_i) + elif isinstance(cmap, str): + # we have some sort of named palette + try: + # is this a matplotlib cmap? + cmap = plt.get_cmap(cmap) + pal = cmap(colors_i) + except ValueError: + # ValueError happens when mpl doesn't like a colormap, try seaborn + try: + from seaborn import color_palette + + pal = color_palette(cmap, n_colors=n_colors) + except (ValueError, ImportError): + # or maybe we just got a single color as a string + cmap = ListedColormap([cmap], N=n_colors) + pal = cmap(colors_i) + else: + # cmap better be a LinearSegmentedColormap (e.g. viridis) + pal = cmap(colors_i) + + return pal + + +# _determine_cmap_params is adapted from Seaborn: +# https://github.com/mwaskom/seaborn/blob/v0.6/seaborn/matrix.py#L158 +# Used under the terms of Seaborn's license, see licenses/SEABORN_LICENSE. + + +def _determine_cmap_params( + plot_data, + vmin=None, + vmax=None, + cmap=None, + center=None, + robust=False, + extend=None, + levels=None, + filled=True, + norm=None, + _is_facetgrid=False, +): + """ + Use some heuristics to set good defaults for colorbar and range. + + Parameters + ---------- + plot_data : Numpy array + Doesn't handle xarray objects + + Returns + ------- + cmap_params : dict + Use depends on the type of the plotting function + """ + import matplotlib as mpl + + if isinstance(levels, Iterable): + levels = sorted(levels) + + calc_data = np.ravel(plot_data[np.isfinite(plot_data)]) + + # Handle all-NaN input data gracefully + if calc_data.size == 0: + # Arbitrary default for when all values are NaN + calc_data = np.array(0.0) + + # Setting center=False prevents a divergent cmap + possibly_divergent = center is not False + + # Set center to 0 so math below makes sense but remember its state + center_is_none = False + if center is None: + center = 0 + center_is_none = True + + # Setting both vmin and vmax prevents a divergent cmap + if (vmin is not None) and (vmax is not None): + possibly_divergent = False + + # Setting vmin or vmax implies linspaced levels + user_minmax = (vmin is not None) or (vmax is not None) + + # vlim might be computed below + vlim = None + + # save state; needed later + vmin_was_none = vmin is None + vmax_was_none = vmax is None + + if vmin is None: + if robust: + vmin = np.percentile(calc_data, ROBUST_PERCENTILE) + else: + vmin = calc_data.min() + elif possibly_divergent: + vlim = abs(vmin - center) + + if vmax is None: + if robust: + vmax = np.percentile(calc_data, 100 - ROBUST_PERCENTILE) + else: + vmax = calc_data.max() + elif possibly_divergent: + vlim = abs(vmax - center) + + if possibly_divergent: + levels_are_divergent = ( + isinstance(levels, Iterable) and levels[0] * levels[-1] < 0 + ) + # kwargs not specific about divergent or not: infer defaults from data + divergent = ( + ((vmin < 0) and (vmax > 0)) or not center_is_none or levels_are_divergent + ) + else: + divergent = False + + # A divergent map should be symmetric around the center value + if divergent: + if vlim is None: + vlim = max(abs(vmin - center), abs(vmax - center)) + vmin, vmax = -vlim, vlim + + # Now add in the centering value and set the limits + vmin += center + vmax += center + + # now check norm and harmonize with vmin, vmax + if norm is not None: + if norm.vmin is None: + norm.vmin = vmin + else: + if not vmin_was_none and vmin != norm.vmin: + raise ValueError("Cannot supply vmin and a norm with a different vmin.") + vmin = norm.vmin + + if norm.vmax is None: + norm.vmax = vmax + else: + if not vmax_was_none and vmax != norm.vmax: + raise ValueError("Cannot supply vmax and a norm with a different vmax.") + vmax = norm.vmax + + # if BoundaryNorm, then set levels + if isinstance(norm, mpl.colors.BoundaryNorm): + levels = norm.boundaries + + # Choose default colormaps if not provided + if cmap is None: + if divergent: + cmap = OPTIONS["cmap_divergent"] + else: + cmap = OPTIONS["cmap_sequential"] + + # Handle discrete levels + if levels is not None: + if is_scalar(levels): + if user_minmax: + levels = np.linspace(vmin, vmax, levels) + elif levels == 1: + levels = np.asarray([(vmin + vmax) / 2]) + else: + # N in MaxNLocator refers to bins, not ticks + ticker = mpl.ticker.MaxNLocator(levels - 1) + levels = ticker.tick_values(vmin, vmax) + vmin, vmax = levels[0], levels[-1] + + # GH3734 + if vmin == vmax: + vmin, vmax = mpl.ticker.LinearLocator(2).tick_values(vmin, vmax) + + if extend is None: + extend = _determine_extend(calc_data, vmin, vmax) + + if (levels is not None) and (not isinstance(norm, mpl.colors.BoundaryNorm)): + cmap, newnorm = _build_discrete_cmap(cmap, levels, extend, filled) + norm = newnorm if norm is None else norm + + # vmin & vmax needs to be None if norm is passed + # TODO: always return a norm with vmin and vmax + if norm is not None: + vmin = None + vmax = None + + return dict( + vmin=vmin, vmax=vmax, cmap=cmap, extend=extend, levels=levels, norm=norm + ) + + +def _infer_xy_labels_3d( + darray: DataArray | Dataset, + x: Hashable | None, + y: Hashable | None, + rgb: Hashable | None, +) -> tuple[Hashable, Hashable]: + """ + Determine x and y labels for showing RGB images. + + Attempts to infer which dimension is RGB/RGBA by size and order of dims. + + """ + assert rgb is None or rgb != x + assert rgb is None or rgb != y + # Start by detecting and reporting invalid combinations of arguments + assert darray.ndim == 3 + not_none = [a for a in (x, y, rgb) if a is not None] + if len(set(not_none)) < len(not_none): + raise ValueError( + "Dimension names must be None or unique strings, but imshow was " + f"passed x={x!r}, y={y!r}, and rgb={rgb!r}." + ) + for label in not_none: + if label not in darray.dims: + raise ValueError(f"{label!r} is not a dimension") + + # Then calculate rgb dimension if certain and check validity + could_be_color = [ + label + for label in darray.dims + if darray[label].size in (3, 4) and label not in (x, y) + ] + if rgb is None and not could_be_color: + raise ValueError( + "A 3-dimensional array was passed to imshow(), but there is no " + "dimension that could be color. At least one dimension must be " + "of size 3 (RGB) or 4 (RGBA), and not given as x or y." + ) + if rgb is None and len(could_be_color) == 1: + rgb = could_be_color[0] + if rgb is not None and darray[rgb].size not in (3, 4): + raise ValueError( + f"Cannot interpret dim {rgb!r} of size {darray[rgb].size} as RGB or RGBA." + ) + + # If rgb dimension is still unknown, there must be two or three dimensions + # in could_be_color. We therefore warn, and use a heuristic to break ties. + if rgb is None: + assert len(could_be_color) in (2, 3) + rgb = could_be_color[-1] + warnings.warn( + "Several dimensions of this array could be colors. Xarray " + f"will use the last possible dimension ({rgb!r}) to match " + "matplotlib.pyplot.imshow. You can pass names of x, y, " + "and/or rgb dimensions to override this guess." + ) + assert rgb is not None + + # Finally, we pick out the red slice and delegate to the 2D version: + return _infer_xy_labels(darray.isel({rgb: 0}), x, y) + + +def _infer_xy_labels( + darray: DataArray | Dataset, + x: Hashable | None, + y: Hashable | None, + imshow: bool = False, + rgb: Hashable | None = None, +) -> tuple[Hashable, Hashable]: + """ + Determine x and y labels. For use in _plot2d + + darray must be a 2 dimensional data array, or 3d for imshow only. + """ + if (x is not None) and (x == y): + raise ValueError("x and y cannot be equal.") + + if imshow and darray.ndim == 3: + return _infer_xy_labels_3d(darray, x, y, rgb) + + if x is None and y is None: + if darray.ndim != 2: + raise ValueError("DataArray must be 2d") + y, x = darray.dims + elif x is None: + _assert_valid_xy(darray, y, "y") + x = darray.dims[0] if y == darray.dims[1] else darray.dims[1] + elif y is None: + _assert_valid_xy(darray, x, "x") + y = darray.dims[0] if x == darray.dims[1] else darray.dims[1] + else: + _assert_valid_xy(darray, x, "x") + _assert_valid_xy(darray, y, "y") + + if darray._indexes.get(x, 1) is darray._indexes.get(y, 2): + if isinstance(darray._indexes[x], PandasMultiIndex): + raise ValueError("x and y cannot be levels of the same MultiIndex") + + return x, y + + +# TODO: Can by used to more than x or y, rename? +def _assert_valid_xy( + darray: DataArray | Dataset, xy: Hashable | None, name: str +) -> None: + """ + make sure x and y passed to plotting functions are valid + """ + + # MultiIndex cannot be plotted; no point in allowing them here + multiindex_dims = { + idx.dim + for idx in darray.xindexes.get_unique() + if isinstance(idx, PandasMultiIndex) + } + + valid_xy = (set(darray.dims) | set(darray.coords)) - multiindex_dims + + if (xy is not None) and (xy not in valid_xy): + valid_xy_str = "', '".join(sorted(tuple(str(v) for v in valid_xy))) + raise ValueError( + f"{name} must be one of None, '{valid_xy_str}'. Received '{xy}' instead." + ) + + +def get_axis( + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + **subplot_kws: Any, +) -> Axes: + try: + import matplotlib as mpl + import matplotlib.pyplot as plt + except ImportError: + raise ImportError("matplotlib is required for plot.utils.get_axis") + + if figsize is not None: + if ax is not None: + raise ValueError("cannot provide both `figsize` and `ax` arguments") + if size is not None: + raise ValueError("cannot provide both `figsize` and `size` arguments") + _, ax = plt.subplots(figsize=figsize, subplot_kw=subplot_kws) + return ax + + if size is not None: + if ax is not None: + raise ValueError("cannot provide both `size` and `ax` arguments") + if aspect is None or aspect == "auto": + width, height = mpl.rcParams["figure.figsize"] + faspect = width / height + elif aspect == "equal": + faspect = 1 + else: + faspect = aspect + figsize = (size * faspect, size) + _, ax = plt.subplots(figsize=figsize, subplot_kw=subplot_kws) + return ax + + if aspect is not None: + raise ValueError("cannot provide `aspect` argument without `size`") + + if subplot_kws and ax is not None: + raise ValueError("cannot use subplot_kws with existing ax") + + if ax is None: + ax = _maybe_gca(**subplot_kws) + + return ax + + +def _maybe_gca(**subplot_kws: Any) -> Axes: + import matplotlib.pyplot as plt + + # can call gcf unconditionally: either it exists or would be created by plt.axes + f = plt.gcf() + + # only call gca if an active axes exists + if f.axes: + # can not pass kwargs to active axes + return plt.gca() + + return plt.axes(**subplot_kws) + + +def _get_units_from_attrs(da: DataArray) -> str: + """Extracts and formats the unit/units from a attributes.""" + pint_array_type = DuckArrayModule("pint").type + units = " [{}]" + if isinstance(da.data, pint_array_type): + return units.format(str(da.data.units)) + if "units" in da.attrs: + return units.format(da.attrs["units"]) + if "unit" in da.attrs: + return units.format(da.attrs["unit"]) + return "" + + +def label_from_attrs(da: DataArray | None, extra: str = "") -> str: + """Makes informative labels if variable metadata (attrs) follows + CF conventions.""" + if da is None: + return "" + + name: str = "{}" + if "long_name" in da.attrs: + name = name.format(da.attrs["long_name"]) + elif "standard_name" in da.attrs: + name = name.format(da.attrs["standard_name"]) + elif da.name is not None: + name = name.format(da.name) + else: + name = "" + + units = _get_units_from_attrs(da) + + # Treat `name` differently if it's a latex sequence + if name.startswith("$") and (name.count("$") % 2 == 0): + return "$\n$".join( + textwrap.wrap(name + extra + units, 60, break_long_words=False) + ) + else: + return "\n".join(textwrap.wrap(name + extra + units, 30)) + + +def _interval_to_mid_points(array: Iterable[pd.Interval]) -> np.ndarray: + """ + Helper function which returns an array + with the Intervals' mid points. + """ + + return np.array([x.mid for x in array]) + + +def _interval_to_bound_points(array: Sequence[pd.Interval]) -> np.ndarray: + """ + Helper function which returns an array + with the Intervals' boundaries. + """ + + array_boundaries = np.array([x.left for x in array]) + array_boundaries = np.concatenate((array_boundaries, np.array([array[-1].right]))) + + return array_boundaries + + +def _interval_to_double_bound_points( + xarray: Iterable[pd.Interval], yarray: Iterable +) -> tuple[np.ndarray, np.ndarray]: + """ + Helper function to deal with a xarray consisting of pd.Intervals. Each + interval is replaced with both boundaries. I.e. the length of xarray + doubles. yarray is modified so it matches the new shape of xarray. + """ + + xarray1 = np.array([x.left for x in xarray]) + xarray2 = np.array([x.right for x in xarray]) + + xarray_out = np.array(list(itertools.chain.from_iterable(zip(xarray1, xarray2)))) + yarray_out = np.array(list(itertools.chain.from_iterable(zip(yarray, yarray)))) + + return xarray_out, yarray_out + + +def _resolve_intervals_1dplot( + xval: np.ndarray, yval: np.ndarray, kwargs: dict +) -> tuple[np.ndarray, np.ndarray, str, str, dict]: + """ + Helper function to replace the values of x and/or y coordinate arrays + containing pd.Interval with their mid-points or - for step plots - double + points which double the length. + """ + x_suffix = "" + y_suffix = "" + + # Is it a step plot? (see matplotlib.Axes.step) + if kwargs.get("drawstyle", "").startswith("steps-"): + remove_drawstyle = False + + # Convert intervals to double points + x_is_interval = _valid_other_type(xval, pd.Interval) + y_is_interval = _valid_other_type(yval, pd.Interval) + if x_is_interval and y_is_interval: + raise TypeError("Can't step plot intervals against intervals.") + elif x_is_interval: + xval, yval = _interval_to_double_bound_points(xval, yval) + remove_drawstyle = True + elif y_is_interval: + yval, xval = _interval_to_double_bound_points(yval, xval) + remove_drawstyle = True + + # Remove steps-* to be sure that matplotlib is not confused + if remove_drawstyle: + del kwargs["drawstyle"] + + # Is it another kind of plot? + else: + # Convert intervals to mid points and adjust labels + if _valid_other_type(xval, pd.Interval): + xval = _interval_to_mid_points(xval) + x_suffix = "_center" + if _valid_other_type(yval, pd.Interval): + yval = _interval_to_mid_points(yval) + y_suffix = "_center" + + # return converted arguments + return xval, yval, x_suffix, y_suffix, kwargs + + +def _resolve_intervals_2dplot(val, func_name): + """ + Helper function to replace the values of a coordinate array containing + pd.Interval with their mid-points or - for pcolormesh - boundaries which + increases length by 1. + """ + label_extra = "" + if _valid_other_type(val, pd.Interval): + if func_name == "pcolormesh": + val = _interval_to_bound_points(val) + else: + val = _interval_to_mid_points(val) + label_extra = "_center" + + return val, label_extra + + +def _valid_other_type( + x: ArrayLike, types: type[object] | tuple[type[object], ...] +) -> bool: + """ + Do all elements of x have a type from types? + """ + return all(isinstance(el, types) for el in np.ravel(x)) + + +def _valid_numpy_subdtype(x, numpy_types): + """ + Is any dtype from numpy_types superior to the dtype of x? + """ + # If any of the types given in numpy_types is understood as numpy.generic, + # all possible x will be considered valid. This is probably unwanted. + for t in numpy_types: + assert not np.issubdtype(np.generic, t) + + return any(np.issubdtype(x.dtype, t) for t in numpy_types) + + +def _ensure_plottable(*args) -> None: + """ + Raise exception if there is anything in args that can't be plotted on an + axis by matplotlib. + """ + numpy_types: tuple[type[object], ...] = ( + np.floating, + np.integer, + np.timedelta64, + np.datetime64, + np.bool_, + np.str_, + ) + other_types: tuple[type[object], ...] = (datetime, date) + cftime_datetime_types: tuple[type[object], ...] = ( + () if cftime is None else (cftime.datetime,) + ) + other_types += cftime_datetime_types + + for x in args: + if not ( + _valid_numpy_subdtype(np.asarray(x), numpy_types) + or _valid_other_type(np.asarray(x), other_types) + ): + raise TypeError( + "Plotting requires coordinates to be numeric, boolean, " + "or dates of type numpy.datetime64, " + "datetime.datetime, cftime.datetime or " + f"pandas.Interval. Received data of type {np.asarray(x).dtype} instead." + ) + if _valid_other_type(np.asarray(x), cftime_datetime_types): + if nc_time_axis_available: + # Register cftime datetypes to matplotlib.units.registry, + # otherwise matplotlib will raise an error: + import nc_time_axis # noqa: F401 + else: + raise ImportError( + "Plotting of arrays of cftime.datetime " + "objects or arrays indexed by " + "cftime.datetime objects requires the " + "optional `nc-time-axis` (v1.2.0 or later) " + "package." + ) + + +def _is_numeric(arr): + numpy_types = [np.floating, np.integer] + return _valid_numpy_subdtype(arr, numpy_types) + + +def _add_colorbar(primitive, ax, cbar_ax, cbar_kwargs, cmap_params): + cbar_kwargs.setdefault("extend", cmap_params["extend"]) + if cbar_ax is None: + cbar_kwargs.setdefault("ax", ax) + else: + cbar_kwargs.setdefault("cax", cbar_ax) + + # dont pass extend as kwarg if it is in the mappable + if hasattr(primitive, "extend"): + cbar_kwargs.pop("extend") + + fig = ax.get_figure() + cbar = fig.colorbar(primitive, **cbar_kwargs) + + return cbar + + +def _rescale_imshow_rgb(darray, vmin, vmax, robust): + assert robust or vmin is not None or vmax is not None + + # Calculate vmin and vmax automatically for `robust=True` + if robust: + if vmax is None: + vmax = np.nanpercentile(darray, 100 - ROBUST_PERCENTILE) + if vmin is None: + vmin = np.nanpercentile(darray, ROBUST_PERCENTILE) + # If not robust and one bound is None, calculate the default other bound + # and check that an interval between them exists. + elif vmax is None: + vmax = 255 if np.issubdtype(darray.dtype, np.integer) else 1 + if vmax < vmin: + raise ValueError( + f"vmin={vmin!r} is less than the default vmax ({vmax!r}) - you must supply " + "a vmax > vmin in this case." + ) + elif vmin is None: + vmin = 0 + if vmin > vmax: + raise ValueError( + f"vmax={vmax!r} is less than the default vmin (0) - you must supply " + "a vmin < vmax in this case." + ) + # Scale interval [vmin .. vmax] to [0 .. 1], with darray as 64-bit float + # to avoid precision loss, integer over/underflow, etc with extreme inputs. + # After scaling, downcast to 32-bit float. This substantially reduces + # memory usage after we hand `darray` off to matplotlib. + darray = ((darray.astype("f8") - vmin) / (vmax - vmin)).astype("f4") + return np.minimum(np.maximum(darray, 0), 1) + + +def _update_axes( + ax: Axes, + xincrease: bool | None, + yincrease: bool | None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, +) -> None: + """ + Update axes with provided parameters + """ + if xincrease is None: + pass + elif xincrease and ax.xaxis_inverted(): + ax.invert_xaxis() + elif not xincrease and not ax.xaxis_inverted(): + ax.invert_xaxis() + + if yincrease is None: + pass + elif yincrease and ax.yaxis_inverted(): + ax.invert_yaxis() + elif not yincrease and not ax.yaxis_inverted(): + ax.invert_yaxis() + + # The default xscale, yscale needs to be None. + # If we set a scale it resets the axes formatters, + # This means that set_xscale('linear') on a datetime axis + # will remove the date labels. So only set the scale when explicitly + # asked to. https://github.com/matplotlib/matplotlib/issues/8740 + if xscale is not None: + ax.set_xscale(xscale) + if yscale is not None: + ax.set_yscale(yscale) + + if xticks is not None: + ax.set_xticks(xticks) + if yticks is not None: + ax.set_yticks(yticks) + + if xlim is not None: + ax.set_xlim(xlim) + if ylim is not None: + ax.set_ylim(ylim) + + +def _is_monotonic(coord, axis=0): + """ + >>> _is_monotonic(np.array([0, 1, 2])) + True + >>> _is_monotonic(np.array([2, 1, 0])) + True + >>> _is_monotonic(np.array([0, 2, 1])) + False + """ + if coord.shape[axis] < 3: + return True + else: + n = coord.shape[axis] + delta_pos = coord.take(np.arange(1, n), axis=axis) >= coord.take( + np.arange(0, n - 1), axis=axis + ) + delta_neg = coord.take(np.arange(1, n), axis=axis) <= coord.take( + np.arange(0, n - 1), axis=axis + ) + return np.all(delta_pos) or np.all(delta_neg) + + +def _infer_interval_breaks(coord, axis=0, scale=None, check_monotonic=False): + """ + >>> _infer_interval_breaks(np.arange(5)) + array([-0.5, 0.5, 1.5, 2.5, 3.5, 4.5]) + >>> _infer_interval_breaks([[0, 1], [3, 4]], axis=1) + array([[-0.5, 0.5, 1.5], + [ 2.5, 3.5, 4.5]]) + >>> _infer_interval_breaks(np.logspace(-2, 2, 5), scale="log") + array([3.16227766e-03, 3.16227766e-02, 3.16227766e-01, 3.16227766e+00, + 3.16227766e+01, 3.16227766e+02]) + """ + coord = np.asarray(coord) + + if check_monotonic and not _is_monotonic(coord, axis=axis): + raise ValueError( + "The input coordinate is not sorted in increasing " + "order along axis %d. This can lead to unexpected " + "results. Consider calling the `sortby` method on " + "the input DataArray. To plot data with categorical " + "axes, consider using the `heatmap` function from " + "the `seaborn` statistical plotting library." % axis + ) + + # If logscale, compute the intervals in the logarithmic space + if scale == "log": + if (coord <= 0).any(): + raise ValueError( + "Found negative or zero value in coordinates. " + + "Coordinates must be positive on logscale plots." + ) + coord = np.log10(coord) + + deltas = 0.5 * np.diff(coord, axis=axis) + if deltas.size == 0: + deltas = np.array(0.0) + first = np.take(coord, [0], axis=axis) - np.take(deltas, [0], axis=axis) + last = np.take(coord, [-1], axis=axis) + np.take(deltas, [-1], axis=axis) + trim_last = tuple( + slice(None, -1) if n == axis else slice(None) for n in range(coord.ndim) + ) + interval_breaks = np.concatenate( + [first, coord[trim_last] + deltas, last], axis=axis + ) + if scale == "log": + # Recovert the intervals into the linear space + return np.power(10, interval_breaks) + return interval_breaks + + +def _process_cmap_cbar_kwargs( + func, + data, + cmap=None, + colors=None, + cbar_kwargs: Iterable[tuple[str, Any]] | Mapping[str, Any] | None = None, + levels=None, + _is_facetgrid=False, + **kwargs, +) -> tuple[dict[str, Any], dict[str, Any]]: + """ + Parameters + ---------- + func : plotting function + data : ndarray, + Data values + + Returns + ------- + cmap_params : dict + cbar_kwargs : dict + """ + if func.__name__ == "surface": + # Leave user to specify cmap settings for surface plots + kwargs["cmap"] = cmap + return { + k: kwargs.get(k, None) + for k in ["vmin", "vmax", "cmap", "extend", "levels", "norm"] + }, {} + + cbar_kwargs = {} if cbar_kwargs is None else dict(cbar_kwargs) + + if "contour" in func.__name__ and levels is None: + levels = 7 # this is the matplotlib default + + # colors is mutually exclusive with cmap + if cmap and colors: + raise ValueError("Can't specify both cmap and colors.") + + # colors is only valid when levels is supplied or the plot is of type + # contour or contourf + if colors and (("contour" not in func.__name__) and (levels is None)): + raise ValueError("Can only specify colors with contour or levels") + + # we should not be getting a list of colors in cmap anymore + # is there a better way to do this test? + if isinstance(cmap, (list, tuple)): + raise ValueError( + "Specifying a list of colors in cmap is deprecated. " + "Use colors keyword instead." + ) + + cmap_kwargs = { + "plot_data": data, + "levels": levels, + "cmap": colors if colors else cmap, + "filled": func.__name__ != "contour", + } + + cmap_args = getfullargspec(_determine_cmap_params).args + cmap_kwargs.update((a, kwargs[a]) for a in cmap_args if a in kwargs) + if not _is_facetgrid: + cmap_params = _determine_cmap_params(**cmap_kwargs) + else: + cmap_params = { + k: cmap_kwargs[k] + for k in ["vmin", "vmax", "cmap", "extend", "levels", "norm"] + } + + return cmap_params, cbar_kwargs + + +def _get_nice_quiver_magnitude(u, v): + import matplotlib as mpl + + ticker = mpl.ticker.MaxNLocator(3) + mean = np.mean(np.hypot(u.to_numpy(), v.to_numpy())) + magnitude = ticker.tick_values(0, mean)[-2] + return magnitude + + +# Copied from matplotlib, tweaked so func can return strings. +# https://github.com/matplotlib/matplotlib/issues/19555 +def legend_elements( + self, prop="colors", num="auto", fmt=None, func=lambda x: x, **kwargs +): + """ + Create legend handles and labels for a PathCollection. + + Each legend handle is a `.Line2D` representing the Path that was drawn, + and each label is a string what each Path represents. + + This is useful for obtaining a legend for a `~.Axes.scatter` plot; + e.g.:: + + scatter = plt.scatter([1, 2, 3], [4, 5, 6], c=[7, 2, 3]) + plt.legend(*scatter.legend_elements()) + + creates three legend elements, one for each color with the numerical + values passed to *c* as the labels. + + Also see the :ref:`automatedlegendcreation` example. + + + Parameters + ---------- + prop : {"colors", "sizes"}, default: "colors" + If "colors", the legend handles will show the different colors of + the collection. If "sizes", the legend will show the different + sizes. To set both, use *kwargs* to directly edit the `.Line2D` + properties. + num : int, None, "auto" (default), array-like, or `~.ticker.Locator` + Target number of elements to create. + If None, use all unique elements of the mappable array. If an + integer, target to use *num* elements in the normed range. + If *"auto"*, try to determine which option better suits the nature + of the data. + The number of created elements may slightly deviate from *num* due + to a `~.ticker.Locator` being used to find useful locations. + If a list or array, use exactly those elements for the legend. + Finally, a `~.ticker.Locator` can be provided. + fmt : str, `~matplotlib.ticker.Formatter`, or None (default) + The format or formatter to use for the labels. If a string must be + a valid input for a `~.StrMethodFormatter`. If None (the default), + use a `~.ScalarFormatter`. + func : function, default: ``lambda x: x`` + Function to calculate the labels. Often the size (or color) + argument to `~.Axes.scatter` will have been pre-processed by the + user using a function ``s = f(x)`` to make the markers visible; + e.g. ``size = np.log10(x)``. Providing the inverse of this + function here allows that pre-processing to be inverted, so that + the legend labels have the correct values; e.g. ``func = lambda + x: 10**x``. + **kwargs + Allowed keyword arguments are *color* and *size*. E.g. it may be + useful to set the color of the markers if *prop="sizes"* is used; + similarly to set the size of the markers if *prop="colors"* is + used. Any further parameters are passed onto the `.Line2D` + instance. This may be useful to e.g. specify a different + *markeredgecolor* or *alpha* for the legend handles. + + Returns + ------- + handles : list of `.Line2D` + Visual representation of each element of the legend. + labels : list of str + The string labels for elements of the legend. + """ + import warnings + + import matplotlib as mpl + + mlines = mpl.lines + + handles = [] + labels = [] + + if prop == "colors": + arr = self.get_array() + if arr is None: + warnings.warn( + "Collection without array used. Make sure to " + "specify the values to be colormapped via the " + "`c` argument." + ) + return handles, labels + _size = kwargs.pop("size", mpl.rcParams["lines.markersize"]) + + def _get_color_and_size(value): + return self.cmap(self.norm(value)), _size + + elif prop == "sizes": + if isinstance(self, mpl.collections.LineCollection): + arr = self.get_linewidths() + else: + arr = self.get_sizes() + _color = kwargs.pop("color", "k") + + def _get_color_and_size(value): + return _color, np.sqrt(value) + + else: + raise ValueError( + "Valid values for `prop` are 'colors' or " + f"'sizes'. You supplied '{prop}' instead." + ) + + # Get the unique values and their labels: + values = np.unique(arr) + label_values = np.asarray(func(values)) + label_values_are_numeric = np.issubdtype(label_values.dtype, np.number) + + # Handle the label format: + if fmt is None and label_values_are_numeric: + fmt = mpl.ticker.ScalarFormatter(useOffset=False, useMathText=True) + elif fmt is None and not label_values_are_numeric: + fmt = mpl.ticker.StrMethodFormatter("{x}") + elif isinstance(fmt, str): + fmt = mpl.ticker.StrMethodFormatter(fmt) + fmt.create_dummy_axis() + + if num == "auto": + num = 9 + if len(values) <= num: + num = None + + if label_values_are_numeric: + label_values_min = label_values.min() + label_values_max = label_values.max() + fmt.axis.set_view_interval(label_values_min, label_values_max) + fmt.axis.set_data_interval(label_values_min, label_values_max) + + if num is not None: + # Labels are numerical but larger than the target + # number of elements, reduce to target using matplotlibs + # ticker classes: + if isinstance(num, mpl.ticker.Locator): + loc = num + elif np.iterable(num): + loc = mpl.ticker.FixedLocator(num) + else: + num = int(num) + loc = mpl.ticker.MaxNLocator( + nbins=num, min_n_ticks=num - 1, steps=[1, 2, 2.5, 3, 5, 6, 8, 10] + ) + + # Get nicely spaced label_values: + label_values = loc.tick_values(label_values_min, label_values_max) + + # Remove extrapolated label_values: + cond = (label_values >= label_values_min) & ( + label_values <= label_values_max + ) + label_values = label_values[cond] + + # Get the corresponding values by creating a linear interpolant + # with small step size: + values_interp = np.linspace(values.min(), values.max(), 256) + label_values_interp = func(values_interp) + ix = np.argsort(label_values_interp) + values = np.interp(label_values, label_values_interp[ix], values_interp[ix]) + elif num is not None and not label_values_are_numeric: + # Labels are not numerical so modifying label_values is not + # possible, instead filter the array with nicely distributed + # indexes: + if type(num) == int: # noqa: E721 + loc = mpl.ticker.LinearLocator(num) + else: + raise ValueError("`num` only supports integers for non-numeric labels.") + + ind = loc.tick_values(0, len(label_values) - 1).astype(int) + label_values = label_values[ind] + values = values[ind] + + # Some formatters requires set_locs: + if hasattr(fmt, "set_locs"): + fmt.set_locs(label_values) + + # Default settings for handles, add or override with kwargs: + kw = dict(markeredgewidth=self.get_linewidths()[0], alpha=self.get_alpha()) + kw.update(kwargs) + + for val, lab in zip(values, label_values): + color, size = _get_color_and_size(val) + + if isinstance(self, mpl.collections.PathCollection): + kw.update(linestyle="", marker=self.get_paths()[0], markersize=size) + elif isinstance(self, mpl.collections.LineCollection): + kw.update(linestyle=self.get_linestyle()[0], linewidth=size) + + h = mlines.Line2D([0], [0], color=color, **kw) + + handles.append(h) + labels.append(fmt(lab)) + + return handles, labels + + +def _legend_add_subtitle(handles, labels, text): + """Add a subtitle to legend handles.""" + import matplotlib.pyplot as plt + + if text and len(handles) > 1: + # Create a blank handle that's not visible, the + # invisibillity will be used to discern which are subtitles + # or not: + blank_handle = plt.Line2D([], [], label=text) + blank_handle.set_visible(False) + + # Subtitles are shown first: + handles = [blank_handle] + handles + labels = [text] + labels + + return handles, labels + + +def _adjust_legend_subtitles(legend): + """Make invisible-handle "subtitles" entries look more like titles.""" + import matplotlib.pyplot as plt + + # Legend title not in rcParams until 3.0 + font_size = plt.rcParams.get("legend.title_fontsize", None) + hpackers = legend.findobj(plt.matplotlib.offsetbox.VPacker)[0].get_children() + hpackers = [v for v in hpackers if isinstance(v, plt.matplotlib.offsetbox.HPacker)] + for hpack in hpackers: + areas = hpack.get_children() + if len(areas) < 2: + continue + draw_area, text_area = areas + + handles = draw_area.get_children() + + # Assume that all artists that are not visible are + # subtitles: + if not all(artist.get_visible() for artist in handles): + # Remove the dummy marker which will bring the text + # more to the center: + draw_area.set_width(0) + for text in text_area.get_children(): + if font_size is not None: + # The sutbtitles should have the same font size + # as normal legend titles: + text.set_size(font_size) + + +def _infer_meta_data(ds, x, y, hue, hue_style, add_guide, funcname): + dvars = set(ds.variables.keys()) + error_msg = f" must be one of ({', '.join(sorted(tuple(str(v) for v in dvars)))})" + + if x not in dvars: + raise ValueError(f"Expected 'x' {error_msg}. Received {x} instead.") + + if y not in dvars: + raise ValueError(f"Expected 'y' {error_msg}. Received {y} instead.") + + if hue is not None and hue not in dvars: + raise ValueError(f"Expected 'hue' {error_msg}. Received {hue} instead.") + + if hue: + hue_is_numeric = _is_numeric(ds[hue].values) + + if hue_style is None: + hue_style = "continuous" if hue_is_numeric else "discrete" + + if not hue_is_numeric and (hue_style == "continuous"): + raise ValueError( + f"Cannot create a colorbar for a non numeric coordinate: {hue}" + ) + + if add_guide is None or add_guide is True: + add_colorbar = True if hue_style == "continuous" else False + add_legend = True if hue_style == "discrete" else False + else: + add_colorbar = False + add_legend = False + else: + if add_guide is True and funcname not in ("quiver", "streamplot"): + raise ValueError("Cannot set add_guide when hue is None.") + add_legend = False + add_colorbar = False + + if (add_guide or add_guide is None) and funcname == "quiver": + add_quiverkey = True + if hue: + add_colorbar = True + if not hue_style: + hue_style = "continuous" + elif hue_style != "continuous": + raise ValueError( + "hue_style must be 'continuous' or None for .plot.quiver or " + ".plot.streamplot" + ) + else: + add_quiverkey = False + + if (add_guide or add_guide is None) and funcname == "streamplot": + if hue: + add_colorbar = True + if not hue_style: + hue_style = "continuous" + elif hue_style != "continuous": + raise ValueError( + "hue_style must be 'continuous' or None for .plot.quiver or " + ".plot.streamplot" + ) + + if hue_style is not None and hue_style not in ["discrete", "continuous"]: + raise ValueError("hue_style must be either None, 'discrete' or 'continuous'.") + + if hue: + hue_label = label_from_attrs(ds[hue]) + hue = ds[hue] + else: + hue_label = None + hue = None + + return { + "add_colorbar": add_colorbar, + "add_legend": add_legend, + "add_quiverkey": add_quiverkey, + "hue_label": hue_label, + "hue_style": hue_style, + "xlabel": label_from_attrs(ds[x]), + "ylabel": label_from_attrs(ds[y]), + "hue": hue, + } + + +@overload +def _parse_size( + data: None, + norm: tuple[float | None, float | None, bool] | Normalize | None, +) -> None: ... + + +@overload +def _parse_size( + data: DataArray, + norm: tuple[float | None, float | None, bool] | Normalize | None, +) -> pd.Series: ... + + +# copied from seaborn +def _parse_size( + data: DataArray | None, + norm: tuple[float | None, float | None, bool] | Normalize | None, +) -> None | pd.Series: + import matplotlib as mpl + + if data is None: + return None + + flatdata = data.values.flatten() + + if not _is_numeric(flatdata): + levels = np.unique(flatdata) + numbers = np.arange(1, 1 + len(levels))[::-1] + else: + levels = numbers = np.sort(np.unique(flatdata)) + + min_width, default_width, max_width = _MARKERSIZE_RANGE + # width_range = min_width, max_width + + if norm is None: + norm = mpl.colors.Normalize() + elif isinstance(norm, tuple): + norm = mpl.colors.Normalize(*norm) + elif not isinstance(norm, mpl.colors.Normalize): + err = "``size_norm`` must be None, tuple, or Normalize object." + raise ValueError(err) + assert isinstance(norm, mpl.colors.Normalize) + + norm.clip = True + if not norm.scaled(): + norm(np.asarray(numbers)) + # limits = norm.vmin, norm.vmax + + scl = norm(numbers) + widths = np.asarray(min_width + scl * (max_width - min_width)) + if scl.mask.any(): + widths[scl.mask] = 0 + sizes = dict(zip(levels, widths)) + + return pd.Series(sizes) + + +class _Normalize(Sequence): + """ + Normalize numerical or categorical values to numerical values. + + The class includes helper methods that simplifies transforming to + and from normalized values. + + Parameters + ---------- + data : DataArray + DataArray to normalize. + width : Sequence of three numbers, optional + Normalize the data to these (min, default, max) values. + The default is None. + """ + + _data: DataArray | None + _data_unique: np.ndarray + _data_unique_index: np.ndarray + _data_unique_inverse: np.ndarray + _data_is_numeric: bool + _width: tuple[float, float, float] | None + + __slots__ = ( + "_data", + "_data_unique", + "_data_unique_index", + "_data_unique_inverse", + "_data_is_numeric", + "_width", + ) + + def __init__( + self, + data: DataArray | None, + width: tuple[float, float, float] | None = None, + _is_facetgrid: bool = False, + ) -> None: + self._data = data + self._width = width if not _is_facetgrid else None + + pint_array_type = DuckArrayModule("pint").type + to_unique = ( + data.to_numpy() # type: ignore[union-attr] + if isinstance(data if data is None else data.data, pint_array_type) + else data + ) + data_unique, data_unique_inverse = np.unique(to_unique, return_inverse=True) # type: ignore[call-overload] + self._data_unique = data_unique + self._data_unique_index = np.arange(0, data_unique.size) + self._data_unique_inverse = data_unique_inverse + self._data_is_numeric = False if data is None else _is_numeric(data) + + def __repr__(self) -> str: + with np.printoptions(precision=4, suppress=True, threshold=5): + return ( + f"<_Normalize(data, width={self._width})>\n" + f"{self._data_unique} -> {self._values_unique}" + ) + + def __len__(self) -> int: + return len(self._data_unique) + + def __getitem__(self, key): + return self._data_unique[key] + + @property + def data(self) -> DataArray | None: + return self._data + + @property + def data_is_numeric(self) -> bool: + """ + Check if data is numeric. + + Examples + -------- + >>> a = xr.DataArray(["b", "a", "a", "b", "c"]) + >>> _Normalize(a).data_is_numeric + False + + >>> a = xr.DataArray([0.5, 0, 0, 0.5, 2, 3]) + >>> _Normalize(a).data_is_numeric + True + + >>> # TODO: Datetime should be numeric right? + >>> a = xr.DataArray(pd.date_range("2000-1-1", periods=4)) + >>> _Normalize(a).data_is_numeric + False + + # TODO: Timedelta should be numeric right? + >>> a = xr.DataArray(pd.timedelta_range("-1D", periods=4, freq="D")) + >>> _Normalize(a).data_is_numeric + True + """ + return self._data_is_numeric + + @overload + def _calc_widths(self, y: np.ndarray) -> np.ndarray: ... + + @overload + def _calc_widths(self, y: DataArray) -> DataArray: ... + + def _calc_widths(self, y: np.ndarray | DataArray) -> np.ndarray | DataArray: + """ + Normalize the values so they're in between self._width. + """ + if self._width is None: + return y + + xmin, xdefault, xmax = self._width + + diff_maxy_miny = np.max(y) - np.min(y) + if diff_maxy_miny == 0: + # Use default with if y is constant: + widths = xdefault + 0 * y + else: + # Normalize in between xmin and xmax: + k = (y - np.min(y)) / diff_maxy_miny + widths = xmin + k * (xmax - xmin) + return widths + + @overload + def _indexes_centered(self, x: np.ndarray) -> np.ndarray: ... + + @overload + def _indexes_centered(self, x: DataArray) -> DataArray: ... + + def _indexes_centered(self, x: np.ndarray | DataArray) -> np.ndarray | DataArray: + """ + Offset indexes to make sure being in the center of self.levels. + ["a", "b", "c"] -> [1, 3, 5] + """ + return x * 2 + 1 + + @property + def values(self) -> DataArray | None: + """ + Return a normalized number array for the unique levels. + + Examples + -------- + >>> a = xr.DataArray(["b", "a", "a", "b", "c"]) + >>> _Normalize(a).values + Size: 40B + array([3, 1, 1, 3, 5]) + Dimensions without coordinates: dim_0 + + >>> _Normalize(a, width=(18, 36, 72)).values + Size: 40B + array([45., 18., 18., 45., 72.]) + Dimensions without coordinates: dim_0 + + >>> a = xr.DataArray([0.5, 0, 0, 0.5, 2, 3]) + >>> _Normalize(a).values + Size: 48B + array([0.5, 0. , 0. , 0.5, 2. , 3. ]) + Dimensions without coordinates: dim_0 + + >>> _Normalize(a, width=(18, 36, 72)).values + Size: 48B + array([27., 18., 18., 27., 54., 72.]) + Dimensions without coordinates: dim_0 + + >>> _Normalize(a * 0, width=(18, 36, 72)).values + Size: 48B + array([36., 36., 36., 36., 36., 36.]) + Dimensions without coordinates: dim_0 + + """ + if self.data is None: + return None + + val: DataArray + if self.data_is_numeric: + val = self.data + else: + arr = self._indexes_centered(self._data_unique_inverse) + val = self.data.copy(data=arr.reshape(self.data.shape)) + + return self._calc_widths(val) + + @property + def _values_unique(self) -> np.ndarray | None: + """ + Return unique values. + + Examples + -------- + >>> a = xr.DataArray(["b", "a", "a", "b", "c"]) + >>> _Normalize(a)._values_unique + array([1, 3, 5]) + + >>> _Normalize(a, width=(18, 36, 72))._values_unique + array([18., 45., 72.]) + + >>> a = xr.DataArray([0.5, 0, 0, 0.5, 2, 3]) + >>> _Normalize(a)._values_unique + array([0. , 0.5, 2. , 3. ]) + + >>> _Normalize(a, width=(18, 36, 72))._values_unique + array([18., 27., 54., 72.]) + """ + if self.data is None: + return None + + val: np.ndarray + if self.data_is_numeric: + val = self._data_unique + else: + val = self._indexes_centered(self._data_unique_index) + + return self._calc_widths(val) + + @property + def ticks(self) -> np.ndarray | None: + """ + Return ticks for plt.colorbar if the data is not numeric. + + Examples + -------- + >>> a = xr.DataArray(["b", "a", "a", "b", "c"]) + >>> _Normalize(a).ticks + array([1, 3, 5]) + """ + val: None | np.ndarray + if self.data_is_numeric: + val = None + else: + val = self._indexes_centered(self._data_unique_index) + + return val + + @property + def levels(self) -> np.ndarray: + """ + Return discrete levels that will evenly bound self.values. + ["a", "b", "c"] -> [0, 2, 4, 6] + + Examples + -------- + >>> a = xr.DataArray(["b", "a", "a", "b", "c"]) + >>> _Normalize(a).levels + array([0, 2, 4, 6]) + """ + return ( + np.append(self._data_unique_index, np.max(self._data_unique_index) + 1) * 2 + ) + + @property + def _lookup(self) -> pd.Series: + if self._values_unique is None: + raise ValueError("self.data can't be None.") + + return pd.Series(dict(zip(self._values_unique, self._data_unique))) + + def _lookup_arr(self, x) -> np.ndarray: + # Use reindex to be less sensitive to float errors. reindex only + # works with sorted index. + # Return as numpy array since legend_elements + # seems to require that: + return self._lookup.sort_index().reindex(x, method="nearest").to_numpy() + + @property + def format(self) -> FuncFormatter: + """ + Return a FuncFormatter that maps self.values elements back to + the original value as a string. Useful with plt.colorbar. + + Examples + -------- + >>> a = xr.DataArray([0.5, 0, 0, 0.5, 2, 3]) + >>> aa = _Normalize(a, width=(0, 0.5, 1)) + >>> aa._lookup + 0.000000 0.0 + 0.166667 0.5 + 0.666667 2.0 + 1.000000 3.0 + dtype: float64 + >>> aa.format(1) + '3.0' + """ + import matplotlib.pyplot as plt + + def _func(x: Any, pos: None | Any = None): + return f"{self._lookup_arr([x])[0]}" + + return plt.FuncFormatter(_func) + + @property + def func(self) -> Callable[[Any, None | Any], Any]: + """ + Return a lambda function that maps self.values elements back to + the original value as a numpy array. Useful with ax.legend_elements. + + Examples + -------- + >>> a = xr.DataArray([0.5, 0, 0, 0.5, 2, 3]) + >>> aa = _Normalize(a, width=(0, 0.5, 1)) + >>> aa._lookup + 0.000000 0.0 + 0.166667 0.5 + 0.666667 2.0 + 1.000000 3.0 + dtype: float64 + >>> aa.func([0.16, 1]) + array([0.5, 3. ]) + """ + + def _func(x: Any, pos: None | Any = None): + return self._lookup_arr(x) + + return _func + + +def _determine_guide( + hueplt_norm: _Normalize, + sizeplt_norm: _Normalize, + add_colorbar: None | bool = None, + add_legend: None | bool = None, + plotfunc_name: str | None = None, +) -> tuple[bool, bool]: + if plotfunc_name == "hist": + return False, False + + if (add_colorbar) and hueplt_norm.data is None: + raise KeyError("Cannot create a colorbar when hue is None.") + if add_colorbar is None: + if hueplt_norm.data is not None: + add_colorbar = True + else: + add_colorbar = False + + if add_legend and hueplt_norm.data is None and sizeplt_norm.data is None: + raise KeyError("Cannot create a legend when hue and markersize is None.") + if add_legend is None: + if ( + not add_colorbar + and (hueplt_norm.data is not None and hueplt_norm.data_is_numeric is False) + or sizeplt_norm.data is not None + ): + add_legend = True + else: + add_legend = False + + return add_colorbar, add_legend + + +def _add_legend( + hueplt_norm: _Normalize, + sizeplt_norm: _Normalize, + primitive, + legend_ax, + plotfunc: str, +): + primitive = primitive if isinstance(primitive, list) else [primitive] + + handles, labels = [], [] + for huesizeplt, prop in [ + (hueplt_norm, "colors"), + (sizeplt_norm, "sizes"), + ]: + if huesizeplt.data is not None: + # Get legend handles and labels that displays the + # values correctly. Order might be different because + # legend_elements uses np.unique instead of pd.unique, + # FacetGrid.add_legend might have troubles with this: + hdl, lbl = [], [] + for p in primitive: + hdl_, lbl_ = legend_elements(p, prop, num="auto", func=huesizeplt.func) + hdl += hdl_ + lbl += lbl_ + + # Only save unique values: + u, ind = np.unique(lbl, return_index=True) + ind = np.argsort(ind) + lbl = u[ind].tolist() + hdl = np.array(hdl)[ind].tolist() + + # Add a subtitle: + hdl, lbl = _legend_add_subtitle(hdl, lbl, label_from_attrs(huesizeplt.data)) + handles += hdl + labels += lbl + legend = legend_ax.legend(handles, labels, framealpha=0.5) + _adjust_legend_subtitles(legend) + + return legend + + +def _guess_coords_to_plot( + darray: DataArray, + coords_to_plot: MutableMapping[str, Hashable | None], + kwargs: dict, + default_guess: tuple[str, ...] = ("x",), + # TODO: Can this be normalized, plt.cbook.normalize_kwargs? + ignore_guess_kwargs: tuple[tuple[str, ...], ...] = ((),), +) -> MutableMapping[str, Hashable]: + """ + Guess what coords to plot if some of the values in coords_to_plot are None which + happens when the user has not defined all available ways of visualizing + the data. + + Parameters + ---------- + darray : DataArray + The DataArray to check for available coords. + coords_to_plot : MutableMapping[str, Hashable] + Coords defined by the user to plot. + kwargs : dict + Extra kwargs that will be sent to matplotlib. + default_guess : Iterable[str], optional + Default values and order to retrieve dims if values in dims_plot is + missing, default: ("x", "hue", "size"). + ignore_guess_kwargs : tuple[tuple[str, ...], ...] + Matplotlib arguments to ignore. + + Examples + -------- + >>> ds = xr.tutorial.scatter_example_dataset(seed=42) + >>> # Only guess x by default: + >>> xr.plot.utils._guess_coords_to_plot( + ... ds.A, + ... coords_to_plot={"x": None, "z": None, "hue": None, "size": None}, + ... kwargs={}, + ... ) + {'x': 'x', 'z': None, 'hue': None, 'size': None} + + >>> # Guess all plot dims with other default values: + >>> xr.plot.utils._guess_coords_to_plot( + ... ds.A, + ... coords_to_plot={"x": None, "z": None, "hue": None, "size": None}, + ... kwargs={}, + ... default_guess=("x", "hue", "size"), + ... ignore_guess_kwargs=((), ("c", "color"), ("s",)), + ... ) + {'x': 'x', 'z': None, 'hue': 'y', 'size': 'z'} + + >>> # Don't guess ´size´, since the matplotlib kwarg ´s´ has been defined: + >>> xr.plot.utils._guess_coords_to_plot( + ... ds.A, + ... coords_to_plot={"x": None, "z": None, "hue": None, "size": None}, + ... kwargs={"s": 5}, + ... default_guess=("x", "hue", "size"), + ... ignore_guess_kwargs=((), ("c", "color"), ("s",)), + ... ) + {'x': 'x', 'z': None, 'hue': 'y', 'size': None} + + >>> # Prioritize ´size´ over ´s´: + >>> xr.plot.utils._guess_coords_to_plot( + ... ds.A, + ... coords_to_plot={"x": None, "z": None, "hue": None, "size": "x"}, + ... kwargs={"s": 5}, + ... default_guess=("x", "hue", "size"), + ... ignore_guess_kwargs=((), ("c", "color"), ("s",)), + ... ) + {'x': 'y', 'z': None, 'hue': 'z', 'size': 'x'} + """ + coords_to_plot_exist = {k: v for k, v in coords_to_plot.items() if v is not None} + available_coords = tuple( + k for k in darray.coords.keys() if k not in coords_to_plot_exist.values() + ) + + # If dims_plot[k] isn't defined then fill with one of the available dims, unless + # one of related mpl kwargs has been used. This should have similar behaviour as + # * plt.plot(x, y) -> Multiple lines with different colors if y is 2d. + # * plt.plot(x, y, color="red") -> Multiple red lines if y is 2d. + for k, dim, ign_kws in zip(default_guess, available_coords, ignore_guess_kwargs): + if coords_to_plot.get(k, None) is None and all( + kwargs.get(ign_kw, None) is None for ign_kw in ign_kws + ): + coords_to_plot[k] = dim + + for k, dim in coords_to_plot.items(): + _assert_valid_xy(darray, dim, k) + + return coords_to_plot + + +def _set_concise_date(ax: Axes, axis: Literal["x", "y", "z"] = "x") -> None: + """ + Use ConciseDateFormatter which is meant to improve the + strings chosen for the ticklabels, and to minimize the + strings used in those tick labels as much as possible. + + https://matplotlib.org/stable/gallery/ticks/date_concise_formatter.html + + Parameters + ---------- + ax : Axes + Figure axes. + axis : Literal["x", "y", "z"], optional + Which axis to make concise. The default is "x". + """ + import matplotlib.dates as mdates + + locator = mdates.AutoDateLocator() + formatter = mdates.ConciseDateFormatter(locator) + _axis = getattr(ax, f"{axis}axis") + _axis.set_major_locator(locator) + _axis.set_major_formatter(formatter) diff --git a/test/fixtures/whole_applications/xarray/xarray/py.typed b/test/fixtures/whole_applications/xarray/xarray/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/test/fixtures/whole_applications/xarray/xarray/static/__init__.py b/test/fixtures/whole_applications/xarray/xarray/static/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/fixtures/whole_applications/xarray/xarray/static/css/__init__.py b/test/fixtures/whole_applications/xarray/xarray/static/css/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/fixtures/whole_applications/xarray/xarray/static/css/style.css b/test/fixtures/whole_applications/xarray/xarray/static/css/style.css new file mode 100644 index 0000000..e0a5131 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/static/css/style.css @@ -0,0 +1,348 @@ +/* CSS stylesheet for displaying xarray objects in jupyterlab. + * + */ + +:root { + --xr-font-color0: var(--jp-content-font-color0, rgba(0, 0, 0, 1)); + --xr-font-color2: var(--jp-content-font-color2, rgba(0, 0, 0, 0.54)); + --xr-font-color3: var(--jp-content-font-color3, rgba(0, 0, 0, 0.38)); + --xr-border-color: var(--jp-border-color2, #e0e0e0); + --xr-disabled-color: var(--jp-layout-color3, #bdbdbd); + --xr-background-color: var(--jp-layout-color0, white); + --xr-background-color-row-even: var(--jp-layout-color1, white); + --xr-background-color-row-odd: var(--jp-layout-color2, #eeeeee); +} + +html[theme=dark], +body[data-theme=dark], +body.vscode-dark { + --xr-font-color0: rgba(255, 255, 255, 1); + --xr-font-color2: rgba(255, 255, 255, 0.54); + --xr-font-color3: rgba(255, 255, 255, 0.38); + --xr-border-color: #1F1F1F; + --xr-disabled-color: #515151; + --xr-background-color: #111111; + --xr-background-color-row-even: #111111; + --xr-background-color-row-odd: #313131; +} + +.xr-wrap { + display: block !important; + min-width: 300px; + max-width: 700px; +} + +.xr-text-repr-fallback { + /* fallback to plain text repr when CSS is not injected (untrusted notebook) */ + display: none; +} + +.xr-header { + padding-top: 6px; + padding-bottom: 6px; + margin-bottom: 4px; + border-bottom: solid 1px var(--xr-border-color); +} + +.xr-header > div, +.xr-header > ul { + display: inline; + margin-top: 0; + margin-bottom: 0; +} + +.xr-obj-type, +.xr-array-name { + margin-left: 2px; + margin-right: 10px; +} + +.xr-obj-type { + color: var(--xr-font-color2); +} + +.xr-sections { + padding-left: 0 !important; + display: grid; + grid-template-columns: 150px auto auto 1fr 20px 20px; +} + +.xr-section-item { + display: contents; +} + +.xr-section-item input { + display: none; +} + +.xr-section-item input + label { + color: var(--xr-disabled-color); +} + +.xr-section-item input:enabled + label { + cursor: pointer; + color: var(--xr-font-color2); +} + +.xr-section-item input:enabled + label:hover { + color: var(--xr-font-color0); +} + +.xr-section-summary { + grid-column: 1; + color: var(--xr-font-color2); + font-weight: 500; +} + +.xr-section-summary > span { + display: inline-block; + padding-left: 0.5em; +} + +.xr-section-summary-in:disabled + label { + color: var(--xr-font-color2); +} + +.xr-section-summary-in + label:before { + display: inline-block; + content: '►'; + font-size: 11px; + width: 15px; + text-align: center; +} + +.xr-section-summary-in:disabled + label:before { + color: var(--xr-disabled-color); +} + +.xr-section-summary-in:checked + label:before { + content: '▼'; +} + +.xr-section-summary-in:checked + label > span { + display: none; +} + +.xr-section-summary, +.xr-section-inline-details { + padding-top: 4px; + padding-bottom: 4px; +} + +.xr-section-inline-details { + grid-column: 2 / -1; +} + +.xr-section-details { + display: none; + grid-column: 1 / -1; + margin-bottom: 5px; +} + +.xr-section-summary-in:checked ~ .xr-section-details { + display: contents; +} + +.xr-array-wrap { + grid-column: 1 / -1; + display: grid; + grid-template-columns: 20px auto; +} + +.xr-array-wrap > label { + grid-column: 1; + vertical-align: top; +} + +.xr-preview { + color: var(--xr-font-color3); +} + +.xr-array-preview, +.xr-array-data { + padding: 0 5px !important; + grid-column: 2; +} + +.xr-array-data, +.xr-array-in:checked ~ .xr-array-preview { + display: none; +} + +.xr-array-in:checked ~ .xr-array-data, +.xr-array-preview { + display: inline-block; +} + +.xr-dim-list { + display: inline-block !important; + list-style: none; + padding: 0 !important; + margin: 0; +} + +.xr-dim-list li { + display: inline-block; + padding: 0; + margin: 0; +} + +.xr-dim-list:before { + content: '('; +} + +.xr-dim-list:after { + content: ')'; +} + +.xr-dim-list li:not(:last-child):after { + content: ','; + padding-right: 5px; +} + +.xr-has-index { + font-weight: bold; +} + +.xr-var-list, +.xr-var-item { + display: contents; +} + +.xr-var-item > div, +.xr-var-item label, +.xr-var-item > .xr-var-name span { + background-color: var(--xr-background-color-row-even); + margin-bottom: 0; +} + +.xr-var-item > .xr-var-name:hover span { + padding-right: 5px; +} + +.xr-var-list > li:nth-child(odd) > div, +.xr-var-list > li:nth-child(odd) > label, +.xr-var-list > li:nth-child(odd) > .xr-var-name span { + background-color: var(--xr-background-color-row-odd); +} + +.xr-var-name { + grid-column: 1; +} + +.xr-var-dims { + grid-column: 2; +} + +.xr-var-dtype { + grid-column: 3; + text-align: right; + color: var(--xr-font-color2); +} + +.xr-var-preview { + grid-column: 4; +} + +.xr-index-preview { + grid-column: 2 / 5; + color: var(--xr-font-color2); +} + +.xr-var-name, +.xr-var-dims, +.xr-var-dtype, +.xr-preview, +.xr-attrs dt { + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; + padding-right: 10px; +} + +.xr-var-name:hover, +.xr-var-dims:hover, +.xr-var-dtype:hover, +.xr-attrs dt:hover { + overflow: visible; + width: auto; + z-index: 1; +} + +.xr-var-attrs, +.xr-var-data, +.xr-index-data { + display: none; + background-color: var(--xr-background-color) !important; + padding-bottom: 5px !important; +} + +.xr-var-attrs-in:checked ~ .xr-var-attrs, +.xr-var-data-in:checked ~ .xr-var-data, +.xr-index-data-in:checked ~ .xr-index-data { + display: block; +} + +.xr-var-data > table { + float: right; +} + +.xr-var-name span, +.xr-var-data, +.xr-index-name div, +.xr-index-data, +.xr-attrs { + padding-left: 25px !important; +} + +.xr-attrs, +.xr-var-attrs, +.xr-var-data, +.xr-index-data { + grid-column: 1 / -1; +} + +dl.xr-attrs { + padding: 0; + margin: 0; + display: grid; + grid-template-columns: 125px auto; +} + +.xr-attrs dt, +.xr-attrs dd { + padding: 0; + margin: 0; + float: left; + padding-right: 10px; + width: auto; +} + +.xr-attrs dt { + font-weight: normal; + grid-column: 1; +} + +.xr-attrs dt:hover span { + display: inline-block; + background: var(--xr-background-color); + padding-right: 10px; +} + +.xr-attrs dd { + grid-column: 2; + white-space: pre-wrap; + word-break: break-all; +} + +.xr-icon-database, +.xr-icon-file-text2, +.xr-no-icon { + display: inline-block; + vertical-align: middle; + width: 1em; + height: 1.5em !important; + stroke-width: 0; + stroke: currentColor; + fill: currentColor; +} diff --git a/test/fixtures/whole_applications/xarray/xarray/static/html/__init__.py b/test/fixtures/whole_applications/xarray/xarray/static/html/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/fixtures/whole_applications/xarray/xarray/static/html/icons-svg-inline.html b/test/fixtures/whole_applications/xarray/xarray/static/html/icons-svg-inline.html new file mode 100644 index 0000000..b0e837a --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/static/html/icons-svg-inline.html @@ -0,0 +1,15 @@ + + + + + + + + + + + + + + + diff --git a/test/fixtures/whole_applications/xarray/xarray/testing/__init__.py b/test/fixtures/whole_applications/xarray/xarray/testing/__init__.py new file mode 100644 index 0000000..316b0ea --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/testing/__init__.py @@ -0,0 +1,24 @@ +# TODO: Add assert_isomorphic when making DataTree API public +from xarray.testing.assertions import ( # noqa: F401 + _assert_dataarray_invariants, + _assert_dataset_invariants, + _assert_indexes_invariants_checks, + _assert_internal_invariants, + _assert_variable_invariants, + _data_allclose_or_equiv, + assert_allclose, + assert_chunks_equal, + assert_duckarray_allclose, + assert_duckarray_equal, + assert_equal, + assert_identical, +) + +__all__ = [ + "assert_allclose", + "assert_chunks_equal", + "assert_duckarray_equal", + "assert_duckarray_allclose", + "assert_equal", + "assert_identical", +] diff --git a/test/fixtures/whole_applications/xarray/xarray/testing/assertions.py b/test/fixtures/whole_applications/xarray/xarray/testing/assertions.py new file mode 100644 index 0000000..6988586 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/testing/assertions.py @@ -0,0 +1,524 @@ +"""Testing functions exposed to the user API""" + +import functools +import warnings +from collections.abc import Hashable +from typing import Union, overload + +import numpy as np +import pandas as pd + +from xarray.core import duck_array_ops, formatting, utils +from xarray.core.coordinates import Coordinates +from xarray.core.dataarray import DataArray +from xarray.core.dataset import Dataset +from xarray.core.datatree import DataTree +from xarray.core.formatting import diff_datatree_repr +from xarray.core.indexes import Index, PandasIndex, PandasMultiIndex, default_indexes +from xarray.core.variable import IndexVariable, Variable + + +def ensure_warnings(func): + # sometimes tests elevate warnings to errors + # -> make sure that does not happen in the assert_* functions + @functools.wraps(func) + def wrapper(*args, **kwargs): + __tracebackhide__ = True + + with warnings.catch_warnings(): + # only remove filters that would "error" + warnings.filters = [f for f in warnings.filters if f[0] != "error"] + + return func(*args, **kwargs) + + return wrapper + + +def _decode_string_data(data): + if data.dtype.kind == "S": + return np.core.defchararray.decode(data, "utf-8", "replace") + return data + + +def _data_allclose_or_equiv(arr1, arr2, rtol=1e-05, atol=1e-08, decode_bytes=True): + if any(arr.dtype.kind == "S" for arr in [arr1, arr2]) and decode_bytes: + arr1 = _decode_string_data(arr1) + arr2 = _decode_string_data(arr2) + exact_dtypes = ["M", "m", "O", "S", "U"] + if any(arr.dtype.kind in exact_dtypes for arr in [arr1, arr2]): + return duck_array_ops.array_equiv(arr1, arr2) + else: + return duck_array_ops.allclose_or_equiv(arr1, arr2, rtol=rtol, atol=atol) + + +@ensure_warnings +def assert_isomorphic(a: DataTree, b: DataTree, from_root: bool = False): + """ + Two DataTrees are considered isomorphic if every node has the same number of children. + + Nothing about the data or attrs in each node is checked. + + Isomorphism is a necessary condition for two trees to be used in a nodewise binary operation, + such as tree1 + tree2. + + By default this function does not check any part of the tree above the given node. + Therefore this function can be used as default to check that two subtrees are isomorphic. + + Parameters + ---------- + a : DataTree + The first object to compare. + b : DataTree + The second object to compare. + from_root : bool, optional, default is False + Whether or not to first traverse to the root of the trees before checking for isomorphism. + If a & b have no parents then this has no effect. + + See Also + -------- + DataTree.isomorphic + assert_equal + assert_identical + """ + __tracebackhide__ = True + assert isinstance(a, type(b)) + + if isinstance(a, DataTree): + if from_root: + a = a.root + b = b.root + + assert a.isomorphic(b, from_root=from_root), diff_datatree_repr( + a, b, "isomorphic" + ) + else: + raise TypeError(f"{type(a)} not of type DataTree") + + +def maybe_transpose_dims(a, b, check_dim_order: bool): + """Helper for assert_equal/allclose/identical""" + __tracebackhide__ = True + if not isinstance(a, (Variable, DataArray, Dataset)): + return b + if not check_dim_order and set(a.dims) == set(b.dims): + # Ensure transpose won't fail if a dimension is missing + # If this is the case, the difference will be caught by the caller + return b.transpose(*a.dims) + return b + + +@overload +def assert_equal(a, b): ... + + +@overload +def assert_equal(a: DataTree, b: DataTree, from_root: bool = True): ... + + +@ensure_warnings +def assert_equal(a, b, from_root=True, check_dim_order: bool = True): + """Like :py:func:`numpy.testing.assert_array_equal`, but for xarray + objects. + + Raises an AssertionError if two objects are not equal. This will match + data values, dimensions and coordinates, but not names or attributes + (except for Dataset objects for which the variable names must match). + Arrays with NaN in the same location are considered equal. + + For DataTree objects, assert_equal is mapped over all Datasets on each node, + with the DataTrees being equal if both are isomorphic and the corresponding + Datasets at each node are themselves equal. + + Parameters + ---------- + a : xarray.Dataset, xarray.DataArray, xarray.Variable, xarray.Coordinates + or xarray.core.datatree.DataTree. The first object to compare. + b : xarray.Dataset, xarray.DataArray, xarray.Variable, xarray.Coordinates + or xarray.core.datatree.DataTree. The second object to compare. + from_root : bool, optional, default is True + Only used when comparing DataTree objects. Indicates whether or not to + first traverse to the root of the trees before checking for isomorphism. + If a & b have no parents then this has no effect. + check_dim_order : bool, optional, default is True + Whether dimensions must be in the same order. + + See Also + -------- + assert_identical, assert_allclose, Dataset.equals, DataArray.equals + numpy.testing.assert_array_equal + """ + __tracebackhide__ = True + assert ( + type(a) == type(b) or isinstance(a, Coordinates) and isinstance(b, Coordinates) + ) + b = maybe_transpose_dims(a, b, check_dim_order) + if isinstance(a, (Variable, DataArray)): + assert a.equals(b), formatting.diff_array_repr(a, b, "equals") + elif isinstance(a, Dataset): + assert a.equals(b), formatting.diff_dataset_repr(a, b, "equals") + elif isinstance(a, Coordinates): + assert a.equals(b), formatting.diff_coords_repr(a, b, "equals") + elif isinstance(a, DataTree): + if from_root: + a = a.root + b = b.root + + assert a.equals(b, from_root=from_root), diff_datatree_repr(a, b, "equals") + else: + raise TypeError(f"{type(a)} not supported by assertion comparison") + + +@overload +def assert_identical(a, b): ... + + +@overload +def assert_identical(a: DataTree, b: DataTree, from_root: bool = True): ... + + +@ensure_warnings +def assert_identical(a, b, from_root=True): + """Like :py:func:`xarray.testing.assert_equal`, but also matches the + objects' names and attributes. + + Raises an AssertionError if two objects are not identical. + + For DataTree objects, assert_identical is mapped over all Datasets on each + node, with the DataTrees being identical if both are isomorphic and the + corresponding Datasets at each node are themselves identical. + + Parameters + ---------- + a : xarray.Dataset, xarray.DataArray, xarray.Variable or xarray.Coordinates + The first object to compare. + b : xarray.Dataset, xarray.DataArray, xarray.Variable or xarray.Coordinates + The second object to compare. + from_root : bool, optional, default is True + Only used when comparing DataTree objects. Indicates whether or not to + first traverse to the root of the trees before checking for isomorphism. + If a & b have no parents then this has no effect. + check_dim_order : bool, optional, default is True + Whether dimensions must be in the same order. + + See Also + -------- + assert_equal, assert_allclose, Dataset.equals, DataArray.equals + """ + __tracebackhide__ = True + assert ( + type(a) == type(b) or isinstance(a, Coordinates) and isinstance(b, Coordinates) + ) + if isinstance(a, Variable): + assert a.identical(b), formatting.diff_array_repr(a, b, "identical") + elif isinstance(a, DataArray): + assert a.name == b.name + assert a.identical(b), formatting.diff_array_repr(a, b, "identical") + elif isinstance(a, (Dataset, Variable)): + assert a.identical(b), formatting.diff_dataset_repr(a, b, "identical") + elif isinstance(a, Coordinates): + assert a.identical(b), formatting.diff_coords_repr(a, b, "identical") + elif isinstance(a, DataTree): + if from_root: + a = a.root + b = b.root + + assert a.identical(b, from_root=from_root), diff_datatree_repr( + a, b, "identical" + ) + else: + raise TypeError(f"{type(a)} not supported by assertion comparison") + + +@ensure_warnings +def assert_allclose( + a, b, rtol=1e-05, atol=1e-08, decode_bytes=True, check_dim_order: bool = True +): + """Like :py:func:`numpy.testing.assert_allclose`, but for xarray objects. + + Raises an AssertionError if two objects are not equal up to desired + tolerance. + + Parameters + ---------- + a : xarray.Dataset, xarray.DataArray or xarray.Variable + The first object to compare. + b : xarray.Dataset, xarray.DataArray or xarray.Variable + The second object to compare. + rtol : float, optional + Relative tolerance. + atol : float, optional + Absolute tolerance. + decode_bytes : bool, optional + Whether byte dtypes should be decoded to strings as UTF-8 or not. + This is useful for testing serialization methods on Python 3 that + return saved strings as bytes. + check_dim_order : bool, optional, default is True + Whether dimensions must be in the same order. + + See Also + -------- + assert_identical, assert_equal, numpy.testing.assert_allclose + """ + __tracebackhide__ = True + assert type(a) == type(b) + b = maybe_transpose_dims(a, b, check_dim_order) + + equiv = functools.partial( + _data_allclose_or_equiv, rtol=rtol, atol=atol, decode_bytes=decode_bytes + ) + equiv.__name__ = "allclose" # type: ignore[attr-defined] + + def compat_variable(a, b): + a = getattr(a, "variable", a) + b = getattr(b, "variable", b) + return a.dims == b.dims and (a._data is b._data or equiv(a.data, b.data)) + + if isinstance(a, Variable): + allclose = compat_variable(a, b) + assert allclose, formatting.diff_array_repr(a, b, compat=equiv) + elif isinstance(a, DataArray): + allclose = utils.dict_equiv( + a.coords, b.coords, compat=compat_variable + ) and compat_variable(a.variable, b.variable) + assert allclose, formatting.diff_array_repr(a, b, compat=equiv) + elif isinstance(a, Dataset): + allclose = a._coord_names == b._coord_names and utils.dict_equiv( + a.variables, b.variables, compat=compat_variable + ) + assert allclose, formatting.diff_dataset_repr(a, b, compat=equiv) + else: + raise TypeError(f"{type(a)} not supported by assertion comparison") + + +def _format_message(x, y, err_msg, verbose): + diff = x - y + abs_diff = max(abs(diff)) + rel_diff = "not implemented" + + n_diff = np.count_nonzero(diff) + n_total = diff.size + + fraction = f"{n_diff} / {n_total}" + percentage = float(n_diff / n_total * 100) + + parts = [ + "Arrays are not equal", + err_msg, + f"Mismatched elements: {fraction} ({percentage:.0f}%)", + f"Max absolute difference: {abs_diff}", + f"Max relative difference: {rel_diff}", + ] + if verbose: + parts += [ + f" x: {x!r}", + f" y: {y!r}", + ] + + return "\n".join(parts) + + +@ensure_warnings +def assert_duckarray_allclose( + actual, desired, rtol=1e-07, atol=0, err_msg="", verbose=True +): + """Like `np.testing.assert_allclose`, but for duckarrays.""" + __tracebackhide__ = True + + allclose = duck_array_ops.allclose_or_equiv(actual, desired, rtol=rtol, atol=atol) + assert allclose, _format_message(actual, desired, err_msg=err_msg, verbose=verbose) + + +@ensure_warnings +def assert_duckarray_equal(x, y, err_msg="", verbose=True): + """Like `np.testing.assert_array_equal`, but for duckarrays""" + __tracebackhide__ = True + + if not utils.is_duck_array(x) and not utils.is_scalar(x): + x = np.asarray(x) + + if not utils.is_duck_array(y) and not utils.is_scalar(y): + y = np.asarray(y) + + if (utils.is_duck_array(x) and utils.is_scalar(y)) or ( + utils.is_scalar(x) and utils.is_duck_array(y) + ): + equiv = (x == y).all() + else: + equiv = duck_array_ops.array_equiv(x, y) + assert equiv, _format_message(x, y, err_msg=err_msg, verbose=verbose) + + +def assert_chunks_equal(a, b): + """ + Assert that chunksizes along chunked dimensions are equal. + + Parameters + ---------- + a : xarray.Dataset or xarray.DataArray + The first object to compare. + b : xarray.Dataset or xarray.DataArray + The second object to compare. + """ + + if isinstance(a, DataArray) != isinstance(b, DataArray): + raise TypeError("a and b have mismatched types") + + left = a.unify_chunks() + right = b.unify_chunks() + assert left.chunks == right.chunks + + +def _assert_indexes_invariants_checks( + indexes, possible_coord_variables, dims, check_default=True +): + assert isinstance(indexes, dict), indexes + assert all(isinstance(v, Index) for v in indexes.values()), { + k: type(v) for k, v in indexes.items() + } + + index_vars = { + k for k, v in possible_coord_variables.items() if isinstance(v, IndexVariable) + } + assert indexes.keys() <= index_vars, (set(indexes), index_vars) + + # check pandas index wrappers vs. coordinate data adapters + for k, index in indexes.items(): + if isinstance(index, PandasIndex): + pd_index = index.index + var = possible_coord_variables[k] + assert (index.dim,) == var.dims, (pd_index, var) + if k == index.dim: + # skip multi-index levels here (checked below) + assert index.coord_dtype == var.dtype, (index.coord_dtype, var.dtype) + assert isinstance(var._data.array, pd.Index), var._data.array + # TODO: check identity instead of equality? + assert pd_index.equals(var._data.array), (pd_index, var) + if isinstance(index, PandasMultiIndex): + pd_index = index.index + for name in index.index.names: + assert name in possible_coord_variables, (pd_index, index_vars) + var = possible_coord_variables[name] + assert (index.dim,) == var.dims, (pd_index, var) + assert index.level_coords_dtype[name] == var.dtype, ( + index.level_coords_dtype[name], + var.dtype, + ) + assert isinstance(var._data.array, pd.MultiIndex), var._data.array + assert pd_index.equals(var._data.array), (pd_index, var) + # check all all levels are in `indexes` + assert name in indexes, (name, set(indexes)) + # index identity is used to find unique indexes in `indexes` + assert index is indexes[name], (pd_index, indexes[name].index) + + if check_default: + defaults = default_indexes(possible_coord_variables, dims) + assert indexes.keys() == defaults.keys(), (set(indexes), set(defaults)) + assert all(v.equals(defaults[k]) for k, v in indexes.items()), ( + indexes, + defaults, + ) + + +def _assert_variable_invariants(var: Variable, name: Hashable = None): + if name is None: + name_or_empty: tuple = () + else: + name_or_empty = (name,) + assert isinstance(var._dims, tuple), name_or_empty + (var._dims,) + assert len(var._dims) == len(var._data.shape), name_or_empty + ( + var._dims, + var._data.shape, + ) + assert isinstance(var._encoding, (type(None), dict)), name_or_empty + ( + var._encoding, + ) + assert isinstance(var._attrs, (type(None), dict)), name_or_empty + (var._attrs,) + + +def _assert_dataarray_invariants(da: DataArray, check_default_indexes: bool): + assert isinstance(da._variable, Variable), da._variable + _assert_variable_invariants(da._variable) + + assert isinstance(da._coords, dict), da._coords + assert all(isinstance(v, Variable) for v in da._coords.values()), da._coords + assert all(set(v.dims) <= set(da.dims) for v in da._coords.values()), ( + da.dims, + {k: v.dims for k, v in da._coords.items()}, + ) + assert all( + isinstance(v, IndexVariable) for (k, v) in da._coords.items() if v.dims == (k,) + ), {k: type(v) for k, v in da._coords.items()} + for k, v in da._coords.items(): + _assert_variable_invariants(v, k) + + if da._indexes is not None: + _assert_indexes_invariants_checks( + da._indexes, da._coords, da.dims, check_default=check_default_indexes + ) + + +def _assert_dataset_invariants(ds: Dataset, check_default_indexes: bool): + assert isinstance(ds._variables, dict), type(ds._variables) + assert all(isinstance(v, Variable) for v in ds._variables.values()), ds._variables + for k, v in ds._variables.items(): + _assert_variable_invariants(v, k) + + assert isinstance(ds._coord_names, set), ds._coord_names + assert ds._coord_names <= ds._variables.keys(), ( + ds._coord_names, + set(ds._variables), + ) + + assert type(ds._dims) is dict, ds._dims # noqa: E721 + assert all(isinstance(v, int) for v in ds._dims.values()), ds._dims + var_dims: set[Hashable] = set() + for v in ds._variables.values(): + var_dims.update(v.dims) + assert ds._dims.keys() == var_dims, (set(ds._dims), var_dims) + assert all( + ds._dims[k] == v.sizes[k] for v in ds._variables.values() for k in v.sizes + ), (ds._dims, {k: v.sizes for k, v in ds._variables.items()}) + + if check_default_indexes: + assert all( + isinstance(v, IndexVariable) + for (k, v) in ds._variables.items() + if v.dims == (k,) + ), {k: type(v) for k, v in ds._variables.items() if v.dims == (k,)} + + if ds._indexes is not None: + _assert_indexes_invariants_checks( + ds._indexes, ds._variables, ds._dims, check_default=check_default_indexes + ) + + assert isinstance(ds._encoding, (type(None), dict)) + assert isinstance(ds._attrs, (type(None), dict)) + + +def _assert_internal_invariants( + xarray_obj: Union[DataArray, Dataset, Variable], check_default_indexes: bool +): + """Validate that an xarray object satisfies its own internal invariants. + + This exists for the benefit of xarray's own test suite, but may be useful + in external projects if they (ill-advisedly) create objects using xarray's + private APIs. + """ + if isinstance(xarray_obj, Variable): + _assert_variable_invariants(xarray_obj) + elif isinstance(xarray_obj, DataArray): + _assert_dataarray_invariants( + xarray_obj, check_default_indexes=check_default_indexes + ) + elif isinstance(xarray_obj, Dataset): + _assert_dataset_invariants( + xarray_obj, check_default_indexes=check_default_indexes + ) + elif isinstance(xarray_obj, Coordinates): + _assert_dataset_invariants( + xarray_obj.to_dataset(), check_default_indexes=check_default_indexes + ) + else: + raise TypeError( + f"{type(xarray_obj)} is not a supported type for xarray invariant checks" + ) diff --git a/test/fixtures/whole_applications/xarray/xarray/testing/strategies.py b/test/fixtures/whole_applications/xarray/xarray/testing/strategies.py new file mode 100644 index 0000000..449d0c7 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/testing/strategies.py @@ -0,0 +1,468 @@ +from collections.abc import Hashable, Iterable, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Protocol, Union, overload + +try: + import hypothesis.strategies as st +except ImportError as e: + raise ImportError( + "`xarray.testing.strategies` requires `hypothesis` to be installed." + ) from e + +import hypothesis.extra.numpy as npst +import numpy as np +from hypothesis.errors import InvalidArgument + +import xarray as xr +from xarray.core.types import T_DuckArray + +if TYPE_CHECKING: + from xarray.core.types import _DTypeLikeNested, _ShapeLike + + +__all__ = [ + "supported_dtypes", + "pandas_index_dtypes", + "names", + "dimension_names", + "dimension_sizes", + "attrs", + "variables", + "unique_subset_of", +] + + +class ArrayStrategyFn(Protocol[T_DuckArray]): + def __call__( + self, + *, + shape: "_ShapeLike", + dtype: "_DTypeLikeNested", + ) -> st.SearchStrategy[T_DuckArray]: ... + + +def supported_dtypes() -> st.SearchStrategy[np.dtype]: + """ + Generates only those numpy dtypes which xarray can handle. + + Use instead of hypothesis.extra.numpy.scalar_dtypes in order to exclude weirder dtypes such as unicode, byte_string, array, or nested dtypes. + Also excludes datetimes, which dodges bugs with pandas non-nanosecond datetime overflows. Checks only native endianness. + + Requires the hypothesis package to be installed. + + See Also + -------- + :ref:`testing.hypothesis`_ + """ + # TODO should this be exposed publicly? + # We should at least decide what the set of numpy dtypes that xarray officially supports is. + return ( + npst.integer_dtypes(endianness="=") + | npst.unsigned_integer_dtypes(endianness="=") + | npst.floating_dtypes(endianness="=") + | npst.complex_number_dtypes(endianness="=") + # | npst.datetime64_dtypes() + # | npst.timedelta64_dtypes() + # | npst.unicode_string_dtypes() + ) + + +def pandas_index_dtypes() -> st.SearchStrategy[np.dtype]: + """ + Dtypes supported by pandas indexes. + Restrict datetime64 and timedelta64 to ns frequency till Xarray relaxes that. + """ + return ( + npst.integer_dtypes(endianness="=", sizes=(32, 64)) + | npst.unsigned_integer_dtypes(endianness="=", sizes=(32, 64)) + | npst.floating_dtypes(endianness="=", sizes=(32, 64)) + # TODO: unset max_period + | npst.datetime64_dtypes(endianness="=", max_period="ns") + # TODO: set max_period="D" + | npst.timedelta64_dtypes(endianness="=", max_period="ns") + | npst.unicode_string_dtypes(endianness="=") + ) + + +# TODO Generalize to all valid unicode characters once formatting bugs in xarray's reprs are fixed + docs can handle it. +_readable_characters = st.characters( + categories=["L", "N"], max_codepoint=0x017F +) # only use characters within the "Latin Extended-A" subset of unicode + + +def names() -> st.SearchStrategy[str]: + """ + Generates arbitrary string names for dimensions / variables. + + Requires the hypothesis package to be installed. + + See Also + -------- + :ref:`testing.hypothesis`_ + """ + return st.text( + _readable_characters, + min_size=1, + max_size=5, + ) + + +def dimension_names( + *, + name_strategy=names(), + min_dims: int = 0, + max_dims: int = 3, +) -> st.SearchStrategy[list[Hashable]]: + """ + Generates an arbitrary list of valid dimension names. + + Requires the hypothesis package to be installed. + + Parameters + ---------- + name_strategy + Strategy for making names. Useful if we need to share this. + min_dims + Minimum number of dimensions in generated list. + max_dims + Maximum number of dimensions in generated list. + """ + + return st.lists( + elements=name_strategy, + min_size=min_dims, + max_size=max_dims, + unique=True, + ) + + +def dimension_sizes( + *, + dim_names: st.SearchStrategy[Hashable] = names(), + min_dims: int = 0, + max_dims: int = 3, + min_side: int = 1, + max_side: Union[int, None] = None, +) -> st.SearchStrategy[Mapping[Hashable, int]]: + """ + Generates an arbitrary mapping from dimension names to lengths. + + Requires the hypothesis package to be installed. + + Parameters + ---------- + dim_names: strategy generating strings, optional + Strategy for generating dimension names. + Defaults to the `names` strategy. + min_dims: int, optional + Minimum number of dimensions in generated list. + Default is 1. + max_dims: int, optional + Maximum number of dimensions in generated list. + Default is 3. + min_side: int, optional + Minimum size of a dimension. + Default is 1. + max_side: int, optional + Minimum size of a dimension. + Default is `min_length` + 5. + + See Also + -------- + :ref:`testing.hypothesis`_ + """ + + if max_side is None: + max_side = min_side + 3 + + return st.dictionaries( + keys=dim_names, + values=st.integers(min_value=min_side, max_value=max_side), + min_size=min_dims, + max_size=max_dims, + ) + + +_readable_strings = st.text( + _readable_characters, + max_size=5, +) +_attr_keys = _readable_strings +_small_arrays = npst.arrays( + shape=npst.array_shapes( + max_side=2, + max_dims=2, + ), + dtype=npst.scalar_dtypes(), +) +_attr_values = st.none() | st.booleans() | _readable_strings | _small_arrays + + +def attrs() -> st.SearchStrategy[Mapping[Hashable, Any]]: + """ + Generates arbitrary valid attributes dictionaries for xarray objects. + + The generated dictionaries can potentially be recursive. + + Requires the hypothesis package to be installed. + + See Also + -------- + :ref:`testing.hypothesis`_ + """ + return st.recursive( + st.dictionaries(_attr_keys, _attr_values), + lambda children: st.dictionaries(_attr_keys, children), + max_leaves=3, + ) + + +@st.composite +def variables( + draw: st.DrawFn, + *, + array_strategy_fn: Union[ArrayStrategyFn, None] = None, + dims: Union[ + st.SearchStrategy[Union[Sequence[Hashable], Mapping[Hashable, int]]], + None, + ] = None, + dtype: st.SearchStrategy[np.dtype] = supported_dtypes(), + attrs: st.SearchStrategy[Mapping] = attrs(), +) -> xr.Variable: + """ + Generates arbitrary xarray.Variable objects. + + Follows the basic signature of the xarray.Variable constructor, but allows passing alternative strategies to + generate either numpy-like array data or dimensions. Also allows specifying the shape or dtype of the wrapped array + up front. + + Passing nothing will generate a completely arbitrary Variable (containing a numpy array). + + Requires the hypothesis package to be installed. + + Parameters + ---------- + array_strategy_fn: Callable which returns a strategy generating array-likes, optional + Callable must only accept shape and dtype kwargs, and must generate results consistent with its input. + If not passed the default is to generate a small numpy array with one of the supported_dtypes. + dims: Strategy for generating the dimensions, optional + Can either be a strategy for generating a sequence of string dimension names, + or a strategy for generating a mapping of string dimension names to integer lengths along each dimension. + If provided as a mapping the array shape will be passed to array_strategy_fn. + Default is to generate arbitrary dimension names for each axis in data. + dtype: Strategy which generates np.dtype objects, optional + Will be passed in to array_strategy_fn. + Default is to generate any scalar dtype using supported_dtypes. + Be aware that this default set of dtypes includes some not strictly allowed by the array API standard. + attrs: Strategy which generates dicts, optional + Default is to generate a nested attributes dictionary containing arbitrary strings, booleans, integers, Nones, + and numpy arrays. + + Returns + ------- + variable_strategy + Strategy for generating xarray.Variable objects. + + Raises + ------ + ValueError + If a custom array_strategy_fn returns a strategy which generates an example array inconsistent with the shape + & dtype input passed to it. + + Examples + -------- + Generate completely arbitrary Variable objects backed by a numpy array: + + >>> variables().example() # doctest: +SKIP + + array([43506, -16, -151], dtype=int32) + >>> variables().example() # doctest: +SKIP + + array([[[-10000000., -10000000.], + [-10000000., -10000000.]], + [[-10000000., -10000000.], + [ 0., -10000000.]], + [[ 0., -10000000.], + [-10000000., inf]], + [[ -0., -10000000.], + [-10000000., -0.]]], dtype=float32) + Attributes: + śřĴ: {'ĉ': {'iĥf': array([-30117, -1740], dtype=int16)}} + + Generate only Variable objects with certain dimension names: + + >>> variables(dims=st.just(["a", "b"])).example() # doctest: +SKIP + + array([[ 248, 4294967295, 4294967295], + [2412855555, 3514117556, 4294967295], + [ 111, 4294967295, 4294967295], + [4294967295, 1084434988, 51688], + [ 47714, 252, 11207]], dtype=uint32) + + Generate only Variable objects with certain dimension names and lengths: + + >>> variables(dims=st.just({"a": 2, "b": 1})).example() # doctest: +SKIP + + array([[-1.00000000e+007+3.40282347e+038j], + [-2.75034266e-225+2.22507386e-311j]]) + + See Also + -------- + :ref:`testing.hypothesis`_ + """ + + if not isinstance(dims, st.SearchStrategy) and dims is not None: + raise InvalidArgument( + f"dims must be provided as a hypothesis.strategies.SearchStrategy object (or None), but got type {type(dims)}. " + "To specify fixed contents, use hypothesis.strategies.just()." + ) + if not isinstance(dtype, st.SearchStrategy) and dtype is not None: + raise InvalidArgument( + f"dtype must be provided as a hypothesis.strategies.SearchStrategy object (or None), but got type {type(dtype)}. " + "To specify fixed contents, use hypothesis.strategies.just()." + ) + if not isinstance(attrs, st.SearchStrategy) and attrs is not None: + raise InvalidArgument( + f"attrs must be provided as a hypothesis.strategies.SearchStrategy object (or None), but got type {type(attrs)}. " + "To specify fixed contents, use hypothesis.strategies.just()." + ) + + _array_strategy_fn: ArrayStrategyFn + if array_strategy_fn is None: + # For some reason if I move the default value to the function signature definition mypy incorrectly says the ignore is no longer necessary, making it impossible to satisfy mypy + _array_strategy_fn = npst.arrays # type: ignore[assignment] # npst.arrays has extra kwargs that we aren't using later + elif not callable(array_strategy_fn): + raise InvalidArgument( + "array_strategy_fn must be a Callable that accepts the kwargs dtype and shape and returns a hypothesis " + "strategy which generates corresponding array-like objects." + ) + else: + _array_strategy_fn = ( + array_strategy_fn # satisfy mypy that this new variable cannot be None + ) + + _dtype = draw(dtype) + + if dims is not None: + # generate dims first then draw data to match + _dims = draw(dims) + if isinstance(_dims, Sequence): + dim_names = list(_dims) + valid_shapes = npst.array_shapes(min_dims=len(_dims), max_dims=len(_dims)) + _shape = draw(valid_shapes) + array_strategy = _array_strategy_fn(shape=_shape, dtype=_dtype) + elif isinstance(_dims, (Mapping, dict)): + # should be a mapping of form {dim_names: lengths} + dim_names, _shape = list(_dims.keys()), tuple(_dims.values()) + array_strategy = _array_strategy_fn(shape=_shape, dtype=_dtype) + else: + raise InvalidArgument( + f"Invalid type returned by dims strategy - drew an object of type {type(dims)}" + ) + else: + # nothing provided, so generate everything consistently + # We still generate the shape first here just so that we always pass shape to array_strategy_fn + _shape = draw(npst.array_shapes()) + array_strategy = _array_strategy_fn(shape=_shape, dtype=_dtype) + dim_names = draw(dimension_names(min_dims=len(_shape), max_dims=len(_shape))) + + _data = draw(array_strategy) + + if _data.shape != _shape: + raise ValueError( + "array_strategy_fn returned an array object with a different shape than it was passed." + f"Passed {_shape}, but returned {_data.shape}." + "Please either specify a consistent shape via the dims kwarg or ensure the array_strategy_fn callable " + "obeys the shape argument passed to it." + ) + if _data.dtype != _dtype: + raise ValueError( + "array_strategy_fn returned an array object with a different dtype than it was passed." + f"Passed {_dtype}, but returned {_data.dtype}" + "Please either specify a consistent dtype via the dtype kwarg or ensure the array_strategy_fn callable " + "obeys the dtype argument passed to it." + ) + + return xr.Variable(dims=dim_names, data=_data, attrs=draw(attrs)) + + +@overload +def unique_subset_of( + objs: Sequence[Hashable], + *, + min_size: int = 0, + max_size: Union[int, None] = None, +) -> st.SearchStrategy[Sequence[Hashable]]: ... + + +@overload +def unique_subset_of( + objs: Mapping[Hashable, Any], + *, + min_size: int = 0, + max_size: Union[int, None] = None, +) -> st.SearchStrategy[Mapping[Hashable, Any]]: ... + + +@st.composite +def unique_subset_of( + draw: st.DrawFn, + objs: Union[Sequence[Hashable], Mapping[Hashable, Any]], + *, + min_size: int = 0, + max_size: Union[int, None] = None, +) -> Union[Sequence[Hashable], Mapping[Hashable, Any]]: + """ + Return a strategy which generates a unique subset of the given objects. + + Each entry in the output subset will be unique (if input was a sequence) or have a unique key (if it was a mapping). + + Requires the hypothesis package to be installed. + + Parameters + ---------- + objs: Union[Sequence[Hashable], Mapping[Hashable, Any]] + Objects from which to sample to produce the subset. + min_size: int, optional + Minimum size of the returned subset. Default is 0. + max_size: int, optional + Maximum size of the returned subset. Default is the full length of the input. + If set to 0 the result will be an empty mapping. + + Returns + ------- + unique_subset_strategy + Strategy generating subset of the input. + + Examples + -------- + >>> unique_subset_of({"x": 2, "y": 3}).example() # doctest: +SKIP + {'y': 3} + >>> unique_subset_of(["x", "y"]).example() # doctest: +SKIP + ['x'] + + See Also + -------- + :ref:`testing.hypothesis`_ + """ + if not isinstance(objs, Iterable): + raise TypeError( + f"Object to sample from must be an Iterable or a Mapping, but received type {type(objs)}" + ) + + if len(objs) == 0: + raise ValueError("Can't sample from a length-zero object.") + + keys = list(objs.keys()) if isinstance(objs, Mapping) else objs + + subset_keys = draw( + st.lists( + st.sampled_from(keys), + unique=True, + min_size=min_size, + max_size=max_size, + ) + ) + + return ( + {k: objs[k] for k in subset_keys} if isinstance(objs, Mapping) else subset_keys + ) diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/__init__.py b/test/fixtures/whole_applications/xarray/xarray/tests/__init__.py new file mode 100644 index 0000000..0caab6e --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/__init__.py @@ -0,0 +1,351 @@ +from __future__ import annotations + +import importlib +import platform +import string +import warnings +from contextlib import contextmanager, nullcontext +from unittest import mock # noqa: F401 + +import numpy as np +import pandas as pd +import pytest +from numpy.testing import assert_array_equal # noqa: F401 +from packaging.version import Version +from pandas.testing import assert_frame_equal # noqa: F401 + +import xarray.testing +from xarray import Dataset +from xarray.core.duck_array_ops import allclose_or_equiv # noqa: F401 +from xarray.core.extension_array import PandasExtensionArray +from xarray.core.options import set_options +from xarray.core.variable import IndexVariable +from xarray.testing import ( # noqa: F401 + assert_chunks_equal, + assert_duckarray_allclose, + assert_duckarray_equal, +) +from xarray.tests.arrays import ( # noqa: F401 + ConcatenatableArray, + DuckArrayWrapper, + FirstElementAccessibleArray, + InaccessibleArray, + UnexpectedDataAccess, +) + +# import mpl and change the backend before other mpl imports +try: + import matplotlib as mpl + + # Order of imports is important here. + # Using a different backend makes Travis CI work + mpl.use("Agg") +except ImportError: + pass + +# https://github.com/pydata/xarray/issues/7322 +warnings.filterwarnings("ignore", "'urllib3.contrib.pyopenssl' module is deprecated") +warnings.filterwarnings("ignore", "Deprecated call to `pkg_resources.declare_namespace") +warnings.filterwarnings("ignore", "pkg_resources is deprecated as an API") + +arm_xfail = pytest.mark.xfail( + platform.machine() == "aarch64" or "arm" in platform.machine(), + reason="expected failure on ARM", +) + + +def assert_writeable(ds): + readonly = [ + name + for name, var in ds.variables.items() + if not isinstance(var, IndexVariable) + and not isinstance(var.data, PandasExtensionArray) + and not var.data.flags.writeable + ] + assert not readonly, readonly + + +def _importorskip( + modname: str, minversion: str | None = None +) -> tuple[bool, pytest.MarkDecorator]: + try: + mod = importlib.import_module(modname) + has = True + if minversion is not None: + v = getattr(mod, "__version__", "999") + if Version(v) < Version(minversion): + raise ImportError("Minimum version not satisfied") + except ImportError: + has = False + + reason = f"requires {modname}" + if minversion is not None: + reason += f">={minversion}" + func = pytest.mark.skipif(not has, reason=reason) + return has, func + + +has_matplotlib, requires_matplotlib = _importorskip("matplotlib") +has_scipy, requires_scipy = _importorskip("scipy") +with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="'cgi' is deprecated and slated for removal in Python 3.13", + category=DeprecationWarning, + ) + has_pydap, requires_pydap = _importorskip("pydap.client") +has_netCDF4, requires_netCDF4 = _importorskip("netCDF4") +with warnings.catch_warnings(): + # see https://github.com/pydata/xarray/issues/8537 + warnings.filterwarnings( + "ignore", + message="h5py is running against HDF5 1.14.3", + category=UserWarning, + ) + + has_h5netcdf, requires_h5netcdf = _importorskip("h5netcdf") +has_cftime, requires_cftime = _importorskip("cftime") +has_dask, requires_dask = _importorskip("dask") +with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="The current Dask DataFrame implementation is deprecated.", + category=DeprecationWarning, + ) + has_dask_expr, requires_dask_expr = _importorskip("dask_expr") +has_bottleneck, requires_bottleneck = _importorskip("bottleneck") +has_rasterio, requires_rasterio = _importorskip("rasterio") +has_zarr, requires_zarr = _importorskip("zarr") +has_fsspec, requires_fsspec = _importorskip("fsspec") +has_iris, requires_iris = _importorskip("iris") +has_numbagg, requires_numbagg = _importorskip("numbagg", "0.4.0") +has_pyarrow, requires_pyarrow = _importorskip("pyarrow") +with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="is_categorical_dtype is deprecated and will be removed in a future version.", + category=DeprecationWarning, + ) + # seaborn uses the deprecated `pandas.is_categorical_dtype` + has_seaborn, requires_seaborn = _importorskip("seaborn") +has_sparse, requires_sparse = _importorskip("sparse") +has_cupy, requires_cupy = _importorskip("cupy") +has_cartopy, requires_cartopy = _importorskip("cartopy") +has_pint, requires_pint = _importorskip("pint") +has_numexpr, requires_numexpr = _importorskip("numexpr") +has_flox, requires_flox = _importorskip("flox") +has_pandas_ge_2_2, __ = _importorskip("pandas", "2.2") +has_pandas_3, requires_pandas_3 = _importorskip("pandas", "3.0.0.dev0") + + +# some special cases +has_scipy_or_netCDF4 = has_scipy or has_netCDF4 +requires_scipy_or_netCDF4 = pytest.mark.skipif( + not has_scipy_or_netCDF4, reason="requires scipy or netCDF4" +) +has_numbagg_or_bottleneck = has_numbagg or has_bottleneck +requires_numbagg_or_bottleneck = pytest.mark.skipif( + not has_numbagg_or_bottleneck, reason="requires numbagg or bottleneck" +) +has_numpy_2, requires_numpy_2 = _importorskip("numpy", "2.0.0") + +has_array_api_strict, requires_array_api_strict = _importorskip("array_api_strict") + + +def _importorskip_h5netcdf_ros3(): + try: + import h5netcdf + + has_h5netcdf = True + except ImportError: + has_h5netcdf = False + + if not has_h5netcdf: + return has_h5netcdf, pytest.mark.skipif( + not has_h5netcdf, reason="requires h5netcdf" + ) + + h5netcdf_with_ros3 = Version(h5netcdf.__version__) >= Version("1.3.0") + + import h5py + + h5py_with_ros3 = h5py.get_config().ros3 + + has_h5netcdf_ros3 = h5netcdf_with_ros3 and h5py_with_ros3 + + return has_h5netcdf_ros3, pytest.mark.skipif( + not has_h5netcdf_ros3, + reason="requires h5netcdf>=1.3.0 and h5py with ros3 support", + ) + + +has_h5netcdf_ros3, requires_h5netcdf_ros3 = _importorskip_h5netcdf_ros3() +has_netCDF4_1_6_2_or_above, requires_netCDF4_1_6_2_or_above = _importorskip( + "netCDF4", "1.6.2" +) + +# change some global options for tests +set_options(warn_for_unclosed_files=True) + +if has_dask: + import dask + + +class CountingScheduler: + """Simple dask scheduler counting the number of computes. + + Reference: https://stackoverflow.com/questions/53289286/""" + + def __init__(self, max_computes=0): + self.total_computes = 0 + self.max_computes = max_computes + + def __call__(self, dsk, keys, **kwargs): + self.total_computes += 1 + if self.total_computes > self.max_computes: + raise RuntimeError( + "Too many computes. Total: %d > max: %d." + % (self.total_computes, self.max_computes) + ) + return dask.get(dsk, keys, **kwargs) + + +def raise_if_dask_computes(max_computes=0): + # return a dummy context manager so that this can be used for non-dask objects + if not has_dask: + return nullcontext() + scheduler = CountingScheduler(max_computes) + return dask.config.set(scheduler=scheduler) + + +flaky = pytest.mark.flaky +network = pytest.mark.network + + +class ReturnItem: + def __getitem__(self, key): + return key + + +class IndexerMaker: + def __init__(self, indexer_cls): + self._indexer_cls = indexer_cls + + def __getitem__(self, key): + if not isinstance(key, tuple): + key = (key,) + return self._indexer_cls(key) + + +def source_ndarray(array): + """Given an ndarray, return the base object which holds its memory, or the + object itself. + """ + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "DatetimeIndex.base") + warnings.filterwarnings("ignore", "TimedeltaIndex.base") + base = getattr(array, "base", np.asarray(array).base) + if base is None: + base = array + return base + + +def format_record(record) -> str: + """Format warning record like `FutureWarning('Function will be deprecated...')`""" + return f"{str(record.category)[8:-2]}('{record.message}'))" + + +@contextmanager +def assert_no_warnings(): + with warnings.catch_warnings(record=True) as record: + yield record + assert ( + len(record) == 0 + ), f"Got {len(record)} unexpected warning(s): {[format_record(r) for r in record]}" + + +# Internal versions of xarray's test functions that validate additional +# invariants + + +def assert_equal(a, b, check_default_indexes=True): + __tracebackhide__ = True + xarray.testing.assert_equal(a, b) + xarray.testing._assert_internal_invariants(a, check_default_indexes) + xarray.testing._assert_internal_invariants(b, check_default_indexes) + + +def assert_identical(a, b, check_default_indexes=True): + __tracebackhide__ = True + xarray.testing.assert_identical(a, b) + xarray.testing._assert_internal_invariants(a, check_default_indexes) + xarray.testing._assert_internal_invariants(b, check_default_indexes) + + +def assert_allclose(a, b, check_default_indexes=True, **kwargs): + __tracebackhide__ = True + xarray.testing.assert_allclose(a, b, **kwargs) + xarray.testing._assert_internal_invariants(a, check_default_indexes) + xarray.testing._assert_internal_invariants(b, check_default_indexes) + + +_DEFAULT_TEST_DIM_SIZES = (8, 9, 10) + + +def create_test_data( + seed: int | None = None, + add_attrs: bool = True, + dim_sizes: tuple[int, int, int] = _DEFAULT_TEST_DIM_SIZES, + use_extension_array: bool = False, +) -> Dataset: + rs = np.random.RandomState(seed) + _vars = { + "var1": ["dim1", "dim2"], + "var2": ["dim1", "dim2"], + "var3": ["dim3", "dim1"], + } + _dims = {"dim1": dim_sizes[0], "dim2": dim_sizes[1], "dim3": dim_sizes[2]} + + obj = Dataset() + obj["dim2"] = ("dim2", 0.5 * np.arange(_dims["dim2"])) + if _dims["dim3"] > 26: + raise RuntimeError( + f'Not enough letters for filling this dimension size ({_dims["dim3"]})' + ) + obj["dim3"] = ("dim3", list(string.ascii_lowercase[0 : _dims["dim3"]])) + obj["time"] = ("time", pd.date_range("2000-01-01", periods=20)) + for v, dims in sorted(_vars.items()): + data = rs.normal(size=tuple(_dims[d] for d in dims)) + obj[v] = (dims, data) + if add_attrs: + obj[v].attrs = {"foo": "variable"} + if use_extension_array: + obj["var4"] = ( + "dim1", + pd.Categorical( + np.random.choice( + list(string.ascii_lowercase[: np.random.randint(5)]), + size=dim_sizes[0], + ) + ), + ) + if dim_sizes == _DEFAULT_TEST_DIM_SIZES: + numbers_values = np.array([0, 1, 2, 0, 0, 1, 1, 2, 2, 3], dtype="int64") + else: + numbers_values = np.random.randint(0, 3, _dims["dim3"], dtype="int64") + obj.coords["numbers"] = ("dim3", numbers_values) + obj.encoding = {"foo": "bar"} + assert_writeable(obj) + return obj + + +_CFTIME_CALENDARS = [ + "365_day", + "360_day", + "julian", + "all_leap", + "366_day", + "gregorian", + "proleptic_gregorian", + "standard", +] diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/arrays.py b/test/fixtures/whole_applications/xarray/xarray/tests/arrays.py new file mode 100644 index 0000000..983e620 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/arrays.py @@ -0,0 +1,179 @@ +from collections.abc import Iterable +from typing import Any, Callable + +import numpy as np + +from xarray.core import utils +from xarray.core.indexing import ExplicitlyIndexed + +""" +This module contains various lazy array classes which can be wrapped and manipulated by xarray objects but will raise on data access. +""" + + +class UnexpectedDataAccess(Exception): + pass + + +class InaccessibleArray(utils.NDArrayMixin, ExplicitlyIndexed): + """Disallows any loading.""" + + def __init__(self, array): + self.array = array + + def get_duck_array(self): + raise UnexpectedDataAccess("Tried accessing data") + + def __array__(self, dtype: np.typing.DTypeLike = None): + raise UnexpectedDataAccess("Tried accessing data") + + def __getitem__(self, key): + raise UnexpectedDataAccess("Tried accessing data.") + + +class FirstElementAccessibleArray(InaccessibleArray): + def __getitem__(self, key): + tuple_idxr = key.tuple + if len(tuple_idxr) > 1: + raise UnexpectedDataAccess("Tried accessing more than one element.") + return self.array[tuple_idxr] + + +class DuckArrayWrapper(utils.NDArrayMixin): + """Array-like that prevents casting to array. + Modeled after cupy.""" + + def __init__(self, array: np.ndarray): + self.array = array + + def __getitem__(self, key): + return type(self)(self.array[key]) + + def __array__(self, dtype: np.typing.DTypeLike = None): + raise UnexpectedDataAccess("Tried accessing data") + + def __array_namespace__(self): + """Present to satisfy is_duck_array test.""" + + +CONCATENATABLEARRAY_HANDLED_ARRAY_FUNCTIONS: dict[str, Callable] = {} + + +def implements(numpy_function): + """Register an __array_function__ implementation for ConcatenatableArray objects.""" + + def decorator(func): + CONCATENATABLEARRAY_HANDLED_ARRAY_FUNCTIONS[numpy_function] = func + return func + + return decorator + + +@implements(np.concatenate) +def concatenate( + arrays: Iterable["ConcatenatableArray"], /, *, axis=0 +) -> "ConcatenatableArray": + if any(not isinstance(arr, ConcatenatableArray) for arr in arrays): + raise TypeError + + result = np.concatenate([arr._array for arr in arrays], axis=axis) + return ConcatenatableArray(result) + + +@implements(np.stack) +def stack( + arrays: Iterable["ConcatenatableArray"], /, *, axis=0 +) -> "ConcatenatableArray": + if any(not isinstance(arr, ConcatenatableArray) for arr in arrays): + raise TypeError + + result = np.stack([arr._array for arr in arrays], axis=axis) + return ConcatenatableArray(result) + + +@implements(np.result_type) +def result_type(*arrays_and_dtypes) -> np.dtype: + """Called by xarray to ensure all arguments to concat have the same dtype.""" + first_dtype, *other_dtypes = (np.dtype(obj) for obj in arrays_and_dtypes) + for other_dtype in other_dtypes: + if other_dtype != first_dtype: + raise ValueError("dtypes not all consistent") + return first_dtype + + +@implements(np.broadcast_to) +def broadcast_to( + x: "ConcatenatableArray", /, shape: tuple[int, ...] +) -> "ConcatenatableArray": + """ + Broadcasts an array to a specified shape, by either manipulating chunk keys or copying chunk manifest entries. + """ + if not isinstance(x, ConcatenatableArray): + raise TypeError + + result = np.broadcast_to(x._array, shape=shape) + return ConcatenatableArray(result) + + +class ConcatenatableArray: + """Disallows loading or coercing to an index but does support concatenation / stacking.""" + + def __init__(self, array): + # use ._array instead of .array because we don't want this to be accessible even to xarray's internals (e.g. create_default_index_implicit) + self._array = array + + @property + def dtype(self: Any) -> np.dtype: + return self._array.dtype + + @property + def shape(self: Any) -> tuple[int, ...]: + return self._array.shape + + @property + def ndim(self: Any) -> int: + return self._array.ndim + + def __repr__(self: Any) -> str: + return f"{type(self).__name__}(array={self._array!r})" + + def get_duck_array(self): + raise UnexpectedDataAccess("Tried accessing data") + + def __array__(self, dtype: np.typing.DTypeLike = None): + raise UnexpectedDataAccess("Tried accessing data") + + def __getitem__(self, key) -> "ConcatenatableArray": + """Some cases of concat require supporting expanding dims by dimensions of size 1""" + # see https://data-apis.org/array-api/2022.12/API_specification/indexing.html#multi-axis-indexing + arr = self._array + for axis, indexer_1d in enumerate(key): + if indexer_1d is None: + arr = np.expand_dims(arr, axis) + elif indexer_1d is Ellipsis: + pass + else: + raise UnexpectedDataAccess("Tried accessing data.") + return ConcatenatableArray(arr) + + def __array_function__(self, func, types, args, kwargs) -> Any: + if func not in CONCATENATABLEARRAY_HANDLED_ARRAY_FUNCTIONS: + return NotImplemented + + # Note: this allows subclasses that don't override + # __array_function__ to handle ManifestArray objects + if not all(issubclass(t, ConcatenatableArray) for t in types): + return NotImplemented + + return CONCATENATABLEARRAY_HANDLED_ARRAY_FUNCTIONS[func](*args, **kwargs) + + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs) -> Any: + """We have to define this in order to convince xarray that this class is a duckarray, even though we will never support ufuncs.""" + return NotImplemented + + def astype(self, dtype: np.dtype, /, *, copy: bool = True) -> "ConcatenatableArray": + """Needed because xarray will call this even when it's a no-op""" + if dtype != self.dtype: + raise NotImplementedError() + else: + return self diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/conftest.py b/test/fixtures/whole_applications/xarray/xarray/tests/conftest.py new file mode 100644 index 0000000..a32b0e0 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/conftest.py @@ -0,0 +1,200 @@ +from __future__ import annotations + +import numpy as np +import pandas as pd +import pytest + +import xarray as xr +from xarray import DataArray, Dataset +from xarray.core.datatree import DataTree +from xarray.tests import create_test_data, requires_dask + + +@pytest.fixture(params=["numpy", pytest.param("dask", marks=requires_dask)]) +def backend(request): + return request.param + + +@pytest.fixture(params=["numbagg", "bottleneck", None]) +def compute_backend(request): + if request.param is None: + options = dict(use_bottleneck=False, use_numbagg=False) + elif request.param == "bottleneck": + options = dict(use_bottleneck=True, use_numbagg=False) + elif request.param == "numbagg": + options = dict(use_bottleneck=False, use_numbagg=True) + else: + raise ValueError + + with xr.set_options(**options): + yield request.param + + +@pytest.fixture(params=[1]) +def ds(request, backend): + if request.param == 1: + ds = Dataset( + dict( + z1=(["y", "x"], np.random.randn(2, 8)), + z2=(["time", "y"], np.random.randn(10, 2)), + ), + dict( + x=("x", np.linspace(0, 1.0, 8)), + time=("time", np.linspace(0, 1.0, 10)), + c=("y", ["a", "b"]), + y=range(2), + ), + ) + elif request.param == 2: + ds = Dataset( + dict( + z1=(["time", "y"], np.random.randn(10, 2)), + z2=(["time"], np.random.randn(10)), + z3=(["x", "time"], np.random.randn(8, 10)), + ), + dict( + x=("x", np.linspace(0, 1.0, 8)), + time=("time", np.linspace(0, 1.0, 10)), + c=("y", ["a", "b"]), + y=range(2), + ), + ) + elif request.param == 3: + ds = create_test_data() + else: + raise ValueError + + if backend == "dask": + return ds.chunk() + + return ds + + +@pytest.fixture(params=[1]) +def da(request, backend): + if request.param == 1: + times = pd.date_range("2000-01-01", freq="1D", periods=21) + da = DataArray( + np.random.random((3, 21, 4)), + dims=("a", "time", "x"), + coords=dict(time=times), + ) + + if request.param == 2: + da = DataArray([0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time") + + if request.param == "repeating_ints": + da = DataArray( + np.tile(np.arange(12), 5).reshape(5, 4, 3), + coords={"x": list("abc"), "y": list("defg")}, + dims=list("zyx"), + ) + + if backend == "dask": + return da.chunk() + elif backend == "numpy": + return da + else: + raise ValueError + + +@pytest.fixture(params=[Dataset, DataArray]) +def type(request): + return request.param + + +@pytest.fixture(params=[1]) +def d(request, backend, type) -> DataArray | Dataset: + """ + For tests which can test either a DataArray or a Dataset. + """ + result: DataArray | Dataset + if request.param == 1: + ds = Dataset( + dict( + a=(["x", "z"], np.arange(24).reshape(2, 12)), + b=(["y", "z"], np.arange(100, 136).reshape(3, 12).astype(np.float64)), + ), + dict( + x=("x", np.linspace(0, 1.0, 2)), + y=range(3), + z=("z", pd.date_range("2000-01-01", periods=12)), + w=("x", ["a", "b"]), + ), + ) + if type == DataArray: + result = ds["a"].assign_coords(w=ds.coords["w"]) + elif type == Dataset: + result = ds + else: + raise ValueError + else: + raise ValueError + + if backend == "dask": + return result.chunk() + elif backend == "numpy": + return result + else: + raise ValueError + + +@pytest.fixture(scope="module") +def create_test_datatree(): + """ + Create a test datatree with this structure: + + + |-- set1 + | |-- + | | Dimensions: () + | | Data variables: + | | a int64 0 + | | b int64 1 + | |-- set1 + | |-- set2 + |-- set2 + | |-- + | | Dimensions: (x: 2) + | | Data variables: + | | a (x) int64 2, 3 + | | b (x) int64 0.1, 0.2 + | |-- set1 + |-- set3 + |-- + | Dimensions: (x: 2, y: 3) + | Data variables: + | a (y) int64 6, 7, 8 + | set0 (x) int64 9, 10 + + The structure has deliberately repeated names of tags, variables, and + dimensions in order to better check for bugs caused by name conflicts. + """ + + def _create_test_datatree(modify=lambda ds: ds): + set1_data = modify(xr.Dataset({"a": 0, "b": 1})) + set2_data = modify(xr.Dataset({"a": ("x", [2, 3]), "b": ("x", [0.1, 0.2])})) + root_data = modify(xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])})) + + # Avoid using __init__ so we can independently test it + root: DataTree = DataTree(data=root_data) + set1: DataTree = DataTree(name="set1", parent=root, data=set1_data) + DataTree(name="set1", parent=set1) + DataTree(name="set2", parent=set1) + set2: DataTree = DataTree(name="set2", parent=root, data=set2_data) + DataTree(name="set1", parent=set2) + DataTree(name="set3", parent=root) + + return root + + return _create_test_datatree + + +@pytest.fixture(scope="module") +def simple_datatree(create_test_datatree): + """ + Invoke create_test_datatree fixture (callback). + + Returns a DataTree. + """ + return create_test_datatree() diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/data/bears.nc b/test/fixtures/whole_applications/xarray/xarray/tests/data/bears.nc new file mode 100644 index 0000000..2f1a063 Binary files /dev/null and b/test/fixtures/whole_applications/xarray/xarray/tests/data/bears.nc differ diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/data/example.grib b/test/fixtures/whole_applications/xarray/xarray/tests/data/example.grib new file mode 100644 index 0000000..596a54d Binary files /dev/null and b/test/fixtures/whole_applications/xarray/xarray/tests/data/example.grib differ diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/data/example.ict b/test/fixtures/whole_applications/xarray/xarray/tests/data/example.ict new file mode 100644 index 0000000..a33e71a --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/data/example.ict @@ -0,0 +1,33 @@ +29, 1001 +Henderson, Barron +U.S. EPA +Example file with artificial data +JUST_A_TEST +1, 1 +2018, 04, 27 2018, 04, 27 +0 +Start_UTC +5 +1, 1, 1, 1, 1 +-9999, -9999, -9999, -9999, -9999 +lat, degrees_north +lon, degrees_east +elev, meters +TEST_ppbv, ppbv +TESTM_ppbv, ppbv +0 +9 +INDEPENDENT_VARIABLE_DEFINITION: Start_UTC +INDEPENDENT_VARIABLE_UNITS: Start_UTC +ULOD_FLAG: -7777 +ULOD_VALUE: N/A +LLOD_FLAG: -8888 +LLOD_VALUE: N/A, N/A, N/A, N/A, 0.025 +OTHER_COMMENTS: www-air.larc.nasa.gov/missions/etc/IcarttDataFormat.htm +REVISION: R0 +R0: No comments for this revision. +Start_UTC, lat, lon, elev, TEST_ppbv, TESTM_ppbv +43200, 41.00000, -71.00000, 5, 1.2345, 2.220 +46800, 42.00000, -72.00000, 15, 2.3456, -9999 +50400, 42.00000, -73.00000, 20, 3.4567, -7777 +50400, 42.00000, -74.00000, 25, 4.5678, -8888 diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/data/example.uamiv b/test/fixtures/whole_applications/xarray/xarray/tests/data/example.uamiv new file mode 100644 index 0000000..fcedcd5 Binary files /dev/null and b/test/fixtures/whole_applications/xarray/xarray/tests/data/example.uamiv differ diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/data/example_1.nc b/test/fixtures/whole_applications/xarray/xarray/tests/data/example_1.nc new file mode 100644 index 0000000..5775622 Binary files /dev/null and b/test/fixtures/whole_applications/xarray/xarray/tests/data/example_1.nc differ diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/data/example_1.nc.gz b/test/fixtures/whole_applications/xarray/xarray/tests/data/example_1.nc.gz new file mode 100644 index 0000000..f8922ed Binary files /dev/null and b/test/fixtures/whole_applications/xarray/xarray/tests/data/example_1.nc.gz differ diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_accessor_dt.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_accessor_dt.py new file mode 100644 index 0000000..686bce9 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_accessor_dt.py @@ -0,0 +1,699 @@ +from __future__ import annotations + +import numpy as np +import pandas as pd +import pytest + +import xarray as xr +from xarray.tests import ( + assert_allclose, + assert_array_equal, + assert_chunks_equal, + assert_equal, + assert_identical, + raise_if_dask_computes, + requires_cftime, + requires_dask, +) + + +class TestDatetimeAccessor: + @pytest.fixture(autouse=True) + def setup(self): + nt = 100 + data = np.random.rand(10, 10, nt) + lons = np.linspace(0, 11, 10) + lats = np.linspace(0, 20, 10) + self.times = pd.date_range(start="2000/01/01", freq="h", periods=nt) + + self.data = xr.DataArray( + data, + coords=[lons, lats, self.times], + dims=["lon", "lat", "time"], + name="data", + ) + + self.times_arr = np.random.choice(self.times, size=(10, 10, nt)) + self.times_data = xr.DataArray( + self.times_arr, + coords=[lons, lats, self.times], + dims=["lon", "lat", "time"], + name="data", + ) + + @pytest.mark.parametrize( + "field", + [ + "year", + "month", + "day", + "hour", + "minute", + "second", + "microsecond", + "nanosecond", + "week", + "weekofyear", + "dayofweek", + "weekday", + "dayofyear", + "quarter", + "date", + "time", + "daysinmonth", + "days_in_month", + "is_month_start", + "is_month_end", + "is_quarter_start", + "is_quarter_end", + "is_year_start", + "is_year_end", + "is_leap_year", + ], + ) + def test_field_access(self, field) -> None: + if field in ["week", "weekofyear"]: + data = self.times.isocalendar()["week"] + else: + data = getattr(self.times, field) + + if data.dtype.kind != "b" and field not in ("date", "time"): + # pandas 2.0 returns int32 for integer fields now + data = data.astype("int64") + + translations = { + "weekday": "dayofweek", + "daysinmonth": "days_in_month", + "weekofyear": "week", + } + name = translations.get(field, field) + + expected = xr.DataArray(data, name=name, coords=[self.times], dims=["time"]) + + if field in ["week", "weekofyear"]: + with pytest.warns( + FutureWarning, match="dt.weekofyear and dt.week have been deprecated" + ): + actual = getattr(self.data.time.dt, field) + else: + actual = getattr(self.data.time.dt, field) + + assert expected.dtype == actual.dtype + assert_identical(expected, actual) + + def test_total_seconds(self) -> None: + # Subtract a value in the middle of the range to ensure that some values + # are negative + delta = self.data.time - np.datetime64("2000-01-03") + actual = delta.dt.total_seconds() + expected = xr.DataArray( + np.arange(-48, 52, dtype=np.float64) * 3600, + name="total_seconds", + coords=[self.data.time], + ) + # This works with assert_identical when pandas is >=1.5.0. + assert_allclose(expected, actual) + + @pytest.mark.parametrize( + "field, pandas_field", + [ + ("year", "year"), + ("week", "week"), + ("weekday", "day"), + ], + ) + def test_isocalendar(self, field, pandas_field) -> None: + # pandas isocalendar has dtypy UInt32Dtype, convert to Int64 + expected = pd.Index(getattr(self.times.isocalendar(), pandas_field).astype(int)) + expected = xr.DataArray( + expected, name=field, coords=[self.times], dims=["time"] + ) + + actual = self.data.time.dt.isocalendar()[field] + assert_equal(expected, actual) + + def test_calendar(self) -> None: + cal = self.data.time.dt.calendar + assert cal == "proleptic_gregorian" + + def test_strftime(self) -> None: + assert ( + "2000-01-01 01:00:00" == self.data.time.dt.strftime("%Y-%m-%d %H:%M:%S")[1] + ) + + def test_not_datetime_type(self) -> None: + nontime_data = self.data.copy() + int_data = np.arange(len(self.data.time)).astype("int8") + nontime_data = nontime_data.assign_coords(time=int_data) + with pytest.raises(AttributeError, match=r"dt"): + nontime_data.time.dt + + @pytest.mark.filterwarnings("ignore:dt.weekofyear and dt.week have been deprecated") + @requires_dask + @pytest.mark.parametrize( + "field", + [ + "year", + "month", + "day", + "hour", + "minute", + "second", + "microsecond", + "nanosecond", + "week", + "weekofyear", + "dayofweek", + "weekday", + "dayofyear", + "quarter", + "date", + "time", + "is_month_start", + "is_month_end", + "is_quarter_start", + "is_quarter_end", + "is_year_start", + "is_year_end", + "is_leap_year", + ], + ) + def test_dask_field_access(self, field) -> None: + import dask.array as da + + expected = getattr(self.times_data.dt, field) + + dask_times_arr = da.from_array(self.times_arr, chunks=(5, 5, 50)) + dask_times_2d = xr.DataArray( + dask_times_arr, coords=self.data.coords, dims=self.data.dims, name="data" + ) + + with raise_if_dask_computes(): + actual = getattr(dask_times_2d.dt, field) + + assert isinstance(actual.data, da.Array) + assert_chunks_equal(actual, dask_times_2d) + assert_equal(actual.compute(), expected.compute()) + + @requires_dask + @pytest.mark.parametrize( + "field", + [ + "year", + "week", + "weekday", + ], + ) + def test_isocalendar_dask(self, field) -> None: + import dask.array as da + + expected = getattr(self.times_data.dt.isocalendar(), field) + + dask_times_arr = da.from_array(self.times_arr, chunks=(5, 5, 50)) + dask_times_2d = xr.DataArray( + dask_times_arr, coords=self.data.coords, dims=self.data.dims, name="data" + ) + + with raise_if_dask_computes(): + actual = dask_times_2d.dt.isocalendar()[field] + + assert isinstance(actual.data, da.Array) + assert_chunks_equal(actual, dask_times_2d) + assert_equal(actual.compute(), expected.compute()) + + @requires_dask + @pytest.mark.parametrize( + "method, parameters", + [ + ("floor", "D"), + ("ceil", "D"), + ("round", "D"), + ("strftime", "%Y-%m-%d %H:%M:%S"), + ], + ) + def test_dask_accessor_method(self, method, parameters) -> None: + import dask.array as da + + expected = getattr(self.times_data.dt, method)(parameters) + dask_times_arr = da.from_array(self.times_arr, chunks=(5, 5, 50)) + dask_times_2d = xr.DataArray( + dask_times_arr, coords=self.data.coords, dims=self.data.dims, name="data" + ) + + with raise_if_dask_computes(): + actual = getattr(dask_times_2d.dt, method)(parameters) + + assert isinstance(actual.data, da.Array) + assert_chunks_equal(actual, dask_times_2d) + assert_equal(actual.compute(), expected.compute()) + + def test_seasons(self) -> None: + dates = xr.date_range( + start="2000/01/01", freq="ME", periods=12, use_cftime=False + ) + dates = dates.append(pd.Index([np.datetime64("NaT")])) + dates = xr.DataArray(dates) + seasons = xr.DataArray( + [ + "DJF", + "DJF", + "MAM", + "MAM", + "MAM", + "JJA", + "JJA", + "JJA", + "SON", + "SON", + "SON", + "DJF", + "nan", + ] + ) + + assert_array_equal(seasons.values, dates.dt.season.values) + + @pytest.mark.parametrize( + "method, parameters", [("floor", "D"), ("ceil", "D"), ("round", "D")] + ) + def test_accessor_method(self, method, parameters) -> None: + dates = pd.date_range("2014-01-01", "2014-05-01", freq="h") + xdates = xr.DataArray(dates, dims=["time"]) + expected = getattr(dates, method)(parameters) + actual = getattr(xdates.dt, method)(parameters) + assert_array_equal(expected, actual) + + +class TestTimedeltaAccessor: + @pytest.fixture(autouse=True) + def setup(self): + nt = 100 + data = np.random.rand(10, 10, nt) + lons = np.linspace(0, 11, 10) + lats = np.linspace(0, 20, 10) + self.times = pd.timedelta_range(start="1 day", freq="6h", periods=nt) + + self.data = xr.DataArray( + data, + coords=[lons, lats, self.times], + dims=["lon", "lat", "time"], + name="data", + ) + + self.times_arr = np.random.choice(self.times, size=(10, 10, nt)) + self.times_data = xr.DataArray( + self.times_arr, + coords=[lons, lats, self.times], + dims=["lon", "lat", "time"], + name="data", + ) + + def test_not_datetime_type(self) -> None: + nontime_data = self.data.copy() + int_data = np.arange(len(self.data.time)).astype("int8") + nontime_data = nontime_data.assign_coords(time=int_data) + with pytest.raises(AttributeError, match=r"dt"): + nontime_data.time.dt + + @pytest.mark.parametrize( + "field", ["days", "seconds", "microseconds", "nanoseconds"] + ) + def test_field_access(self, field) -> None: + expected = xr.DataArray( + getattr(self.times, field), name=field, coords=[self.times], dims=["time"] + ) + actual = getattr(self.data.time.dt, field) + assert_equal(expected, actual) + + @pytest.mark.parametrize( + "method, parameters", [("floor", "D"), ("ceil", "D"), ("round", "D")] + ) + def test_accessor_methods(self, method, parameters) -> None: + dates = pd.timedelta_range(start="1 day", end="30 days", freq="6h") + xdates = xr.DataArray(dates, dims=["time"]) + expected = getattr(dates, method)(parameters) + actual = getattr(xdates.dt, method)(parameters) + assert_array_equal(expected, actual) + + @requires_dask + @pytest.mark.parametrize( + "field", ["days", "seconds", "microseconds", "nanoseconds"] + ) + def test_dask_field_access(self, field) -> None: + import dask.array as da + + expected = getattr(self.times_data.dt, field) + + dask_times_arr = da.from_array(self.times_arr, chunks=(5, 5, 50)) + dask_times_2d = xr.DataArray( + dask_times_arr, coords=self.data.coords, dims=self.data.dims, name="data" + ) + + with raise_if_dask_computes(): + actual = getattr(dask_times_2d.dt, field) + + assert isinstance(actual.data, da.Array) + assert_chunks_equal(actual, dask_times_2d) + assert_equal(actual, expected) + + @requires_dask + @pytest.mark.parametrize( + "method, parameters", [("floor", "D"), ("ceil", "D"), ("round", "D")] + ) + def test_dask_accessor_method(self, method, parameters) -> None: + import dask.array as da + + expected = getattr(self.times_data.dt, method)(parameters) + dask_times_arr = da.from_array(self.times_arr, chunks=(5, 5, 50)) + dask_times_2d = xr.DataArray( + dask_times_arr, coords=self.data.coords, dims=self.data.dims, name="data" + ) + + with raise_if_dask_computes(): + actual = getattr(dask_times_2d.dt, method)(parameters) + + assert isinstance(actual.data, da.Array) + assert_chunks_equal(actual, dask_times_2d) + assert_equal(actual.compute(), expected.compute()) + + +_CFTIME_CALENDARS = [ + "365_day", + "360_day", + "julian", + "all_leap", + "366_day", + "gregorian", + "proleptic_gregorian", +] +_NT = 100 + + +@pytest.fixture(params=_CFTIME_CALENDARS) +def calendar(request): + return request.param + + +@pytest.fixture() +def times(calendar): + import cftime + + return cftime.num2date( + np.arange(_NT), + units="hours since 2000-01-01", + calendar=calendar, + only_use_cftime_datetimes=True, + ) + + +@pytest.fixture() +def data(times): + data = np.random.rand(10, 10, _NT) + lons = np.linspace(0, 11, 10) + lats = np.linspace(0, 20, 10) + return xr.DataArray( + data, coords=[lons, lats, times], dims=["lon", "lat", "time"], name="data" + ) + + +@pytest.fixture() +def times_3d(times): + lons = np.linspace(0, 11, 10) + lats = np.linspace(0, 20, 10) + times_arr = np.random.choice(times, size=(10, 10, _NT)) + return xr.DataArray( + times_arr, coords=[lons, lats, times], dims=["lon", "lat", "time"], name="data" + ) + + +@requires_cftime +@pytest.mark.parametrize( + "field", ["year", "month", "day", "hour", "dayofyear", "dayofweek"] +) +def test_field_access(data, field) -> None: + result = getattr(data.time.dt, field) + expected = xr.DataArray( + getattr(xr.coding.cftimeindex.CFTimeIndex(data.time.values), field), + name=field, + coords=data.time.coords, + dims=data.time.dims, + ) + + assert_equal(result, expected) + + +@requires_cftime +def test_calendar_cftime(data) -> None: + expected = data.time.values[0].calendar + assert data.time.dt.calendar == expected + + +def test_calendar_datetime64_2d() -> None: + data = xr.DataArray(np.zeros((4, 5), dtype="datetime64[ns]"), dims=("x", "y")) + assert data.dt.calendar == "proleptic_gregorian" + + +@requires_dask +def test_calendar_datetime64_3d_dask() -> None: + import dask.array as da + + data = xr.DataArray( + da.zeros((4, 5, 6), dtype="datetime64[ns]"), dims=("x", "y", "z") + ) + with raise_if_dask_computes(): + assert data.dt.calendar == "proleptic_gregorian" + + +@requires_dask +@requires_cftime +def test_calendar_dask_cftime() -> None: + from cftime import num2date + + # 3D lazy dask + data = xr.DataArray( + num2date( + np.random.randint(1, 1000000, size=(4, 5, 6)), + "hours since 1970-01-01T00:00", + calendar="noleap", + ), + dims=("x", "y", "z"), + ).chunk() + with raise_if_dask_computes(max_computes=2): + assert data.dt.calendar == "noleap" + + +@requires_cftime +def test_isocalendar_cftime(data) -> None: + with pytest.raises( + AttributeError, match=r"'CFTimeIndex' object has no attribute 'isocalendar'" + ): + data.time.dt.isocalendar() + + +@requires_cftime +def test_date_cftime(data) -> None: + with pytest.raises( + AttributeError, + match=r"'CFTimeIndex' object has no attribute `date`. Consider using the floor method instead, for instance: `.time.dt.floor\('D'\)`.", + ): + data.time.dt.date() + + +@requires_cftime +@pytest.mark.filterwarnings("ignore::RuntimeWarning") +def test_cftime_strftime_access(data) -> None: + """compare cftime formatting against datetime formatting""" + date_format = "%Y%m%d%H" + result = data.time.dt.strftime(date_format) + datetime_array = xr.DataArray( + xr.coding.cftimeindex.CFTimeIndex(data.time.values).to_datetimeindex(), + name="stftime", + coords=data.time.coords, + dims=data.time.dims, + ) + expected = datetime_array.dt.strftime(date_format) + assert_equal(result, expected) + + +@requires_cftime +@requires_dask +@pytest.mark.parametrize( + "field", ["year", "month", "day", "hour", "dayofyear", "dayofweek"] +) +def test_dask_field_access_1d(data, field) -> None: + import dask.array as da + + expected = xr.DataArray( + getattr(xr.coding.cftimeindex.CFTimeIndex(data.time.values), field), + name=field, + dims=["time"], + ) + times = xr.DataArray(data.time.values, dims=["time"]).chunk({"time": 50}) + result = getattr(times.dt, field) + assert isinstance(result.data, da.Array) + assert result.chunks == times.chunks + assert_equal(result.compute(), expected) + + +@requires_cftime +@requires_dask +@pytest.mark.parametrize( + "field", ["year", "month", "day", "hour", "dayofyear", "dayofweek"] +) +def test_dask_field_access(times_3d, data, field) -> None: + import dask.array as da + + expected = xr.DataArray( + getattr( + xr.coding.cftimeindex.CFTimeIndex(times_3d.values.ravel()), field + ).reshape(times_3d.shape), + name=field, + coords=times_3d.coords, + dims=times_3d.dims, + ) + times_3d = times_3d.chunk({"lon": 5, "lat": 5, "time": 50}) + result = getattr(times_3d.dt, field) + assert isinstance(result.data, da.Array) + assert result.chunks == times_3d.chunks + assert_equal(result.compute(), expected) + + +@pytest.fixture() +def cftime_date_type(calendar): + from xarray.tests.test_coding_times import _all_cftime_date_types + + return _all_cftime_date_types()[calendar] + + +@requires_cftime +def test_seasons(cftime_date_type) -> None: + dates = xr.DataArray( + np.array([cftime_date_type(2000, month, 15) for month in range(1, 13)]) + ) + seasons = xr.DataArray( + [ + "DJF", + "DJF", + "MAM", + "MAM", + "MAM", + "JJA", + "JJA", + "JJA", + "SON", + "SON", + "SON", + "DJF", + ] + ) + + assert_array_equal(seasons.values, dates.dt.season.values) + + +@pytest.fixture +def cftime_rounding_dataarray(cftime_date_type): + return xr.DataArray( + [ + [cftime_date_type(1, 1, 1, 1), cftime_date_type(1, 1, 1, 15)], + [cftime_date_type(1, 1, 1, 23), cftime_date_type(1, 1, 2, 1)], + ] + ) + + +@requires_cftime +@requires_dask +@pytest.mark.parametrize("use_dask", [False, True]) +def test_cftime_floor_accessor( + cftime_rounding_dataarray, cftime_date_type, use_dask +) -> None: + import dask.array as da + + freq = "D" + expected = xr.DataArray( + [ + [cftime_date_type(1, 1, 1, 0), cftime_date_type(1, 1, 1, 0)], + [cftime_date_type(1, 1, 1, 0), cftime_date_type(1, 1, 2, 0)], + ], + name="floor", + ) + + if use_dask: + chunks = {"dim_0": 1} + # Currently a compute is done to inspect a single value of the array + # if it is of object dtype to check if it is a cftime.datetime (if not + # we raise an error when using the dt accessor). + with raise_if_dask_computes(max_computes=1): + result = cftime_rounding_dataarray.chunk(chunks).dt.floor(freq) + expected = expected.chunk(chunks) + assert isinstance(result.data, da.Array) + assert result.chunks == expected.chunks + else: + result = cftime_rounding_dataarray.dt.floor(freq) + + assert_identical(result, expected) + + +@requires_cftime +@requires_dask +@pytest.mark.parametrize("use_dask", [False, True]) +def test_cftime_ceil_accessor( + cftime_rounding_dataarray, cftime_date_type, use_dask +) -> None: + import dask.array as da + + freq = "D" + expected = xr.DataArray( + [ + [cftime_date_type(1, 1, 2, 0), cftime_date_type(1, 1, 2, 0)], + [cftime_date_type(1, 1, 2, 0), cftime_date_type(1, 1, 3, 0)], + ], + name="ceil", + ) + + if use_dask: + chunks = {"dim_0": 1} + # Currently a compute is done to inspect a single value of the array + # if it is of object dtype to check if it is a cftime.datetime (if not + # we raise an error when using the dt accessor). + with raise_if_dask_computes(max_computes=1): + result = cftime_rounding_dataarray.chunk(chunks).dt.ceil(freq) + expected = expected.chunk(chunks) + assert isinstance(result.data, da.Array) + assert result.chunks == expected.chunks + else: + result = cftime_rounding_dataarray.dt.ceil(freq) + + assert_identical(result, expected) + + +@requires_cftime +@requires_dask +@pytest.mark.parametrize("use_dask", [False, True]) +def test_cftime_round_accessor( + cftime_rounding_dataarray, cftime_date_type, use_dask +) -> None: + import dask.array as da + + freq = "D" + expected = xr.DataArray( + [ + [cftime_date_type(1, 1, 1, 0), cftime_date_type(1, 1, 2, 0)], + [cftime_date_type(1, 1, 2, 0), cftime_date_type(1, 1, 2, 0)], + ], + name="round", + ) + + if use_dask: + chunks = {"dim_0": 1} + # Currently a compute is done to inspect a single value of the array + # if it is of object dtype to check if it is a cftime.datetime (if not + # we raise an error when using the dt accessor). + with raise_if_dask_computes(max_computes=1): + result = cftime_rounding_dataarray.chunk(chunks).dt.round(freq) + expected = expected.chunk(chunks) + assert isinstance(result.data, da.Array) + assert result.chunks == expected.chunks + else: + result = cftime_rounding_dataarray.dt.round(freq) + + assert_identical(result, expected) diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_accessor_str.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_accessor_str.py new file mode 100644 index 0000000..e0c9619 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_accessor_str.py @@ -0,0 +1,3705 @@ +# Tests for the `str` accessor are derived from the original +# pandas string accessor tests. + +# For reference, here is a copy of the pandas copyright notice: + +# (c) 2011-2012, Lambda Foundry, Inc. and PyData Development Team +# All rights reserved. + +# Copyright (c) 2008-2011 AQR Capital Management, LLC +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: + +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. + +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided +# with the distribution. + +# * Neither the name of the copyright holder nor the names of any +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDER AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from __future__ import annotations + +import re +from typing import Callable + +import numpy as np +import pytest + +import xarray as xr +from xarray.tests import assert_equal, assert_identical, requires_dask + + +@pytest.fixture( + params=[pytest.param(np.str_, id="str"), pytest.param(np.bytes_, id="bytes")] +) +def dtype(request): + return request.param + + +@requires_dask +def test_dask() -> None: + import dask.array as da + + arr = da.from_array(["a", "b", "c"], chunks=-1) + xarr = xr.DataArray(arr) + + result = xarr.str.len().compute() + expected = xr.DataArray([1, 1, 1]) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_count(dtype) -> None: + values = xr.DataArray(["foo", "foofoo", "foooofooofommmfoo"]).astype(dtype) + pat_str = dtype(r"f[o]+") + pat_re = re.compile(pat_str) + + result_str = values.str.count(pat_str) + result_re = values.str.count(pat_re) + + expected = xr.DataArray([1, 2, 4]) + + assert result_str.dtype == expected.dtype + assert result_re.dtype == expected.dtype + assert_equal(result_str, expected) + assert_equal(result_re, expected) + + +def test_count_broadcast(dtype) -> None: + values = xr.DataArray(["foo", "foofoo", "foooofooofommmfoo"]).astype(dtype) + pat_str = np.array([r"f[o]+", r"o", r"m"]).astype(dtype) + pat_re = np.array([re.compile(x) for x in pat_str]) + + result_str = values.str.count(pat_str) + result_re = values.str.count(pat_re) + + expected = xr.DataArray([1, 4, 3]) + + assert result_str.dtype == expected.dtype + assert result_re.dtype == expected.dtype + assert_equal(result_str, expected) + assert_equal(result_re, expected) + + +def test_contains(dtype) -> None: + values = xr.DataArray(["Foo", "xYz", "fOOomMm__fOo", "MMM_"]).astype(dtype) + + # case insensitive using regex + pat = values.dtype.type("FOO|mmm") + result = values.str.contains(pat, case=False) + expected = xr.DataArray([True, False, True, True]) + assert result.dtype == expected.dtype + assert_equal(result, expected) + result = values.str.contains(re.compile(pat, flags=re.IGNORECASE)) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # case sensitive using regex + pat = values.dtype.type("Foo|mMm") + result = values.str.contains(pat) + expected = xr.DataArray([True, False, True, False]) + assert result.dtype == expected.dtype + assert_equal(result, expected) + result = values.str.contains(re.compile(pat)) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # case insensitive without regex + result = values.str.contains("foo", regex=False, case=False) + expected = xr.DataArray([True, False, True, False]) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # case sensitive without regex + result = values.str.contains("fO", regex=False, case=True) + expected = xr.DataArray([False, False, True, False]) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # regex regex=False + pat_re = re.compile("(/w+)") + with pytest.raises( + ValueError, + match="Must use regular expression matching for regular expression object.", + ): + values.str.contains(pat_re, regex=False) + + +def test_contains_broadcast(dtype) -> None: + values = xr.DataArray(["Foo", "xYz", "fOOomMm__fOo", "MMM_"], dims="X").astype( + dtype + ) + pat_str = xr.DataArray(["FOO|mmm", "Foo", "MMM"], dims="Y").astype(dtype) + pat_re = xr.DataArray([re.compile(x) for x in pat_str.data], dims="Y") + + # case insensitive using regex + result = values.str.contains(pat_str, case=False) + expected = xr.DataArray( + [ + [True, True, False], + [False, False, False], + [True, True, True], + [True, False, True], + ], + dims=["X", "Y"], + ) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # case sensitive using regex + result = values.str.contains(pat_str) + expected = xr.DataArray( + [ + [False, True, False], + [False, False, False], + [False, False, False], + [False, False, True], + ], + dims=["X", "Y"], + ) + assert result.dtype == expected.dtype + assert_equal(result, expected) + result = values.str.contains(pat_re) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # case insensitive without regex + result = values.str.contains(pat_str, regex=False, case=False) + expected = xr.DataArray( + [ + [False, True, False], + [False, False, False], + [False, True, True], + [False, False, True], + ], + dims=["X", "Y"], + ) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # case insensitive with regex + result = values.str.contains(pat_str, regex=False, case=True) + expected = xr.DataArray( + [ + [False, True, False], + [False, False, False], + [False, False, False], + [False, False, True], + ], + dims=["X", "Y"], + ) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_starts_ends_with(dtype) -> None: + values = xr.DataArray(["om", "foo_nom", "nom", "bar_foo", "foo"]).astype(dtype) + + result = values.str.startswith("foo") + expected = xr.DataArray([False, True, False, False, True]) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.endswith("foo") + expected = xr.DataArray([False, False, False, True, True]) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_starts_ends_with_broadcast(dtype) -> None: + values = xr.DataArray( + ["om", "foo_nom", "nom", "bar_foo", "foo_bar"], dims="X" + ).astype(dtype) + pat = xr.DataArray(["foo", "bar"], dims="Y").astype(dtype) + + result = values.str.startswith(pat) + expected = xr.DataArray( + [[False, False], [True, False], [False, False], [False, True], [True, False]], + dims=["X", "Y"], + ) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.endswith(pat) + expected = xr.DataArray( + [[False, False], [False, False], [False, False], [True, False], [False, True]], + dims=["X", "Y"], + ) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_case_bytes() -> None: + value = xr.DataArray(["SOme wOrd"]).astype(np.bytes_) + + exp_capitalized = xr.DataArray(["Some word"]).astype(np.bytes_) + exp_lowered = xr.DataArray(["some word"]).astype(np.bytes_) + exp_swapped = xr.DataArray(["soME WoRD"]).astype(np.bytes_) + exp_titled = xr.DataArray(["Some Word"]).astype(np.bytes_) + exp_uppered = xr.DataArray(["SOME WORD"]).astype(np.bytes_) + + res_capitalized = value.str.capitalize() + res_lowered = value.str.lower() + res_swapped = value.str.swapcase() + res_titled = value.str.title() + res_uppered = value.str.upper() + + assert res_capitalized.dtype == exp_capitalized.dtype + assert res_lowered.dtype == exp_lowered.dtype + assert res_swapped.dtype == exp_swapped.dtype + assert res_titled.dtype == exp_titled.dtype + assert res_uppered.dtype == exp_uppered.dtype + + assert_equal(res_capitalized, exp_capitalized) + assert_equal(res_lowered, exp_lowered) + assert_equal(res_swapped, exp_swapped) + assert_equal(res_titled, exp_titled) + assert_equal(res_uppered, exp_uppered) + + +def test_case_str() -> None: + # This string includes some unicode characters + # that are common case management corner cases + value = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi⁵Å Ç Ⅰ"]).astype(np.str_) + + exp_capitalized = xr.DataArray(["Some word dž ß ᾓ σς ffi⁵å ç ⅰ"]).astype(np.str_) + exp_lowered = xr.DataArray(["some word dž ß ᾓ σς ffi⁵å ç ⅰ"]).astype(np.str_) + exp_swapped = xr.DataArray(["soME WoRD dž SS ᾛ σς FFI⁵å ç ⅰ"]).astype(np.str_) + exp_titled = xr.DataArray(["Some Word Dž Ss ᾛ Σς Ffi⁵Å Ç Ⅰ"]).astype(np.str_) + exp_uppered = xr.DataArray(["SOME WORD DŽ SS ἫΙ ΣΣ FFI⁵Å Ç Ⅰ"]).astype(np.str_) + exp_casefolded = xr.DataArray(["some word dž ss ἣι σσ ffi⁵å ç ⅰ"]).astype(np.str_) + + exp_norm_nfc = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi⁵Å Ç Ⅰ"]).astype(np.str_) + exp_norm_nfkc = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi5Å Ç I"]).astype(np.str_) + exp_norm_nfd = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi⁵Å Ç Ⅰ"]).astype(np.str_) + exp_norm_nfkd = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi5Å Ç I"]).astype(np.str_) + + res_capitalized = value.str.capitalize() + res_casefolded = value.str.casefold() + res_lowered = value.str.lower() + res_swapped = value.str.swapcase() + res_titled = value.str.title() + res_uppered = value.str.upper() + + res_norm_nfc = value.str.normalize("NFC") + res_norm_nfd = value.str.normalize("NFD") + res_norm_nfkc = value.str.normalize("NFKC") + res_norm_nfkd = value.str.normalize("NFKD") + + assert res_capitalized.dtype == exp_capitalized.dtype + assert res_casefolded.dtype == exp_casefolded.dtype + assert res_lowered.dtype == exp_lowered.dtype + assert res_swapped.dtype == exp_swapped.dtype + assert res_titled.dtype == exp_titled.dtype + assert res_uppered.dtype == exp_uppered.dtype + + assert res_norm_nfc.dtype == exp_norm_nfc.dtype + assert res_norm_nfd.dtype == exp_norm_nfd.dtype + assert res_norm_nfkc.dtype == exp_norm_nfkc.dtype + assert res_norm_nfkd.dtype == exp_norm_nfkd.dtype + + assert_equal(res_capitalized, exp_capitalized) + assert_equal(res_casefolded, exp_casefolded) + assert_equal(res_lowered, exp_lowered) + assert_equal(res_swapped, exp_swapped) + assert_equal(res_titled, exp_titled) + assert_equal(res_uppered, exp_uppered) + + assert_equal(res_norm_nfc, exp_norm_nfc) + assert_equal(res_norm_nfd, exp_norm_nfd) + assert_equal(res_norm_nfkc, exp_norm_nfkc) + assert_equal(res_norm_nfkd, exp_norm_nfkd) + + +def test_replace(dtype) -> None: + values = xr.DataArray(["fooBAD__barBAD"], dims=["x"]).astype(dtype) + result = values.str.replace("BAD[_]*", "") + expected = xr.DataArray(["foobar"], dims=["x"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.replace("BAD[_]*", "", n=1) + expected = xr.DataArray(["foobarBAD"], dims=["x"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + pat = xr.DataArray(["BAD[_]*", "AD[_]*"], dims=["y"]).astype(dtype) + result = values.str.replace(pat, "") + expected = xr.DataArray([["foobar", "fooBbarB"]], dims=["x", "y"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + repl = xr.DataArray(["", "spam"], dims=["y"]).astype(dtype) + result = values.str.replace(pat, repl, n=1) + expected = xr.DataArray([["foobarBAD", "fooBspambarBAD"]], dims=["x", "y"]).astype( + dtype + ) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + values = xr.DataArray( + ["A", "B", "C", "Aaba", "Baca", "", "CABA", "dog", "cat"] + ).astype(dtype) + expected = xr.DataArray( + ["YYY", "B", "C", "YYYaba", "Baca", "", "CYYYBYYY", "dog", "cat"] + ).astype(dtype) + result = values.str.replace("A", "YYY") + assert result.dtype == expected.dtype + assert_equal(result, expected) + result = values.str.replace("A", "YYY", regex=False) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.replace("A", "YYY", case=False) + expected = xr.DataArray( + ["YYY", "B", "C", "YYYYYYbYYY", "BYYYcYYY", "", "CYYYBYYY", "dog", "cYYYt"] + ).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.replace("^.a|dog", "XX-XX ", case=False) + expected = xr.DataArray( + ["A", "B", "C", "XX-XX ba", "XX-XX ca", "", "XX-XX BA", "XX-XX ", "XX-XX t"] + ).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_replace_callable() -> None: + values = xr.DataArray(["fooBAD__barBAD"]) + + # test with callable + repl = lambda m: m.group(0).swapcase() + result = values.str.replace("[a-z][A-Z]{2}", repl, n=2) + exp = xr.DataArray(["foObaD__baRbaD"]) + assert result.dtype == exp.dtype + assert_equal(result, exp) + + # test regex named groups + values = xr.DataArray(["Foo Bar Baz"]) + pat = r"(?P\w+) (?P\w+) (?P\w+)" + repl = lambda m: m.group("middle").swapcase() + result = values.str.replace(pat, repl) + exp = xr.DataArray(["bAR"]) + assert result.dtype == exp.dtype + assert_equal(result, exp) + + # test broadcast + values = xr.DataArray(["Foo Bar Baz"], dims=["x"]) + pat = r"(?P\w+) (?P\w+) (?P\w+)" + repl2 = xr.DataArray( + [ + lambda m: m.group("first").swapcase(), + lambda m: m.group("middle").swapcase(), + lambda m: m.group("last").swapcase(), + ], + dims=["Y"], + ) + result = values.str.replace(pat, repl2) + exp = xr.DataArray([["fOO", "bAR", "bAZ"]], dims=["x", "Y"]) + assert result.dtype == exp.dtype + assert_equal(result, exp) + + +def test_replace_unicode() -> None: + # flags + unicode + values = xr.DataArray([b"abcd,\xc3\xa0".decode("utf-8")]) + expected = xr.DataArray([b"abcd, \xc3\xa0".decode("utf-8")]) + pat = re.compile(r"(?<=\w),(?=\w)", flags=re.UNICODE) + result = values.str.replace(pat, ", ") + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # broadcast version + values = xr.DataArray([b"abcd,\xc3\xa0".decode("utf-8")], dims=["X"]) + expected = xr.DataArray( + [[b"abcd, \xc3\xa0".decode("utf-8"), b"BAcd,\xc3\xa0".decode("utf-8")]], + dims=["X", "Y"], + ) + pat2 = xr.DataArray( + [re.compile(r"(?<=\w),(?=\w)", flags=re.UNICODE), r"ab"], dims=["Y"] + ) + repl = xr.DataArray([", ", "BA"], dims=["Y"]) + result = values.str.replace(pat2, repl) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_replace_compiled_regex(dtype) -> None: + values = xr.DataArray(["fooBAD__barBAD"], dims=["x"]).astype(dtype) + + # test with compiled regex + pat = re.compile(dtype("BAD[_]*")) + result = values.str.replace(pat, "") + expected = xr.DataArray(["foobar"], dims=["x"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.replace(pat, "", n=1) + expected = xr.DataArray(["foobarBAD"], dims=["x"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # broadcast + pat2 = xr.DataArray( + [re.compile(dtype("BAD[_]*")), re.compile(dtype("AD[_]*"))], dims=["y"] + ) + result = values.str.replace(pat2, "") + expected = xr.DataArray([["foobar", "fooBbarB"]], dims=["x", "y"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + repl = xr.DataArray(["", "spam"], dims=["y"]).astype(dtype) + result = values.str.replace(pat2, repl, n=1) + expected = xr.DataArray([["foobarBAD", "fooBspambarBAD"]], dims=["x", "y"]).astype( + dtype + ) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # case and flags provided to str.replace will have no effect + # and will produce warnings + values = xr.DataArray(["fooBAD__barBAD__bad"]).astype(dtype) + pat3 = re.compile(dtype("BAD[_]*")) + + with pytest.raises( + ValueError, match="Flags cannot be set when pat is a compiled regex." + ): + result = values.str.replace(pat3, "", flags=re.IGNORECASE) + + with pytest.raises( + ValueError, match="Case cannot be set when pat is a compiled regex." + ): + result = values.str.replace(pat3, "", case=False) + + with pytest.raises( + ValueError, match="Case cannot be set when pat is a compiled regex." + ): + result = values.str.replace(pat3, "", case=True) + + # test with callable + values = xr.DataArray(["fooBAD__barBAD"]).astype(dtype) + repl2 = lambda m: m.group(0).swapcase() + pat4 = re.compile(dtype("[a-z][A-Z]{2}")) + result = values.str.replace(pat4, repl2, n=2) + expected = xr.DataArray(["foObaD__baRbaD"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_replace_literal(dtype) -> None: + # GH16808 literal replace (regex=False vs regex=True) + values = xr.DataArray(["f.o", "foo"], dims=["X"]).astype(dtype) + expected = xr.DataArray(["bao", "bao"], dims=["X"]).astype(dtype) + result = values.str.replace("f.", "ba") + assert result.dtype == expected.dtype + assert_equal(result, expected) + + expected = xr.DataArray(["bao", "foo"], dims=["X"]).astype(dtype) + result = values.str.replace("f.", "ba", regex=False) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # Broadcast + pat = xr.DataArray(["f.", ".o"], dims=["yy"]).astype(dtype) + expected = xr.DataArray([["bao", "fba"], ["bao", "bao"]], dims=["X", "yy"]).astype( + dtype + ) + result = values.str.replace(pat, "ba") + assert result.dtype == expected.dtype + assert_equal(result, expected) + + expected = xr.DataArray([["bao", "fba"], ["foo", "foo"]], dims=["X", "yy"]).astype( + dtype + ) + result = values.str.replace(pat, "ba", regex=False) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # Cannot do a literal replace if given a callable repl or compiled + # pattern + callable_repl = lambda m: m.group(0).swapcase() + compiled_pat = re.compile("[a-z][A-Z]{2}") + + msg = "Cannot use a callable replacement when regex=False" + with pytest.raises(ValueError, match=msg): + values.str.replace("abc", callable_repl, regex=False) + + msg = "Cannot use a compiled regex as replacement pattern with regex=False" + with pytest.raises(ValueError, match=msg): + values.str.replace(compiled_pat, "", regex=False) + + +def test_extract_extractall_findall_empty_raises(dtype) -> None: + pat_str = dtype(r".*") + pat_re = re.compile(pat_str) + + value = xr.DataArray([["a"]], dims=["X", "Y"]).astype(dtype) + + with pytest.raises(ValueError, match="No capture groups found in pattern."): + value.str.extract(pat=pat_str, dim="ZZ") + + with pytest.raises(ValueError, match="No capture groups found in pattern."): + value.str.extract(pat=pat_re, dim="ZZ") + + with pytest.raises(ValueError, match="No capture groups found in pattern."): + value.str.extractall(pat=pat_str, group_dim="XX", match_dim="YY") + + with pytest.raises(ValueError, match="No capture groups found in pattern."): + value.str.extractall(pat=pat_re, group_dim="XX", match_dim="YY") + + with pytest.raises(ValueError, match="No capture groups found in pattern."): + value.str.findall(pat=pat_str) + + with pytest.raises(ValueError, match="No capture groups found in pattern."): + value.str.findall(pat=pat_re) + + +def test_extract_multi_None_raises(dtype) -> None: + pat_str = r"(\w+)_(\d+)" + pat_re = re.compile(pat_str) + + value = xr.DataArray([["a_b"]], dims=["X", "Y"]).astype(dtype) + + with pytest.raises( + ValueError, + match="Dimension must be specified if more than one capture group is given.", + ): + value.str.extract(pat=pat_str, dim=None) + + with pytest.raises( + ValueError, + match="Dimension must be specified if more than one capture group is given.", + ): + value.str.extract(pat=pat_re, dim=None) + + +def test_extract_extractall_findall_case_re_raises(dtype) -> None: + pat_str = r".*" + pat_re = re.compile(pat_str) + + value = xr.DataArray([["a"]], dims=["X", "Y"]).astype(dtype) + + with pytest.raises( + ValueError, match="Case cannot be set when pat is a compiled regex." + ): + value.str.extract(pat=pat_re, case=True, dim="ZZ") + + with pytest.raises( + ValueError, match="Case cannot be set when pat is a compiled regex." + ): + value.str.extract(pat=pat_re, case=False, dim="ZZ") + + with pytest.raises( + ValueError, match="Case cannot be set when pat is a compiled regex." + ): + value.str.extractall(pat=pat_re, case=True, group_dim="XX", match_dim="YY") + + with pytest.raises( + ValueError, match="Case cannot be set when pat is a compiled regex." + ): + value.str.extractall(pat=pat_re, case=False, group_dim="XX", match_dim="YY") + + with pytest.raises( + ValueError, match="Case cannot be set when pat is a compiled regex." + ): + value.str.findall(pat=pat_re, case=True) + + with pytest.raises( + ValueError, match="Case cannot be set when pat is a compiled regex." + ): + value.str.findall(pat=pat_re, case=False) + + +def test_extract_extractall_name_collision_raises(dtype) -> None: + pat_str = r"(\w+)" + pat_re = re.compile(pat_str) + + value = xr.DataArray([["a"]], dims=["X", "Y"]).astype(dtype) + + with pytest.raises(KeyError, match="Dimension 'X' already present in DataArray."): + value.str.extract(pat=pat_str, dim="X") + + with pytest.raises(KeyError, match="Dimension 'X' already present in DataArray."): + value.str.extract(pat=pat_re, dim="X") + + with pytest.raises( + KeyError, match="Group dimension 'X' already present in DataArray." + ): + value.str.extractall(pat=pat_str, group_dim="X", match_dim="ZZ") + + with pytest.raises( + KeyError, match="Group dimension 'X' already present in DataArray." + ): + value.str.extractall(pat=pat_re, group_dim="X", match_dim="YY") + + with pytest.raises( + KeyError, match="Match dimension 'Y' already present in DataArray." + ): + value.str.extractall(pat=pat_str, group_dim="XX", match_dim="Y") + + with pytest.raises( + KeyError, match="Match dimension 'Y' already present in DataArray." + ): + value.str.extractall(pat=pat_re, group_dim="XX", match_dim="Y") + + with pytest.raises( + KeyError, match="Group dimension 'ZZ' is the same as match dimension 'ZZ'." + ): + value.str.extractall(pat=pat_str, group_dim="ZZ", match_dim="ZZ") + + with pytest.raises( + KeyError, match="Group dimension 'ZZ' is the same as match dimension 'ZZ'." + ): + value.str.extractall(pat=pat_re, group_dim="ZZ", match_dim="ZZ") + + +def test_extract_single_case(dtype) -> None: + pat_str = r"(\w+)_Xy_\d*" + pat_re: str | bytes = ( + pat_str if dtype == np.str_ else bytes(pat_str, encoding="UTF-8") + ) + pat_compiled = re.compile(pat_re) + + value = xr.DataArray( + [ + ["a_Xy_0", "ab_xY_10-bab_Xy_110-baab_Xy_1100", "abc_Xy_01-cbc_Xy_2210"], + [ + "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + "", + "abcdef_Xy_101-fef_Xy_5543210", + ], + ], + dims=["X", "Y"], + ).astype(dtype) + + targ_none = xr.DataArray( + [["a", "bab", "abc"], ["abcd", "", "abcdef"]], dims=["X", "Y"] + ).astype(dtype) + targ_dim = xr.DataArray( + [[["a"], ["bab"], ["abc"]], [["abcd"], [""], ["abcdef"]]], dims=["X", "Y", "XX"] + ).astype(dtype) + + res_str_none = value.str.extract(pat=pat_str, dim=None) + res_str_dim = value.str.extract(pat=pat_str, dim="XX") + res_str_none_case = value.str.extract(pat=pat_str, dim=None, case=True) + res_str_dim_case = value.str.extract(pat=pat_str, dim="XX", case=True) + res_re_none = value.str.extract(pat=pat_compiled, dim=None) + res_re_dim = value.str.extract(pat=pat_compiled, dim="XX") + + assert res_str_none.dtype == targ_none.dtype + assert res_str_dim.dtype == targ_dim.dtype + assert res_str_none_case.dtype == targ_none.dtype + assert res_str_dim_case.dtype == targ_dim.dtype + assert res_re_none.dtype == targ_none.dtype + assert res_re_dim.dtype == targ_dim.dtype + + assert_equal(res_str_none, targ_none) + assert_equal(res_str_dim, targ_dim) + assert_equal(res_str_none_case, targ_none) + assert_equal(res_str_dim_case, targ_dim) + assert_equal(res_re_none, targ_none) + assert_equal(res_re_dim, targ_dim) + + +def test_extract_single_nocase(dtype) -> None: + pat_str = r"(\w+)?_Xy_\d*" + pat_re: str | bytes = ( + pat_str if dtype == np.str_ else bytes(pat_str, encoding="UTF-8") + ) + pat_compiled = re.compile(pat_re, flags=re.IGNORECASE) + + value = xr.DataArray( + [ + ["a_Xy_0", "ab_xY_10-bab_Xy_110-baab_Xy_1100", "abc_Xy_01-cbc_Xy_2210"], + [ + "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + "_Xy_1", + "abcdef_Xy_101-fef_Xy_5543210", + ], + ], + dims=["X", "Y"], + ).astype(dtype) + + targ_none = xr.DataArray( + [["a", "ab", "abc"], ["abcd", "", "abcdef"]], dims=["X", "Y"] + ).astype(dtype) + targ_dim = xr.DataArray( + [[["a"], ["ab"], ["abc"]], [["abcd"], [""], ["abcdef"]]], dims=["X", "Y", "XX"] + ).astype(dtype) + + res_str_none = value.str.extract(pat=pat_str, dim=None, case=False) + res_str_dim = value.str.extract(pat=pat_str, dim="XX", case=False) + res_re_none = value.str.extract(pat=pat_compiled, dim=None) + res_re_dim = value.str.extract(pat=pat_compiled, dim="XX") + + assert res_re_dim.dtype == targ_none.dtype + assert res_str_dim.dtype == targ_dim.dtype + assert res_re_none.dtype == targ_none.dtype + assert res_re_dim.dtype == targ_dim.dtype + + assert_equal(res_str_none, targ_none) + assert_equal(res_str_dim, targ_dim) + assert_equal(res_re_none, targ_none) + assert_equal(res_re_dim, targ_dim) + + +def test_extract_multi_case(dtype) -> None: + pat_str = r"(\w+)_Xy_(\d*)" + pat_re: str | bytes = ( + pat_str if dtype == np.str_ else bytes(pat_str, encoding="UTF-8") + ) + pat_compiled = re.compile(pat_re) + + value = xr.DataArray( + [ + ["a_Xy_0", "ab_xY_10-bab_Xy_110-baab_Xy_1100", "abc_Xy_01-cbc_Xy_2210"], + [ + "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + "", + "abcdef_Xy_101-fef_Xy_5543210", + ], + ], + dims=["X", "Y"], + ).astype(dtype) + + expected = xr.DataArray( + [ + [["a", "0"], ["bab", "110"], ["abc", "01"]], + [["abcd", ""], ["", ""], ["abcdef", "101"]], + ], + dims=["X", "Y", "XX"], + ).astype(dtype) + + res_str = value.str.extract(pat=pat_str, dim="XX") + res_re = value.str.extract(pat=pat_compiled, dim="XX") + res_str_case = value.str.extract(pat=pat_str, dim="XX", case=True) + + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + assert res_str_case.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) + assert_equal(res_str_case, expected) + + +def test_extract_multi_nocase(dtype) -> None: + pat_str = r"(\w+)_Xy_(\d*)" + pat_re: str | bytes = ( + pat_str if dtype == np.str_ else bytes(pat_str, encoding="UTF-8") + ) + pat_compiled = re.compile(pat_re, flags=re.IGNORECASE) + + value = xr.DataArray( + [ + ["a_Xy_0", "ab_xY_10-bab_Xy_110-baab_Xy_1100", "abc_Xy_01-cbc_Xy_2210"], + [ + "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + "", + "abcdef_Xy_101-fef_Xy_5543210", + ], + ], + dims=["X", "Y"], + ).astype(dtype) + + expected = xr.DataArray( + [ + [["a", "0"], ["ab", "10"], ["abc", "01"]], + [["abcd", ""], ["", ""], ["abcdef", "101"]], + ], + dims=["X", "Y", "XX"], + ).astype(dtype) + + res_str = value.str.extract(pat=pat_str, dim="XX", case=False) + res_re = value.str.extract(pat=pat_compiled, dim="XX") + + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) + + +def test_extract_broadcast(dtype) -> None: + value = xr.DataArray( + ["a_Xy_0", "ab_xY_10", "abc_Xy_01"], + dims=["X"], + ).astype(dtype) + + pat_str = xr.DataArray( + [r"(\w+)_Xy_(\d*)", r"(\w+)_xY_(\d*)"], + dims=["Y"], + ).astype(dtype) + pat_compiled = value.str._re_compile(pat=pat_str) + + expected_list = [ + [["a", "0"], ["", ""]], + [["", ""], ["ab", "10"]], + [["abc", "01"], ["", ""]], + ] + expected = xr.DataArray(expected_list, dims=["X", "Y", "Zz"]).astype(dtype) + + res_str = value.str.extract(pat=pat_str, dim="Zz") + res_re = value.str.extract(pat=pat_compiled, dim="Zz") + + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) + + +def test_extractall_single_single_case(dtype) -> None: + pat_str = r"(\w+)_Xy_\d*" + pat_re: str | bytes = ( + pat_str if dtype == np.str_ else bytes(pat_str, encoding="UTF-8") + ) + pat_compiled = re.compile(pat_re) + + value = xr.DataArray( + [["a_Xy_0", "ab_xY_10", "abc_Xy_01"], ["abcd_Xy_", "", "abcdef_Xy_101"]], + dims=["X", "Y"], + ).astype(dtype) + + expected = xr.DataArray( + [[[["a"]], [[""]], [["abc"]]], [[["abcd"]], [[""]], [["abcdef"]]]], + dims=["X", "Y", "XX", "YY"], + ).astype(dtype) + + res_str = value.str.extractall(pat=pat_str, group_dim="XX", match_dim="YY") + res_re = value.str.extractall(pat=pat_compiled, group_dim="XX", match_dim="YY") + res_str_case = value.str.extractall( + pat=pat_str, group_dim="XX", match_dim="YY", case=True + ) + + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + assert res_str_case.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) + assert_equal(res_str_case, expected) + + +def test_extractall_single_single_nocase(dtype) -> None: + pat_str = r"(\w+)_Xy_\d*" + pat_re: str | bytes = ( + pat_str if dtype == np.str_ else bytes(pat_str, encoding="UTF-8") + ) + pat_compiled = re.compile(pat_re, flags=re.I) + + value = xr.DataArray( + [["a_Xy_0", "ab_xY_10", "abc_Xy_01"], ["abcd_Xy_", "", "abcdef_Xy_101"]], + dims=["X", "Y"], + ).astype(dtype) + + expected = xr.DataArray( + [[[["a"]], [["ab"]], [["abc"]]], [[["abcd"]], [[""]], [["abcdef"]]]], + dims=["X", "Y", "XX", "YY"], + ).astype(dtype) + + res_str = value.str.extractall( + pat=pat_str, group_dim="XX", match_dim="YY", case=False + ) + res_re = value.str.extractall(pat=pat_compiled, group_dim="XX", match_dim="YY") + + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) + + +def test_extractall_single_multi_case(dtype) -> None: + pat_str = r"(\w+)_Xy_\d*" + pat_re: str | bytes = ( + pat_str if dtype == np.str_ else bytes(pat_str, encoding="UTF-8") + ) + pat_compiled = re.compile(pat_re) + + value = xr.DataArray( + [ + ["a_Xy_0", "ab_xY_10-bab_Xy_110-baab_Xy_1100", "abc_Xy_01-cbc_Xy_2210"], + [ + "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + "", + "abcdef_Xy_101-fef_Xy_5543210", + ], + ], + dims=["X", "Y"], + ).astype(dtype) + + expected = xr.DataArray( + [ + [[["a"], [""], [""]], [["bab"], ["baab"], [""]], [["abc"], ["cbc"], [""]]], + [ + [["abcd"], ["dcd"], ["dccd"]], + [[""], [""], [""]], + [["abcdef"], ["fef"], [""]], + ], + ], + dims=["X", "Y", "XX", "YY"], + ).astype(dtype) + + res_str = value.str.extractall(pat=pat_str, group_dim="XX", match_dim="YY") + res_re = value.str.extractall(pat=pat_compiled, group_dim="XX", match_dim="YY") + res_str_case = value.str.extractall( + pat=pat_str, group_dim="XX", match_dim="YY", case=True + ) + + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + assert res_str_case.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) + assert_equal(res_str_case, expected) + + +def test_extractall_single_multi_nocase(dtype) -> None: + pat_str = r"(\w+)_Xy_\d*" + pat_re: str | bytes = ( + pat_str if dtype == np.str_ else bytes(pat_str, encoding="UTF-8") + ) + pat_compiled = re.compile(pat_re, flags=re.I) + + value = xr.DataArray( + [ + ["a_Xy_0", "ab_xY_10-bab_Xy_110-baab_Xy_1100", "abc_Xy_01-cbc_Xy_2210"], + [ + "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + "", + "abcdef_Xy_101-fef_Xy_5543210", + ], + ], + dims=["X", "Y"], + ).astype(dtype) + + expected = xr.DataArray( + [ + [ + [["a"], [""], [""]], + [["ab"], ["bab"], ["baab"]], + [["abc"], ["cbc"], [""]], + ], + [ + [["abcd"], ["dcd"], ["dccd"]], + [[""], [""], [""]], + [["abcdef"], ["fef"], [""]], + ], + ], + dims=["X", "Y", "XX", "YY"], + ).astype(dtype) + + res_str = value.str.extractall( + pat=pat_str, group_dim="XX", match_dim="YY", case=False + ) + res_re = value.str.extractall(pat=pat_compiled, group_dim="XX", match_dim="YY") + + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) + + +def test_extractall_multi_single_case(dtype) -> None: + pat_str = r"(\w+)_Xy_(\d*)" + pat_re: str | bytes = ( + pat_str if dtype == np.str_ else bytes(pat_str, encoding="UTF-8") + ) + pat_compiled = re.compile(pat_re) + + value = xr.DataArray( + [["a_Xy_0", "ab_xY_10", "abc_Xy_01"], ["abcd_Xy_", "", "abcdef_Xy_101"]], + dims=["X", "Y"], + ).astype(dtype) + + expected = xr.DataArray( + [ + [[["a", "0"]], [["", ""]], [["abc", "01"]]], + [[["abcd", ""]], [["", ""]], [["abcdef", "101"]]], + ], + dims=["X", "Y", "XX", "YY"], + ).astype(dtype) + + res_str = value.str.extractall(pat=pat_str, group_dim="XX", match_dim="YY") + res_re = value.str.extractall(pat=pat_compiled, group_dim="XX", match_dim="YY") + res_str_case = value.str.extractall( + pat=pat_str, group_dim="XX", match_dim="YY", case=True + ) + + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + assert res_str_case.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) + assert_equal(res_str_case, expected) + + +def test_extractall_multi_single_nocase(dtype) -> None: + pat_str = r"(\w+)_Xy_(\d*)" + pat_re: str | bytes = ( + pat_str if dtype == np.str_ else bytes(pat_str, encoding="UTF-8") + ) + pat_compiled = re.compile(pat_re, flags=re.I) + + value = xr.DataArray( + [["a_Xy_0", "ab_xY_10", "abc_Xy_01"], ["abcd_Xy_", "", "abcdef_Xy_101"]], + dims=["X", "Y"], + ).astype(dtype) + + expected = xr.DataArray( + [ + [[["a", "0"]], [["ab", "10"]], [["abc", "01"]]], + [[["abcd", ""]], [["", ""]], [["abcdef", "101"]]], + ], + dims=["X", "Y", "XX", "YY"], + ).astype(dtype) + + res_str = value.str.extractall( + pat=pat_str, group_dim="XX", match_dim="YY", case=False + ) + res_re = value.str.extractall(pat=pat_compiled, group_dim="XX", match_dim="YY") + + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) + + +def test_extractall_multi_multi_case(dtype) -> None: + pat_str = r"(\w+)_Xy_(\d*)" + pat_re: str | bytes = ( + pat_str if dtype == np.str_ else bytes(pat_str, encoding="UTF-8") + ) + pat_compiled = re.compile(pat_re) + + value = xr.DataArray( + [ + ["a_Xy_0", "ab_xY_10-bab_Xy_110-baab_Xy_1100", "abc_Xy_01-cbc_Xy_2210"], + [ + "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + "", + "abcdef_Xy_101-fef_Xy_5543210", + ], + ], + dims=["X", "Y"], + ).astype(dtype) + + expected = xr.DataArray( + [ + [ + [["a", "0"], ["", ""], ["", ""]], + [["bab", "110"], ["baab", "1100"], ["", ""]], + [["abc", "01"], ["cbc", "2210"], ["", ""]], + ], + [ + [["abcd", ""], ["dcd", "33210"], ["dccd", "332210"]], + [["", ""], ["", ""], ["", ""]], + [["abcdef", "101"], ["fef", "5543210"], ["", ""]], + ], + ], + dims=["X", "Y", "XX", "YY"], + ).astype(dtype) + + res_str = value.str.extractall(pat=pat_str, group_dim="XX", match_dim="YY") + res_re = value.str.extractall(pat=pat_compiled, group_dim="XX", match_dim="YY") + res_str_case = value.str.extractall( + pat=pat_str, group_dim="XX", match_dim="YY", case=True + ) + + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + assert res_str_case.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) + assert_equal(res_str_case, expected) + + +def test_extractall_multi_multi_nocase(dtype) -> None: + pat_str = r"(\w+)_Xy_(\d*)" + pat_re: str | bytes = ( + pat_str if dtype == np.str_ else bytes(pat_str, encoding="UTF-8") + ) + pat_compiled = re.compile(pat_re, flags=re.I) + + value = xr.DataArray( + [ + ["a_Xy_0", "ab_xY_10-bab_Xy_110-baab_Xy_1100", "abc_Xy_01-cbc_Xy_2210"], + [ + "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + "", + "abcdef_Xy_101-fef_Xy_5543210", + ], + ], + dims=["X", "Y"], + ).astype(dtype) + + expected = xr.DataArray( + [ + [ + [["a", "0"], ["", ""], ["", ""]], + [["ab", "10"], ["bab", "110"], ["baab", "1100"]], + [["abc", "01"], ["cbc", "2210"], ["", ""]], + ], + [ + [["abcd", ""], ["dcd", "33210"], ["dccd", "332210"]], + [["", ""], ["", ""], ["", ""]], + [["abcdef", "101"], ["fef", "5543210"], ["", ""]], + ], + ], + dims=["X", "Y", "XX", "YY"], + ).astype(dtype) + + res_str = value.str.extractall( + pat=pat_str, group_dim="XX", match_dim="YY", case=False + ) + res_re = value.str.extractall(pat=pat_compiled, group_dim="XX", match_dim="YY") + + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) + + +def test_extractall_broadcast(dtype) -> None: + value = xr.DataArray( + ["a_Xy_0", "ab_xY_10", "abc_Xy_01"], + dims=["X"], + ).astype(dtype) + + pat_str = xr.DataArray( + [r"(\w+)_Xy_(\d*)", r"(\w+)_xY_(\d*)"], + dims=["Y"], + ).astype(dtype) + pat_re = value.str._re_compile(pat=pat_str) + + expected_list = [ + [[["a", "0"]], [["", ""]]], + [[["", ""]], [["ab", "10"]]], + [[["abc", "01"]], [["", ""]]], + ] + expected = xr.DataArray(expected_list, dims=["X", "Y", "ZX", "ZY"]).astype(dtype) + + res_str = value.str.extractall(pat=pat_str, group_dim="ZX", match_dim="ZY") + res_re = value.str.extractall(pat=pat_re, group_dim="ZX", match_dim="ZY") + + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) + + +def test_findall_single_single_case(dtype) -> None: + pat_str = r"(\w+)_Xy_\d*" + pat_re = re.compile(dtype(pat_str)) + + value = xr.DataArray( + [["a_Xy_0", "ab_xY_10", "abc_Xy_01"], ["abcd_Xy_", "", "abcdef_Xy_101"]], + dims=["X", "Y"], + ).astype(dtype) + + expected_list: list[list[list]] = [[["a"], [], ["abc"]], [["abcd"], [], ["abcdef"]]] + expected_dtype = [[[dtype(x) for x in y] for y in z] for z in expected_list] + expected_np = np.array(expected_dtype, dtype=np.object_) + expected = xr.DataArray(expected_np, dims=["X", "Y"]) + + res_str = value.str.findall(pat=pat_str) + res_re = value.str.findall(pat=pat_re) + res_str_case = value.str.findall(pat=pat_str, case=True) + + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + assert res_str_case.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) + assert_equal(res_str_case, expected) + + +def test_findall_single_single_nocase(dtype) -> None: + pat_str = r"(\w+)_Xy_\d*" + pat_re = re.compile(dtype(pat_str), flags=re.I) + + value = xr.DataArray( + [["a_Xy_0", "ab_xY_10", "abc_Xy_01"], ["abcd_Xy_", "", "abcdef_Xy_101"]], + dims=["X", "Y"], + ).astype(dtype) + + expected_list: list[list[list]] = [ + [["a"], ["ab"], ["abc"]], + [["abcd"], [], ["abcdef"]], + ] + expected_dtype = [[[dtype(x) for x in y] for y in z] for z in expected_list] + expected_np = np.array(expected_dtype, dtype=np.object_) + expected = xr.DataArray(expected_np, dims=["X", "Y"]) + + res_str = value.str.findall(pat=pat_str, case=False) + res_re = value.str.findall(pat=pat_re) + + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) + + +def test_findall_single_multi_case(dtype) -> None: + pat_str = r"(\w+)_Xy_\d*" + pat_re = re.compile(dtype(pat_str)) + + value = xr.DataArray( + [ + ["a_Xy_0", "ab_xY_10-bab_Xy_110-baab_Xy_1100", "abc_Xy_01-cbc_Xy_2210"], + [ + "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + "", + "abcdef_Xy_101-fef_Xy_5543210", + ], + ], + dims=["X", "Y"], + ).astype(dtype) + + expected_list: list[list[list]] = [ + [["a"], ["bab", "baab"], ["abc", "cbc"]], + [ + ["abcd", "dcd", "dccd"], + [], + ["abcdef", "fef"], + ], + ] + expected_dtype = [[[dtype(x) for x in y] for y in z] for z in expected_list] + expected_np = np.array(expected_dtype, dtype=np.object_) + expected = xr.DataArray(expected_np, dims=["X", "Y"]) + + res_str = value.str.findall(pat=pat_str) + res_re = value.str.findall(pat=pat_re) + res_str_case = value.str.findall(pat=pat_str, case=True) + + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + assert res_str_case.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) + assert_equal(res_str_case, expected) + + +def test_findall_single_multi_nocase(dtype) -> None: + pat_str = r"(\w+)_Xy_\d*" + pat_re = re.compile(dtype(pat_str), flags=re.I) + + value = xr.DataArray( + [ + ["a_Xy_0", "ab_xY_10-bab_Xy_110-baab_Xy_1100", "abc_Xy_01-cbc_Xy_2210"], + [ + "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + "", + "abcdef_Xy_101-fef_Xy_5543210", + ], + ], + dims=["X", "Y"], + ).astype(dtype) + + expected_list: list[list[list]] = [ + [ + ["a"], + ["ab", "bab", "baab"], + ["abc", "cbc"], + ], + [ + ["abcd", "dcd", "dccd"], + [], + ["abcdef", "fef"], + ], + ] + expected_dtype = [[[dtype(x) for x in y] for y in z] for z in expected_list] + expected_np = np.array(expected_dtype, dtype=np.object_) + expected = xr.DataArray(expected_np, dims=["X", "Y"]) + + res_str = value.str.findall(pat=pat_str, case=False) + res_re = value.str.findall(pat=pat_re) + + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) + + +def test_findall_multi_single_case(dtype) -> None: + pat_str = r"(\w+)_Xy_(\d*)" + pat_re = re.compile(dtype(pat_str)) + + value = xr.DataArray( + [["a_Xy_0", "ab_xY_10", "abc_Xy_01"], ["abcd_Xy_", "", "abcdef_Xy_101"]], + dims=["X", "Y"], + ).astype(dtype) + + expected_list: list[list[list[list]]] = [ + [[["a", "0"]], [], [["abc", "01"]]], + [[["abcd", ""]], [], [["abcdef", "101"]]], + ] + expected_dtype = [ + [[tuple(dtype(x) for x in y) for y in z] for z in w] for w in expected_list + ] + expected_np = np.array(expected_dtype, dtype=np.object_) + expected = xr.DataArray(expected_np, dims=["X", "Y"]) + + res_str = value.str.findall(pat=pat_str) + res_re = value.str.findall(pat=pat_re) + res_str_case = value.str.findall(pat=pat_str, case=True) + + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + assert res_str_case.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) + assert_equal(res_str_case, expected) + + +def test_findall_multi_single_nocase(dtype) -> None: + pat_str = r"(\w+)_Xy_(\d*)" + pat_re = re.compile(dtype(pat_str), flags=re.I) + + value = xr.DataArray( + [["a_Xy_0", "ab_xY_10", "abc_Xy_01"], ["abcd_Xy_", "", "abcdef_Xy_101"]], + dims=["X", "Y"], + ).astype(dtype) + + expected_list: list[list[list[list]]] = [ + [[["a", "0"]], [["ab", "10"]], [["abc", "01"]]], + [[["abcd", ""]], [], [["abcdef", "101"]]], + ] + expected_dtype = [ + [[tuple(dtype(x) for x in y) for y in z] for z in w] for w in expected_list + ] + expected_np = np.array(expected_dtype, dtype=np.object_) + expected = xr.DataArray(expected_np, dims=["X", "Y"]) + + res_str = value.str.findall(pat=pat_str, case=False) + res_re = value.str.findall(pat=pat_re) + + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) + + +def test_findall_multi_multi_case(dtype) -> None: + pat_str = r"(\w+)_Xy_(\d*)" + pat_re = re.compile(dtype(pat_str)) + + value = xr.DataArray( + [ + ["a_Xy_0", "ab_xY_10-bab_Xy_110-baab_Xy_1100", "abc_Xy_01-cbc_Xy_2210"], + [ + "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + "", + "abcdef_Xy_101-fef_Xy_5543210", + ], + ], + dims=["X", "Y"], + ).astype(dtype) + + expected_list: list[list[list[list]]] = [ + [ + [["a", "0"]], + [["bab", "110"], ["baab", "1100"]], + [["abc", "01"], ["cbc", "2210"]], + ], + [ + [["abcd", ""], ["dcd", "33210"], ["dccd", "332210"]], + [], + [["abcdef", "101"], ["fef", "5543210"]], + ], + ] + expected_dtype = [ + [[tuple(dtype(x) for x in y) for y in z] for z in w] for w in expected_list + ] + expected_np = np.array(expected_dtype, dtype=np.object_) + expected = xr.DataArray(expected_np, dims=["X", "Y"]) + + res_str = value.str.findall(pat=pat_str) + res_re = value.str.findall(pat=pat_re) + res_str_case = value.str.findall(pat=pat_str, case=True) + + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + assert res_str_case.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) + assert_equal(res_str_case, expected) + + +def test_findall_multi_multi_nocase(dtype) -> None: + pat_str = r"(\w+)_Xy_(\d*)" + pat_re = re.compile(dtype(pat_str), flags=re.I) + + value = xr.DataArray( + [ + ["a_Xy_0", "ab_xY_10-bab_Xy_110-baab_Xy_1100", "abc_Xy_01-cbc_Xy_2210"], + [ + "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + "", + "abcdef_Xy_101-fef_Xy_5543210", + ], + ], + dims=["X", "Y"], + ).astype(dtype) + + expected_list: list[list[list[list]]] = [ + [ + [["a", "0"]], + [["ab", "10"], ["bab", "110"], ["baab", "1100"]], + [["abc", "01"], ["cbc", "2210"]], + ], + [ + [["abcd", ""], ["dcd", "33210"], ["dccd", "332210"]], + [], + [["abcdef", "101"], ["fef", "5543210"]], + ], + ] + expected_dtype = [ + [[tuple(dtype(x) for x in y) for y in z] for z in w] for w in expected_list + ] + expected_np = np.array(expected_dtype, dtype=np.object_) + expected = xr.DataArray(expected_np, dims=["X", "Y"]) + + res_str = value.str.findall(pat=pat_str, case=False) + res_re = value.str.findall(pat=pat_re) + + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) + + +def test_findall_broadcast(dtype) -> None: + value = xr.DataArray( + ["a_Xy_0", "ab_xY_10", "abc_Xy_01"], + dims=["X"], + ).astype(dtype) + + pat_str = xr.DataArray( + [r"(\w+)_Xy_\d*", r"\w+_Xy_(\d*)"], + dims=["Y"], + ).astype(dtype) + pat_re = value.str._re_compile(pat=pat_str) + + expected_list: list[list[list]] = [[["a"], ["0"]], [[], []], [["abc"], ["01"]]] + expected_dtype = [[[dtype(x) for x in y] for y in z] for z in expected_list] + expected_np = np.array(expected_dtype, dtype=np.object_) + expected = xr.DataArray(expected_np, dims=["X", "Y"]) + + res_str = value.str.findall(pat=pat_str) + res_re = value.str.findall(pat=pat_re) + + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) + + +def test_repeat(dtype) -> None: + values = xr.DataArray(["a", "b", "c", "d"]).astype(dtype) + + result = values.str.repeat(3) + result_mul = values.str * 3 + + expected = xr.DataArray(["aaa", "bbb", "ccc", "ddd"]).astype(dtype) + + assert result.dtype == expected.dtype + assert result_mul.dtype == expected.dtype + + assert_equal(result_mul, expected) + assert_equal(result, expected) + + +def test_repeat_broadcast(dtype) -> None: + values = xr.DataArray(["a", "b", "c", "d"], dims=["X"]).astype(dtype) + reps = xr.DataArray([3, 4], dims=["Y"]) + + result = values.str.repeat(reps) + result_mul = values.str * reps + + expected = xr.DataArray( + [["aaa", "aaaa"], ["bbb", "bbbb"], ["ccc", "cccc"], ["ddd", "dddd"]], + dims=["X", "Y"], + ).astype(dtype) + + assert result.dtype == expected.dtype + assert result_mul.dtype == expected.dtype + + assert_equal(result_mul, expected) + assert_equal(result, expected) + + +def test_match(dtype) -> None: + values = xr.DataArray(["fooBAD__barBAD", "foo"]).astype(dtype) + + # New match behavior introduced in 0.13 + pat = values.dtype.type(".*(BAD[_]+).*(BAD)") + result = values.str.match(pat) + expected = xr.DataArray([True, False]) + assert result.dtype == expected.dtype + assert_equal(result, expected) + result = values.str.match(re.compile(pat)) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # Case-sensitive + pat = values.dtype.type(".*BAD[_]+.*BAD") + result = values.str.match(pat) + expected = xr.DataArray([True, False]) + assert result.dtype == expected.dtype + assert_equal(result, expected) + result = values.str.match(re.compile(pat)) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # Case-insensitive + pat = values.dtype.type(".*bAd[_]+.*bad") + result = values.str.match(pat, case=False) + expected = xr.DataArray([True, False]) + assert result.dtype == expected.dtype + assert_equal(result, expected) + result = values.str.match(re.compile(pat, flags=re.IGNORECASE)) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_empty_str_methods() -> None: + empty = xr.DataArray(np.empty(shape=(0,), dtype="U")) + empty_str = empty + empty_int = xr.DataArray(np.empty(shape=(0,), dtype=int)) + empty_bool = xr.DataArray(np.empty(shape=(0,), dtype=bool)) + empty_bytes = xr.DataArray(np.empty(shape=(0,), dtype="S")) + + # TODO: Determine why U and S dtype sizes don't match and figure + # out a reliable way to predict what they should be + + assert empty_bool.dtype == empty.str.contains("a").dtype + assert empty_bool.dtype == empty.str.endswith("a").dtype + assert empty_bool.dtype == empty.str.match("^a").dtype + assert empty_bool.dtype == empty.str.startswith("a").dtype + assert empty_bool.dtype == empty.str.isalnum().dtype + assert empty_bool.dtype == empty.str.isalpha().dtype + assert empty_bool.dtype == empty.str.isdecimal().dtype + assert empty_bool.dtype == empty.str.isdigit().dtype + assert empty_bool.dtype == empty.str.islower().dtype + assert empty_bool.dtype == empty.str.isnumeric().dtype + assert empty_bool.dtype == empty.str.isspace().dtype + assert empty_bool.dtype == empty.str.istitle().dtype + assert empty_bool.dtype == empty.str.isupper().dtype + assert empty_bytes.dtype.kind == empty.str.encode("ascii").dtype.kind + assert empty_int.dtype.kind == empty.str.count("a").dtype.kind + assert empty_int.dtype.kind == empty.str.find("a").dtype.kind + assert empty_int.dtype.kind == empty.str.len().dtype.kind + assert empty_int.dtype.kind == empty.str.rfind("a").dtype.kind + assert empty_str.dtype.kind == empty.str.capitalize().dtype.kind + assert empty_str.dtype.kind == empty.str.center(42).dtype.kind + assert empty_str.dtype.kind == empty.str.get(0).dtype.kind + assert empty_str.dtype.kind == empty.str.lower().dtype.kind + assert empty_str.dtype.kind == empty.str.lstrip().dtype.kind + assert empty_str.dtype.kind == empty.str.pad(42).dtype.kind + assert empty_str.dtype.kind == empty.str.repeat(3).dtype.kind + assert empty_str.dtype.kind == empty.str.rstrip().dtype.kind + assert empty_str.dtype.kind == empty.str.slice(step=1).dtype.kind + assert empty_str.dtype.kind == empty.str.slice(stop=1).dtype.kind + assert empty_str.dtype.kind == empty.str.strip().dtype.kind + assert empty_str.dtype.kind == empty.str.swapcase().dtype.kind + assert empty_str.dtype.kind == empty.str.title().dtype.kind + assert empty_str.dtype.kind == empty.str.upper().dtype.kind + assert empty_str.dtype.kind == empty.str.wrap(42).dtype.kind + assert empty_str.dtype.kind == empty_bytes.str.decode("ascii").dtype.kind + + assert_equal(empty_bool, empty.str.contains("a")) + assert_equal(empty_bool, empty.str.endswith("a")) + assert_equal(empty_bool, empty.str.match("^a")) + assert_equal(empty_bool, empty.str.startswith("a")) + assert_equal(empty_bool, empty.str.isalnum()) + assert_equal(empty_bool, empty.str.isalpha()) + assert_equal(empty_bool, empty.str.isdecimal()) + assert_equal(empty_bool, empty.str.isdigit()) + assert_equal(empty_bool, empty.str.islower()) + assert_equal(empty_bool, empty.str.isnumeric()) + assert_equal(empty_bool, empty.str.isspace()) + assert_equal(empty_bool, empty.str.istitle()) + assert_equal(empty_bool, empty.str.isupper()) + assert_equal(empty_bytes, empty.str.encode("ascii")) + assert_equal(empty_int, empty.str.count("a")) + assert_equal(empty_int, empty.str.find("a")) + assert_equal(empty_int, empty.str.len()) + assert_equal(empty_int, empty.str.rfind("a")) + assert_equal(empty_str, empty.str.capitalize()) + assert_equal(empty_str, empty.str.center(42)) + assert_equal(empty_str, empty.str.get(0)) + assert_equal(empty_str, empty.str.lower()) + assert_equal(empty_str, empty.str.lstrip()) + assert_equal(empty_str, empty.str.pad(42)) + assert_equal(empty_str, empty.str.repeat(3)) + assert_equal(empty_str, empty.str.replace("a", "b")) + assert_equal(empty_str, empty.str.rstrip()) + assert_equal(empty_str, empty.str.slice(step=1)) + assert_equal(empty_str, empty.str.slice(stop=1)) + assert_equal(empty_str, empty.str.strip()) + assert_equal(empty_str, empty.str.swapcase()) + assert_equal(empty_str, empty.str.title()) + assert_equal(empty_str, empty.str.upper()) + assert_equal(empty_str, empty.str.wrap(42)) + assert_equal(empty_str, empty_bytes.str.decode("ascii")) + + table = str.maketrans("a", "b") + assert empty_str.dtype.kind == empty.str.translate(table).dtype.kind + assert_equal(empty_str, empty.str.translate(table)) + + +@pytest.mark.parametrize( + ["func", "expected"], + [ + pytest.param( + lambda x: x.str.isalnum(), + [True, True, True, True, True, False, True, True, False, False], + id="isalnum", + ), + pytest.param( + lambda x: x.str.isalpha(), + [True, True, True, False, False, False, True, False, False, False], + id="isalpha", + ), + pytest.param( + lambda x: x.str.isdigit(), + [False, False, False, True, False, False, False, True, False, False], + id="isdigit", + ), + pytest.param( + lambda x: x.str.islower(), + [False, True, False, False, False, False, False, False, False, False], + id="islower", + ), + pytest.param( + lambda x: x.str.isspace(), + [False, False, False, False, False, False, False, False, False, True], + id="isspace", + ), + pytest.param( + lambda x: x.str.istitle(), + [True, False, True, False, True, False, False, False, False, False], + id="istitle", + ), + pytest.param( + lambda x: x.str.isupper(), + [True, False, False, False, True, False, True, False, False, False], + id="isupper", + ), + ], +) +def test_ismethods( + dtype, func: Callable[[xr.DataArray], xr.DataArray], expected: list[bool] +) -> None: + values = xr.DataArray( + ["A", "b", "Xy", "4", "3A", "", "TT", "55", "-", " "] + ).astype(dtype) + + expected_da = xr.DataArray(expected) + actual = func(values) + + assert actual.dtype == expected_da.dtype + assert_equal(actual, expected_da) + + +def test_isnumeric() -> None: + # 0x00bc: ¼ VULGAR FRACTION ONE QUARTER + # 0x2605: ★ not number + # 0x1378: ፸ ETHIOPIC NUMBER SEVENTY + # 0xFF13: 3 Em 3 + values = xr.DataArray(["A", "3", "¼", "★", "፸", "3", "four"]) + exp_numeric = xr.DataArray([False, True, True, False, True, True, False]) + exp_decimal = xr.DataArray([False, True, False, False, False, True, False]) + + res_numeric = values.str.isnumeric() + res_decimal = values.str.isdecimal() + + assert res_numeric.dtype == exp_numeric.dtype + assert res_decimal.dtype == exp_decimal.dtype + + assert_equal(res_numeric, exp_numeric) + assert_equal(res_decimal, exp_decimal) + + +def test_len(dtype) -> None: + values = ["foo", "fooo", "fooooo", "fooooooo"] + result = xr.DataArray(values).astype(dtype).str.len() + expected = xr.DataArray([len(x) for x in values]) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_find(dtype) -> None: + values = xr.DataArray(["ABCDEFG", "BCDEFEF", "DEFGHIJEF", "EFGHEF", "XXX"]) + values = values.astype(dtype) + + result_0 = values.str.find("EF") + result_1 = values.str.find("EF", side="left") + expected_0 = xr.DataArray([4, 3, 1, 0, -1]) + expected_1 = xr.DataArray([v.find(dtype("EF")) for v in values.values]) + assert result_0.dtype == expected_0.dtype + assert result_0.dtype == expected_1.dtype + assert result_1.dtype == expected_0.dtype + assert result_1.dtype == expected_1.dtype + assert_equal(result_0, expected_0) + assert_equal(result_0, expected_1) + assert_equal(result_1, expected_0) + assert_equal(result_1, expected_1) + + result_0 = values.str.rfind("EF") + result_1 = values.str.find("EF", side="right") + expected_0 = xr.DataArray([4, 5, 7, 4, -1]) + expected_1 = xr.DataArray([v.rfind(dtype("EF")) for v in values.values]) + assert result_0.dtype == expected_0.dtype + assert result_0.dtype == expected_1.dtype + assert result_1.dtype == expected_0.dtype + assert result_1.dtype == expected_1.dtype + assert_equal(result_0, expected_0) + assert_equal(result_0, expected_1) + assert_equal(result_1, expected_0) + assert_equal(result_1, expected_1) + + result_0 = values.str.find("EF", 3) + result_1 = values.str.find("EF", 3, side="left") + expected_0 = xr.DataArray([4, 3, 7, 4, -1]) + expected_1 = xr.DataArray([v.find(dtype("EF"), 3) for v in values.values]) + assert result_0.dtype == expected_0.dtype + assert result_0.dtype == expected_1.dtype + assert result_1.dtype == expected_0.dtype + assert result_1.dtype == expected_1.dtype + assert_equal(result_0, expected_0) + assert_equal(result_0, expected_1) + assert_equal(result_1, expected_0) + assert_equal(result_1, expected_1) + + result_0 = values.str.rfind("EF", 3) + result_1 = values.str.find("EF", 3, side="right") + expected_0 = xr.DataArray([4, 5, 7, 4, -1]) + expected_1 = xr.DataArray([v.rfind(dtype("EF"), 3) for v in values.values]) + assert result_0.dtype == expected_0.dtype + assert result_0.dtype == expected_1.dtype + assert result_1.dtype == expected_0.dtype + assert result_1.dtype == expected_1.dtype + assert_equal(result_0, expected_0) + assert_equal(result_0, expected_1) + assert_equal(result_1, expected_0) + assert_equal(result_1, expected_1) + + result_0 = values.str.find("EF", 3, 6) + result_1 = values.str.find("EF", 3, 6, side="left") + expected_0 = xr.DataArray([4, 3, -1, 4, -1]) + expected_1 = xr.DataArray([v.find(dtype("EF"), 3, 6) for v in values.values]) + assert result_0.dtype == expected_0.dtype + assert result_0.dtype == expected_1.dtype + assert result_1.dtype == expected_0.dtype + assert result_1.dtype == expected_1.dtype + assert_equal(result_0, expected_0) + assert_equal(result_0, expected_1) + assert_equal(result_1, expected_0) + assert_equal(result_1, expected_1) + + result_0 = values.str.rfind("EF", 3, 6) + result_1 = values.str.find("EF", 3, 6, side="right") + expected_0 = xr.DataArray([4, 3, -1, 4, -1]) + expected_1 = xr.DataArray([v.rfind(dtype("EF"), 3, 6) for v in values.values]) + assert result_0.dtype == expected_0.dtype + assert result_0.dtype == expected_1.dtype + assert result_1.dtype == expected_0.dtype + assert result_1.dtype == expected_1.dtype + assert_equal(result_0, expected_0) + assert_equal(result_0, expected_1) + assert_equal(result_1, expected_0) + assert_equal(result_1, expected_1) + + +def test_find_broadcast(dtype) -> None: + values = xr.DataArray( + ["ABCDEFG", "BCDEFEF", "DEFGHIJEF", "EFGHEF", "XXX"], dims=["X"] + ) + values = values.astype(dtype) + sub = xr.DataArray(["EF", "BC", "XX"], dims=["Y"]).astype(dtype) + start = xr.DataArray([0, 7], dims=["Z"]) + end = xr.DataArray([6, 9], dims=["Z"]) + + result_0 = values.str.find(sub, start, end) + result_1 = values.str.find(sub, start, end, side="left") + expected = xr.DataArray( + [ + [[4, -1], [1, -1], [-1, -1]], + [[3, -1], [0, -1], [-1, -1]], + [[1, 7], [-1, -1], [-1, -1]], + [[0, -1], [-1, -1], [-1, -1]], + [[-1, -1], [-1, -1], [0, -1]], + ], + dims=["X", "Y", "Z"], + ) + + assert result_0.dtype == expected.dtype + assert result_1.dtype == expected.dtype + assert_equal(result_0, expected) + assert_equal(result_1, expected) + + result_0 = values.str.rfind(sub, start, end) + result_1 = values.str.find(sub, start, end, side="right") + expected = xr.DataArray( + [ + [[4, -1], [1, -1], [-1, -1]], + [[3, -1], [0, -1], [-1, -1]], + [[1, 7], [-1, -1], [-1, -1]], + [[4, -1], [-1, -1], [-1, -1]], + [[-1, -1], [-1, -1], [1, -1]], + ], + dims=["X", "Y", "Z"], + ) + + assert result_0.dtype == expected.dtype + assert result_1.dtype == expected.dtype + assert_equal(result_0, expected) + assert_equal(result_1, expected) + + +def test_index(dtype) -> None: + s = xr.DataArray(["ABCDEFG", "BCDEFEF", "DEFGHIJEF", "EFGHEF"]).astype(dtype) + + result_0 = s.str.index("EF") + result_1 = s.str.index("EF", side="left") + expected = xr.DataArray([4, 3, 1, 0]) + assert result_0.dtype == expected.dtype + assert result_1.dtype == expected.dtype + assert_equal(result_0, expected) + assert_equal(result_1, expected) + + result_0 = s.str.rindex("EF") + result_1 = s.str.index("EF", side="right") + expected = xr.DataArray([4, 5, 7, 4]) + assert result_0.dtype == expected.dtype + assert result_1.dtype == expected.dtype + assert_equal(result_0, expected) + assert_equal(result_1, expected) + + result_0 = s.str.index("EF", 3) + result_1 = s.str.index("EF", 3, side="left") + expected = xr.DataArray([4, 3, 7, 4]) + assert result_0.dtype == expected.dtype + assert result_1.dtype == expected.dtype + assert_equal(result_0, expected) + assert_equal(result_1, expected) + + result_0 = s.str.rindex("EF", 3) + result_1 = s.str.index("EF", 3, side="right") + expected = xr.DataArray([4, 5, 7, 4]) + assert result_0.dtype == expected.dtype + assert result_1.dtype == expected.dtype + assert_equal(result_0, expected) + assert_equal(result_1, expected) + + result_0 = s.str.index("E", 4, 8) + result_1 = s.str.index("E", 4, 8, side="left") + expected = xr.DataArray([4, 5, 7, 4]) + assert result_0.dtype == expected.dtype + assert result_1.dtype == expected.dtype + assert_equal(result_0, expected) + assert_equal(result_1, expected) + + result_0 = s.str.rindex("E", 0, 5) + result_1 = s.str.index("E", 0, 5, side="right") + expected = xr.DataArray([4, 3, 1, 4]) + assert result_0.dtype == expected.dtype + assert result_1.dtype == expected.dtype + assert_equal(result_0, expected) + assert_equal(result_1, expected) + + matchtype = "subsection" if dtype == np.bytes_ else "substring" + with pytest.raises(ValueError, match=f"{matchtype} not found"): + s.str.index("DE") + + +def test_index_broadcast(dtype) -> None: + values = xr.DataArray( + ["ABCDEFGEFDBCA", "BCDEFEFEFDBC", "DEFBCGHIEFBC", "EFGHBCEFBCBCBCEF"], + dims=["X"], + ) + values = values.astype(dtype) + sub = xr.DataArray(["EF", "BC"], dims=["Y"]).astype(dtype) + start = xr.DataArray([0, 6], dims=["Z"]) + end = xr.DataArray([6, 12], dims=["Z"]) + + result_0 = values.str.index(sub, start, end) + result_1 = values.str.index(sub, start, end, side="left") + expected = xr.DataArray( + [[[4, 7], [1, 10]], [[3, 7], [0, 10]], [[1, 8], [3, 10]], [[0, 6], [4, 8]]], + dims=["X", "Y", "Z"], + ) + + assert result_0.dtype == expected.dtype + assert result_1.dtype == expected.dtype + assert_equal(result_0, expected) + assert_equal(result_1, expected) + + result_0 = values.str.rindex(sub, start, end) + result_1 = values.str.index(sub, start, end, side="right") + expected = xr.DataArray( + [[[4, 7], [1, 10]], [[3, 7], [0, 10]], [[1, 8], [3, 10]], [[0, 6], [4, 10]]], + dims=["X", "Y", "Z"], + ) + + assert result_0.dtype == expected.dtype + assert result_1.dtype == expected.dtype + assert_equal(result_0, expected) + assert_equal(result_1, expected) + + +def test_translate() -> None: + values = xr.DataArray(["abcdefg", "abcc", "cdddfg", "cdefggg"]) + table = str.maketrans("abc", "cde") + result = values.str.translate(table) + expected = xr.DataArray(["cdedefg", "cdee", "edddfg", "edefggg"]) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_pad_center_ljust_rjust(dtype) -> None: + values = xr.DataArray(["a", "b", "c", "eeeee"]).astype(dtype) + + result = values.str.center(5) + expected = xr.DataArray([" a ", " b ", " c ", "eeeee"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + result = values.str.pad(5, side="both") + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.ljust(5) + expected = xr.DataArray(["a ", "b ", "c ", "eeeee"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + result = values.str.pad(5, side="right") + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.rjust(5) + expected = xr.DataArray([" a", " b", " c", "eeeee"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + result = values.str.pad(5, side="left") + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_pad_center_ljust_rjust_fillchar(dtype) -> None: + values = xr.DataArray(["a", "bb", "cccc", "ddddd", "eeeeee"]).astype(dtype) + + result = values.str.center(5, fillchar="X") + expected = xr.DataArray(["XXaXX", "XXbbX", "Xcccc", "ddddd", "eeeeee"]).astype( + dtype + ) + assert result.dtype == expected.dtype + assert_equal(result, expected) + result = values.str.pad(5, side="both", fillchar="X") + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.ljust(5, fillchar="X") + expected = xr.DataArray(["aXXXX", "bbXXX", "ccccX", "ddddd", "eeeeee"]).astype( + dtype + ) + assert result.dtype == expected.dtype + assert_equal(result, expected.astype(dtype)) + result = values.str.pad(5, side="right", fillchar="X") + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.rjust(5, fillchar="X") + expected = xr.DataArray(["XXXXa", "XXXbb", "Xcccc", "ddddd", "eeeeee"]).astype( + dtype + ) + assert result.dtype == expected.dtype + assert_equal(result, expected.astype(dtype)) + result = values.str.pad(5, side="left", fillchar="X") + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # If fillchar is not a charatter, normal str raises TypeError + # 'aaa'.ljust(5, 'XY') + # TypeError: must be char, not str + template = "fillchar must be a character, not {dtype}" + + with pytest.raises(TypeError, match=template.format(dtype="str")): + values.str.center(5, fillchar="XY") + + with pytest.raises(TypeError, match=template.format(dtype="str")): + values.str.ljust(5, fillchar="XY") + + with pytest.raises(TypeError, match=template.format(dtype="str")): + values.str.rjust(5, fillchar="XY") + + with pytest.raises(TypeError, match=template.format(dtype="str")): + values.str.pad(5, fillchar="XY") + + +def test_pad_center_ljust_rjust_broadcast(dtype) -> None: + values = xr.DataArray(["a", "bb", "cccc", "ddddd", "eeeeee"], dims="X").astype( + dtype + ) + width = xr.DataArray([5, 4], dims="Y") + fillchar = xr.DataArray(["X", "#"], dims="Y").astype(dtype) + + result = values.str.center(width, fillchar=fillchar) + expected = xr.DataArray( + [ + ["XXaXX", "#a##"], + ["XXbbX", "#bb#"], + ["Xcccc", "cccc"], + ["ddddd", "ddddd"], + ["eeeeee", "eeeeee"], + ], + dims=["X", "Y"], + ).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + result = values.str.pad(width, side="both", fillchar=fillchar) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.ljust(width, fillchar=fillchar) + expected = xr.DataArray( + [ + ["aXXXX", "a###"], + ["bbXXX", "bb##"], + ["ccccX", "cccc"], + ["ddddd", "ddddd"], + ["eeeeee", "eeeeee"], + ], + dims=["X", "Y"], + ).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected.astype(dtype)) + result = values.str.pad(width, side="right", fillchar=fillchar) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.rjust(width, fillchar=fillchar) + expected = xr.DataArray( + [ + ["XXXXa", "###a"], + ["XXXbb", "##bb"], + ["Xcccc", "cccc"], + ["ddddd", "ddddd"], + ["eeeeee", "eeeeee"], + ], + dims=["X", "Y"], + ).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected.astype(dtype)) + result = values.str.pad(width, side="left", fillchar=fillchar) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_zfill(dtype) -> None: + values = xr.DataArray(["1", "22", "aaa", "333", "45678"]).astype(dtype) + + result = values.str.zfill(5) + expected = xr.DataArray(["00001", "00022", "00aaa", "00333", "45678"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.zfill(3) + expected = xr.DataArray(["001", "022", "aaa", "333", "45678"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_zfill_broadcast(dtype) -> None: + values = xr.DataArray(["1", "22", "aaa", "333", "45678"]).astype(dtype) + width = np.array([4, 5, 0, 3, 8]) + + result = values.str.zfill(width) + expected = xr.DataArray(["0001", "00022", "aaa", "333", "00045678"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_slice(dtype) -> None: + arr = xr.DataArray(["aafootwo", "aabartwo", "aabazqux"]).astype(dtype) + + result = arr.str.slice(2, 5) + exp = xr.DataArray(["foo", "bar", "baz"]).astype(dtype) + assert result.dtype == exp.dtype + assert_equal(result, exp) + + for start, stop, step in [(0, 3, -1), (None, None, -1), (3, 10, 2), (3, 0, -1)]: + try: + result = arr.str[start:stop:step] + expected = xr.DataArray([s[start:stop:step] for s in arr.values]) + assert_equal(result, expected.astype(dtype)) + except IndexError: + print(f"failed on {start}:{stop}:{step}") + raise + + +def test_slice_broadcast(dtype) -> None: + arr = xr.DataArray(["aafootwo", "aabartwo", "aabazqux"]).astype(dtype) + start = xr.DataArray([1, 2, 3]) + stop = 5 + + result = arr.str.slice(start=start, stop=stop) + exp = xr.DataArray(["afoo", "bar", "az"]).astype(dtype) + assert result.dtype == exp.dtype + assert_equal(result, exp) + + +def test_slice_replace(dtype) -> None: + da = lambda x: xr.DataArray(x).astype(dtype) + values = da(["short", "a bit longer", "evenlongerthanthat", ""]) + + expected = da(["shrt", "a it longer", "evnlongerthanthat", ""]) + result = values.str.slice_replace(2, 3) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + expected = da(["shzrt", "a zit longer", "evznlongerthanthat", "z"]) + result = values.str.slice_replace(2, 3, "z") + assert result.dtype == expected.dtype + assert_equal(result, expected) + + expected = da(["shzort", "a zbit longer", "evzenlongerthanthat", "z"]) + result = values.str.slice_replace(2, 2, "z") + assert result.dtype == expected.dtype + assert_equal(result, expected) + + expected = da(["shzort", "a zbit longer", "evzenlongerthanthat", "z"]) + result = values.str.slice_replace(2, 1, "z") + assert result.dtype == expected.dtype + assert_equal(result, expected) + + expected = da(["shorz", "a bit longez", "evenlongerthanthaz", "z"]) + result = values.str.slice_replace(-1, None, "z") + assert result.dtype == expected.dtype + assert_equal(result, expected) + + expected = da(["zrt", "zer", "zat", "z"]) + result = values.str.slice_replace(None, -2, "z") + assert result.dtype == expected.dtype + assert_equal(result, expected) + + expected = da(["shortz", "a bit znger", "evenlozerthanthat", "z"]) + result = values.str.slice_replace(6, 8, "z") + assert result.dtype == expected.dtype + assert_equal(result, expected) + + expected = da(["zrt", "a zit longer", "evenlongzerthanthat", "z"]) + result = values.str.slice_replace(-10, 3, "z") + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_slice_replace_broadcast(dtype) -> None: + values = xr.DataArray(["short", "a bit longer", "evenlongerthanthat", ""]).astype( + dtype + ) + start = 2 + stop = np.array([4, 5, None, 7]) + repl = "test" + + expected = xr.DataArray(["shtestt", "a test longer", "evtest", "test"]).astype( + dtype + ) + result = values.str.slice_replace(start, stop, repl) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_strip_lstrip_rstrip(dtype) -> None: + values = xr.DataArray([" aa ", " bb \n", "cc "]).astype(dtype) + + result = values.str.strip() + expected = xr.DataArray(["aa", "bb", "cc"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.lstrip() + expected = xr.DataArray(["aa ", "bb \n", "cc "]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.rstrip() + expected = xr.DataArray([" aa", " bb", "cc"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_strip_lstrip_rstrip_args(dtype) -> None: + values = xr.DataArray(["xxABCxx", "xx BNSD", "LDFJH xx"]).astype(dtype) + + result = values.str.strip("x") + expected = xr.DataArray(["ABC", " BNSD", "LDFJH "]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.lstrip("x") + expected = xr.DataArray(["ABCxx", " BNSD", "LDFJH xx"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.rstrip("x") + expected = xr.DataArray(["xxABC", "xx BNSD", "LDFJH "]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_strip_lstrip_rstrip_broadcast(dtype) -> None: + values = xr.DataArray(["xxABCxx", "yy BNSD", "LDFJH zz"]).astype(dtype) + to_strip = xr.DataArray(["x", "y", "z"]).astype(dtype) + + result = values.str.strip(to_strip) + expected = xr.DataArray(["ABC", " BNSD", "LDFJH "]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.lstrip(to_strip) + expected = xr.DataArray(["ABCxx", " BNSD", "LDFJH zz"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.rstrip(to_strip) + expected = xr.DataArray(["xxABC", "yy BNSD", "LDFJH "]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_wrap() -> None: + # test values are: two words less than width, two words equal to width, + # two words greater than width, one word less than width, one word + # equal to width, one word greater than width, multiple tokens with + # trailing whitespace equal to width + values = xr.DataArray( + [ + "hello world", + "hello world!", + "hello world!!", + "abcdefabcde", + "abcdefabcdef", + "abcdefabcdefa", + "ab ab ab ab ", + "ab ab ab ab a", + "\t", + ] + ) + + # expected values + expected = xr.DataArray( + [ + "hello world", + "hello world!", + "hello\nworld!!", + "abcdefabcde", + "abcdefabcdef", + "abcdefabcdef\na", + "ab ab ab ab", + "ab ab ab ab\na", + "", + ] + ) + + result = values.str.wrap(12, break_long_words=True) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # test with pre and post whitespace (non-unicode), NaN, and non-ascii + # Unicode + values = xr.DataArray([" pre ", "\xac\u20ac\U00008000 abadcafe"]) + expected = xr.DataArray([" pre", "\xac\u20ac\U00008000 ab\nadcafe"]) + result = values.str.wrap(6) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_wrap_kwargs_passed() -> None: + # GH4334 + + values = xr.DataArray(" hello world ") + + result = values.str.wrap(7) + expected = xr.DataArray(" hello\nworld") + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.wrap(7, drop_whitespace=False) + expected = xr.DataArray(" hello\n world\n ") + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_get(dtype) -> None: + values = xr.DataArray(["a_b_c", "c_d_e", "f_g_h"]).astype(dtype) + + result = values.str[2] + expected = xr.DataArray(["b", "d", "g"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # bounds testing + values = xr.DataArray(["1_2_3_4_5", "6_7_8_9_10", "11_12"]).astype(dtype) + + # positive index + result = values.str[5] + expected = xr.DataArray(["_", "_", ""]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # negative index + result = values.str[-6] + expected = xr.DataArray(["_", "8", ""]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_get_default(dtype) -> None: + # GH4334 + values = xr.DataArray(["a_b", "c", ""]).astype(dtype) + + result = values.str.get(2, "default") + expected = xr.DataArray(["b", "default", "default"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_get_broadcast(dtype) -> None: + values = xr.DataArray(["a_b_c", "c_d_e", "f_g_h"], dims=["X"]).astype(dtype) + inds = xr.DataArray([0, 2], dims=["Y"]) + + result = values.str.get(inds) + expected = xr.DataArray( + [["a", "b"], ["c", "d"], ["f", "g"]], dims=["X", "Y"] + ).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_encode_decode() -> None: + data = xr.DataArray(["a", "b", "a\xe4"]) + encoded = data.str.encode("utf-8") + decoded = encoded.str.decode("utf-8") + assert data.dtype == decoded.dtype + assert_equal(data, decoded) + + +def test_encode_decode_errors() -> None: + encodeBase = xr.DataArray(["a", "b", "a\x9d"]) + + msg = ( + r"'charmap' codec can't encode character '\\x9d' in position 1:" + " character maps to " + ) + with pytest.raises(UnicodeEncodeError, match=msg): + encodeBase.str.encode("cp1252") + + f = lambda x: x.encode("cp1252", "ignore") + result = encodeBase.str.encode("cp1252", "ignore") + expected = xr.DataArray([f(x) for x in encodeBase.values.tolist()]) + + assert result.dtype == expected.dtype + assert_equal(result, expected) + + decodeBase = xr.DataArray([b"a", b"b", b"a\x9d"]) + + msg = ( + "'charmap' codec can't decode byte 0x9d in position 1:" + " character maps to " + ) + with pytest.raises(UnicodeDecodeError, match=msg): + decodeBase.str.decode("cp1252") + + f = lambda x: x.decode("cp1252", "ignore") + result = decodeBase.str.decode("cp1252", "ignore") + expected = xr.DataArray([f(x) for x in decodeBase.values.tolist()]) + + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_partition_whitespace(dtype) -> None: + values = xr.DataArray( + [ + ["abc def", "spam eggs swallow", "red_blue"], + ["test0 test1 test2 test3", "", "abra ka da bra"], + ], + dims=["X", "Y"], + ).astype(dtype) + + exp_part_dim_list = [ + [ + ["abc", " ", "def"], + ["spam", " ", "eggs swallow"], + ["red_blue", "", ""], + ], + [ + ["test0", " ", "test1 test2 test3"], + ["", "", ""], + ["abra", " ", "ka da bra"], + ], + ] + + exp_rpart_dim_list = [ + [ + ["abc", " ", "def"], + ["spam eggs", " ", "swallow"], + ["", "", "red_blue"], + ], + [ + ["test0 test1 test2", " ", "test3"], + ["", "", ""], + ["abra ka da", " ", "bra"], + ], + ] + + exp_part_dim = xr.DataArray(exp_part_dim_list, dims=["X", "Y", "ZZ"]).astype(dtype) + exp_rpart_dim = xr.DataArray(exp_rpart_dim_list, dims=["X", "Y", "ZZ"]).astype( + dtype + ) + + res_part_dim = values.str.partition(dim="ZZ") + res_rpart_dim = values.str.rpartition(dim="ZZ") + + assert res_part_dim.dtype == exp_part_dim.dtype + assert res_rpart_dim.dtype == exp_rpart_dim.dtype + + assert_equal(res_part_dim, exp_part_dim) + assert_equal(res_rpart_dim, exp_rpart_dim) + + +def test_partition_comma(dtype) -> None: + values = xr.DataArray( + [ + ["abc, def", "spam, eggs, swallow", "red_blue"], + ["test0, test1, test2, test3", "", "abra, ka, da, bra"], + ], + dims=["X", "Y"], + ).astype(dtype) + + exp_part_dim_list = [ + [ + ["abc", ", ", "def"], + ["spam", ", ", "eggs, swallow"], + ["red_blue", "", ""], + ], + [ + ["test0", ", ", "test1, test2, test3"], + ["", "", ""], + ["abra", ", ", "ka, da, bra"], + ], + ] + + exp_rpart_dim_list = [ + [ + ["abc", ", ", "def"], + ["spam, eggs", ", ", "swallow"], + ["", "", "red_blue"], + ], + [ + ["test0, test1, test2", ", ", "test3"], + ["", "", ""], + ["abra, ka, da", ", ", "bra"], + ], + ] + + exp_part_dim = xr.DataArray(exp_part_dim_list, dims=["X", "Y", "ZZ"]).astype(dtype) + exp_rpart_dim = xr.DataArray(exp_rpart_dim_list, dims=["X", "Y", "ZZ"]).astype( + dtype + ) + + res_part_dim = values.str.partition(sep=", ", dim="ZZ") + res_rpart_dim = values.str.rpartition(sep=", ", dim="ZZ") + + assert res_part_dim.dtype == exp_part_dim.dtype + assert res_rpart_dim.dtype == exp_rpart_dim.dtype + + assert_equal(res_part_dim, exp_part_dim) + assert_equal(res_rpart_dim, exp_rpart_dim) + + +def test_partition_empty(dtype) -> None: + values = xr.DataArray([], dims=["X"]).astype(dtype) + expected = xr.DataArray(np.zeros((0, 0)), dims=["X", "ZZ"]).astype(dtype) + + res = values.str.partition(sep=", ", dim="ZZ") + + assert res.dtype == expected.dtype + assert_equal(res, expected) + + +@pytest.mark.parametrize( + ["func", "expected"], + [ + pytest.param( + lambda x: x.str.split(dim=None), + [ + [["abc", "def"], ["spam", "eggs", "swallow"], ["red_blue"]], + [["test0", "test1", "test2", "test3"], [], ["abra", "ka", "da", "bra"]], + ], + id="split_full", + ), + pytest.param( + lambda x: x.str.rsplit(dim=None), + [ + [["abc", "def"], ["spam", "eggs", "swallow"], ["red_blue"]], + [["test0", "test1", "test2", "test3"], [], ["abra", "ka", "da", "bra"]], + ], + id="rsplit_full", + ), + pytest.param( + lambda x: x.str.split(dim=None, maxsplit=1), + [ + [["abc", "def"], ["spam", "eggs\tswallow"], ["red_blue"]], + [["test0", "test1\ntest2\n\ntest3"], [], ["abra", "ka\nda\tbra"]], + ], + id="split_1", + ), + pytest.param( + lambda x: x.str.rsplit(dim=None, maxsplit=1), + [ + [["abc", "def"], ["spam\t\teggs", "swallow"], ["red_blue"]], + [["test0\ntest1\ntest2", "test3"], [], ["abra ka\nda", "bra"]], + ], + id="rsplit_1", + ), + ], +) +def test_split_whitespace_nodim( + dtype, func: Callable[[xr.DataArray], xr.DataArray], expected: xr.DataArray +) -> None: + values = xr.DataArray( + [ + ["abc def", "spam\t\teggs\tswallow", "red_blue"], + ["test0\ntest1\ntest2\n\ntest3", "", "abra ka\nda\tbra"], + ], + dims=["X", "Y"], + ).astype(dtype) + + expected_dtype = [[[dtype(x) for x in y] for y in z] for z in expected] + expected_np = np.array(expected_dtype, dtype=np.object_) + expected_da = xr.DataArray(expected_np, dims=["X", "Y"]) + + actual = func(values) + + assert actual.dtype == expected_da.dtype + assert_equal(actual, expected_da) + + +@pytest.mark.parametrize( + ["func", "expected"], + [ + pytest.param( + lambda x: x.str.split(dim="ZZ"), + [ + [ + ["abc", "def", "", ""], + ["spam", "eggs", "swallow", ""], + ["red_blue", "", "", ""], + ], + [ + ["test0", "test1", "test2", "test3"], + ["", "", "", ""], + ["abra", "ka", "da", "bra"], + ], + ], + id="split_full", + ), + pytest.param( + lambda x: x.str.rsplit(dim="ZZ"), + [ + [ + ["", "", "abc", "def"], + ["", "spam", "eggs", "swallow"], + ["", "", "", "red_blue"], + ], + [ + ["test0", "test1", "test2", "test3"], + ["", "", "", ""], + ["abra", "ka", "da", "bra"], + ], + ], + id="rsplit_full", + ), + pytest.param( + lambda x: x.str.split(dim="ZZ", maxsplit=1), + [ + [["abc", "def"], ["spam", "eggs\tswallow"], ["red_blue", ""]], + [["test0", "test1\ntest2\n\ntest3"], ["", ""], ["abra", "ka\nda\tbra"]], + ], + id="split_1", + ), + pytest.param( + lambda x: x.str.rsplit(dim="ZZ", maxsplit=1), + [ + [["abc", "def"], ["spam\t\teggs", "swallow"], ["", "red_blue"]], + [["test0\ntest1\ntest2", "test3"], ["", ""], ["abra ka\nda", "bra"]], + ], + id="rsplit_1", + ), + ], +) +def test_split_whitespace_dim( + dtype, func: Callable[[xr.DataArray], xr.DataArray], expected: xr.DataArray +) -> None: + values = xr.DataArray( + [ + ["abc def", "spam\t\teggs\tswallow", "red_blue"], + ["test0\ntest1\ntest2\n\ntest3", "", "abra ka\nda\tbra"], + ], + dims=["X", "Y"], + ).astype(dtype) + + expected_dtype = [[[dtype(x) for x in y] for y in z] for z in expected] + expected_np = np.array(expected_dtype, dtype=np.object_) + expected_da = xr.DataArray(expected_np, dims=["X", "Y", "ZZ"]).astype(dtype) + + actual = func(values) + + assert actual.dtype == expected_da.dtype + assert_equal(actual, expected_da) + + +@pytest.mark.parametrize( + ["func", "expected"], + [ + pytest.param( + lambda x: x.str.split(sep=",", dim=None), + [ + [["abc", "def"], ["spam", "", "eggs", "swallow"], ["red_blue"]], + [ + ["test0", "test1", "test2", "test3"], + [""], + ["abra", "ka", "da", "bra"], + ], + ], + id="split_full", + ), + pytest.param( + lambda x: x.str.rsplit(sep=",", dim=None), + [ + [["abc", "def"], ["spam", "", "eggs", "swallow"], ["red_blue"]], + [ + ["test0", "test1", "test2", "test3"], + [""], + ["abra", "ka", "da", "bra"], + ], + ], + id="rsplit_full", + ), + pytest.param( + lambda x: x.str.split(sep=",", dim=None, maxsplit=1), + [ + [["abc", "def"], ["spam", ",eggs,swallow"], ["red_blue"]], + [["test0", "test1,test2,test3"], [""], ["abra", "ka,da,bra"]], + ], + id="split_1", + ), + pytest.param( + lambda x: x.str.rsplit(sep=",", dim=None, maxsplit=1), + [ + [["abc", "def"], ["spam,,eggs", "swallow"], ["red_blue"]], + [["test0,test1,test2", "test3"], [""], ["abra,ka,da", "bra"]], + ], + id="rsplit_1", + ), + pytest.param( + lambda x: x.str.split(sep=",", dim=None, maxsplit=10), + [ + [["abc", "def"], ["spam", "", "eggs", "swallow"], ["red_blue"]], + [ + ["test0", "test1", "test2", "test3"], + [""], + ["abra", "ka", "da", "bra"], + ], + ], + id="split_10", + ), + pytest.param( + lambda x: x.str.rsplit(sep=",", dim=None, maxsplit=10), + [ + [["abc", "def"], ["spam", "", "eggs", "swallow"], ["red_blue"]], + [ + ["test0", "test1", "test2", "test3"], + [""], + ["abra", "ka", "da", "bra"], + ], + ], + id="rsplit_10", + ), + ], +) +def test_split_comma_nodim( + dtype, func: Callable[[xr.DataArray], xr.DataArray], expected: xr.DataArray +) -> None: + values = xr.DataArray( + [ + ["abc,def", "spam,,eggs,swallow", "red_blue"], + ["test0,test1,test2,test3", "", "abra,ka,da,bra"], + ], + dims=["X", "Y"], + ).astype(dtype) + + expected_dtype = [[[dtype(x) for x in y] for y in z] for z in expected] + expected_np = np.array(expected_dtype, dtype=np.object_) + expected_da = xr.DataArray(expected_np, dims=["X", "Y"]) + + actual = func(values) + + assert actual.dtype == expected_da.dtype + assert_equal(actual, expected_da) + + +@pytest.mark.parametrize( + ["func", "expected"], + [ + pytest.param( + lambda x: x.str.split(sep=",", dim="ZZ"), + [ + [ + ["abc", "def", "", ""], + ["spam", "", "eggs", "swallow"], + ["red_blue", "", "", ""], + ], + [ + ["test0", "test1", "test2", "test3"], + ["", "", "", ""], + ["abra", "ka", "da", "bra"], + ], + ], + id="split_full", + ), + pytest.param( + lambda x: x.str.rsplit(sep=",", dim="ZZ"), + [ + [ + ["", "", "abc", "def"], + ["spam", "", "eggs", "swallow"], + ["", "", "", "red_blue"], + ], + [ + ["test0", "test1", "test2", "test3"], + ["", "", "", ""], + ["abra", "ka", "da", "bra"], + ], + ], + id="rsplit_full", + ), + pytest.param( + lambda x: x.str.split(sep=",", dim="ZZ", maxsplit=1), + [ + [["abc", "def"], ["spam", ",eggs,swallow"], ["red_blue", ""]], + [["test0", "test1,test2,test3"], ["", ""], ["abra", "ka,da,bra"]], + ], + id="split_1", + ), + pytest.param( + lambda x: x.str.rsplit(sep=",", dim="ZZ", maxsplit=1), + [ + [["abc", "def"], ["spam,,eggs", "swallow"], ["", "red_blue"]], + [["test0,test1,test2", "test3"], ["", ""], ["abra,ka,da", "bra"]], + ], + id="rsplit_1", + ), + pytest.param( + lambda x: x.str.split(sep=",", dim="ZZ", maxsplit=10), + [ + [ + ["abc", "def", "", ""], + ["spam", "", "eggs", "swallow"], + ["red_blue", "", "", ""], + ], + [ + ["test0", "test1", "test2", "test3"], + ["", "", "", ""], + ["abra", "ka", "da", "bra"], + ], + ], + id="split_10", + ), + pytest.param( + lambda x: x.str.rsplit(sep=",", dim="ZZ", maxsplit=10), + [ + [ + ["", "", "abc", "def"], + ["spam", "", "eggs", "swallow"], + ["", "", "", "red_blue"], + ], + [ + ["test0", "test1", "test2", "test3"], + ["", "", "", ""], + ["abra", "ka", "da", "bra"], + ], + ], + id="rsplit_10", + ), + ], +) +def test_split_comma_dim( + dtype, func: Callable[[xr.DataArray], xr.DataArray], expected: xr.DataArray +) -> None: + values = xr.DataArray( + [ + ["abc,def", "spam,,eggs,swallow", "red_blue"], + ["test0,test1,test2,test3", "", "abra,ka,da,bra"], + ], + dims=["X", "Y"], + ).astype(dtype) + + expected_dtype = [[[dtype(x) for x in y] for y in z] for z in expected] + expected_np = np.array(expected_dtype, dtype=np.object_) + expected_da = xr.DataArray(expected_np, dims=["X", "Y", "ZZ"]).astype(dtype) + + actual = func(values) + + assert actual.dtype == expected_da.dtype + assert_equal(actual, expected_da) + + +def test_splitters_broadcast(dtype) -> None: + values = xr.DataArray( + ["ab cd,de fg", "spam, ,eggs swallow", "red_blue"], + dims=["X"], + ).astype(dtype) + + sep = xr.DataArray( + [" ", ","], + dims=["Y"], + ).astype(dtype) + + expected_left = xr.DataArray( + [ + [["ab", "cd,de fg"], ["ab cd", "de fg"]], + [["spam,", ",eggs swallow"], ["spam", " ,eggs swallow"]], + [["red_blue", ""], ["red_blue", ""]], + ], + dims=["X", "Y", "ZZ"], + ).astype(dtype) + expected_right = xr.DataArray( + [ + [["ab cd,de", "fg"], ["ab cd", "de fg"]], + [["spam, ,eggs", "swallow"], ["spam, ", "eggs swallow"]], + [["", "red_blue"], ["", "red_blue"]], + ], + dims=["X", "Y", "ZZ"], + ).astype(dtype) + + res_left = values.str.split(dim="ZZ", sep=sep, maxsplit=1) + res_right = values.str.rsplit(dim="ZZ", sep=sep, maxsplit=1) + + # assert res_left.dtype == expected_left.dtype + # assert res_right.dtype == expected_right.dtype + + assert_equal(res_left, expected_left) + assert_equal(res_right, expected_right) + + expected_left = xr.DataArray( + [ + [["ab", " ", "cd,de fg"], ["ab cd", ",", "de fg"]], + [["spam,", " ", ",eggs swallow"], ["spam", ",", " ,eggs swallow"]], + [["red_blue", "", ""], ["red_blue", "", ""]], + ], + dims=["X", "Y", "ZZ"], + ).astype(dtype) + expected_right = xr.DataArray( + [ + [["ab", " ", "cd,de fg"], ["ab cd", ",", "de fg"]], + [["spam,", " ", ",eggs swallow"], ["spam", ",", " ,eggs swallow"]], + [["red_blue", "", ""], ["red_blue", "", ""]], + ], + dims=["X", "Y", "ZZ"], + ).astype(dtype) + + res_left = values.str.partition(dim="ZZ", sep=sep) + res_right = values.str.partition(dim="ZZ", sep=sep) + + # assert res_left.dtype == expected_left.dtype + # assert res_right.dtype == expected_right.dtype + + assert_equal(res_left, expected_left) + assert_equal(res_right, expected_right) + + +def test_split_empty(dtype) -> None: + values = xr.DataArray([], dims=["X"]).astype(dtype) + expected = xr.DataArray(np.zeros((0, 0)), dims=["X", "ZZ"]).astype(dtype) + + res = values.str.split(sep=", ", dim="ZZ") + + assert res.dtype == expected.dtype + assert_equal(res, expected) + + +def test_get_dummies(dtype) -> None: + values_line = xr.DataArray( + [["a|ab~abc|abc", "ab", "a||abc|abcd"], ["abcd|ab|a", "abc|ab~abc", "|a"]], + dims=["X", "Y"], + ).astype(dtype) + values_comma = xr.DataArray( + [["a~ab|abc~~abc", "ab", "a~abc~abcd"], ["abcd~ab~a", "abc~ab|abc", "~a"]], + dims=["X", "Y"], + ).astype(dtype) + + vals_line = np.array(["a", "ab", "abc", "abcd", "ab~abc"]).astype(dtype) + vals_comma = np.array(["a", "ab", "abc", "abcd", "ab|abc"]).astype(dtype) + expected_list = [ + [ + [True, False, True, False, True], + [False, True, False, False, False], + [True, False, True, True, False], + ], + [ + [True, True, False, True, False], + [False, False, True, False, True], + [True, False, False, False, False], + ], + ] + expected_np = np.array(expected_list) + expected = xr.DataArray(expected_np, dims=["X", "Y", "ZZ"]) + targ_line = expected.copy() + targ_comma = expected.copy() + targ_line.coords["ZZ"] = vals_line + targ_comma.coords["ZZ"] = vals_comma + + res_default = values_line.str.get_dummies(dim="ZZ") + res_line = values_line.str.get_dummies(dim="ZZ", sep="|") + res_comma = values_comma.str.get_dummies(dim="ZZ", sep="~") + + assert res_default.dtype == targ_line.dtype + assert res_line.dtype == targ_line.dtype + assert res_comma.dtype == targ_comma.dtype + + assert_equal(res_default, targ_line) + assert_equal(res_line, targ_line) + assert_equal(res_comma, targ_comma) + + +def test_get_dummies_broadcast(dtype) -> None: + values = xr.DataArray( + ["x~x|x~x", "x", "x|x~x", "x~x"], + dims=["X"], + ).astype(dtype) + + sep = xr.DataArray( + ["|", "~"], + dims=["Y"], + ).astype(dtype) + + expected_list = [ + [[False, False, True], [True, True, False]], + [[True, False, False], [True, False, False]], + [[True, False, True], [True, True, False]], + [[False, False, True], [True, False, False]], + ] + expected_np = np.array(expected_list) + expected = xr.DataArray(expected_np, dims=["X", "Y", "ZZ"]) + expected.coords["ZZ"] = np.array(["x", "x|x", "x~x"]).astype(dtype) + + res = values.str.get_dummies(dim="ZZ", sep=sep) + + assert res.dtype == expected.dtype + + assert_equal(res, expected) + + +def test_get_dummies_empty(dtype) -> None: + values = xr.DataArray([], dims=["X"]).astype(dtype) + expected = xr.DataArray(np.zeros((0, 0)), dims=["X", "ZZ"]).astype(dtype) + + res = values.str.get_dummies(dim="ZZ") + + assert res.dtype == expected.dtype + assert_equal(res, expected) + + +def test_splitters_empty_str(dtype) -> None: + values = xr.DataArray( + [["", "", ""], ["", "", ""]], + dims=["X", "Y"], + ).astype(dtype) + + targ_partition_dim = xr.DataArray( + [ + [["", "", ""], ["", "", ""], ["", "", ""]], + [["", "", ""], ["", "", ""], ["", "", ""]], + ], + dims=["X", "Y", "ZZ"], + ).astype(dtype) + + targ_partition_none_list = [ + [["", "", ""], ["", "", ""], ["", "", ""]], + [["", "", ""], ["", "", ""], ["", "", "", ""]], + ] + targ_partition_none_list = [ + [[dtype(x) for x in y] for y in z] for z in targ_partition_none_list + ] + targ_partition_none_np = np.array(targ_partition_none_list, dtype=np.object_) + del targ_partition_none_np[-1, -1][-1] + targ_partition_none = xr.DataArray( + targ_partition_none_np, + dims=["X", "Y"], + ) + + targ_split_dim = xr.DataArray( + [[[""], [""], [""]], [[""], [""], [""]]], + dims=["X", "Y", "ZZ"], + ).astype(dtype) + targ_split_none = xr.DataArray( + np.array([[[], [], []], [[], [], [""]]], dtype=np.object_), + dims=["X", "Y"], + ) + del targ_split_none.data[-1, -1][-1] + + res_partition_dim = values.str.partition(dim="ZZ") + res_rpartition_dim = values.str.rpartition(dim="ZZ") + res_partition_none = values.str.partition(dim=None) + res_rpartition_none = values.str.rpartition(dim=None) + + res_split_dim = values.str.split(dim="ZZ") + res_rsplit_dim = values.str.rsplit(dim="ZZ") + res_split_none = values.str.split(dim=None) + res_rsplit_none = values.str.rsplit(dim=None) + + res_dummies = values.str.rsplit(dim="ZZ") + + assert res_partition_dim.dtype == targ_partition_dim.dtype + assert res_rpartition_dim.dtype == targ_partition_dim.dtype + assert res_partition_none.dtype == targ_partition_none.dtype + assert res_rpartition_none.dtype == targ_partition_none.dtype + + assert res_split_dim.dtype == targ_split_dim.dtype + assert res_rsplit_dim.dtype == targ_split_dim.dtype + assert res_split_none.dtype == targ_split_none.dtype + assert res_rsplit_none.dtype == targ_split_none.dtype + + assert res_dummies.dtype == targ_split_dim.dtype + + assert_equal(res_partition_dim, targ_partition_dim) + assert_equal(res_rpartition_dim, targ_partition_dim) + assert_equal(res_partition_none, targ_partition_none) + assert_equal(res_rpartition_none, targ_partition_none) + + assert_equal(res_split_dim, targ_split_dim) + assert_equal(res_rsplit_dim, targ_split_dim) + assert_equal(res_split_none, targ_split_none) + assert_equal(res_rsplit_none, targ_split_none) + + assert_equal(res_dummies, targ_split_dim) + + +def test_cat_str(dtype) -> None: + values_1 = xr.DataArray( + [["a", "bb", "cccc"], ["ddddd", "eeee", "fff"]], + dims=["X", "Y"], + ).astype(dtype) + values_2 = "111" + + targ_blank = xr.DataArray( + [["a111", "bb111", "cccc111"], ["ddddd111", "eeee111", "fff111"]], + dims=["X", "Y"], + ).astype(dtype) + + targ_space = xr.DataArray( + [["a 111", "bb 111", "cccc 111"], ["ddddd 111", "eeee 111", "fff 111"]], + dims=["X", "Y"], + ).astype(dtype) + + targ_bars = xr.DataArray( + [["a||111", "bb||111", "cccc||111"], ["ddddd||111", "eeee||111", "fff||111"]], + dims=["X", "Y"], + ).astype(dtype) + + targ_comma = xr.DataArray( + [["a, 111", "bb, 111", "cccc, 111"], ["ddddd, 111", "eeee, 111", "fff, 111"]], + dims=["X", "Y"], + ).astype(dtype) + + res_blank = values_1.str.cat(values_2) + res_add = values_1.str + values_2 + res_space = values_1.str.cat(values_2, sep=" ") + res_bars = values_1.str.cat(values_2, sep="||") + res_comma = values_1.str.cat(values_2, sep=", ") + + assert res_blank.dtype == targ_blank.dtype + assert res_add.dtype == targ_blank.dtype + assert res_space.dtype == targ_space.dtype + assert res_bars.dtype == targ_bars.dtype + assert res_comma.dtype == targ_comma.dtype + + assert_equal(res_blank, targ_blank) + assert_equal(res_add, targ_blank) + assert_equal(res_space, targ_space) + assert_equal(res_bars, targ_bars) + assert_equal(res_comma, targ_comma) + + +def test_cat_uniform(dtype) -> None: + values_1 = xr.DataArray( + [["a", "bb", "cccc"], ["ddddd", "eeee", "fff"]], + dims=["X", "Y"], + ).astype(dtype) + values_2 = xr.DataArray( + [["11111", "222", "33"], ["4", "5555", "66"]], + dims=["X", "Y"], + ) + + targ_blank = xr.DataArray( + [["a11111", "bb222", "cccc33"], ["ddddd4", "eeee5555", "fff66"]], + dims=["X", "Y"], + ).astype(dtype) + + targ_space = xr.DataArray( + [["a 11111", "bb 222", "cccc 33"], ["ddddd 4", "eeee 5555", "fff 66"]], + dims=["X", "Y"], + ).astype(dtype) + + targ_bars = xr.DataArray( + [["a||11111", "bb||222", "cccc||33"], ["ddddd||4", "eeee||5555", "fff||66"]], + dims=["X", "Y"], + ).astype(dtype) + + targ_comma = xr.DataArray( + [["a, 11111", "bb, 222", "cccc, 33"], ["ddddd, 4", "eeee, 5555", "fff, 66"]], + dims=["X", "Y"], + ).astype(dtype) + + res_blank = values_1.str.cat(values_2) + res_add = values_1.str + values_2 + res_space = values_1.str.cat(values_2, sep=" ") + res_bars = values_1.str.cat(values_2, sep="||") + res_comma = values_1.str.cat(values_2, sep=", ") + + assert res_blank.dtype == targ_blank.dtype + assert res_add.dtype == targ_blank.dtype + assert res_space.dtype == targ_space.dtype + assert res_bars.dtype == targ_bars.dtype + assert res_comma.dtype == targ_comma.dtype + + assert_equal(res_blank, targ_blank) + assert_equal(res_add, targ_blank) + assert_equal(res_space, targ_space) + assert_equal(res_bars, targ_bars) + assert_equal(res_comma, targ_comma) + + +def test_cat_broadcast_right(dtype) -> None: + values_1 = xr.DataArray( + [["a", "bb", "cccc"], ["ddddd", "eeee", "fff"]], + dims=["X", "Y"], + ).astype(dtype) + values_2 = xr.DataArray( + ["11111", "222", "33"], + dims=["Y"], + ) + + targ_blank = xr.DataArray( + [["a11111", "bb222", "cccc33"], ["ddddd11111", "eeee222", "fff33"]], + dims=["X", "Y"], + ).astype(dtype) + + targ_space = xr.DataArray( + [["a 11111", "bb 222", "cccc 33"], ["ddddd 11111", "eeee 222", "fff 33"]], + dims=["X", "Y"], + ).astype(dtype) + + targ_bars = xr.DataArray( + [["a||11111", "bb||222", "cccc||33"], ["ddddd||11111", "eeee||222", "fff||33"]], + dims=["X", "Y"], + ).astype(dtype) + + targ_comma = xr.DataArray( + [["a, 11111", "bb, 222", "cccc, 33"], ["ddddd, 11111", "eeee, 222", "fff, 33"]], + dims=["X", "Y"], + ).astype(dtype) + + res_blank = values_1.str.cat(values_2) + res_add = values_1.str + values_2 + res_space = values_1.str.cat(values_2, sep=" ") + res_bars = values_1.str.cat(values_2, sep="||") + res_comma = values_1.str.cat(values_2, sep=", ") + + assert res_blank.dtype == targ_blank.dtype + assert res_add.dtype == targ_blank.dtype + assert res_space.dtype == targ_space.dtype + assert res_bars.dtype == targ_bars.dtype + assert res_comma.dtype == targ_comma.dtype + + assert_equal(res_blank, targ_blank) + assert_equal(res_add, targ_blank) + assert_equal(res_space, targ_space) + assert_equal(res_bars, targ_bars) + assert_equal(res_comma, targ_comma) + + +def test_cat_broadcast_left(dtype) -> None: + values_1 = xr.DataArray( + ["a", "bb", "cccc"], + dims=["Y"], + ).astype(dtype) + values_2 = xr.DataArray( + [["11111", "222", "33"], ["4", "5555", "66"]], + dims=["X", "Y"], + ) + + targ_blank = ( + xr.DataArray( + [["a11111", "bb222", "cccc33"], ["a4", "bb5555", "cccc66"]], + dims=["X", "Y"], + ) + .astype(dtype) + .T + ) + + targ_space = ( + xr.DataArray( + [["a 11111", "bb 222", "cccc 33"], ["a 4", "bb 5555", "cccc 66"]], + dims=["X", "Y"], + ) + .astype(dtype) + .T + ) + + targ_bars = ( + xr.DataArray( + [["a||11111", "bb||222", "cccc||33"], ["a||4", "bb||5555", "cccc||66"]], + dims=["X", "Y"], + ) + .astype(dtype) + .T + ) + + targ_comma = ( + xr.DataArray( + [["a, 11111", "bb, 222", "cccc, 33"], ["a, 4", "bb, 5555", "cccc, 66"]], + dims=["X", "Y"], + ) + .astype(dtype) + .T + ) + + res_blank = values_1.str.cat(values_2) + res_add = values_1.str + values_2 + res_space = values_1.str.cat(values_2, sep=" ") + res_bars = values_1.str.cat(values_2, sep="||") + res_comma = values_1.str.cat(values_2, sep=", ") + + assert res_blank.dtype == targ_blank.dtype + assert res_add.dtype == targ_blank.dtype + assert res_space.dtype == targ_space.dtype + assert res_bars.dtype == targ_bars.dtype + assert res_comma.dtype == targ_comma.dtype + + assert_equal(res_blank, targ_blank) + assert_equal(res_add, targ_blank) + assert_equal(res_space, targ_space) + assert_equal(res_bars, targ_bars) + assert_equal(res_comma, targ_comma) + + +def test_cat_broadcast_both(dtype) -> None: + values_1 = xr.DataArray( + ["a", "bb", "cccc"], + dims=["Y"], + ).astype(dtype) + values_2 = xr.DataArray( + ["11111", "4"], + dims=["X"], + ) + + targ_blank = ( + xr.DataArray( + [["a11111", "bb11111", "cccc11111"], ["a4", "bb4", "cccc4"]], + dims=["X", "Y"], + ) + .astype(dtype) + .T + ) + + targ_space = ( + xr.DataArray( + [["a 11111", "bb 11111", "cccc 11111"], ["a 4", "bb 4", "cccc 4"]], + dims=["X", "Y"], + ) + .astype(dtype) + .T + ) + + targ_bars = ( + xr.DataArray( + [["a||11111", "bb||11111", "cccc||11111"], ["a||4", "bb||4", "cccc||4"]], + dims=["X", "Y"], + ) + .astype(dtype) + .T + ) + + targ_comma = ( + xr.DataArray( + [["a, 11111", "bb, 11111", "cccc, 11111"], ["a, 4", "bb, 4", "cccc, 4"]], + dims=["X", "Y"], + ) + .astype(dtype) + .T + ) + + res_blank = values_1.str.cat(values_2) + res_add = values_1.str + values_2 + res_space = values_1.str.cat(values_2, sep=" ") + res_bars = values_1.str.cat(values_2, sep="||") + res_comma = values_1.str.cat(values_2, sep=", ") + + assert res_blank.dtype == targ_blank.dtype + assert res_add.dtype == targ_blank.dtype + assert res_space.dtype == targ_space.dtype + assert res_bars.dtype == targ_bars.dtype + assert res_comma.dtype == targ_comma.dtype + + assert_equal(res_blank, targ_blank) + assert_equal(res_add, targ_blank) + assert_equal(res_space, targ_space) + assert_equal(res_bars, targ_bars) + assert_equal(res_comma, targ_comma) + + +def test_cat_multi() -> None: + values_1 = xr.DataArray( + ["11111", "4"], + dims=["X"], + ) + + values_2 = xr.DataArray( + ["a", "bb", "cccc"], + dims=["Y"], + ).astype(np.bytes_) + + values_3 = np.array(3.4) + + values_4 = "" + + values_5 = np.array("", dtype=np.str_) + + sep = xr.DataArray( + [" ", ", "], + dims=["ZZ"], + ).astype(np.str_) + + expected = xr.DataArray( + [ + [ + ["11111 a 3.4 ", "11111, a, 3.4, , "], + ["11111 bb 3.4 ", "11111, bb, 3.4, , "], + ["11111 cccc 3.4 ", "11111, cccc, 3.4, , "], + ], + [ + ["4 a 3.4 ", "4, a, 3.4, , "], + ["4 bb 3.4 ", "4, bb, 3.4, , "], + ["4 cccc 3.4 ", "4, cccc, 3.4, , "], + ], + ], + dims=["X", "Y", "ZZ"], + ).astype(np.str_) + + res = values_1.str.cat(values_2, values_3, values_4, values_5, sep=sep) + + assert res.dtype == expected.dtype + assert_equal(res, expected) + + +def test_join_scalar(dtype) -> None: + values = xr.DataArray("aaa").astype(dtype) + + targ = xr.DataArray("aaa").astype(dtype) + + res_blank = values.str.join() + res_space = values.str.join(sep=" ") + + assert res_blank.dtype == targ.dtype + assert res_space.dtype == targ.dtype + + assert_identical(res_blank, targ) + assert_identical(res_space, targ) + + +def test_join_vector(dtype) -> None: + values = xr.DataArray( + ["a", "bb", "cccc"], + dims=["Y"], + ).astype(dtype) + + targ_blank = xr.DataArray("abbcccc").astype(dtype) + targ_space = xr.DataArray("a bb cccc").astype(dtype) + + res_blank_none = values.str.join() + res_blank_y = values.str.join(dim="Y") + + res_space_none = values.str.join(sep=" ") + res_space_y = values.str.join(dim="Y", sep=" ") + + assert res_blank_none.dtype == targ_blank.dtype + assert res_blank_y.dtype == targ_blank.dtype + assert res_space_none.dtype == targ_space.dtype + assert res_space_y.dtype == targ_space.dtype + + assert_identical(res_blank_none, targ_blank) + assert_identical(res_blank_y, targ_blank) + assert_identical(res_space_none, targ_space) + assert_identical(res_space_y, targ_space) + + +def test_join_2d(dtype) -> None: + values = xr.DataArray( + [["a", "bb", "cccc"], ["ddddd", "eeee", "fff"]], + dims=["X", "Y"], + ).astype(dtype) + + targ_blank_x = xr.DataArray( + ["addddd", "bbeeee", "ccccfff"], + dims=["Y"], + ).astype(dtype) + targ_space_x = xr.DataArray( + ["a ddddd", "bb eeee", "cccc fff"], + dims=["Y"], + ).astype(dtype) + + targ_blank_y = xr.DataArray( + ["abbcccc", "dddddeeeefff"], + dims=["X"], + ).astype(dtype) + targ_space_y = xr.DataArray( + ["a bb cccc", "ddddd eeee fff"], + dims=["X"], + ).astype(dtype) + + res_blank_x = values.str.join(dim="X") + res_blank_y = values.str.join(dim="Y") + + res_space_x = values.str.join(dim="X", sep=" ") + res_space_y = values.str.join(dim="Y", sep=" ") + + assert res_blank_x.dtype == targ_blank_x.dtype + assert res_blank_y.dtype == targ_blank_y.dtype + assert res_space_x.dtype == targ_space_x.dtype + assert res_space_y.dtype == targ_space_y.dtype + + assert_identical(res_blank_x, targ_blank_x) + assert_identical(res_blank_y, targ_blank_y) + assert_identical(res_space_x, targ_space_x) + assert_identical(res_space_y, targ_space_y) + + with pytest.raises( + ValueError, match="Dimension must be specified for multidimensional arrays." + ): + values.str.join() + + +def test_join_broadcast(dtype) -> None: + values = xr.DataArray( + ["a", "bb", "cccc"], + dims=["X"], + ).astype(dtype) + + sep = xr.DataArray( + [" ", ", "], + dims=["ZZ"], + ).astype(dtype) + + expected = xr.DataArray( + ["a bb cccc", "a, bb, cccc"], + dims=["ZZ"], + ).astype(dtype) + + res = values.str.join(sep=sep) + + assert res.dtype == expected.dtype + assert_identical(res, expected) + + +def test_format_scalar() -> None: + values = xr.DataArray( + ["{}.{Y}.{ZZ}", "{},{},{X},{X}", "{X}-{Y}-{ZZ}"], + dims=["X"], + ).astype(np.str_) + + pos0 = 1 + pos1 = 1.2 + pos2 = "2.3" + X = "'test'" + Y = "X" + ZZ = None + W = "NO!" + + expected = xr.DataArray( + ["1.X.None", "1,1.2,'test','test'", "'test'-X-None"], + dims=["X"], + ).astype(np.str_) + + res = values.str.format(pos0, pos1, pos2, X=X, Y=Y, ZZ=ZZ, W=W) + + assert res.dtype == expected.dtype + assert_equal(res, expected) + + +def test_format_broadcast() -> None: + values = xr.DataArray( + ["{}.{Y}.{ZZ}", "{},{},{X},{X}", "{X}-{Y}-{ZZ}"], + dims=["X"], + ).astype(np.str_) + + pos0 = 1 + pos1 = 1.2 + + pos2 = xr.DataArray( + ["2.3", "3.44444"], + dims=["YY"], + ) + + X = "'test'" + Y = "X" + ZZ = None + W = "NO!" + + expected = xr.DataArray( + [ + ["1.X.None", "1.X.None"], + ["1,1.2,'test','test'", "1,1.2,'test','test'"], + ["'test'-X-None", "'test'-X-None"], + ], + dims=["X", "YY"], + ).astype(np.str_) + + res = values.str.format(pos0, pos1, pos2, X=X, Y=Y, ZZ=ZZ, W=W) + + assert res.dtype == expected.dtype + assert_equal(res, expected) + + +def test_mod_scalar() -> None: + values = xr.DataArray( + ["%s.%s.%s", "%s,%s,%s", "%s-%s-%s"], + dims=["X"], + ).astype(np.str_) + + pos0 = 1 + pos1 = 1.2 + pos2 = "2.3" + + expected = xr.DataArray( + ["1.1.2.2.3", "1,1.2,2.3", "1-1.2-2.3"], + dims=["X"], + ).astype(np.str_) + + res = values.str % (pos0, pos1, pos2) + + assert res.dtype == expected.dtype + assert_equal(res, expected) + + +def test_mod_dict() -> None: + values = xr.DataArray( + ["%(a)s.%(a)s.%(b)s", "%(b)s,%(c)s,%(b)s", "%(c)s-%(b)s-%(a)s"], + dims=["X"], + ).astype(np.str_) + + a = 1 + b = 1.2 + c = "2.3" + + expected = xr.DataArray( + ["1.1.1.2", "1.2,2.3,1.2", "2.3-1.2-1"], + dims=["X"], + ).astype(np.str_) + + res = values.str % {"a": a, "b": b, "c": c} + + assert res.dtype == expected.dtype + assert_equal(res, expected) + + +def test_mod_broadcast_single() -> None: + values = xr.DataArray( + ["%s_1", "%s_2", "%s_3"], + dims=["X"], + ).astype(np.str_) + + pos = xr.DataArray( + ["2.3", "3.44444"], + dims=["YY"], + ) + + expected = xr.DataArray( + [["2.3_1", "3.44444_1"], ["2.3_2", "3.44444_2"], ["2.3_3", "3.44444_3"]], + dims=["X", "YY"], + ).astype(np.str_) + + res = values.str % pos + + assert res.dtype == expected.dtype + assert_equal(res, expected) + + +def test_mod_broadcast_multi() -> None: + values = xr.DataArray( + ["%s.%s.%s", "%s,%s,%s", "%s-%s-%s"], + dims=["X"], + ).astype(np.str_) + + pos0 = 1 + pos1 = 1.2 + + pos2 = xr.DataArray( + ["2.3", "3.44444"], + dims=["YY"], + ) + + expected = xr.DataArray( + [ + ["1.1.2.2.3", "1.1.2.3.44444"], + ["1,1.2,2.3", "1,1.2,3.44444"], + ["1-1.2-2.3", "1-1.2-3.44444"], + ], + dims=["X", "YY"], + ).astype(np.str_) + + res = values.str % (pos0, pos1, pos2) + + assert res.dtype == expected.dtype + assert_equal(res, expected) diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_array_api.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_array_api.py new file mode 100644 index 0000000..03c77e2 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_array_api.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +import pytest + +import xarray as xr +from xarray.testing import assert_equal + +np = pytest.importorskip("numpy", minversion="1.22") +xp = pytest.importorskip("array_api_strict") + +from array_api_strict._array_object import Array # isort:skip # type: ignore[no-redef] + + +@pytest.fixture +def arrays() -> tuple[xr.DataArray, xr.DataArray]: + np_arr = xr.DataArray( + np.array([[1.0, 2.0, 3.0], [4.0, 5.0, np.nan]]), + dims=("x", "y"), + coords={"x": [10, 20]}, + ) + xp_arr = xr.DataArray( + xp.asarray([[1.0, 2.0, 3.0], [4.0, 5.0, np.nan]]), + dims=("x", "y"), + coords={"x": [10, 20]}, + ) + assert isinstance(xp_arr.data, Array) + return np_arr, xp_arr + + +def test_arithmetic(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: + np_arr, xp_arr = arrays + expected = np_arr + 7 + actual = xp_arr + 7 + assert isinstance(actual.data, Array) + assert_equal(actual, expected) + + +def test_aggregation(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: + np_arr, xp_arr = arrays + expected = np_arr.sum() + actual = xp_arr.sum() + assert isinstance(actual.data, Array) + assert_equal(actual, expected) + + +def test_aggregation_skipna(arrays) -> None: + np_arr, xp_arr = arrays + expected = np_arr.sum(skipna=False) + actual = xp_arr.sum(skipna=False) + assert isinstance(actual.data, Array) + assert_equal(actual, expected) + + +def test_astype(arrays) -> None: + np_arr, xp_arr = arrays + expected = np_arr.astype(np.int64) + actual = xp_arr.astype(xp.int64) + assert actual.dtype == xp.int64 + assert isinstance(actual.data, Array) + assert_equal(actual, expected) + + +def test_broadcast(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: + np_arr, xp_arr = arrays + np_arr2 = xr.DataArray(np.array([1.0, 2.0]), dims="x") + xp_arr2 = xr.DataArray(xp.asarray([1.0, 2.0]), dims="x") + + expected = xr.broadcast(np_arr, np_arr2) + actual = xr.broadcast(xp_arr, xp_arr2) + assert len(actual) == len(expected) + for a, e in zip(actual, expected): + assert isinstance(a.data, Array) + assert_equal(a, e) + + +def test_broadcast_during_arithmetic(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: + np_arr, xp_arr = arrays + np_arr2 = xr.DataArray(np.array([1.0, 2.0]), dims="x") + xp_arr2 = xr.DataArray(xp.asarray([1.0, 2.0]), dims="x") + + expected = np_arr * np_arr2 + actual = xp_arr * xp_arr2 + assert isinstance(actual.data, Array) + assert_equal(actual, expected) + + expected = np_arr2 * np_arr + actual = xp_arr2 * xp_arr + assert isinstance(actual.data, Array) + assert_equal(actual, expected) + + +def test_concat(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: + np_arr, xp_arr = arrays + expected = xr.concat((np_arr, np_arr), dim="x") + actual = xr.concat((xp_arr, xp_arr), dim="x") + assert isinstance(actual.data, Array) + assert_equal(actual, expected) + + +def test_indexing(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: + np_arr, xp_arr = arrays + expected = np_arr[:, 0] + actual = xp_arr[:, 0] + assert isinstance(actual.data, Array) + assert_equal(actual, expected) + + +def test_properties(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: + np_arr, xp_arr = arrays + + expected = np_arr.data.nbytes + assert np_arr.nbytes == expected + assert xp_arr.nbytes == expected + + +def test_reorganizing_operation(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: + np_arr, xp_arr = arrays + expected = np_arr.transpose() + actual = xp_arr.transpose() + assert isinstance(actual.data, Array) + assert_equal(actual, expected) + + +def test_stack(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: + np_arr, xp_arr = arrays + expected = np_arr.stack(z=("x", "y")) + actual = xp_arr.stack(z=("x", "y")) + assert isinstance(actual.data, Array) + assert_equal(actual, expected) + + +def test_unstack(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: + np_arr, xp_arr = arrays + expected = np_arr.stack(z=("x", "y")).unstack() + actual = xp_arr.stack(z=("x", "y")).unstack() + assert isinstance(actual.data, Array) + assert_equal(actual, expected) + + +def test_where() -> None: + np_arr = xr.DataArray(np.array([1, 0]), dims="x") + xp_arr = xr.DataArray(xp.asarray([1, 0]), dims="x") + expected = xr.where(np_arr, 1, 0) + actual = xr.where(xp_arr, 1, 0) + assert isinstance(actual.data, Array) + assert_equal(actual, expected) diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_assertions.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_assertions.py new file mode 100644 index 0000000..aa0ea46 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_assertions.py @@ -0,0 +1,197 @@ +from __future__ import annotations + +import warnings + +import numpy as np +import pytest + +import xarray as xr +from xarray.tests import has_dask + +try: + from dask.array import from_array as dask_from_array +except ImportError: + dask_from_array = lambda x: x # type: ignore + +try: + import pint + + unit_registry = pint.UnitRegistry(force_ndarray_like=True) + + def quantity(x): + return unit_registry.Quantity(x, "m") + + has_pint = True +except ImportError: + + def quantity(x): + return x + + has_pint = False + + +def test_allclose_regression() -> None: + x = xr.DataArray(1.01) + y = xr.DataArray(1.02) + xr.testing.assert_allclose(x, y, atol=0.01) + + +@pytest.mark.parametrize( + "obj1,obj2", + ( + pytest.param( + xr.Variable("x", [1e-17, 2]), xr.Variable("x", [0, 3]), id="Variable" + ), + pytest.param( + xr.DataArray([1e-17, 2], dims="x"), + xr.DataArray([0, 3], dims="x"), + id="DataArray", + ), + pytest.param( + xr.Dataset({"a": ("x", [1e-17, 2]), "b": ("y", [-2e-18, 2])}), + xr.Dataset({"a": ("x", [0, 2]), "b": ("y", [0, 1])}), + id="Dataset", + ), + ), +) +def test_assert_allclose(obj1, obj2) -> None: + with pytest.raises(AssertionError): + xr.testing.assert_allclose(obj1, obj2) + with pytest.raises(AssertionError): + xr.testing.assert_allclose(obj1, obj2, check_dim_order=False) + + +@pytest.mark.parametrize("func", ["assert_equal", "assert_allclose"]) +def test_assert_allclose_equal_transpose(func) -> None: + """Transposed DataArray raises assertion unless check_dim_order=False.""" + obj1 = xr.DataArray([[0, 1, 2], [2, 3, 4]], dims=["a", "b"]) + obj2 = xr.DataArray([[0, 2], [1, 3], [2, 4]], dims=["b", "a"]) + with pytest.raises(AssertionError): + getattr(xr.testing, func)(obj1, obj2) + getattr(xr.testing, func)(obj1, obj2, check_dim_order=False) + ds1 = obj1.to_dataset(name="varname") + ds1["var2"] = obj1 + ds2 = obj1.to_dataset(name="varname") + ds2["var2"] = obj1.transpose() + with pytest.raises(AssertionError): + getattr(xr.testing, func)(ds1, ds2) + getattr(xr.testing, func)(ds1, ds2, check_dim_order=False) + + +@pytest.mark.filterwarnings("error") +@pytest.mark.parametrize( + "duckarray", + ( + pytest.param(np.array, id="numpy"), + pytest.param( + dask_from_array, + id="dask", + marks=pytest.mark.skipif(not has_dask, reason="requires dask"), + ), + pytest.param( + quantity, + id="pint", + marks=pytest.mark.skipif(not has_pint, reason="requires pint"), + ), + ), +) +@pytest.mark.parametrize( + ["obj1", "obj2"], + ( + pytest.param([1e-10, 2], [0.0, 2.0], id="both arrays"), + pytest.param([1e-17, 2], 0.0, id="second scalar"), + pytest.param(0.0, [1e-17, 2], id="first scalar"), + ), +) +def test_assert_duckarray_equal_failing(duckarray, obj1, obj2) -> None: + # TODO: actually check the repr + a = duckarray(obj1) + b = duckarray(obj2) + with pytest.raises(AssertionError): + xr.testing.assert_duckarray_equal(a, b) + + +@pytest.mark.filterwarnings("error") +@pytest.mark.parametrize( + "duckarray", + ( + pytest.param( + np.array, + id="numpy", + ), + pytest.param( + dask_from_array, + id="dask", + marks=pytest.mark.skipif(not has_dask, reason="requires dask"), + ), + pytest.param( + quantity, + id="pint", + marks=pytest.mark.skipif(not has_pint, reason="requires pint"), + ), + ), +) +@pytest.mark.parametrize( + ["obj1", "obj2"], + ( + pytest.param([0, 2], [0.0, 2.0], id="both arrays"), + pytest.param([0, 0], 0.0, id="second scalar"), + pytest.param(0.0, [0, 0], id="first scalar"), + ), +) +def test_assert_duckarray_equal(duckarray, obj1, obj2) -> None: + a = duckarray(obj1) + b = duckarray(obj2) + + xr.testing.assert_duckarray_equal(a, b) + + +@pytest.mark.parametrize( + "func", + [ + "assert_equal", + "assert_identical", + "assert_allclose", + "assert_duckarray_equal", + "assert_duckarray_allclose", + ], +) +def test_ensure_warnings_not_elevated(func) -> None: + # make sure warnings are not elevated to errors in the assertion functions + # e.g. by @pytest.mark.filterwarnings("error") + # see https://github.com/pydata/xarray/pull/4760#issuecomment-774101639 + + # define a custom Variable class that raises a warning in assert_* + class WarningVariable(xr.Variable): + @property # type: ignore[misc] + def dims(self): + warnings.warn("warning in test") + return super().dims + + def __array__(self, dtype=None, copy=None): + warnings.warn("warning in test") + return super().__array__() + + a = WarningVariable("x", [1]) + b = WarningVariable("x", [2]) + + with warnings.catch_warnings(record=True) as w: + # elevate warnings to errors + warnings.filterwarnings("error") + with pytest.raises(AssertionError): + getattr(xr.testing, func)(a, b) + + assert len(w) > 0 + + # ensure warnings still raise outside of assert_* + with pytest.raises(UserWarning): + warnings.warn("test") + + # ensure warnings stay ignored in assert_* + with warnings.catch_warnings(record=True) as w: + # ignore warnings + warnings.filterwarnings("ignore") + with pytest.raises(AssertionError): + getattr(xr.testing, func)(a, b) + + assert len(w) == 0 diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_backends.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_backends.py new file mode 100644 index 0000000..177700a --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_backends.py @@ -0,0 +1,5950 @@ +from __future__ import annotations + +import contextlib +import gzip +import itertools +import math +import os.path +import pickle +import platform +import re +import shutil +import sys +import tempfile +import uuid +import warnings +from collections.abc import Generator, Iterator, Mapping +from contextlib import ExitStack +from io import BytesIO +from os import listdir +from pathlib import Path +from typing import TYPE_CHECKING, Any, Final, Literal, cast +from unittest.mock import patch + +import numpy as np +import pandas as pd +import pytest +from packaging.version import Version +from pandas.errors import OutOfBoundsDatetime + +import xarray as xr +from xarray import ( + DataArray, + Dataset, + backends, + load_dataarray, + load_dataset, + open_dataarray, + open_dataset, + open_mfdataset, + save_mfdataset, +) +from xarray.backends.common import robust_getitem +from xarray.backends.h5netcdf_ import H5netcdfBackendEntrypoint +from xarray.backends.netcdf3 import _nc3_dtype_coercions +from xarray.backends.netCDF4_ import ( + NetCDF4BackendEntrypoint, + _extract_nc4_variable_encoding, +) +from xarray.backends.pydap_ import PydapDataStore +from xarray.backends.scipy_ import ScipyBackendEntrypoint +from xarray.coding.cftime_offsets import cftime_range +from xarray.coding.strings import check_vlen_dtype, create_vlen_dtype +from xarray.coding.variables import SerializationWarning +from xarray.conventions import encode_dataset_coordinates +from xarray.core import indexing +from xarray.core.options import set_options +from xarray.namedarray.pycompat import array_type +from xarray.tests import ( + assert_allclose, + assert_array_equal, + assert_equal, + assert_identical, + assert_no_warnings, + has_dask, + has_netCDF4, + has_numpy_2, + has_scipy, + mock, + network, + requires_cftime, + requires_dask, + requires_fsspec, + requires_h5netcdf, + requires_h5netcdf_ros3, + requires_iris, + requires_netCDF4, + requires_netCDF4_1_6_2_or_above, + requires_pydap, + requires_scipy, + requires_scipy_or_netCDF4, + requires_zarr, +) +from xarray.tests.test_coding_times import ( + _ALL_CALENDARS, + _NON_STANDARD_CALENDARS, + _STANDARD_CALENDARS, +) +from xarray.tests.test_dataset import ( + create_append_string_length_mismatch_test_data, + create_append_test_data, + create_test_data, +) + +try: + import netCDF4 as nc4 +except ImportError: + pass + +try: + import dask + import dask.array as da +except ImportError: + pass + +have_zarr_kvstore = False +try: + from zarr.storage import KVStore + + have_zarr_kvstore = True +except ImportError: + KVStore = None + +have_zarr_v3 = False +try: + # as of Zarr v2.13 these imports require environment variable + # ZARR_V3_EXPERIMENTAL_API=1 + from zarr import DirectoryStoreV3, KVStoreV3 + + have_zarr_v3 = True +except ImportError: + KVStoreV3 = None + +ON_WINDOWS = sys.platform == "win32" +default_value = object() +dask_array_type = array_type("dask") + +if TYPE_CHECKING: + from xarray.backends.api import T_NetcdfEngine, T_NetcdfTypes + + +def open_example_dataset(name, *args, **kwargs) -> Dataset: + return open_dataset( + os.path.join(os.path.dirname(__file__), "data", name), *args, **kwargs + ) + + +def open_example_mfdataset(names, *args, **kwargs) -> Dataset: + return open_mfdataset( + [os.path.join(os.path.dirname(__file__), "data", name) for name in names], + *args, + **kwargs, + ) + + +def create_masked_and_scaled_data(dtype: np.dtype) -> Dataset: + x = np.array([np.nan, np.nan, 10, 10.1, 10.2], dtype=dtype) + encoding = { + "_FillValue": -1, + "add_offset": dtype.type(10), + "scale_factor": dtype.type(0.1), + "dtype": "i2", + } + return Dataset({"x": ("t", x, {}, encoding)}) + + +def create_encoded_masked_and_scaled_data(dtype: np.dtype) -> Dataset: + attributes = { + "_FillValue": -1, + "add_offset": dtype.type(10), + "scale_factor": dtype.type(0.1), + } + return Dataset( + {"x": ("t", np.array([-1, -1, 0, 1, 2], dtype=np.int16), attributes)} + ) + + +def create_unsigned_masked_scaled_data(dtype: np.dtype) -> Dataset: + encoding = { + "_FillValue": 255, + "_Unsigned": "true", + "dtype": "i1", + "add_offset": dtype.type(10), + "scale_factor": dtype.type(0.1), + } + x = np.array([10.0, 10.1, 22.7, 22.8, np.nan], dtype=dtype) + return Dataset({"x": ("t", x, {}, encoding)}) + + +def create_encoded_unsigned_masked_scaled_data(dtype: np.dtype) -> Dataset: + # These are values as written to the file: the _FillValue will + # be represented in the signed form. + attributes = { + "_FillValue": -1, + "_Unsigned": "true", + "add_offset": dtype.type(10), + "scale_factor": dtype.type(0.1), + } + # Create unsigned data corresponding to [0, 1, 127, 128, 255] unsigned + sb = np.asarray([0, 1, 127, -128, -1], dtype="i1") + return Dataset({"x": ("t", sb, attributes)}) + + +def create_bad_unsigned_masked_scaled_data(dtype: np.dtype) -> Dataset: + encoding = { + "_FillValue": 255, + "_Unsigned": True, + "dtype": "i1", + "add_offset": dtype.type(10), + "scale_factor": dtype.type(0.1), + } + x = np.array([10.0, 10.1, 22.7, 22.8, np.nan], dtype=dtype) + return Dataset({"x": ("t", x, {}, encoding)}) + + +def create_bad_encoded_unsigned_masked_scaled_data(dtype: np.dtype) -> Dataset: + # These are values as written to the file: the _FillValue will + # be represented in the signed form. + attributes = { + "_FillValue": -1, + "_Unsigned": True, + "add_offset": dtype.type(10), + "scale_factor": dtype.type(0.1), + } + # Create signed data corresponding to [0, 1, 127, 128, 255] unsigned + sb = np.asarray([0, 1, 127, -128, -1], dtype="i1") + return Dataset({"x": ("t", sb, attributes)}) + + +def create_signed_masked_scaled_data(dtype: np.dtype) -> Dataset: + encoding = { + "_FillValue": -127, + "_Unsigned": "false", + "dtype": "i1", + "add_offset": dtype.type(10), + "scale_factor": dtype.type(0.1), + } + x = np.array([-1.0, 10.1, 22.7, np.nan], dtype=dtype) + return Dataset({"x": ("t", x, {}, encoding)}) + + +def create_encoded_signed_masked_scaled_data(dtype: np.dtype) -> Dataset: + # These are values as written to the file: the _FillValue will + # be represented in the signed form. + attributes = { + "_FillValue": -127, + "_Unsigned": "false", + "add_offset": dtype.type(10), + "scale_factor": dtype.type(0.1), + } + # Create signed data corresponding to [0, 1, 127, 128, 255] unsigned + sb = np.asarray([-110, 1, 127, -127], dtype="i1") + return Dataset({"x": ("t", sb, attributes)}) + + +def create_boolean_data() -> Dataset: + attributes = {"units": "-"} + return Dataset({"x": ("t", [True, False, False, True], attributes)}) + + +class TestCommon: + def test_robust_getitem(self) -> None: + class UnreliableArrayFailure(Exception): + pass + + class UnreliableArray: + def __init__(self, array, failures=1): + self.array = array + self.failures = failures + + def __getitem__(self, key): + if self.failures > 0: + self.failures -= 1 + raise UnreliableArrayFailure + return self.array[key] + + array = UnreliableArray([0]) + with pytest.raises(UnreliableArrayFailure): + array[0] + assert array[0] == 0 + + actual = robust_getitem(array, 0, catch=UnreliableArrayFailure, initial_delay=0) + assert actual == 0 + + +class NetCDF3Only: + netcdf3_formats: tuple[T_NetcdfTypes, ...] = ("NETCDF3_CLASSIC", "NETCDF3_64BIT") + + @requires_scipy + def test_dtype_coercion_error(self) -> None: + """Failing dtype coercion should lead to an error""" + for dtype, format in itertools.product( + _nc3_dtype_coercions, self.netcdf3_formats + ): + if dtype == "bool": + # coerced upcast (bool to int8) ==> can never fail + continue + + # Using the largest representable value, create some data that will + # no longer compare equal after the coerced downcast + maxval = np.iinfo(dtype).max + x = np.array([0, 1, 2, maxval], dtype=dtype) + ds = Dataset({"x": ("t", x, {})}) + + with create_tmp_file(allow_cleanup_failure=False) as path: + with pytest.raises(ValueError, match="could not safely cast"): + ds.to_netcdf(path, format=format) + + +class DatasetIOBase: + engine: T_NetcdfEngine | None = None + file_format: T_NetcdfTypes | None = None + + def create_store(self): + raise NotImplementedError() + + @contextlib.contextmanager + def roundtrip( + self, data, save_kwargs=None, open_kwargs=None, allow_cleanup_failure=False + ): + if save_kwargs is None: + save_kwargs = {} + if open_kwargs is None: + open_kwargs = {} + with create_tmp_file(allow_cleanup_failure=allow_cleanup_failure) as path: + self.save(data, path, **save_kwargs) + with self.open(path, **open_kwargs) as ds: + yield ds + + @contextlib.contextmanager + def roundtrip_append( + self, data, save_kwargs=None, open_kwargs=None, allow_cleanup_failure=False + ): + if save_kwargs is None: + save_kwargs = {} + if open_kwargs is None: + open_kwargs = {} + with create_tmp_file(allow_cleanup_failure=allow_cleanup_failure) as path: + for i, key in enumerate(data.variables): + mode = "a" if i > 0 else "w" + self.save(data[[key]], path, mode=mode, **save_kwargs) + with self.open(path, **open_kwargs) as ds: + yield ds + + # The save/open methods may be overwritten below + def save(self, dataset, path, **kwargs): + return dataset.to_netcdf( + path, engine=self.engine, format=self.file_format, **kwargs + ) + + @contextlib.contextmanager + def open(self, path, **kwargs): + with open_dataset(path, engine=self.engine, **kwargs) as ds: + yield ds + + def test_zero_dimensional_variable(self) -> None: + expected = create_test_data() + expected["float_var"] = ([], 1.0e9, {"units": "units of awesome"}) + expected["bytes_var"] = ([], b"foobar") + expected["string_var"] = ([], "foobar") + with self.roundtrip(expected) as actual: + assert_identical(expected, actual) + + def test_write_store(self) -> None: + expected = create_test_data() + with self.create_store() as store: + expected.dump_to_store(store) + # we need to cf decode the store because it has time and + # non-dimension coordinates + with xr.decode_cf(store) as actual: + assert_allclose(expected, actual) + + def check_dtypes_roundtripped(self, expected, actual): + for k in expected.variables: + expected_dtype = expected.variables[k].dtype + + # For NetCDF3, the backend should perform dtype coercion + if ( + isinstance(self, NetCDF3Only) + and str(expected_dtype) in _nc3_dtype_coercions + ): + expected_dtype = np.dtype(_nc3_dtype_coercions[str(expected_dtype)]) + + actual_dtype = actual.variables[k].dtype + # TODO: check expected behavior for string dtypes more carefully + string_kinds = {"O", "S", "U"} + assert expected_dtype == actual_dtype or ( + expected_dtype.kind in string_kinds + and actual_dtype.kind in string_kinds + ) + + def test_roundtrip_test_data(self) -> None: + expected = create_test_data() + with self.roundtrip(expected) as actual: + self.check_dtypes_roundtripped(expected, actual) + assert_identical(expected, actual) + + def test_load(self) -> None: + expected = create_test_data() + + @contextlib.contextmanager + def assert_loads(vars=None): + if vars is None: + vars = expected + with self.roundtrip(expected) as actual: + for k, v in actual.variables.items(): + # IndexVariables are eagerly loaded into memory + assert v._in_memory == (k in actual.dims) + yield actual + for k, v in actual.variables.items(): + if k in vars: + assert v._in_memory + assert_identical(expected, actual) + + with pytest.raises(AssertionError): + # make sure the contextmanager works! + with assert_loads() as ds: + pass + + with assert_loads() as ds: + ds.load() + + with assert_loads(["var1", "dim1", "dim2"]) as ds: + ds["var1"].load() + + # verify we can read data even after closing the file + with self.roundtrip(expected) as ds: + actual = ds.load() + assert_identical(expected, actual) + + def test_dataset_compute(self) -> None: + expected = create_test_data() + + with self.roundtrip(expected) as actual: + # Test Dataset.compute() + for k, v in actual.variables.items(): + # IndexVariables are eagerly cached + assert v._in_memory == (k in actual.dims) + + computed = actual.compute() + + for k, v in actual.variables.items(): + assert v._in_memory == (k in actual.dims) + for v in computed.variables.values(): + assert v._in_memory + + assert_identical(expected, actual) + assert_identical(expected, computed) + + def test_pickle(self) -> None: + expected = Dataset({"foo": ("x", [42])}) + with self.roundtrip(expected, allow_cleanup_failure=ON_WINDOWS) as roundtripped: + with roundtripped: + # Windows doesn't like reopening an already open file + raw_pickle = pickle.dumps(roundtripped) + with pickle.loads(raw_pickle) as unpickled_ds: + assert_identical(expected, unpickled_ds) + + @pytest.mark.filterwarnings("ignore:deallocating CachingFileManager") + def test_pickle_dataarray(self) -> None: + expected = Dataset({"foo": ("x", [42])}) + with self.roundtrip(expected, allow_cleanup_failure=ON_WINDOWS) as roundtripped: + with roundtripped: + raw_pickle = pickle.dumps(roundtripped["foo"]) + # TODO: figure out how to explicitly close the file for the + # unpickled DataArray? + unpickled = pickle.loads(raw_pickle) + assert_identical(expected["foo"], unpickled) + + def test_dataset_caching(self) -> None: + expected = Dataset({"foo": ("x", [5, 6, 7])}) + with self.roundtrip(expected) as actual: + assert isinstance(actual.foo.variable._data, indexing.MemoryCachedArray) + assert not actual.foo.variable._in_memory + actual.foo.values # cache + assert actual.foo.variable._in_memory + + with self.roundtrip(expected, open_kwargs={"cache": False}) as actual: + assert isinstance(actual.foo.variable._data, indexing.CopyOnWriteArray) + assert not actual.foo.variable._in_memory + actual.foo.values # no caching + assert not actual.foo.variable._in_memory + + @pytest.mark.filterwarnings("ignore:deallocating CachingFileManager") + def test_roundtrip_None_variable(self) -> None: + expected = Dataset({None: (("x", "y"), [[0, 1], [2, 3]])}) + with self.roundtrip(expected) as actual: + assert_identical(expected, actual) + + def test_roundtrip_object_dtype(self) -> None: + floats = np.array([0.0, 0.0, 1.0, 2.0, 3.0], dtype=object) + floats_nans = np.array([np.nan, np.nan, 1.0, 2.0, 3.0], dtype=object) + bytes_ = np.array([b"ab", b"cdef", b"g"], dtype=object) + bytes_nans = np.array([b"ab", b"cdef", np.nan], dtype=object) + strings = np.array(["ab", "cdef", "g"], dtype=object) + strings_nans = np.array(["ab", "cdef", np.nan], dtype=object) + all_nans = np.array([np.nan, np.nan], dtype=object) + original = Dataset( + { + "floats": ("a", floats), + "floats_nans": ("a", floats_nans), + "bytes": ("b", bytes_), + "bytes_nans": ("b", bytes_nans), + "strings": ("b", strings), + "strings_nans": ("b", strings_nans), + "all_nans": ("c", all_nans), + "nan": ([], np.nan), + } + ) + expected = original.copy(deep=True) + with self.roundtrip(original) as actual: + try: + assert_identical(expected, actual) + except AssertionError: + # Most stores use '' for nans in strings, but some don't. + # First try the ideal case (where the store returns exactly) + # the original Dataset), then try a more realistic case. + # This currently includes all netCDF files when encoding is not + # explicitly set. + # https://github.com/pydata/xarray/issues/1647 + expected["bytes_nans"][-1] = b"" + expected["strings_nans"][-1] = "" + assert_identical(expected, actual) + + def test_roundtrip_string_data(self) -> None: + expected = Dataset({"x": ("t", ["ab", "cdef"])}) + with self.roundtrip(expected) as actual: + assert_identical(expected, actual) + + def test_roundtrip_string_encoded_characters(self) -> None: + expected = Dataset({"x": ("t", ["ab", "cdef"])}) + expected["x"].encoding["dtype"] = "S1" + with self.roundtrip(expected) as actual: + assert_identical(expected, actual) + assert actual["x"].encoding["_Encoding"] == "utf-8" + + expected["x"].encoding["_Encoding"] = "ascii" + with self.roundtrip(expected) as actual: + assert_identical(expected, actual) + assert actual["x"].encoding["_Encoding"] == "ascii" + + def test_roundtrip_numpy_datetime_data(self) -> None: + times = pd.to_datetime(["2000-01-01", "2000-01-02", "NaT"], unit="ns") + expected = Dataset({"t": ("t", times), "t0": times[0]}) + kwargs = {"encoding": {"t0": {"units": "days since 1950-01-01"}}} + with self.roundtrip(expected, save_kwargs=kwargs) as actual: + assert_identical(expected, actual) + assert actual.t0.encoding["units"] == "days since 1950-01-01" + + @requires_cftime + def test_roundtrip_cftime_datetime_data(self) -> None: + from xarray.tests.test_coding_times import _all_cftime_date_types + + date_types = _all_cftime_date_types() + for date_type in date_types.values(): + times = [date_type(1, 1, 1), date_type(1, 1, 2)] + expected = Dataset({"t": ("t", times), "t0": times[0]}) + kwargs = {"encoding": {"t0": {"units": "days since 0001-01-01"}}} + expected_decoded_t = np.array(times) + expected_decoded_t0 = np.array([date_type(1, 1, 1)]) + expected_calendar = times[0].calendar + + with warnings.catch_warnings(): + if expected_calendar in {"proleptic_gregorian", "standard"}: + warnings.filterwarnings("ignore", "Unable to decode time axis") + + with self.roundtrip(expected, save_kwargs=kwargs) as actual: + abs_diff = abs(actual.t.values - expected_decoded_t) + assert (abs_diff <= np.timedelta64(1, "s")).all() + assert ( + actual.t.encoding["units"] + == "days since 0001-01-01 00:00:00.000000" + ) + assert actual.t.encoding["calendar"] == expected_calendar + + abs_diff = abs(actual.t0.values - expected_decoded_t0) + assert (abs_diff <= np.timedelta64(1, "s")).all() + assert actual.t0.encoding["units"] == "days since 0001-01-01" + assert actual.t.encoding["calendar"] == expected_calendar + + def test_roundtrip_timedelta_data(self) -> None: + time_deltas = pd.to_timedelta(["1h", "2h", "NaT"]) + expected = Dataset({"td": ("td", time_deltas), "td0": time_deltas[0]}) + with self.roundtrip(expected) as actual: + assert_identical(expected, actual) + + def test_roundtrip_float64_data(self) -> None: + expected = Dataset({"x": ("y", np.array([1.0, 2.0, np.pi], dtype="float64"))}) + with self.roundtrip(expected) as actual: + assert_identical(expected, actual) + + def test_roundtrip_example_1_netcdf(self) -> None: + with open_example_dataset("example_1.nc") as expected: + with self.roundtrip(expected) as actual: + # we allow the attributes to differ since that + # will depend on the encoding used. For example, + # without CF encoding 'actual' will end up with + # a dtype attribute. + assert_equal(expected, actual) + + def test_roundtrip_coordinates(self) -> None: + original = Dataset( + {"foo": ("x", [0, 1])}, {"x": [2, 3], "y": ("a", [42]), "z": ("x", [4, 5])} + ) + + with self.roundtrip(original) as actual: + assert_identical(original, actual) + + original["foo"].encoding["coordinates"] = "y" + with self.roundtrip(original, open_kwargs={"decode_coords": False}) as expected: + # check roundtripping when decode_coords=False + with self.roundtrip( + expected, open_kwargs={"decode_coords": False} + ) as actual: + assert_identical(expected, actual) + + def test_roundtrip_global_coordinates(self) -> None: + original = Dataset( + {"foo": ("x", [0, 1])}, {"x": [2, 3], "y": ("a", [42]), "z": ("x", [4, 5])} + ) + with self.roundtrip(original) as actual: + assert_identical(original, actual) + + # test that global "coordinates" is as expected + _, attrs = encode_dataset_coordinates(original) + assert attrs["coordinates"] == "y" + + # test warning when global "coordinates" is already set + original.attrs["coordinates"] = "foo" + with pytest.warns(SerializationWarning): + _, attrs = encode_dataset_coordinates(original) + assert attrs["coordinates"] == "foo" + + def test_roundtrip_coordinates_with_space(self) -> None: + original = Dataset(coords={"x": 0, "y z": 1}) + expected = Dataset({"y z": 1}, {"x": 0}) + with pytest.warns(SerializationWarning): + with self.roundtrip(original) as actual: + assert_identical(expected, actual) + + def test_roundtrip_boolean_dtype(self) -> None: + original = create_boolean_data() + assert original["x"].dtype == "bool" + with self.roundtrip(original) as actual: + assert_identical(original, actual) + assert actual["x"].dtype == "bool" + # this checks for preserving dtype during second roundtrip + # see https://github.com/pydata/xarray/issues/7652#issuecomment-1476956975 + with self.roundtrip(actual) as actual2: + assert_identical(original, actual2) + assert actual2["x"].dtype == "bool" + + def test_orthogonal_indexing(self) -> None: + in_memory = create_test_data() + with self.roundtrip(in_memory) as on_disk: + indexers = {"dim1": [1, 2, 0], "dim2": [3, 2, 0, 3], "dim3": np.arange(5)} + expected = in_memory.isel(indexers) + actual = on_disk.isel(**indexers) + # make sure the array is not yet loaded into memory + assert not actual["var1"].variable._in_memory + assert_identical(expected, actual) + # do it twice, to make sure we're switched from orthogonal -> numpy + # when we cached the values + actual = on_disk.isel(**indexers) + assert_identical(expected, actual) + + def test_vectorized_indexing(self) -> None: + in_memory = create_test_data() + with self.roundtrip(in_memory) as on_disk: + indexers = { + "dim1": DataArray([0, 2, 0], dims="a"), + "dim2": DataArray([0, 2, 3], dims="a"), + } + expected = in_memory.isel(indexers) + actual = on_disk.isel(**indexers) + # make sure the array is not yet loaded into memory + assert not actual["var1"].variable._in_memory + assert_identical(expected, actual.load()) + # do it twice, to make sure we're switched from + # vectorized -> numpy when we cached the values + actual = on_disk.isel(**indexers) + assert_identical(expected, actual) + + def multiple_indexing(indexers): + # make sure a sequence of lazy indexings certainly works. + with self.roundtrip(in_memory) as on_disk: + actual = on_disk["var3"] + expected = in_memory["var3"] + for ind in indexers: + actual = actual.isel(ind) + expected = expected.isel(ind) + # make sure the array is not yet loaded into memory + assert not actual.variable._in_memory + assert_identical(expected, actual.load()) + + # two-staged vectorized-indexing + indexers2 = [ + { + "dim1": DataArray([[0, 7], [2, 6], [3, 5]], dims=["a", "b"]), + "dim3": DataArray([[0, 4], [1, 3], [2, 2]], dims=["a", "b"]), + }, + {"a": DataArray([0, 1], dims=["c"]), "b": DataArray([0, 1], dims=["c"])}, + ] + multiple_indexing(indexers2) + + # vectorized-slice mixed + indexers3 = [ + { + "dim1": DataArray([[0, 7], [2, 6], [3, 5]], dims=["a", "b"]), + "dim3": slice(None, 10), + } + ] + multiple_indexing(indexers3) + + # vectorized-integer mixed + indexers4 = [ + {"dim3": 0}, + {"dim1": DataArray([[0, 7], [2, 6], [3, 5]], dims=["a", "b"])}, + {"a": slice(None, None, 2)}, + ] + multiple_indexing(indexers4) + + # vectorized-integer mixed + indexers5 = [ + {"dim3": 0}, + {"dim1": DataArray([[0, 7], [2, 6], [3, 5]], dims=["a", "b"])}, + {"a": 1, "b": 0}, + ] + multiple_indexing(indexers5) + + def test_vectorized_indexing_negative_step(self) -> None: + # use dask explicitly when present + open_kwargs: dict[str, Any] | None + if has_dask: + open_kwargs = {"chunks": {}} + else: + open_kwargs = None + in_memory = create_test_data() + + def multiple_indexing(indexers): + # make sure a sequence of lazy indexings certainly works. + with self.roundtrip(in_memory, open_kwargs=open_kwargs) as on_disk: + actual = on_disk["var3"] + expected = in_memory["var3"] + for ind in indexers: + actual = actual.isel(ind) + expected = expected.isel(ind) + # make sure the array is not yet loaded into memory + assert not actual.variable._in_memory + assert_identical(expected, actual.load()) + + # with negative step slice. + indexers = [ + { + "dim1": DataArray([[0, 7], [2, 6], [3, 5]], dims=["a", "b"]), + "dim3": slice(-1, 1, -1), + } + ] + multiple_indexing(indexers) + + # with negative step slice. + indexers = [ + { + "dim1": DataArray([[0, 7], [2, 6], [3, 5]], dims=["a", "b"]), + "dim3": slice(-1, 1, -2), + } + ] + multiple_indexing(indexers) + + def test_outer_indexing_reversed(self) -> None: + # regression test for GH6560 + ds = xr.Dataset( + {"z": (("t", "p", "y", "x"), np.ones((1, 1, 31, 40)))}, + ) + + with self.roundtrip(ds) as on_disk: + subset = on_disk.isel(t=[0], p=0).z[:, ::10, ::10][:, ::-1, :] + assert subset.sizes == subset.load().sizes + + def test_isel_dataarray(self) -> None: + # Make sure isel works lazily. GH:issue:1688 + in_memory = create_test_data() + with self.roundtrip(in_memory) as on_disk: + expected = in_memory.isel(dim2=in_memory["dim2"] < 3) + actual = on_disk.isel(dim2=on_disk["dim2"] < 3) + assert_identical(expected, actual) + + def validate_array_type(self, ds): + # Make sure that only NumpyIndexingAdapter stores a bare np.ndarray. + def find_and_validate_array(obj): + # recursively called function. obj: array or array wrapper. + if hasattr(obj, "array"): + if isinstance(obj.array, indexing.ExplicitlyIndexed): + find_and_validate_array(obj.array) + else: + if isinstance(obj.array, np.ndarray): + assert isinstance(obj, indexing.NumpyIndexingAdapter) + elif isinstance(obj.array, dask_array_type): + assert isinstance(obj, indexing.DaskIndexingAdapter) + elif isinstance(obj.array, pd.Index): + assert isinstance(obj, indexing.PandasIndexingAdapter) + else: + raise TypeError(f"{type(obj.array)} is wrapped by {type(obj)}") + + for k, v in ds.variables.items(): + find_and_validate_array(v._data) + + def test_array_type_after_indexing(self) -> None: + in_memory = create_test_data() + with self.roundtrip(in_memory) as on_disk: + self.validate_array_type(on_disk) + indexers = {"dim1": [1, 2, 0], "dim2": [3, 2, 0, 3], "dim3": np.arange(5)} + expected = in_memory.isel(indexers) + actual = on_disk.isel(**indexers) + assert_identical(expected, actual) + self.validate_array_type(actual) + # do it twice, to make sure we're switched from orthogonal -> numpy + # when we cached the values + actual = on_disk.isel(**indexers) + assert_identical(expected, actual) + self.validate_array_type(actual) + + def test_dropna(self) -> None: + # regression test for GH:issue:1694 + a = np.random.randn(4, 3) + a[1, 1] = np.nan + in_memory = xr.Dataset( + {"a": (("y", "x"), a)}, coords={"y": np.arange(4), "x": np.arange(3)} + ) + + assert_identical( + in_memory.dropna(dim="x"), in_memory.isel(x=slice(None, None, 2)) + ) + + with self.roundtrip(in_memory) as on_disk: + self.validate_array_type(on_disk) + expected = in_memory.dropna(dim="x") + actual = on_disk.dropna(dim="x") + assert_identical(expected, actual) + + def test_ondisk_after_print(self) -> None: + """Make sure print does not load file into memory""" + in_memory = create_test_data() + with self.roundtrip(in_memory) as on_disk: + repr(on_disk) + assert not on_disk["var1"]._in_memory + + +class CFEncodedBase(DatasetIOBase): + def test_roundtrip_bytes_with_fill_value(self) -> None: + values = np.array([b"ab", b"cdef", np.nan], dtype=object) + encoding = {"_FillValue": b"X", "dtype": "S1"} + original = Dataset({"x": ("t", values, {}, encoding)}) + expected = original.copy(deep=True) + with self.roundtrip(original) as actual: + assert_identical(expected, actual) + + original = Dataset({"x": ("t", values, {}, {"_FillValue": b""})}) + with self.roundtrip(original) as actual: + assert_identical(expected, actual) + + def test_roundtrip_string_with_fill_value_nchar(self) -> None: + values = np.array(["ab", "cdef", np.nan], dtype=object) + expected = Dataset({"x": ("t", values)}) + + encoding = {"dtype": "S1", "_FillValue": b"X"} + original = Dataset({"x": ("t", values, {}, encoding)}) + # Not supported yet. + with pytest.raises(NotImplementedError): + with self.roundtrip(original) as actual: + assert_identical(expected, actual) + + def test_roundtrip_empty_vlen_string_array(self) -> None: + # checks preserving vlen dtype for empty arrays GH7862 + dtype = create_vlen_dtype(str) + original = Dataset({"a": np.array([], dtype=dtype)}) + assert check_vlen_dtype(original["a"].dtype) == str + with self.roundtrip(original) as actual: + assert_identical(original, actual) + if np.issubdtype(actual["a"].dtype, object): + # only check metadata for capable backends + # eg. NETCDF3 based backends do not roundtrip metadata + if actual["a"].dtype.metadata is not None: + assert check_vlen_dtype(actual["a"].dtype) == str + else: + assert actual["a"].dtype == np.dtype(" None: + if hasattr(self, "zarr_version") and dtype == np.float32: + pytest.skip("float32 will be treated as float64 in zarr") + decoded = decoded_fn(dtype) + encoded = encoded_fn(dtype) + with self.roundtrip(decoded) as actual: + for k in decoded.variables: + assert decoded.variables[k].dtype == actual.variables[k].dtype + assert_allclose(decoded, actual, decode_bytes=False) + + with self.roundtrip(decoded, open_kwargs=dict(decode_cf=False)) as actual: + # TODO: this assumes that all roundtrips will first + # encode. Is that something we want to test for? + for k in encoded.variables: + assert encoded.variables[k].dtype == actual.variables[k].dtype + assert_allclose(encoded, actual, decode_bytes=False) + + with self.roundtrip(encoded, open_kwargs=dict(decode_cf=False)) as actual: + for k in encoded.variables: + assert encoded.variables[k].dtype == actual.variables[k].dtype + assert_allclose(encoded, actual, decode_bytes=False) + + # make sure roundtrip encoding didn't change the + # original dataset. + assert_allclose(encoded, encoded_fn(dtype), decode_bytes=False) + + with self.roundtrip(encoded) as actual: + for k in decoded.variables: + assert decoded.variables[k].dtype == actual.variables[k].dtype + assert_allclose(decoded, actual, decode_bytes=False) + + @staticmethod + def _create_cf_dataset(): + original = Dataset( + dict( + variable=( + ("ln_p", "latitude", "longitude"), + np.arange(8, dtype="f4").reshape(2, 2, 2), + {"ancillary_variables": "std_devs det_lim"}, + ), + std_devs=( + ("ln_p", "latitude", "longitude"), + np.arange(0.1, 0.9, 0.1).reshape(2, 2, 2), + {"standard_name": "standard_error"}, + ), + det_lim=( + (), + 0.1, + {"standard_name": "detection_minimum"}, + ), + ), + dict( + latitude=("latitude", [0, 1], {"units": "degrees_north"}), + longitude=("longitude", [0, 1], {"units": "degrees_east"}), + latlon=((), -1, {"grid_mapping_name": "latitude_longitude"}), + latitude_bnds=(("latitude", "bnds2"), [[0, 1], [1, 2]]), + longitude_bnds=(("longitude", "bnds2"), [[0, 1], [1, 2]]), + areas=( + ("latitude", "longitude"), + [[1, 1], [1, 1]], + {"units": "degree^2"}, + ), + ln_p=( + "ln_p", + [1.0, 0.5], + { + "standard_name": "atmosphere_ln_pressure_coordinate", + "computed_standard_name": "air_pressure", + }, + ), + P0=((), 1013.25, {"units": "hPa"}), + ), + ) + original["variable"].encoding.update( + {"cell_measures": "area: areas", "grid_mapping": "latlon"}, + ) + original.coords["latitude"].encoding.update( + dict(grid_mapping="latlon", bounds="latitude_bnds") + ) + original.coords["longitude"].encoding.update( + dict(grid_mapping="latlon", bounds="longitude_bnds") + ) + original.coords["ln_p"].encoding.update({"formula_terms": "p0: P0 lev : ln_p"}) + return original + + def test_grid_mapping_and_bounds_are_not_coordinates_in_file(self) -> None: + original = self._create_cf_dataset() + with create_tmp_file() as tmp_file: + original.to_netcdf(tmp_file) + with open_dataset(tmp_file, decode_coords=False) as ds: + assert ds.coords["latitude"].attrs["bounds"] == "latitude_bnds" + assert ds.coords["longitude"].attrs["bounds"] == "longitude_bnds" + assert "coordinates" not in ds["variable"].attrs + assert "coordinates" not in ds.attrs + + def test_coordinate_variables_after_dataset_roundtrip(self) -> None: + original = self._create_cf_dataset() + with self.roundtrip(original, open_kwargs={"decode_coords": "all"}) as actual: + assert_identical(actual, original) + + with self.roundtrip(original) as actual: + expected = original.reset_coords( + ["latitude_bnds", "longitude_bnds", "areas", "P0", "latlon"] + ) + # equal checks that coords and data_vars are equal which + # should be enough + # identical would require resetting a number of attributes + # skip that. + assert_equal(actual, expected) + + def test_grid_mapping_and_bounds_are_coordinates_after_dataarray_roundtrip( + self, + ) -> None: + original = self._create_cf_dataset() + # The DataArray roundtrip should have the same warnings as the + # Dataset, but we already tested for those, so just go for the + # new warnings. It would appear that there is no way to tell + # pytest "This warning and also this warning should both be + # present". + # xarray/tests/test_conventions.py::TestCFEncodedDataStore + # needs the to_dataset. The other backends should be fine + # without it. + with pytest.warns( + UserWarning, + match=( + r"Variable\(s\) referenced in bounds not in variables: " + r"\['l(at|ong)itude_bnds'\]" + ), + ): + with self.roundtrip( + original["variable"].to_dataset(), open_kwargs={"decode_coords": "all"} + ) as actual: + assert_identical(actual, original["variable"].to_dataset()) + + @requires_iris + def test_coordinate_variables_after_iris_roundtrip(self) -> None: + original = self._create_cf_dataset() + iris_cube = original["variable"].to_iris() + actual = DataArray.from_iris(iris_cube) + # Bounds will be missing (xfail) + del original.coords["latitude_bnds"], original.coords["longitude_bnds"] + # Ancillary vars will be missing + # Those are data_vars, and will be dropped when grabbing the variable + assert_identical(actual, original["variable"]) + + def test_coordinates_encoding(self) -> None: + def equals_latlon(obj): + return obj == "lat lon" or obj == "lon lat" + + original = Dataset( + {"temp": ("x", [0, 1]), "precip": ("x", [0, -1])}, + {"lat": ("x", [2, 3]), "lon": ("x", [4, 5])}, + ) + with self.roundtrip(original) as actual: + assert_identical(actual, original) + with create_tmp_file() as tmp_file: + original.to_netcdf(tmp_file) + with open_dataset(tmp_file, decode_coords=False) as ds: + assert equals_latlon(ds["temp"].attrs["coordinates"]) + assert equals_latlon(ds["precip"].attrs["coordinates"]) + assert "coordinates" not in ds.attrs + assert "coordinates" not in ds["lat"].attrs + assert "coordinates" not in ds["lon"].attrs + + modified = original.drop_vars(["temp", "precip"]) + with self.roundtrip(modified) as actual: + assert_identical(actual, modified) + with create_tmp_file() as tmp_file: + modified.to_netcdf(tmp_file) + with open_dataset(tmp_file, decode_coords=False) as ds: + assert equals_latlon(ds.attrs["coordinates"]) + assert "coordinates" not in ds["lat"].attrs + assert "coordinates" not in ds["lon"].attrs + + original["temp"].encoding["coordinates"] = "lat" + with self.roundtrip(original) as actual: + assert_identical(actual, original) + original["precip"].encoding["coordinates"] = "lat" + with create_tmp_file() as tmp_file: + original.to_netcdf(tmp_file) + with open_dataset(tmp_file, decode_coords=True) as ds: + assert "lon" not in ds["temp"].encoding["coordinates"] + assert "lon" not in ds["precip"].encoding["coordinates"] + assert "coordinates" not in ds["lat"].encoding + assert "coordinates" not in ds["lon"].encoding + + def test_roundtrip_endian(self) -> None: + ds = Dataset( + { + "x": np.arange(3, 10, dtype=">i2"), + "y": np.arange(3, 20, dtype=" None: + te = (TypeError, "string or None") + ve = (ValueError, "string must be length 1 or") + data = np.random.random((2, 2)) + da = xr.DataArray(data) + for name, (error, msg) in zip([0, (4, 5), True, ""], [te, te, te, ve]): + ds = Dataset({name: da}) + with pytest.raises(error) as excinfo: + with self.roundtrip(ds): + pass + excinfo.match(msg) + excinfo.match(repr(name)) + + def test_encoding_kwarg(self) -> None: + ds = Dataset({"x": ("y", np.arange(10.0))}) + + kwargs: dict[str, Any] = dict(encoding={"x": {"dtype": "f4"}}) + with self.roundtrip(ds, save_kwargs=kwargs) as actual: + encoded_dtype = actual.x.encoding["dtype"] + # On OS X, dtype sometimes switches endianness for unclear reasons + assert encoded_dtype.kind == "f" and encoded_dtype.itemsize == 4 + assert ds.x.encoding == {} + + kwargs = dict(encoding={"x": {"foo": "bar"}}) + with pytest.raises(ValueError, match=r"unexpected encoding"): + with self.roundtrip(ds, save_kwargs=kwargs) as actual: + pass + + kwargs = dict(encoding={"x": "foo"}) + with pytest.raises(ValueError, match=r"must be castable"): + with self.roundtrip(ds, save_kwargs=kwargs) as actual: + pass + + kwargs = dict(encoding={"invalid": {}}) + with pytest.raises(KeyError): + with self.roundtrip(ds, save_kwargs=kwargs) as actual: + pass + + def test_encoding_kwarg_dates(self) -> None: + ds = Dataset({"t": pd.date_range("2000-01-01", periods=3)}) + units = "days since 1900-01-01" + kwargs = dict(encoding={"t": {"units": units}}) + with self.roundtrip(ds, save_kwargs=kwargs) as actual: + assert actual.t.encoding["units"] == units + assert_identical(actual, ds) + + def test_encoding_kwarg_fixed_width_string(self) -> None: + # regression test for GH2149 + for strings in [[b"foo", b"bar", b"baz"], ["foo", "bar", "baz"]]: + ds = Dataset({"x": strings}) + kwargs = dict(encoding={"x": {"dtype": "S1"}}) + with self.roundtrip(ds, save_kwargs=kwargs) as actual: + assert actual["x"].encoding["dtype"] == "S1" + assert_identical(actual, ds) + + def test_default_fill_value(self) -> None: + # Test default encoding for float: + ds = Dataset({"x": ("y", np.arange(10.0))}) + kwargs = dict(encoding={"x": {"dtype": "f4"}}) + with self.roundtrip(ds, save_kwargs=kwargs) as actual: + assert math.isnan(actual.x.encoding["_FillValue"]) + assert ds.x.encoding == {} + + # Test default encoding for int: + ds = Dataset({"x": ("y", np.arange(10.0))}) + kwargs = dict(encoding={"x": {"dtype": "int16"}}) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", ".*floating point data as an integer") + with self.roundtrip(ds, save_kwargs=kwargs) as actual: + assert "_FillValue" not in actual.x.encoding + assert ds.x.encoding == {} + + # Test default encoding for implicit int: + ds = Dataset({"x": ("y", np.arange(10, dtype="int16"))}) + with self.roundtrip(ds) as actual: + assert "_FillValue" not in actual.x.encoding + assert ds.x.encoding == {} + + def test_explicitly_omit_fill_value(self) -> None: + ds = Dataset({"x": ("y", [np.pi, -np.pi])}) + ds.x.encoding["_FillValue"] = None + with self.roundtrip(ds) as actual: + assert "_FillValue" not in actual.x.encoding + + def test_explicitly_omit_fill_value_via_encoding_kwarg(self) -> None: + ds = Dataset({"x": ("y", [np.pi, -np.pi])}) + kwargs = dict(encoding={"x": {"_FillValue": None}}) + with self.roundtrip(ds, save_kwargs=kwargs) as actual: + assert "_FillValue" not in actual.x.encoding + assert ds.y.encoding == {} + + def test_explicitly_omit_fill_value_in_coord(self) -> None: + ds = Dataset({"x": ("y", [np.pi, -np.pi])}, coords={"y": [0.0, 1.0]}) + ds.y.encoding["_FillValue"] = None + with self.roundtrip(ds) as actual: + assert "_FillValue" not in actual.y.encoding + + def test_explicitly_omit_fill_value_in_coord_via_encoding_kwarg(self) -> None: + ds = Dataset({"x": ("y", [np.pi, -np.pi])}, coords={"y": [0.0, 1.0]}) + kwargs = dict(encoding={"y": {"_FillValue": None}}) + with self.roundtrip(ds, save_kwargs=kwargs) as actual: + assert "_FillValue" not in actual.y.encoding + assert ds.y.encoding == {} + + def test_encoding_same_dtype(self) -> None: + ds = Dataset({"x": ("y", np.arange(10.0, dtype="f4"))}) + kwargs = dict(encoding={"x": {"dtype": "f4"}}) + with self.roundtrip(ds, save_kwargs=kwargs) as actual: + encoded_dtype = actual.x.encoding["dtype"] + # On OS X, dtype sometimes switches endianness for unclear reasons + assert encoded_dtype.kind == "f" and encoded_dtype.itemsize == 4 + assert ds.x.encoding == {} + + def test_append_write(self) -> None: + # regression for GH1215 + data = create_test_data() + with self.roundtrip_append(data) as actual: + assert_identical(data, actual) + + def test_append_overwrite_values(self) -> None: + # regression for GH1215 + data = create_test_data() + with create_tmp_file(allow_cleanup_failure=False) as tmp_file: + self.save(data, tmp_file, mode="w") + data["var2"][:] = -999 + data["var9"] = data["var2"] * 3 + self.save(data[["var2", "var9"]], tmp_file, mode="a") + with self.open(tmp_file) as actual: + assert_identical(data, actual) + + def test_append_with_invalid_dim_raises(self) -> None: + data = create_test_data() + with create_tmp_file(allow_cleanup_failure=False) as tmp_file: + self.save(data, tmp_file, mode="w") + data["var9"] = data["var2"] * 3 + data = data.isel(dim1=slice(2, 6)) # modify one dimension + with pytest.raises( + ValueError, match=r"Unable to update size for existing dimension" + ): + self.save(data, tmp_file, mode="a") + + def test_multiindex_not_implemented(self) -> None: + ds = Dataset(coords={"y": ("x", [1, 2]), "z": ("x", ["a", "b"])}).set_index( + x=["y", "z"] + ) + with pytest.raises(NotImplementedError, match=r"MultiIndex"): + with self.roundtrip(ds): + pass + + # regression GH8628 (can serialize reset multi-index level coordinates) + ds_reset = ds.reset_index("x") + with self.roundtrip(ds_reset) as actual: + assert_identical(actual, ds_reset) + + +class NetCDFBase(CFEncodedBase): + """Tests for all netCDF3 and netCDF4 backends.""" + + @pytest.mark.skipif( + ON_WINDOWS, reason="Windows does not allow modifying open files" + ) + def test_refresh_from_disk(self) -> None: + # regression test for https://github.com/pydata/xarray/issues/4862 + + with create_tmp_file() as example_1_path: + with create_tmp_file() as example_1_modified_path: + with open_example_dataset("example_1.nc") as example_1: + self.save(example_1, example_1_path) + + example_1.rh.values += 100 + self.save(example_1, example_1_modified_path) + + a = open_dataset(example_1_path, engine=self.engine).load() + + # Simulate external process modifying example_1.nc while this script is running + shutil.copy(example_1_modified_path, example_1_path) + + # Reopen example_1.nc (modified) as `b`; note that `a` has NOT been closed + b = open_dataset(example_1_path, engine=self.engine).load() + + try: + assert not np.array_equal(a.rh.values, b.rh.values) + finally: + a.close() + b.close() + + +_counter = itertools.count() + + +@contextlib.contextmanager +def create_tmp_file( + suffix: str = ".nc", allow_cleanup_failure: bool = False +) -> Iterator[str]: + temp_dir = tempfile.mkdtemp() + path = os.path.join(temp_dir, f"temp-{next(_counter)}{suffix}") + try: + yield path + finally: + try: + shutil.rmtree(temp_dir) + except OSError: + if not allow_cleanup_failure: + raise + + +@contextlib.contextmanager +def create_tmp_files( + nfiles: int, suffix: str = ".nc", allow_cleanup_failure: bool = False +) -> Iterator[list[str]]: + with ExitStack() as stack: + files = [ + stack.enter_context(create_tmp_file(suffix, allow_cleanup_failure)) + for _ in range(nfiles) + ] + yield files + + +class NetCDF4Base(NetCDFBase): + """Tests for both netCDF4-python and h5netcdf.""" + + engine: T_NetcdfEngine = "netcdf4" + + def test_open_group(self) -> None: + # Create a netCDF file with a dataset stored within a group + with create_tmp_file() as tmp_file: + with nc4.Dataset(tmp_file, "w") as rootgrp: + foogrp = rootgrp.createGroup("foo") + ds = foogrp + ds.createDimension("time", size=10) + x = np.arange(10) + ds.createVariable("x", np.int32, dimensions=("time",)) + ds.variables["x"][:] = x + + expected = Dataset() + expected["x"] = ("time", x) + + # check equivalent ways to specify group + for group in "foo", "/foo", "foo/", "/foo/": + with self.open(tmp_file, group=group) as actual: + assert_equal(actual["x"], expected["x"]) + + # check that missing group raises appropriate exception + with pytest.raises(OSError): + open_dataset(tmp_file, group="bar") + with pytest.raises(ValueError, match=r"must be a string"): + open_dataset(tmp_file, group=(1, 2, 3)) + + def test_open_subgroup(self) -> None: + # Create a netCDF file with a dataset stored within a group within a + # group + with create_tmp_file() as tmp_file: + rootgrp = nc4.Dataset(tmp_file, "w") + foogrp = rootgrp.createGroup("foo") + bargrp = foogrp.createGroup("bar") + ds = bargrp + ds.createDimension("time", size=10) + x = np.arange(10) + ds.createVariable("x", np.int32, dimensions=("time",)) + ds.variables["x"][:] = x + rootgrp.close() + + expected = Dataset() + expected["x"] = ("time", x) + + # check equivalent ways to specify group + for group in "foo/bar", "/foo/bar", "foo/bar/", "/foo/bar/": + with self.open(tmp_file, group=group) as actual: + assert_equal(actual["x"], expected["x"]) + + def test_write_groups(self) -> None: + data1 = create_test_data() + data2 = data1 * 2 + with create_tmp_file() as tmp_file: + self.save(data1, tmp_file, group="data/1") + self.save(data2, tmp_file, group="data/2", mode="a") + with self.open(tmp_file, group="data/1") as actual1: + assert_identical(data1, actual1) + with self.open(tmp_file, group="data/2") as actual2: + assert_identical(data2, actual2) + + @pytest.mark.parametrize( + "input_strings, is_bytes", + [ + ([b"foo", b"bar", b"baz"], True), + (["foo", "bar", "baz"], False), + (["foó", "bár", "baź"], False), + ], + ) + def test_encoding_kwarg_vlen_string( + self, input_strings: list[str], is_bytes: bool + ) -> None: + original = Dataset({"x": input_strings}) + + expected_string = ["foo", "bar", "baz"] if is_bytes else input_strings + expected = Dataset({"x": expected_string}) + kwargs = dict(encoding={"x": {"dtype": str}}) + with self.roundtrip(original, save_kwargs=kwargs) as actual: + assert actual["x"].encoding["dtype"] == " None: + values = np.array(["ab", "cdef", np.nan], dtype=object) + expected = Dataset({"x": ("t", values)}) + + original = Dataset({"x": ("t", values, {}, {"_FillValue": fill_value})}) + with self.roundtrip(original) as actual: + assert_identical(expected, actual) + + original = Dataset({"x": ("t", values, {}, {"_FillValue": ""})}) + with self.roundtrip(original) as actual: + assert_identical(expected, actual) + + def test_roundtrip_character_array(self) -> None: + with create_tmp_file() as tmp_file: + values = np.array([["a", "b", "c"], ["d", "e", "f"]], dtype="S") + + with nc4.Dataset(tmp_file, mode="w") as nc: + nc.createDimension("x", 2) + nc.createDimension("string3", 3) + v = nc.createVariable("x", np.dtype("S1"), ("x", "string3")) + v[:] = values + + values = np.array(["abc", "def"], dtype="S") + expected = Dataset({"x": ("x", values)}) + with open_dataset(tmp_file) as actual: + assert_identical(expected, actual) + # regression test for #157 + with self.roundtrip(actual) as roundtripped: + assert_identical(expected, roundtripped) + + def test_default_to_char_arrays(self) -> None: + data = Dataset({"x": np.array(["foo", "zzzz"], dtype="S")}) + with self.roundtrip(data) as actual: + assert_identical(data, actual) + assert actual["x"].dtype == np.dtype("S4") + + def test_open_encodings(self) -> None: + # Create a netCDF file with explicit time units + # and make sure it makes it into the encodings + # and survives a round trip + with create_tmp_file() as tmp_file: + with nc4.Dataset(tmp_file, "w") as ds: + ds.createDimension("time", size=10) + ds.createVariable("time", np.int32, dimensions=("time",)) + units = "days since 1999-01-01" + ds.variables["time"].setncattr("units", units) + ds.variables["time"][:] = np.arange(10) + 4 + + expected = Dataset() + + time = pd.date_range("1999-01-05", periods=10) + encoding = {"units": units, "dtype": np.dtype("int32")} + expected["time"] = ("time", time, {}, encoding) + + with open_dataset(tmp_file) as actual: + assert_equal(actual["time"], expected["time"]) + actual_encoding = { + k: v + for k, v in actual["time"].encoding.items() + if k in expected["time"].encoding + } + assert actual_encoding == expected["time"].encoding + + def test_dump_encodings(self) -> None: + # regression test for #709 + ds = Dataset({"x": ("y", np.arange(10.0))}) + kwargs = dict(encoding={"x": {"zlib": True}}) + with self.roundtrip(ds, save_kwargs=kwargs) as actual: + assert actual.x.encoding["zlib"] + + def test_dump_and_open_encodings(self) -> None: + # Create a netCDF file with explicit time units + # and make sure it makes it into the encodings + # and survives a round trip + with create_tmp_file() as tmp_file: + with nc4.Dataset(tmp_file, "w") as ds: + ds.createDimension("time", size=10) + ds.createVariable("time", np.int32, dimensions=("time",)) + units = "days since 1999-01-01" + ds.variables["time"].setncattr("units", units) + ds.variables["time"][:] = np.arange(10) + 4 + + with open_dataset(tmp_file) as xarray_dataset: + with create_tmp_file() as tmp_file2: + xarray_dataset.to_netcdf(tmp_file2) + with nc4.Dataset(tmp_file2, "r") as ds: + assert ds.variables["time"].getncattr("units") == units + assert_array_equal(ds.variables["time"], np.arange(10) + 4) + + def test_compression_encoding_legacy(self) -> None: + data = create_test_data() + data["var2"].encoding.update( + { + "zlib": True, + "chunksizes": (5, 5), + "fletcher32": True, + "shuffle": True, + "original_shape": data.var2.shape, + } + ) + with self.roundtrip(data) as actual: + for k, v in data["var2"].encoding.items(): + assert v == actual["var2"].encoding[k] + + # regression test for #156 + expected = data.isel(dim1=0) + with self.roundtrip(expected) as actual: + assert_equal(expected, actual) + + def test_encoding_kwarg_compression(self) -> None: + ds = Dataset({"x": np.arange(10.0)}) + encoding = dict( + dtype="f4", + zlib=True, + complevel=9, + fletcher32=True, + chunksizes=(5,), + shuffle=True, + ) + kwargs = dict(encoding=dict(x=encoding)) + + with self.roundtrip(ds, save_kwargs=kwargs) as actual: + assert_equal(actual, ds) + assert actual.x.encoding["dtype"] == "f4" + assert actual.x.encoding["zlib"] + assert actual.x.encoding["complevel"] == 9 + assert actual.x.encoding["fletcher32"] + assert actual.x.encoding["chunksizes"] == (5,) + assert actual.x.encoding["shuffle"] + + assert ds.x.encoding == {} + + def test_keep_chunksizes_if_no_original_shape(self) -> None: + ds = Dataset({"x": [1, 2, 3]}) + chunksizes = (2,) + ds.variables["x"].encoding = {"chunksizes": chunksizes} + + with self.roundtrip(ds) as actual: + assert_identical(ds, actual) + assert_array_equal( + ds["x"].encoding["chunksizes"], actual["x"].encoding["chunksizes"] + ) + + def test_preferred_chunks_is_present(self) -> None: + ds = Dataset({"x": [1, 2, 3]}) + chunksizes = (2,) + ds.variables["x"].encoding = {"chunksizes": chunksizes} + + with self.roundtrip(ds) as actual: + assert actual["x"].encoding["preferred_chunks"] == {"x": 2} + + @requires_dask + def test_auto_chunking_is_based_on_disk_chunk_sizes(self) -> None: + x_size = y_size = 1000 + y_chunksize = y_size + x_chunksize = 10 + + with dask.config.set({"array.chunk-size": "100KiB"}): + with self.chunked_roundtrip( + (1, y_size, x_size), + (1, y_chunksize, x_chunksize), + open_kwargs={"chunks": "auto"}, + ) as ds: + t_chunks, y_chunks, x_chunks = ds["image"].data.chunks + assert all(np.asanyarray(y_chunks) == y_chunksize) + # Check that the chunk size is a multiple of the file chunk size + assert all(np.asanyarray(x_chunks) % x_chunksize == 0) + + @requires_dask + def test_base_chunking_uses_disk_chunk_sizes(self) -> None: + x_size = y_size = 1000 + y_chunksize = y_size + x_chunksize = 10 + + with self.chunked_roundtrip( + (1, y_size, x_size), + (1, y_chunksize, x_chunksize), + open_kwargs={"chunks": {}}, + ) as ds: + for chunksizes, expected in zip( + ds["image"].data.chunks, (1, y_chunksize, x_chunksize) + ): + assert all(np.asanyarray(chunksizes) == expected) + + @contextlib.contextmanager + def chunked_roundtrip( + self, + array_shape: tuple[int, int, int], + chunk_sizes: tuple[int, int, int], + open_kwargs: dict[str, Any] | None = None, + ) -> Generator[Dataset, None, None]: + t_size, y_size, x_size = array_shape + t_chunksize, y_chunksize, x_chunksize = chunk_sizes + + image = xr.DataArray( + np.arange(t_size * x_size * y_size, dtype=np.int16).reshape( + (t_size, y_size, x_size) + ), + dims=["t", "y", "x"], + ) + image.encoding = {"chunksizes": (t_chunksize, y_chunksize, x_chunksize)} + dataset = xr.Dataset(dict(image=image)) + + with self.roundtrip(dataset, open_kwargs=open_kwargs) as ds: + yield ds + + def test_preferred_chunks_are_disk_chunk_sizes(self) -> None: + x_size = y_size = 1000 + y_chunksize = y_size + x_chunksize = 10 + + with self.chunked_roundtrip( + (1, y_size, x_size), (1, y_chunksize, x_chunksize) + ) as ds: + assert ds["image"].encoding["preferred_chunks"] == { + "t": 1, + "y": y_chunksize, + "x": x_chunksize, + } + + def test_encoding_chunksizes_unlimited(self) -> None: + # regression test for GH1225 + ds = Dataset({"x": [1, 2, 3], "y": ("x", [2, 3, 4])}) + ds.variables["x"].encoding = { + "zlib": False, + "shuffle": False, + "complevel": 0, + "fletcher32": False, + "contiguous": False, + "chunksizes": (2**20,), + "original_shape": (3,), + } + with self.roundtrip(ds) as actual: + assert_equal(ds, actual) + + def test_mask_and_scale(self) -> None: + with create_tmp_file() as tmp_file: + with nc4.Dataset(tmp_file, mode="w") as nc: + nc.createDimension("t", 5) + nc.createVariable("x", "int16", ("t",), fill_value=-1) + v = nc.variables["x"] + v.set_auto_maskandscale(False) + v.add_offset = 10 + v.scale_factor = 0.1 + v[:] = np.array([-1, -1, 0, 1, 2]) + dtype = type(v.scale_factor) + + # first make sure netCDF4 reads the masked and scaled data + # correctly + with nc4.Dataset(tmp_file, mode="r") as nc: + expected = np.ma.array( + [-1, -1, 10, 10.1, 10.2], mask=[True, True, False, False, False] + ) + actual = nc.variables["x"][:] + assert_array_equal(expected, actual) + + # now check xarray + with open_dataset(tmp_file) as ds: + expected = create_masked_and_scaled_data(np.dtype(dtype)) + assert_identical(expected, ds) + + def test_0dimensional_variable(self) -> None: + # This fix verifies our work-around to this netCDF4-python bug: + # https://github.com/Unidata/netcdf4-python/pull/220 + with create_tmp_file() as tmp_file: + with nc4.Dataset(tmp_file, mode="w") as nc: + v = nc.createVariable("x", "int16") + v[...] = 123 + + with open_dataset(tmp_file) as ds: + expected = Dataset({"x": ((), 123)}) + assert_identical(expected, ds) + + def test_read_variable_len_strings(self) -> None: + with create_tmp_file() as tmp_file: + values = np.array(["foo", "bar", "baz"], dtype=object) + + with nc4.Dataset(tmp_file, mode="w") as nc: + nc.createDimension("x", 3) + v = nc.createVariable("x", str, ("x",)) + v[:] = values + + expected = Dataset({"x": ("x", values)}) + for kwargs in [{}, {"decode_cf": True}]: + with open_dataset(tmp_file, **cast(dict, kwargs)) as actual: + assert_identical(expected, actual) + + def test_encoding_unlimited_dims(self) -> None: + ds = Dataset({"x": ("y", np.arange(10.0))}) + with self.roundtrip(ds, save_kwargs=dict(unlimited_dims=["y"])) as actual: + assert actual.encoding["unlimited_dims"] == set("y") + assert_equal(ds, actual) + ds.encoding = {"unlimited_dims": ["y"]} + with self.roundtrip(ds) as actual: + assert actual.encoding["unlimited_dims"] == set("y") + assert_equal(ds, actual) + + def test_raise_on_forward_slashes_in_names(self) -> None: + # test for forward slash in variable names and dimensions + # see GH 7943 + data_vars: list[dict[str, Any]] = [ + {"PASS/FAIL": (["PASSFAIL"], np.array([0]))}, + {"PASS/FAIL": np.array([0])}, + {"PASSFAIL": (["PASS/FAIL"], np.array([0]))}, + ] + for dv in data_vars: + ds = Dataset(data_vars=dv) + with pytest.raises(ValueError, match="Forward slashes '/' are not allowed"): + with self.roundtrip(ds): + pass + + @requires_netCDF4 + def test_encoding_enum__no_fill_value(self): + with create_tmp_file() as tmp_file: + cloud_type_dict = {"clear": 0, "cloudy": 1} + with nc4.Dataset(tmp_file, mode="w") as nc: + nc.createDimension("time", size=2) + cloud_type = nc.createEnumType("u1", "cloud_type", cloud_type_dict) + v = nc.createVariable( + "clouds", + cloud_type, + "time", + fill_value=None, + ) + v[:] = 1 + with open_dataset(tmp_file) as original: + save_kwargs = {} + if self.engine == "h5netcdf": + save_kwargs["invalid_netcdf"] = True + with self.roundtrip(original, save_kwargs=save_kwargs) as actual: + assert_equal(original, actual) + assert ( + actual.clouds.encoding["dtype"].metadata["enum"] + == cloud_type_dict + ) + if self.engine != "h5netcdf": + # not implemented in h5netcdf yet + assert ( + actual.clouds.encoding["dtype"].metadata["enum_name"] + == "cloud_type" + ) + + @requires_netCDF4 + def test_encoding_enum__multiple_variable_with_enum(self): + with create_tmp_file() as tmp_file: + cloud_type_dict = {"clear": 0, "cloudy": 1, "missing": 255} + with nc4.Dataset(tmp_file, mode="w") as nc: + nc.createDimension("time", size=2) + cloud_type = nc.createEnumType("u1", "cloud_type", cloud_type_dict) + nc.createVariable( + "clouds", + cloud_type, + "time", + fill_value=255, + ) + nc.createVariable( + "tifa", + cloud_type, + "time", + fill_value=255, + ) + with open_dataset(tmp_file) as original: + save_kwargs = {} + if self.engine == "h5netcdf": + save_kwargs["invalid_netcdf"] = True + with self.roundtrip(original, save_kwargs=save_kwargs) as actual: + assert_equal(original, actual) + assert ( + actual.clouds.encoding["dtype"] == actual.tifa.encoding["dtype"] + ) + assert ( + actual.clouds.encoding["dtype"].metadata + == actual.tifa.encoding["dtype"].metadata + ) + assert ( + actual.clouds.encoding["dtype"].metadata["enum"] + == cloud_type_dict + ) + if self.engine != "h5netcdf": + # not implemented in h5netcdf yet + assert ( + actual.clouds.encoding["dtype"].metadata["enum_name"] + == "cloud_type" + ) + + @requires_netCDF4 + def test_encoding_enum__error_multiple_variable_with_changing_enum(self): + """ + Given 2 variables, if they share the same enum type, + the 2 enum definition should be identical. + """ + with create_tmp_file() as tmp_file: + cloud_type_dict = {"clear": 0, "cloudy": 1, "missing": 255} + with nc4.Dataset(tmp_file, mode="w") as nc: + nc.createDimension("time", size=2) + cloud_type = nc.createEnumType("u1", "cloud_type", cloud_type_dict) + nc.createVariable( + "clouds", + cloud_type, + "time", + fill_value=255, + ) + nc.createVariable( + "tifa", + cloud_type, + "time", + fill_value=255, + ) + with open_dataset(tmp_file) as original: + assert ( + original.clouds.encoding["dtype"].metadata + == original.tifa.encoding["dtype"].metadata + ) + modified_enum = original.clouds.encoding["dtype"].metadata["enum"] + modified_enum.update({"neblig": 2}) + original.clouds.encoding["dtype"] = np.dtype( + "u1", + metadata={"enum": modified_enum, "enum_name": "cloud_type"}, + ) + if self.engine != "h5netcdf": + # not implemented yet in h5netcdf + with pytest.raises( + ValueError, + match=( + "Cannot save variable .*" + " because an enum `cloud_type` already exists in the Dataset .*" + ), + ): + with self.roundtrip(original): + pass + + +@requires_netCDF4 +class TestNetCDF4Data(NetCDF4Base): + @contextlib.contextmanager + def create_store(self): + with create_tmp_file() as tmp_file: + with backends.NetCDF4DataStore.open(tmp_file, mode="w") as store: + yield store + + def test_variable_order(self) -> None: + # doesn't work with scipy or h5py :( + ds = Dataset() + ds["a"] = 1 + ds["z"] = 2 + ds["b"] = 3 + ds.coords["c"] = 4 + + with self.roundtrip(ds) as actual: + assert list(ds.variables) == list(actual.variables) + + def test_unsorted_index_raises(self) -> None: + # should be fixed in netcdf4 v1.2.1 + random_data = np.random.random(size=(4, 6)) + dim0 = [0, 1, 2, 3] + dim1 = [0, 2, 1, 3, 5, 4] # We will sort this in a later step + da = xr.DataArray( + data=random_data, + dims=("dim0", "dim1"), + coords={"dim0": dim0, "dim1": dim1}, + name="randovar", + ) + ds = da.to_dataset() + + with self.roundtrip(ds) as ondisk: + inds = np.argsort(dim1) + ds2 = ondisk.isel(dim1=inds) + # Older versions of NetCDF4 raise an exception here, and if so we + # want to ensure we improve (that is, replace) the error message + try: + ds2.randovar.values + except IndexError as err: + assert "first by calling .load" in str(err) + + def test_setncattr_string(self) -> None: + list_of_strings = ["list", "of", "strings"] + one_element_list_of_strings = ["one element"] + one_string = "one string" + attrs = { + "foo": list_of_strings, + "bar": one_element_list_of_strings, + "baz": one_string, + } + ds = Dataset({"x": ("y", [1, 2, 3], attrs)}, attrs=attrs) + + with self.roundtrip(ds) as actual: + for totest in [actual, actual["x"]]: + assert_array_equal(list_of_strings, totest.attrs["foo"]) + assert_array_equal(one_element_list_of_strings, totest.attrs["bar"]) + assert one_string == totest.attrs["baz"] + + @pytest.mark.parametrize( + "compression", + [ + None, + "zlib", + "szip", + "zstd", + "blosc_lz", + "blosc_lz4", + "blosc_lz4hc", + "blosc_zlib", + "blosc_zstd", + ], + ) + @requires_netCDF4_1_6_2_or_above + @pytest.mark.xfail(ON_WINDOWS, reason="new compression not yet implemented") + def test_compression_encoding(self, compression: str | None) -> None: + data = create_test_data(dim_sizes=(20, 80, 10)) + encoding_params: dict[str, Any] = dict(compression=compression, blosc_shuffle=1) + data["var2"].encoding.update(encoding_params) + data["var2"].encoding.update( + { + "chunksizes": (20, 40), + "original_shape": data.var2.shape, + "blosc_shuffle": 1, + "fletcher32": False, + } + ) + with self.roundtrip(data) as actual: + expected_encoding = data["var2"].encoding.copy() + # compression does not appear in the retrieved encoding, that differs + # from the input encoding. shuffle also chantges. Here we modify the + # expected encoding to account for this + compression = expected_encoding.pop("compression") + blosc_shuffle = expected_encoding.pop("blosc_shuffle") + if compression is not None: + if "blosc" in compression and blosc_shuffle: + expected_encoding["blosc"] = { + "compressor": compression, + "shuffle": blosc_shuffle, + } + expected_encoding["shuffle"] = False + elif compression == "szip": + expected_encoding["szip"] = { + "coding": "nn", + "pixels_per_block": 8, + } + expected_encoding["shuffle"] = False + else: + # This will set a key like zlib=true which is what appears in + # the encoding when we read it. + expected_encoding[compression] = True + if compression == "zstd": + expected_encoding["shuffle"] = False + else: + expected_encoding["shuffle"] = False + + actual_encoding = actual["var2"].encoding + assert expected_encoding.items() <= actual_encoding.items() + if ( + encoding_params["compression"] is not None + and "blosc" not in encoding_params["compression"] + ): + # regression test for #156 + expected = data.isel(dim1=0) + with self.roundtrip(expected) as actual: + assert_equal(expected, actual) + + @pytest.mark.skip(reason="https://github.com/Unidata/netcdf4-python/issues/1195") + def test_refresh_from_disk(self) -> None: + super().test_refresh_from_disk() + + +@requires_netCDF4 +class TestNetCDF4AlreadyOpen: + def test_base_case(self) -> None: + with create_tmp_file() as tmp_file: + with nc4.Dataset(tmp_file, mode="w") as nc: + v = nc.createVariable("x", "int") + v[...] = 42 + + nc = nc4.Dataset(tmp_file, mode="r") + store = backends.NetCDF4DataStore(nc) + with open_dataset(store) as ds: + expected = Dataset({"x": ((), 42)}) + assert_identical(expected, ds) + + def test_group(self) -> None: + with create_tmp_file() as tmp_file: + with nc4.Dataset(tmp_file, mode="w") as nc: + group = nc.createGroup("g") + v = group.createVariable("x", "int") + v[...] = 42 + + nc = nc4.Dataset(tmp_file, mode="r") + store = backends.NetCDF4DataStore(nc.groups["g"]) + with open_dataset(store) as ds: + expected = Dataset({"x": ((), 42)}) + assert_identical(expected, ds) + + nc = nc4.Dataset(tmp_file, mode="r") + store = backends.NetCDF4DataStore(nc, group="g") + with open_dataset(store) as ds: + expected = Dataset({"x": ((), 42)}) + assert_identical(expected, ds) + + with nc4.Dataset(tmp_file, mode="r") as nc: + with pytest.raises(ValueError, match="must supply a root"): + backends.NetCDF4DataStore(nc.groups["g"], group="g") + + def test_deepcopy(self) -> None: + # regression test for https://github.com/pydata/xarray/issues/4425 + with create_tmp_file() as tmp_file: + with nc4.Dataset(tmp_file, mode="w") as nc: + nc.createDimension("x", 10) + v = nc.createVariable("y", np.int32, ("x",)) + v[:] = np.arange(10) + + h5 = nc4.Dataset(tmp_file, mode="r") + store = backends.NetCDF4DataStore(h5) + with open_dataset(store) as ds: + copied = ds.copy(deep=True) + expected = Dataset({"y": ("x", np.arange(10))}) + assert_identical(expected, copied) + + +@requires_netCDF4 +@requires_dask +@pytest.mark.filterwarnings("ignore:deallocating CachingFileManager") +class TestNetCDF4ViaDaskData(TestNetCDF4Data): + @contextlib.contextmanager + def roundtrip( + self, data, save_kwargs=None, open_kwargs=None, allow_cleanup_failure=False + ): + if open_kwargs is None: + open_kwargs = {} + if save_kwargs is None: + save_kwargs = {} + open_kwargs.setdefault("chunks", -1) + with TestNetCDF4Data.roundtrip( + self, data, save_kwargs, open_kwargs, allow_cleanup_failure + ) as ds: + yield ds + + def test_unsorted_index_raises(self) -> None: + # Skip when using dask because dask rewrites indexers to getitem, + # dask first pulls items by block. + pass + + @pytest.mark.skip(reason="caching behavior differs for dask") + def test_dataset_caching(self) -> None: + pass + + def test_write_inconsistent_chunks(self) -> None: + # Construct two variables with the same dimensions, but different + # chunk sizes. + x = da.zeros((100, 100), dtype="f4", chunks=(50, 100)) + x = DataArray(data=x, dims=("lat", "lon"), name="x") + x.encoding["chunksizes"] = (50, 100) + x.encoding["original_shape"] = (100, 100) + y = da.ones((100, 100), dtype="f4", chunks=(100, 50)) + y = DataArray(data=y, dims=("lat", "lon"), name="y") + y.encoding["chunksizes"] = (100, 50) + y.encoding["original_shape"] = (100, 100) + # Put them both into the same dataset + ds = Dataset({"x": x, "y": y}) + with self.roundtrip(ds) as actual: + assert actual["x"].encoding["chunksizes"] == (50, 100) + assert actual["y"].encoding["chunksizes"] == (100, 50) + + +@requires_zarr +class ZarrBase(CFEncodedBase): + DIMENSION_KEY = "_ARRAY_DIMENSIONS" + zarr_version = 2 + version_kwargs: dict[str, Any] = {} + + def create_zarr_target(self): + raise NotImplementedError + + @contextlib.contextmanager + def create_store(self): + with self.create_zarr_target() as store_target: + yield backends.ZarrStore.open_group( + store_target, mode="w", **self.version_kwargs + ) + + def save(self, dataset, store_target, **kwargs): + return dataset.to_zarr(store=store_target, **kwargs, **self.version_kwargs) + + @contextlib.contextmanager + def open(self, store_target, **kwargs): + with xr.open_dataset( + store_target, engine="zarr", **kwargs, **self.version_kwargs + ) as ds: + yield ds + + @contextlib.contextmanager + def roundtrip( + self, data, save_kwargs=None, open_kwargs=None, allow_cleanup_failure=False + ): + if save_kwargs is None: + save_kwargs = {} + if open_kwargs is None: + open_kwargs = {} + with self.create_zarr_target() as store_target: + self.save(data, store_target, **save_kwargs) + with self.open(store_target, **open_kwargs) as ds: + yield ds + + @pytest.mark.parametrize("consolidated", [False, True, None]) + def test_roundtrip_consolidated(self, consolidated) -> None: + if consolidated and self.zarr_version > 2: + pytest.xfail("consolidated metadata is not supported for zarr v3 yet") + expected = create_test_data() + with self.roundtrip( + expected, + save_kwargs={"consolidated": consolidated}, + open_kwargs={"backend_kwargs": {"consolidated": consolidated}}, + ) as actual: + self.check_dtypes_roundtripped(expected, actual) + assert_identical(expected, actual) + + def test_read_non_consolidated_warning(self) -> None: + if self.zarr_version > 2: + pytest.xfail("consolidated metadata is not supported for zarr v3 yet") + + expected = create_test_data() + with self.create_zarr_target() as store: + expected.to_zarr(store, consolidated=False, **self.version_kwargs) + with pytest.warns( + RuntimeWarning, + match="Failed to open Zarr store with consolidated", + ): + with xr.open_zarr(store, **self.version_kwargs) as ds: + assert_identical(ds, expected) + + def test_non_existent_store(self) -> None: + with pytest.raises(FileNotFoundError, match=r"No such file or directory:"): + xr.open_zarr(f"{uuid.uuid4()}") + + def test_with_chunkstore(self) -> None: + expected = create_test_data() + with ( + self.create_zarr_target() as store_target, + self.create_zarr_target() as chunk_store, + ): + save_kwargs = {"chunk_store": chunk_store} + self.save(expected, store_target, **save_kwargs) + # the chunk store must have been populated with some entries + assert len(chunk_store) > 0 + open_kwargs = {"backend_kwargs": {"chunk_store": chunk_store}} + with self.open(store_target, **open_kwargs) as ds: + assert_equal(ds, expected) + + @requires_dask + def test_auto_chunk(self) -> None: + original = create_test_data().chunk() + + with self.roundtrip(original, open_kwargs={"chunks": None}) as actual: + for k, v in actual.variables.items(): + # only index variables should be in memory + assert v._in_memory == (k in actual.dims) + # there should be no chunks + assert v.chunks is None + + with self.roundtrip(original, open_kwargs={"chunks": {}}) as actual: + for k, v in actual.variables.items(): + # only index variables should be in memory + assert v._in_memory == (k in actual.dims) + # chunk size should be the same as original + assert v.chunks == original[k].chunks + + @requires_dask + @pytest.mark.filterwarnings("ignore:The specified chunks separate:UserWarning") + def test_manual_chunk(self) -> None: + original = create_test_data().chunk({"dim1": 3, "dim2": 4, "dim3": 3}) + + # Using chunks = None should return non-chunked arrays + open_kwargs: dict[str, Any] = {"chunks": None} + with self.roundtrip(original, open_kwargs=open_kwargs) as actual: + for k, v in actual.variables.items(): + # only index variables should be in memory + assert v._in_memory == (k in actual.dims) + # there should be no chunks + assert v.chunks is None + + # uniform arrays + for i in range(2, 6): + rechunked = original.chunk(chunks=i) + open_kwargs = {"chunks": i} + with self.roundtrip(original, open_kwargs=open_kwargs) as actual: + for k, v in actual.variables.items(): + # only index variables should be in memory + assert v._in_memory == (k in actual.dims) + # chunk size should be the same as rechunked + assert v.chunks == rechunked[k].chunks + + chunks = {"dim1": 2, "dim2": 3, "dim3": 5} + rechunked = original.chunk(chunks=chunks) + + open_kwargs = { + "chunks": chunks, + "backend_kwargs": {"overwrite_encoded_chunks": True}, + } + with self.roundtrip(original, open_kwargs=open_kwargs) as actual: + for k, v in actual.variables.items(): + assert v.chunks == rechunked[k].chunks + + with self.roundtrip(actual) as auto: + # encoding should have changed + for k, v in actual.variables.items(): + assert v.chunks == rechunked[k].chunks + + assert_identical(actual, auto) + assert_identical(actual.load(), auto.load()) + + @requires_dask + def test_warning_on_bad_chunks(self) -> None: + original = create_test_data().chunk({"dim1": 4, "dim2": 3, "dim3": 3}) + + bad_chunks = (2, {"dim2": (3, 3, 2, 1)}) + for chunks in bad_chunks: + kwargs = {"chunks": chunks} + with pytest.warns(UserWarning): + with self.roundtrip(original, open_kwargs=kwargs) as actual: + for k, v in actual.variables.items(): + # only index variables should be in memory + assert v._in_memory == (k in actual.dims) + + good_chunks: tuple[dict[str, Any], ...] = ({"dim2": 3}, {"dim3": (6, 4)}, {}) + for chunks in good_chunks: + kwargs = {"chunks": chunks} + with assert_no_warnings(): + with self.roundtrip(original, open_kwargs=kwargs) as actual: + for k, v in actual.variables.items(): + # only index variables should be in memory + assert v._in_memory == (k in actual.dims) + + @requires_dask + def test_deprecate_auto_chunk(self) -> None: + original = create_test_data().chunk() + with pytest.raises(TypeError): + with self.roundtrip(original, open_kwargs={"auto_chunk": True}) as actual: + for k, v in actual.variables.items(): + # only index variables should be in memory + assert v._in_memory == (k in actual.dims) + # chunk size should be the same as original + assert v.chunks == original[k].chunks + + with pytest.raises(TypeError): + with self.roundtrip(original, open_kwargs={"auto_chunk": False}) as actual: + for k, v in actual.variables.items(): + # only index variables should be in memory + assert v._in_memory == (k in actual.dims) + # there should be no chunks + assert v.chunks is None + + @requires_dask + def test_write_uneven_dask_chunks(self) -> None: + # regression for GH#2225 + original = create_test_data().chunk({"dim1": 3, "dim2": 4, "dim3": 3}) + with self.roundtrip(original, open_kwargs={"chunks": {}}) as actual: + for k, v in actual.data_vars.items(): + assert v.chunks == actual[k].chunks + + def test_chunk_encoding(self) -> None: + # These datasets have no dask chunks. All chunking specified in + # encoding + data = create_test_data() + chunks = (5, 5) + data["var2"].encoding.update({"chunks": chunks}) + + with self.roundtrip(data) as actual: + assert chunks == actual["var2"].encoding["chunks"] + + # expect an error with non-integer chunks + data["var2"].encoding.update({"chunks": (5, 4.5)}) + with pytest.raises(TypeError): + with self.roundtrip(data) as actual: + pass + + @requires_dask + @pytest.mark.skipif( + ON_WINDOWS, + reason="Very flaky on Windows CI. Can re-enable assuming it starts consistently passing.", + ) + def test_chunk_encoding_with_dask(self) -> None: + # These datasets DO have dask chunks. Need to check for various + # interactions between dask and zarr chunks + ds = xr.DataArray((np.arange(12)), dims="x", name="var1").to_dataset() + + # - no encoding specified - + # zarr automatically gets chunk information from dask chunks + ds_chunk4 = ds.chunk({"x": 4}) + with self.roundtrip(ds_chunk4) as actual: + assert (4,) == actual["var1"].encoding["chunks"] + + # should fail if dask_chunks are irregular... + ds_chunk_irreg = ds.chunk({"x": (5, 4, 3)}) + with pytest.raises(ValueError, match=r"uniform chunk sizes."): + with self.roundtrip(ds_chunk_irreg) as actual: + pass + + # should fail if encoding["chunks"] clashes with dask_chunks + badenc = ds.chunk({"x": 4}) + badenc.var1.encoding["chunks"] = (6,) + with pytest.raises(ValueError, match=r"named 'var1' would overlap"): + with self.roundtrip(badenc) as actual: + pass + + # unless... + with self.roundtrip(badenc, save_kwargs={"safe_chunks": False}) as actual: + # don't actually check equality because the data could be corrupted + pass + + # if dask chunks (4) are an integer multiple of zarr chunks (2) it should not fail... + goodenc = ds.chunk({"x": 4}) + goodenc.var1.encoding["chunks"] = (2,) + with self.roundtrip(goodenc) as actual: + pass + + # if initial dask chunks are aligned, size of last dask chunk doesn't matter + goodenc = ds.chunk({"x": (3, 3, 6)}) + goodenc.var1.encoding["chunks"] = (3,) + with self.roundtrip(goodenc) as actual: + pass + + goodenc = ds.chunk({"x": (3, 6, 3)}) + goodenc.var1.encoding["chunks"] = (3,) + with self.roundtrip(goodenc) as actual: + pass + + # ... also if the last chunk is irregular + ds_chunk_irreg = ds.chunk({"x": (5, 5, 2)}) + with self.roundtrip(ds_chunk_irreg) as actual: + assert (5,) == actual["var1"].encoding["chunks"] + # re-save Zarr arrays + with self.roundtrip(ds_chunk_irreg) as original: + with self.roundtrip(original) as actual: + assert_identical(original, actual) + + # but itermediate unaligned chunks are bad + badenc = ds.chunk({"x": (3, 5, 3, 1)}) + badenc.var1.encoding["chunks"] = (3,) + with pytest.raises(ValueError, match=r"would overlap multiple dask chunks"): + with self.roundtrip(badenc) as actual: + pass + + # - encoding specified - + # specify compatible encodings + for chunk_enc in 4, (4,): + ds_chunk4["var1"].encoding.update({"chunks": chunk_enc}) + with self.roundtrip(ds_chunk4) as actual: + assert (4,) == actual["var1"].encoding["chunks"] + + # TODO: remove this failure once synchronized overlapping writes are + # supported by xarray + ds_chunk4["var1"].encoding.update({"chunks": 5}) + with pytest.raises(ValueError, match=r"named 'var1' would overlap"): + with self.roundtrip(ds_chunk4) as actual: + pass + # override option + with self.roundtrip(ds_chunk4, save_kwargs={"safe_chunks": False}) as actual: + # don't actually check equality because the data could be corrupted + pass + + def test_drop_encoding(self): + with open_example_dataset("example_1.nc") as ds: + encodings = {v: {**ds[v].encoding} for v in ds.data_vars} + with self.create_zarr_target() as store: + ds.to_zarr(store, encoding=encodings) + + def test_hidden_zarr_keys(self) -> None: + expected = create_test_data() + with self.create_store() as store: + expected.dump_to_store(store) + zarr_group = store.ds + + # check that a variable hidden attribute is present and correct + # JSON only has a single array type, which maps to list in Python. + # In contrast, dims in xarray is always a tuple. + for var in expected.variables.keys(): + dims = zarr_group[var].attrs[self.DIMENSION_KEY] + assert dims == list(expected[var].dims) + + with xr.decode_cf(store): + # make sure it is hidden + for var in expected.variables.keys(): + assert self.DIMENSION_KEY not in expected[var].attrs + + # put it back and try removing from a variable + del zarr_group.var2.attrs[self.DIMENSION_KEY] + with pytest.raises(KeyError): + with xr.decode_cf(store): + pass + + @pytest.mark.parametrize("group", [None, "group1"]) + def test_write_persistence_modes(self, group) -> None: + original = create_test_data() + + # overwrite mode + with self.roundtrip( + original, + save_kwargs={"mode": "w", "group": group}, + open_kwargs={"group": group}, + ) as actual: + assert_identical(original, actual) + + # don't overwrite mode + with self.roundtrip( + original, + save_kwargs={"mode": "w-", "group": group}, + open_kwargs={"group": group}, + ) as actual: + assert_identical(original, actual) + + # make sure overwriting works as expected + with self.create_zarr_target() as store: + self.save(original, store) + # should overwrite with no error + self.save(original, store, mode="w", group=group) + with self.open(store, group=group) as actual: + assert_identical(original, actual) + with pytest.raises(ValueError): + self.save(original, store, mode="w-") + + # check append mode for normal write + with self.roundtrip( + original, + save_kwargs={"mode": "a", "group": group}, + open_kwargs={"group": group}, + ) as actual: + assert_identical(original, actual) + + # check append mode for append write + ds, ds_to_append, _ = create_append_test_data() + with self.create_zarr_target() as store_target: + ds.to_zarr(store_target, mode="w", group=group, **self.version_kwargs) + ds_to_append.to_zarr( + store_target, append_dim="time", group=group, **self.version_kwargs + ) + original = xr.concat([ds, ds_to_append], dim="time") + actual = xr.open_dataset( + store_target, group=group, engine="zarr", **self.version_kwargs + ) + assert_identical(original, actual) + + def test_compressor_encoding(self) -> None: + original = create_test_data() + # specify a custom compressor + import zarr + + blosc_comp = zarr.Blosc(cname="zstd", clevel=3, shuffle=2) + save_kwargs = dict(encoding={"var1": {"compressor": blosc_comp}}) + with self.roundtrip(original, save_kwargs=save_kwargs) as ds: + actual = ds["var1"].encoding["compressor"] + # get_config returns a dictionary of compressor attributes + assert actual.get_config() == blosc_comp.get_config() + + def test_group(self) -> None: + original = create_test_data() + group = "some/random/path" + with self.roundtrip( + original, save_kwargs={"group": group}, open_kwargs={"group": group} + ) as actual: + assert_identical(original, actual) + + def test_zarr_mode_w_overwrites_encoding(self) -> None: + import zarr + + data = Dataset({"foo": ("x", [1.0, 1.0, 1.0])}) + with self.create_zarr_target() as store: + data.to_zarr( + store, **self.version_kwargs, encoding={"foo": {"add_offset": 1}} + ) + np.testing.assert_equal( + zarr.open_group(store, **self.version_kwargs)["foo"], data.foo.data - 1 + ) + data.to_zarr( + store, + **self.version_kwargs, + encoding={"foo": {"add_offset": 0}}, + mode="w", + ) + np.testing.assert_equal( + zarr.open_group(store, **self.version_kwargs)["foo"], data.foo.data + ) + + def test_encoding_kwarg_fixed_width_string(self) -> None: + # not relevant for zarr, since we don't use EncodedStringCoder + pass + + def test_dataset_caching(self) -> None: + super().test_dataset_caching() + + def test_append_write(self) -> None: + super().test_append_write() + + def test_append_with_mode_rplus_success(self) -> None: + original = Dataset({"foo": ("x", [1])}) + modified = Dataset({"foo": ("x", [2])}) + with self.create_zarr_target() as store: + original.to_zarr(store, **self.version_kwargs) + modified.to_zarr(store, mode="r+", **self.version_kwargs) + with self.open(store) as actual: + assert_identical(actual, modified) + + def test_append_with_mode_rplus_fails(self) -> None: + original = Dataset({"foo": ("x", [1])}) + modified = Dataset({"bar": ("x", [2])}) + with self.create_zarr_target() as store: + original.to_zarr(store, **self.version_kwargs) + with pytest.raises( + ValueError, match="dataset contains non-pre-existing variables" + ): + modified.to_zarr(store, mode="r+", **self.version_kwargs) + + def test_append_with_invalid_dim_raises(self) -> None: + ds, ds_to_append, _ = create_append_test_data() + with self.create_zarr_target() as store_target: + ds.to_zarr(store_target, mode="w", **self.version_kwargs) + with pytest.raises( + ValueError, match="does not match any existing dataset dimensions" + ): + ds_to_append.to_zarr( + store_target, append_dim="notvalid", **self.version_kwargs + ) + + def test_append_with_no_dims_raises(self) -> None: + with self.create_zarr_target() as store_target: + Dataset({"foo": ("x", [1])}).to_zarr( + store_target, mode="w", **self.version_kwargs + ) + with pytest.raises(ValueError, match="different dimension names"): + Dataset({"foo": ("y", [2])}).to_zarr( + store_target, mode="a", **self.version_kwargs + ) + + def test_append_with_append_dim_not_set_raises(self) -> None: + ds, ds_to_append, _ = create_append_test_data() + with self.create_zarr_target() as store_target: + ds.to_zarr(store_target, mode="w", **self.version_kwargs) + with pytest.raises(ValueError, match="different dimension sizes"): + ds_to_append.to_zarr(store_target, mode="a", **self.version_kwargs) + + def test_append_with_mode_not_a_raises(self) -> None: + ds, ds_to_append, _ = create_append_test_data() + with self.create_zarr_target() as store_target: + ds.to_zarr(store_target, mode="w", **self.version_kwargs) + with pytest.raises(ValueError, match="cannot set append_dim unless"): + ds_to_append.to_zarr( + store_target, mode="w", append_dim="time", **self.version_kwargs + ) + + def test_append_with_existing_encoding_raises(self) -> None: + ds, ds_to_append, _ = create_append_test_data() + with self.create_zarr_target() as store_target: + ds.to_zarr(store_target, mode="w", **self.version_kwargs) + with pytest.raises(ValueError, match="but encoding was provided"): + ds_to_append.to_zarr( + store_target, + append_dim="time", + encoding={"da": {"compressor": None}}, + **self.version_kwargs, + ) + + @pytest.mark.parametrize("dtype", ["U", "S"]) + def test_append_string_length_mismatch_raises(self, dtype) -> None: + ds, ds_to_append = create_append_string_length_mismatch_test_data(dtype) + with self.create_zarr_target() as store_target: + ds.to_zarr(store_target, mode="w", **self.version_kwargs) + with pytest.raises(ValueError, match="Mismatched dtypes for variable"): + ds_to_append.to_zarr( + store_target, append_dim="time", **self.version_kwargs + ) + + def test_check_encoding_is_consistent_after_append(self) -> None: + ds, ds_to_append, _ = create_append_test_data() + + # check encoding consistency + with self.create_zarr_target() as store_target: + import zarr + + compressor = zarr.Blosc() + encoding = {"da": {"compressor": compressor}} + ds.to_zarr(store_target, mode="w", encoding=encoding, **self.version_kwargs) + ds_to_append.to_zarr(store_target, append_dim="time", **self.version_kwargs) + actual_ds = xr.open_dataset( + store_target, engine="zarr", **self.version_kwargs + ) + actual_encoding = actual_ds["da"].encoding["compressor"] + assert actual_encoding.get_config() == compressor.get_config() + assert_identical( + xr.open_dataset( + store_target, engine="zarr", **self.version_kwargs + ).compute(), + xr.concat([ds, ds_to_append], dim="time"), + ) + + def test_append_with_new_variable(self) -> None: + ds, ds_to_append, ds_with_new_var = create_append_test_data() + + # check append mode for new variable + with self.create_zarr_target() as store_target: + xr.concat([ds, ds_to_append], dim="time").to_zarr( + store_target, mode="w", **self.version_kwargs + ) + ds_with_new_var.to_zarr(store_target, mode="a", **self.version_kwargs) + combined = xr.concat([ds, ds_to_append], dim="time") + combined["new_var"] = ds_with_new_var["new_var"] + assert_identical( + combined, + xr.open_dataset(store_target, engine="zarr", **self.version_kwargs), + ) + + def test_append_with_append_dim_no_overwrite(self) -> None: + ds, ds_to_append, _ = create_append_test_data() + with self.create_zarr_target() as store_target: + ds.to_zarr(store_target, mode="w", **self.version_kwargs) + original = xr.concat([ds, ds_to_append], dim="time") + original2 = xr.concat([original, ds_to_append], dim="time") + + # overwrite a coordinate; + # for mode='a-', this will not get written to the store + # because it does not have the append_dim as a dim + lon = ds_to_append.lon.to_numpy().copy() + lon[:] = -999 + ds_to_append["lon"] = lon + ds_to_append.to_zarr( + store_target, mode="a-", append_dim="time", **self.version_kwargs + ) + actual = xr.open_dataset(store_target, engine="zarr", **self.version_kwargs) + assert_identical(original, actual) + + # by default, mode="a" will overwrite all coordinates. + ds_to_append.to_zarr(store_target, append_dim="time", **self.version_kwargs) + actual = xr.open_dataset(store_target, engine="zarr", **self.version_kwargs) + lon = original2.lon.to_numpy().copy() + lon[:] = -999 + original2["lon"] = lon + assert_identical(original2, actual) + + @requires_dask + def test_to_zarr_compute_false_roundtrip(self) -> None: + from dask.delayed import Delayed + + original = create_test_data().chunk() + + with self.create_zarr_target() as store: + delayed_obj = self.save(original, store, compute=False) + assert isinstance(delayed_obj, Delayed) + + # make sure target store has not been written to yet + with pytest.raises(AssertionError): + with self.open(store) as actual: + assert_identical(original, actual) + + delayed_obj.compute() + + with self.open(store) as actual: + assert_identical(original, actual) + + @requires_dask + def test_to_zarr_append_compute_false_roundtrip(self) -> None: + from dask.delayed import Delayed + + ds, ds_to_append, _ = create_append_test_data() + ds, ds_to_append = ds.chunk(), ds_to_append.chunk() + + with pytest.warns(SerializationWarning): + with self.create_zarr_target() as store: + delayed_obj = self.save(ds, store, compute=False, mode="w") + assert isinstance(delayed_obj, Delayed) + + with pytest.raises(AssertionError): + with self.open(store) as actual: + assert_identical(ds, actual) + + delayed_obj.compute() + + with self.open(store) as actual: + assert_identical(ds, actual) + + delayed_obj = self.save( + ds_to_append, store, compute=False, append_dim="time" + ) + assert isinstance(delayed_obj, Delayed) + + with pytest.raises(AssertionError): + with self.open(store) as actual: + assert_identical( + xr.concat([ds, ds_to_append], dim="time"), actual + ) + + delayed_obj.compute() + + with self.open(store) as actual: + assert_identical(xr.concat([ds, ds_to_append], dim="time"), actual) + + @pytest.mark.parametrize("chunk", [False, True]) + def test_save_emptydim(self, chunk) -> None: + if chunk and not has_dask: + pytest.skip("requires dask") + ds = Dataset({"x": (("a", "b"), np.empty((5, 0))), "y": ("a", [1, 2, 5, 8, 9])}) + if chunk: + ds = ds.chunk({}) # chunk dataset to save dask array + with self.roundtrip(ds) as ds_reload: + assert_identical(ds, ds_reload) + + @requires_dask + def test_no_warning_from_open_emptydim_with_chunks(self) -> None: + ds = Dataset({"x": (("a", "b"), np.empty((5, 0)))}).chunk({"a": 1}) + with assert_no_warnings(): + with self.roundtrip(ds, open_kwargs=dict(chunks={"a": 1})) as ds_reload: + assert_identical(ds, ds_reload) + + @pytest.mark.parametrize("consolidated", [False, True, None]) + @pytest.mark.parametrize("compute", [False, True]) + @pytest.mark.parametrize("use_dask", [False, True]) + @pytest.mark.parametrize("write_empty", [False, True, None]) + def test_write_region(self, consolidated, compute, use_dask, write_empty) -> None: + if (use_dask or not compute) and not has_dask: + pytest.skip("requires dask") + if consolidated and self.zarr_version > 2: + pytest.xfail("consolidated metadata is not supported for zarr v3 yet") + + zeros = Dataset({"u": (("x",), np.zeros(10))}) + nonzeros = Dataset({"u": (("x",), np.arange(1, 11))}) + + if use_dask: + zeros = zeros.chunk(2) + nonzeros = nonzeros.chunk(2) + + with self.create_zarr_target() as store: + zeros.to_zarr( + store, + consolidated=consolidated, + compute=compute, + encoding={"u": dict(chunks=2)}, + **self.version_kwargs, + ) + if compute: + with xr.open_zarr( + store, consolidated=consolidated, **self.version_kwargs + ) as actual: + assert_identical(actual, zeros) + for i in range(0, 10, 2): + region = {"x": slice(i, i + 2)} + nonzeros.isel(region).to_zarr( + store, + region=region, + consolidated=consolidated, + write_empty_chunks=write_empty, + **self.version_kwargs, + ) + with xr.open_zarr( + store, consolidated=consolidated, **self.version_kwargs + ) as actual: + assert_identical(actual, nonzeros) + + @pytest.mark.parametrize("mode", [None, "r+", "a"]) + def test_write_region_mode(self, mode) -> None: + zeros = Dataset({"u": (("x",), np.zeros(10))}) + nonzeros = Dataset({"u": (("x",), np.arange(1, 11))}) + with self.create_zarr_target() as store: + zeros.to_zarr(store, **self.version_kwargs) + for region in [{"x": slice(5)}, {"x": slice(5, 10)}]: + nonzeros.isel(region).to_zarr( + store, region=region, mode=mode, **self.version_kwargs + ) + with xr.open_zarr(store, **self.version_kwargs) as actual: + assert_identical(actual, nonzeros) + + @requires_dask + def test_write_preexisting_override_metadata(self) -> None: + """Metadata should be overridden if mode="a" but not in mode="r+".""" + original = Dataset( + {"u": (("x",), np.zeros(10), {"variable": "original"})}, + attrs={"global": "original"}, + ) + both_modified = Dataset( + {"u": (("x",), np.ones(10), {"variable": "modified"})}, + attrs={"global": "modified"}, + ) + global_modified = Dataset( + {"u": (("x",), np.ones(10), {"variable": "original"})}, + attrs={"global": "modified"}, + ) + only_new_data = Dataset( + {"u": (("x",), np.ones(10), {"variable": "original"})}, + attrs={"global": "original"}, + ) + + with self.create_zarr_target() as store: + original.to_zarr(store, compute=False, **self.version_kwargs) + both_modified.to_zarr(store, mode="a", **self.version_kwargs) + with self.open(store) as actual: + # NOTE: this arguably incorrect -- we should probably be + # overriding the variable metadata, too. See the TODO note in + # ZarrStore.set_variables. + assert_identical(actual, global_modified) + + with self.create_zarr_target() as store: + original.to_zarr(store, compute=False, **self.version_kwargs) + both_modified.to_zarr(store, mode="r+", **self.version_kwargs) + with self.open(store) as actual: + assert_identical(actual, only_new_data) + + with self.create_zarr_target() as store: + original.to_zarr(store, compute=False, **self.version_kwargs) + # with region, the default mode becomes r+ + both_modified.to_zarr( + store, region={"x": slice(None)}, **self.version_kwargs + ) + with self.open(store) as actual: + assert_identical(actual, only_new_data) + + def test_write_region_errors(self) -> None: + data = Dataset({"u": (("x",), np.arange(5))}) + data2 = Dataset({"u": (("x",), np.array([10, 11]))}) + + @contextlib.contextmanager + def setup_and_verify_store(expected=data): + with self.create_zarr_target() as store: + data.to_zarr(store, **self.version_kwargs) + yield store + with self.open(store) as actual: + assert_identical(actual, expected) + + # verify the base case works + expected = Dataset({"u": (("x",), np.array([10, 11, 2, 3, 4]))}) + with setup_and_verify_store(expected) as store: + data2.to_zarr(store, region={"x": slice(2)}, **self.version_kwargs) + + with setup_and_verify_store() as store: + with pytest.raises( + ValueError, + match=re.escape( + "cannot set region unless mode='a', mode='a-', mode='r+' or mode=None" + ), + ): + data.to_zarr( + store, region={"x": slice(None)}, mode="w", **self.version_kwargs + ) + + with setup_and_verify_store() as store: + with pytest.raises(TypeError, match=r"must be a dict"): + data.to_zarr(store, region=slice(None), **self.version_kwargs) # type: ignore[call-overload] + + with setup_and_verify_store() as store: + with pytest.raises(TypeError, match=r"must be slice objects"): + data2.to_zarr(store, region={"x": [0, 1]}, **self.version_kwargs) # type: ignore[dict-item] + + with setup_and_verify_store() as store: + with pytest.raises(ValueError, match=r"step on all slices"): + data2.to_zarr( + store, region={"x": slice(None, None, 2)}, **self.version_kwargs + ) + + with setup_and_verify_store() as store: + with pytest.raises( + ValueError, + match=r"all keys in ``region`` are not in Dataset dimensions", + ): + data.to_zarr(store, region={"y": slice(None)}, **self.version_kwargs) + + with setup_and_verify_store() as store: + with pytest.raises( + ValueError, + match=r"all variables in the dataset to write must have at least one dimension in common", + ): + data2.assign(v=2).to_zarr( + store, region={"x": slice(2)}, **self.version_kwargs + ) + + with setup_and_verify_store() as store: + with pytest.raises( + ValueError, match=r"cannot list the same dimension in both" + ): + data.to_zarr( + store, + region={"x": slice(None)}, + append_dim="x", + **self.version_kwargs, + ) + + with setup_and_verify_store() as store: + with pytest.raises( + ValueError, + match=r"variable 'u' already exists with different dimension sizes", + ): + data2.to_zarr(store, region={"x": slice(3)}, **self.version_kwargs) + + @requires_dask + def test_encoding_chunksizes(self) -> None: + # regression test for GH2278 + # see also test_encoding_chunksizes_unlimited + nx, ny, nt = 4, 4, 5 + original = xr.Dataset( + {}, coords={"x": np.arange(nx), "y": np.arange(ny), "t": np.arange(nt)} + ) + original["v"] = xr.Variable(("x", "y", "t"), np.zeros((nx, ny, nt))) + original = original.chunk({"t": 1, "x": 2, "y": 2}) + + with self.roundtrip(original) as ds1: + assert_equal(ds1, original) + with self.roundtrip(ds1.isel(t=0)) as ds2: + assert_equal(ds2, original.isel(t=0)) + + @requires_dask + def test_chunk_encoding_with_partial_dask_chunks(self) -> None: + original = xr.Dataset( + {"x": xr.DataArray(np.random.random(size=(6, 8)), dims=("a", "b"))} + ).chunk({"a": 3}) + + with self.roundtrip( + original, save_kwargs={"encoding": {"x": {"chunks": [3, 2]}}} + ) as ds1: + assert_equal(ds1, original) + + @requires_dask + def test_chunk_encoding_with_larger_dask_chunks(self) -> None: + original = xr.Dataset({"a": ("x", [1, 2, 3, 4])}).chunk({"x": 2}) + + with self.roundtrip( + original, save_kwargs={"encoding": {"a": {"chunks": [1]}}} + ) as ds1: + assert_equal(ds1, original) + + @requires_cftime + def test_open_zarr_use_cftime(self) -> None: + ds = create_test_data() + with self.create_zarr_target() as store_target: + ds.to_zarr(store_target, **self.version_kwargs) + ds_a = xr.open_zarr(store_target, **self.version_kwargs) + assert_identical(ds, ds_a) + ds_b = xr.open_zarr(store_target, use_cftime=True, **self.version_kwargs) + assert xr.coding.times.contains_cftime_datetimes(ds_b.time.variable) + + def test_write_read_select_write(self) -> None: + # Test for https://github.com/pydata/xarray/issues/4084 + ds = create_test_data() + + # NOTE: using self.roundtrip, which uses open_dataset, will not trigger the bug. + with self.create_zarr_target() as initial_store: + ds.to_zarr(initial_store, mode="w", **self.version_kwargs) + ds1 = xr.open_zarr(initial_store, **self.version_kwargs) + + # Combination of where+squeeze triggers error on write. + ds_sel = ds1.where(ds1.coords["dim3"] == "a", drop=True).squeeze("dim3") + with self.create_zarr_target() as final_store: + ds_sel.to_zarr(final_store, mode="w", **self.version_kwargs) + + @pytest.mark.parametrize("obj", [Dataset(), DataArray(name="foo")]) + def test_attributes(self, obj) -> None: + obj = obj.copy() + + obj.attrs["good"] = {"key": "value"} + ds = obj if isinstance(obj, Dataset) else obj.to_dataset() + with self.create_zarr_target() as store_target: + ds.to_zarr(store_target, **self.version_kwargs) + assert_identical(ds, xr.open_zarr(store_target, **self.version_kwargs)) + + obj.attrs["bad"] = DataArray() + ds = obj if isinstance(obj, Dataset) else obj.to_dataset() + with self.create_zarr_target() as store_target: + with pytest.raises(TypeError, match=r"Invalid attribute in Dataset.attrs."): + ds.to_zarr(store_target, **self.version_kwargs) + + @requires_dask + @pytest.mark.parametrize("dtype", ["datetime64[ns]", "timedelta64[ns]"]) + def test_chunked_datetime64_or_timedelta64(self, dtype) -> None: + # Generalized from @malmans2's test in PR #8253 + original = create_test_data().astype(dtype).chunk(1) + with self.roundtrip(original, open_kwargs={"chunks": {}}) as actual: + for name, actual_var in actual.variables.items(): + assert original[name].chunks == actual_var.chunks + assert original.chunks == actual.chunks + + @requires_cftime + @requires_dask + def test_chunked_cftime_datetime(self) -> None: + # Based on @malmans2's test in PR #8253 + times = cftime_range("2000", freq="D", periods=3) + original = xr.Dataset(data_vars={"chunked_times": (["time"], times)}) + original = original.chunk({"time": 1}) + with self.roundtrip(original, open_kwargs={"chunks": {}}) as actual: + for name, actual_var in actual.variables.items(): + assert original[name].chunks == actual_var.chunks + assert original.chunks == actual.chunks + + +@requires_zarr +@pytest.mark.skipif(not have_zarr_v3, reason="requires zarr version 3") +class TestInstrumentedZarrStore: + methods = [ + "__iter__", + "__contains__", + "__setitem__", + "__getitem__", + "listdir", + "list_prefix", + ] + + @contextlib.contextmanager + def create_zarr_target(self): + import zarr + + if Version(zarr.__version__) < Version("2.18.0"): + pytest.skip("Instrumented tests only work on latest Zarr.") + + store = KVStoreV3({}) + yield store + + def make_patches(self, store): + from unittest.mock import MagicMock + + return { + method: MagicMock( + f"KVStoreV3.{method}", + side_effect=getattr(store, method), + autospec=True, + ) + for method in self.methods + } + + def summarize(self, patches): + summary = {} + for name, patch_ in patches.items(): + count = 0 + for call in patch_.mock_calls: + if "zarr.json" not in call.args: + count += 1 + summary[name.strip("__")] = count + return summary + + def check_requests(self, expected, patches): + summary = self.summarize(patches) + for k in summary: + assert summary[k] <= expected[k], (k, summary) + + def test_append(self) -> None: + original = Dataset({"foo": ("x", [1])}, coords={"x": [0]}) + modified = Dataset({"foo": ("x", [2])}, coords={"x": [1]}) + with self.create_zarr_target() as store: + expected = { + "iter": 2, + "contains": 9, + "setitem": 9, + "getitem": 6, + "listdir": 2, + "list_prefix": 2, + } + patches = self.make_patches(store) + with patch.multiple(KVStoreV3, **patches): + original.to_zarr(store) + self.check_requests(expected, patches) + + patches = self.make_patches(store) + # v2024.03.0: {'iter': 6, 'contains': 2, 'setitem': 5, 'getitem': 10, 'listdir': 6, 'list_prefix': 0} + # 6057128b: {'iter': 5, 'contains': 2, 'setitem': 5, 'getitem': 10, "listdir": 5, "list_prefix": 0} + expected = { + "iter": 2, + "contains": 2, + "setitem": 5, + "getitem": 6, + "listdir": 2, + "list_prefix": 0, + } + with patch.multiple(KVStoreV3, **patches): + modified.to_zarr(store, mode="a", append_dim="x") + self.check_requests(expected, patches) + + patches = self.make_patches(store) + expected = { + "iter": 2, + "contains": 2, + "setitem": 5, + "getitem": 6, + "listdir": 2, + "list_prefix": 0, + } + with patch.multiple(KVStoreV3, **patches): + modified.to_zarr(store, mode="a-", append_dim="x") + self.check_requests(expected, patches) + + with open_dataset(store, engine="zarr") as actual: + assert_identical( + actual, xr.concat([original, modified, modified], dim="x") + ) + + @requires_dask + def test_region_write(self) -> None: + ds = Dataset({"foo": ("x", [1, 2, 3])}, coords={"x": [0, 1, 2]}).chunk() + with self.create_zarr_target() as store: + expected = { + "iter": 2, + "contains": 7, + "setitem": 8, + "getitem": 6, + "listdir": 2, + "list_prefix": 4, + } + patches = self.make_patches(store) + with patch.multiple(KVStoreV3, **patches): + ds.to_zarr(store, mode="w", compute=False) + self.check_requests(expected, patches) + + # v2024.03.0: {'iter': 5, 'contains': 2, 'setitem': 1, 'getitem': 6, 'listdir': 5, 'list_prefix': 0} + # 6057128b: {'iter': 4, 'contains': 2, 'setitem': 1, 'getitem': 5, 'listdir': 4, 'list_prefix': 0} + expected = { + "iter": 2, + "contains": 2, + "setitem": 1, + "getitem": 3, + "listdir": 2, + "list_prefix": 0, + } + patches = self.make_patches(store) + with patch.multiple(KVStoreV3, **patches): + ds.to_zarr(store, region={"x": slice(None)}) + self.check_requests(expected, patches) + + # v2024.03.0: {'iter': 6, 'contains': 4, 'setitem': 1, 'getitem': 11, 'listdir': 6, 'list_prefix': 0} + # 6057128b: {'iter': 4, 'contains': 2, 'setitem': 1, 'getitem': 7, 'listdir': 4, 'list_prefix': 0} + expected = { + "iter": 2, + "contains": 2, + "setitem": 1, + "getitem": 5, + "listdir": 2, + "list_prefix": 0, + } + patches = self.make_patches(store) + with patch.multiple(KVStoreV3, **patches): + ds.to_zarr(store, region="auto") + self.check_requests(expected, patches) + + expected = { + "iter": 1, + "contains": 2, + "setitem": 0, + "getitem": 5, + "listdir": 1, + "list_prefix": 0, + } + patches = self.make_patches(store) + with patch.multiple(KVStoreV3, **patches): + with open_dataset(store, engine="zarr") as actual: + assert_identical(actual, ds) + self.check_requests(expected, patches) + + +@requires_zarr +class TestZarrDictStore(ZarrBase): + @contextlib.contextmanager + def create_zarr_target(self): + if have_zarr_kvstore: + yield KVStore({}) + else: + yield {} + + +@requires_zarr +@pytest.mark.skipif( + ON_WINDOWS, + reason="Very flaky on Windows CI. Can re-enable assuming it starts consistently passing.", +) +class TestZarrDirectoryStore(ZarrBase): + @contextlib.contextmanager + def create_zarr_target(self): + with create_tmp_file(suffix=".zarr") as tmp: + yield tmp + + @contextlib.contextmanager + def create_store(self): + with self.create_zarr_target() as store_target: + group = backends.ZarrStore.open_group(store_target, mode="w") + # older Zarr versions do not have the _store_version attribute + if have_zarr_v3: + # verify that a v2 store was created + assert group.zarr_group.store._store_version == 2 + yield group + + +@requires_zarr +class TestZarrWriteEmpty(TestZarrDirectoryStore): + @contextlib.contextmanager + def temp_dir(self) -> Iterator[tuple[str, str]]: + with tempfile.TemporaryDirectory() as d: + store = os.path.join(d, "test.zarr") + yield d, store + + @contextlib.contextmanager + def roundtrip_dir( + self, + data, + store, + save_kwargs=None, + open_kwargs=None, + allow_cleanup_failure=False, + ) -> Iterator[Dataset]: + if save_kwargs is None: + save_kwargs = {} + if open_kwargs is None: + open_kwargs = {} + + data.to_zarr(store, **save_kwargs, **self.version_kwargs) + with xr.open_dataset( + store, engine="zarr", **open_kwargs, **self.version_kwargs + ) as ds: + yield ds + + @pytest.mark.parametrize("consolidated", [True, False, None]) + @pytest.mark.parametrize("write_empty", [True, False, None]) + def test_write_empty( + self, consolidated: bool | None, write_empty: bool | None + ) -> None: + if write_empty is False: + expected = ["0.1.0", "1.1.0"] + else: + expected = [ + "0.0.0", + "0.0.1", + "0.1.0", + "0.1.1", + "1.0.0", + "1.0.1", + "1.1.0", + "1.1.1", + ] + + ds = xr.Dataset( + data_vars={ + "test": ( + ("Z", "Y", "X"), + np.array([np.nan, np.nan, 1.0, np.nan]).reshape((1, 2, 2)), + ) + } + ) + + if has_dask: + ds["test"] = ds["test"].chunk(1) + encoding = None + else: + encoding = {"test": {"chunks": (1, 1, 1)}} + + with self.temp_dir() as (d, store): + ds.to_zarr( + store, + mode="w", + encoding=encoding, + write_empty_chunks=write_empty, + ) + + with self.roundtrip_dir( + ds, + store, + {"mode": "a", "append_dim": "Z", "write_empty_chunks": write_empty}, + ) as a_ds: + expected_ds = xr.concat([ds, ds], dim="Z") + + assert_identical(a_ds, expected_ds) + + ls = listdir(os.path.join(store, "test")) + assert set(expected) == set([file for file in ls if file[0] != "."]) + + def test_avoid_excess_metadata_calls(self) -> None: + """Test that chunk requests do not trigger redundant metadata requests. + + This test targets logic in backends.zarr.ZarrArrayWrapper, asserting that calls + to retrieve chunk data after initialization do not trigger additional + metadata requests. + + https://github.com/pydata/xarray/issues/8290 + """ + + import zarr + + ds = xr.Dataset(data_vars={"test": (("Z",), np.array([123]).reshape(1))}) + + # The call to retrieve metadata performs a group lookup. We patch Group.__getitem__ + # so that we can inspect calls to this method - specifically count of calls. + # Use of side_effect means that calls are passed through to the original method + # rather than a mocked method. + Group = zarr.hierarchy.Group + with ( + self.create_zarr_target() as store, + patch.object( + Group, "__getitem__", side_effect=Group.__getitem__, autospec=True + ) as mock, + ): + ds.to_zarr(store, mode="w") + + # We expect this to request array metadata information, so call_count should be == 1, + xrds = xr.open_zarr(store) + call_count = mock.call_count + assert call_count == 1 + + # compute() requests array data, which should not trigger additional metadata requests + # we assert that the number of calls has not increased after fetchhing the array + xrds.test.compute(scheduler="sync") + assert mock.call_count == call_count + + +class ZarrBaseV3(ZarrBase): + zarr_version = 3 + + def test_roundtrip_coordinates_with_space(self): + original = Dataset(coords={"x": 0, "y z": 1}) + with pytest.warns(SerializationWarning): + # v3 stores do not allow spaces in the key name + with pytest.raises(ValueError): + with self.roundtrip(original): + pass + + +@pytest.mark.skipif(not have_zarr_v3, reason="requires zarr version 3") +class TestZarrKVStoreV3(ZarrBaseV3): + @contextlib.contextmanager + def create_zarr_target(self): + yield KVStoreV3({}) + + +@pytest.mark.skipif(not have_zarr_v3, reason="requires zarr version 3") +class TestZarrDirectoryStoreV3(ZarrBaseV3): + @contextlib.contextmanager + def create_zarr_target(self): + with create_tmp_file(suffix=".zr3") as tmp: + yield DirectoryStoreV3(tmp) + + +@pytest.mark.skipif(not have_zarr_v3, reason="requires zarr version 3") +class TestZarrDirectoryStoreV3FromPath(TestZarrDirectoryStoreV3): + # Must specify zarr_version=3 to get a v3 store because create_zarr_target + # is a string path. + version_kwargs = {"zarr_version": 3} + + @contextlib.contextmanager + def create_zarr_target(self): + with create_tmp_file(suffix=".zr3") as tmp: + yield tmp + + +@requires_zarr +@requires_fsspec +def test_zarr_storage_options() -> None: + pytest.importorskip("aiobotocore") + ds = create_test_data() + store_target = "memory://test.zarr" + ds.to_zarr(store_target, storage_options={"test": "zarr_write"}) + ds_a = xr.open_zarr(store_target, storage_options={"test": "zarr_read"}) + assert_identical(ds, ds_a) + + +@requires_scipy +class TestScipyInMemoryData(CFEncodedBase, NetCDF3Only): + engine: T_NetcdfEngine = "scipy" + + @contextlib.contextmanager + def create_store(self): + fobj = BytesIO() + yield backends.ScipyDataStore(fobj, "w") + + def test_to_netcdf_explicit_engine(self) -> None: + # regression test for GH1321 + Dataset({"foo": 42}).to_netcdf(engine="scipy") + + def test_bytes_pickle(self) -> None: + data = Dataset({"foo": ("x", [1, 2, 3])}) + fobj = data.to_netcdf() + with self.open(fobj) as ds: + unpickled = pickle.loads(pickle.dumps(ds)) + assert_identical(unpickled, data) + + +@requires_scipy +class TestScipyFileObject(CFEncodedBase, NetCDF3Only): + engine: T_NetcdfEngine = "scipy" + + @contextlib.contextmanager + def create_store(self): + fobj = BytesIO() + yield backends.ScipyDataStore(fobj, "w") + + @contextlib.contextmanager + def roundtrip( + self, data, save_kwargs=None, open_kwargs=None, allow_cleanup_failure=False + ): + if save_kwargs is None: + save_kwargs = {} + if open_kwargs is None: + open_kwargs = {} + with create_tmp_file() as tmp_file: + with open(tmp_file, "wb") as f: + self.save(data, f, **save_kwargs) + with open(tmp_file, "rb") as f: + with self.open(f, **open_kwargs) as ds: + yield ds + + @pytest.mark.skip(reason="cannot pickle file objects") + def test_pickle(self) -> None: + pass + + @pytest.mark.skip(reason="cannot pickle file objects") + def test_pickle_dataarray(self) -> None: + pass + + +@requires_scipy +class TestScipyFilePath(CFEncodedBase, NetCDF3Only): + engine: T_NetcdfEngine = "scipy" + + @contextlib.contextmanager + def create_store(self): + with create_tmp_file() as tmp_file: + with backends.ScipyDataStore(tmp_file, mode="w") as store: + yield store + + def test_array_attrs(self) -> None: + ds = Dataset(attrs={"foo": [[1, 2], [3, 4]]}) + with pytest.raises(ValueError, match=r"must be 1-dimensional"): + with self.roundtrip(ds): + pass + + def test_roundtrip_example_1_netcdf_gz(self) -> None: + with open_example_dataset("example_1.nc.gz") as expected: + with open_example_dataset("example_1.nc") as actual: + assert_identical(expected, actual) + + def test_netcdf3_endianness(self) -> None: + # regression test for GH416 + with open_example_dataset("bears.nc", engine="scipy") as expected: + for var in expected.variables.values(): + assert var.dtype.isnative + + @requires_netCDF4 + def test_nc4_scipy(self) -> None: + with create_tmp_file(allow_cleanup_failure=True) as tmp_file: + with nc4.Dataset(tmp_file, "w", format="NETCDF4") as rootgrp: + rootgrp.createGroup("foo") + + with pytest.raises(TypeError, match=r"pip install netcdf4"): + open_dataset(tmp_file, engine="scipy") + + +@requires_netCDF4 +class TestNetCDF3ViaNetCDF4Data(CFEncodedBase, NetCDF3Only): + engine: T_NetcdfEngine = "netcdf4" + file_format: T_NetcdfTypes = "NETCDF3_CLASSIC" + + @contextlib.contextmanager + def create_store(self): + with create_tmp_file() as tmp_file: + with backends.NetCDF4DataStore.open( + tmp_file, mode="w", format="NETCDF3_CLASSIC" + ) as store: + yield store + + def test_encoding_kwarg_vlen_string(self) -> None: + original = Dataset({"x": ["foo", "bar", "baz"]}) + kwargs = dict(encoding={"x": {"dtype": str}}) + with pytest.raises(ValueError, match=r"encoding dtype=str for vlen"): + with self.roundtrip(original, save_kwargs=kwargs): + pass + + +@requires_netCDF4 +class TestNetCDF4ClassicViaNetCDF4Data(CFEncodedBase, NetCDF3Only): + engine: T_NetcdfEngine = "netcdf4" + file_format: T_NetcdfTypes = "NETCDF4_CLASSIC" + + @contextlib.contextmanager + def create_store(self): + with create_tmp_file() as tmp_file: + with backends.NetCDF4DataStore.open( + tmp_file, mode="w", format="NETCDF4_CLASSIC" + ) as store: + yield store + + +@requires_scipy_or_netCDF4 +class TestGenericNetCDFData(CFEncodedBase, NetCDF3Only): + # verify that we can read and write netCDF3 files as long as we have scipy + # or netCDF4-python installed + file_format: T_NetcdfTypes = "NETCDF3_64BIT" + + def test_write_store(self) -> None: + # there's no specific store to test here + pass + + @requires_scipy + def test_engine(self) -> None: + data = create_test_data() + with pytest.raises(ValueError, match=r"unrecognized engine"): + data.to_netcdf("foo.nc", engine="foobar") # type: ignore[call-overload] + with pytest.raises(ValueError, match=r"invalid engine"): + data.to_netcdf(engine="netcdf4") + + with create_tmp_file() as tmp_file: + data.to_netcdf(tmp_file) + with pytest.raises(ValueError, match=r"unrecognized engine"): + open_dataset(tmp_file, engine="foobar") + + netcdf_bytes = data.to_netcdf() + with pytest.raises(ValueError, match=r"unrecognized engine"): + open_dataset(BytesIO(netcdf_bytes), engine="foobar") + + def test_cross_engine_read_write_netcdf3(self) -> None: + data = create_test_data() + valid_engines: set[T_NetcdfEngine] = set() + if has_netCDF4: + valid_engines.add("netcdf4") + if has_scipy: + valid_engines.add("scipy") + + for write_engine in valid_engines: + for format in self.netcdf3_formats: + with create_tmp_file() as tmp_file: + data.to_netcdf(tmp_file, format=format, engine=write_engine) + for read_engine in valid_engines: + with open_dataset(tmp_file, engine=read_engine) as actual: + # hack to allow test to work: + # coord comes back as DataArray rather than coord, + # and so need to loop through here rather than in + # the test function (or we get recursion) + [ + assert_allclose(data[k].variable, actual[k].variable) + for k in data.variables + ] + + def test_encoding_unlimited_dims(self) -> None: + ds = Dataset({"x": ("y", np.arange(10.0))}) + with self.roundtrip(ds, save_kwargs=dict(unlimited_dims=["y"])) as actual: + assert actual.encoding["unlimited_dims"] == set("y") + assert_equal(ds, actual) + + # Regression test for https://github.com/pydata/xarray/issues/2134 + with self.roundtrip(ds, save_kwargs=dict(unlimited_dims="y")) as actual: + assert actual.encoding["unlimited_dims"] == set("y") + assert_equal(ds, actual) + + ds.encoding = {"unlimited_dims": ["y"]} + with self.roundtrip(ds) as actual: + assert actual.encoding["unlimited_dims"] == set("y") + assert_equal(ds, actual) + + # Regression test for https://github.com/pydata/xarray/issues/2134 + ds.encoding = {"unlimited_dims": "y"} + with self.roundtrip(ds) as actual: + assert actual.encoding["unlimited_dims"] == set("y") + assert_equal(ds, actual) + + +@requires_h5netcdf +@requires_netCDF4 +@pytest.mark.filterwarnings("ignore:use make_scale(name) instead") +class TestH5NetCDFData(NetCDF4Base): + engine: T_NetcdfEngine = "h5netcdf" + + @contextlib.contextmanager + def create_store(self): + with create_tmp_file() as tmp_file: + yield backends.H5NetCDFStore.open(tmp_file, "w") + + def test_complex(self) -> None: + expected = Dataset({"x": ("y", np.ones(5) + 1j * np.ones(5))}) + save_kwargs = {"invalid_netcdf": True} + with pytest.warns(UserWarning, match="You are writing invalid netcdf features"): + with self.roundtrip(expected, save_kwargs=save_kwargs) as actual: + assert_equal(expected, actual) + + @pytest.mark.parametrize("invalid_netcdf", [None, False]) + def test_complex_error(self, invalid_netcdf) -> None: + import h5netcdf + + expected = Dataset({"x": ("y", np.ones(5) + 1j * np.ones(5))}) + save_kwargs = {"invalid_netcdf": invalid_netcdf} + with pytest.raises( + h5netcdf.CompatibilityError, match="are not a supported NetCDF feature" + ): + with self.roundtrip(expected, save_kwargs=save_kwargs) as actual: + assert_equal(expected, actual) + + def test_numpy_bool_(self) -> None: + # h5netcdf loads booleans as numpy.bool_, this type needs to be supported + # when writing invalid_netcdf datasets in order to support a roundtrip + expected = Dataset({"x": ("y", np.ones(5), {"numpy_bool": np.bool_(True)})}) + save_kwargs = {"invalid_netcdf": True} + with pytest.warns(UserWarning, match="You are writing invalid netcdf features"): + with self.roundtrip(expected, save_kwargs=save_kwargs) as actual: + assert_identical(expected, actual) + + def test_cross_engine_read_write_netcdf4(self) -> None: + # Drop dim3, because its labels include strings. These appear to be + # not properly read with python-netCDF4, which converts them into + # unicode instead of leaving them as bytes. + data = create_test_data().drop_vars("dim3") + data.attrs["foo"] = "bar" + valid_engines: list[T_NetcdfEngine] = ["netcdf4", "h5netcdf"] + for write_engine in valid_engines: + with create_tmp_file() as tmp_file: + data.to_netcdf(tmp_file, engine=write_engine) + for read_engine in valid_engines: + with open_dataset(tmp_file, engine=read_engine) as actual: + assert_identical(data, actual) + + def test_read_byte_attrs_as_unicode(self) -> None: + with create_tmp_file() as tmp_file: + with nc4.Dataset(tmp_file, "w") as nc: + nc.foo = b"bar" + with open_dataset(tmp_file) as actual: + expected = Dataset(attrs={"foo": "bar"}) + assert_identical(expected, actual) + + def test_encoding_unlimited_dims(self) -> None: + ds = Dataset({"x": ("y", np.arange(10.0))}) + with self.roundtrip(ds, save_kwargs=dict(unlimited_dims=["y"])) as actual: + assert actual.encoding["unlimited_dims"] == set("y") + assert_equal(ds, actual) + ds.encoding = {"unlimited_dims": ["y"]} + with self.roundtrip(ds) as actual: + assert actual.encoding["unlimited_dims"] == set("y") + assert_equal(ds, actual) + + def test_compression_encoding_h5py(self) -> None: + ENCODINGS: tuple[tuple[dict[str, Any], dict[str, Any]], ...] = ( + # h5py style compression with gzip codec will be converted to + # NetCDF4-Python style on round-trip + ( + {"compression": "gzip", "compression_opts": 9}, + {"zlib": True, "complevel": 9}, + ), + # What can't be expressed in NetCDF4-Python style is + # round-tripped unaltered + ( + {"compression": "lzf", "compression_opts": None}, + {"compression": "lzf", "compression_opts": None}, + ), + # If both styles are used together, h5py format takes precedence + ( + { + "compression": "lzf", + "compression_opts": None, + "zlib": True, + "complevel": 9, + }, + {"compression": "lzf", "compression_opts": None}, + ), + ) + + for compr_in, compr_out in ENCODINGS: + data = create_test_data() + compr_common = { + "chunksizes": (5, 5), + "fletcher32": True, + "shuffle": True, + "original_shape": data.var2.shape, + } + data["var2"].encoding.update(compr_in) + data["var2"].encoding.update(compr_common) + compr_out.update(compr_common) + data["scalar"] = ("scalar_dim", np.array([2.0])) + data["scalar"] = data["scalar"][0] + with self.roundtrip(data) as actual: + for k, v in compr_out.items(): + assert v == actual["var2"].encoding[k] + + def test_compression_check_encoding_h5py(self) -> None: + """When mismatched h5py and NetCDF4-Python encodings are expressed + in to_netcdf(encoding=...), must raise ValueError + """ + data = Dataset({"x": ("y", np.arange(10.0))}) + # Compatible encodings are graciously supported + with create_tmp_file() as tmp_file: + data.to_netcdf( + tmp_file, + engine="h5netcdf", + encoding={ + "x": { + "compression": "gzip", + "zlib": True, + "compression_opts": 6, + "complevel": 6, + } + }, + ) + with open_dataset(tmp_file, engine="h5netcdf") as actual: + assert actual.x.encoding["zlib"] is True + assert actual.x.encoding["complevel"] == 6 + + # Incompatible encodings cause a crash + with create_tmp_file() as tmp_file: + with pytest.raises( + ValueError, match=r"'zlib' and 'compression' encodings mismatch" + ): + data.to_netcdf( + tmp_file, + engine="h5netcdf", + encoding={"x": {"compression": "lzf", "zlib": True}}, + ) + + with create_tmp_file() as tmp_file: + with pytest.raises( + ValueError, + match=r"'complevel' and 'compression_opts' encodings mismatch", + ): + data.to_netcdf( + tmp_file, + engine="h5netcdf", + encoding={ + "x": { + "compression": "gzip", + "compression_opts": 5, + "complevel": 6, + } + }, + ) + + def test_dump_encodings_h5py(self) -> None: + # regression test for #709 + ds = Dataset({"x": ("y", np.arange(10.0))}) + + kwargs = {"encoding": {"x": {"compression": "gzip", "compression_opts": 9}}} + with self.roundtrip(ds, save_kwargs=kwargs) as actual: + assert actual.x.encoding["zlib"] + assert actual.x.encoding["complevel"] == 9 + + kwargs = {"encoding": {"x": {"compression": "lzf", "compression_opts": None}}} + with self.roundtrip(ds, save_kwargs=kwargs) as actual: + assert actual.x.encoding["compression"] == "lzf" + assert actual.x.encoding["compression_opts"] is None + + def test_decode_utf8_warning(self) -> None: + title = b"\xc3" + with create_tmp_file() as tmp_file: + with nc4.Dataset(tmp_file, "w") as f: + f.title = title + with pytest.warns(UnicodeWarning, match="returning bytes undecoded") as w: + ds = xr.load_dataset(tmp_file, engine="h5netcdf") + assert ds.title == title + assert "attribute 'title' of h5netcdf object '/'" in str(w[0].message) + + +@requires_h5netcdf +@requires_netCDF4 +class TestH5NetCDFAlreadyOpen: + def test_open_dataset_group(self) -> None: + import h5netcdf + + with create_tmp_file() as tmp_file: + with nc4.Dataset(tmp_file, mode="w") as nc: + group = nc.createGroup("g") + v = group.createVariable("x", "int") + v[...] = 42 + + kwargs = {"decode_vlen_strings": True} + + h5 = h5netcdf.File(tmp_file, mode="r", **kwargs) + store = backends.H5NetCDFStore(h5["g"]) + with open_dataset(store) as ds: + expected = Dataset({"x": ((), 42)}) + assert_identical(expected, ds) + + h5 = h5netcdf.File(tmp_file, mode="r", **kwargs) + store = backends.H5NetCDFStore(h5, group="g") + with open_dataset(store) as ds: + expected = Dataset({"x": ((), 42)}) + assert_identical(expected, ds) + + def test_deepcopy(self) -> None: + import h5netcdf + + with create_tmp_file() as tmp_file: + with nc4.Dataset(tmp_file, mode="w") as nc: + nc.createDimension("x", 10) + v = nc.createVariable("y", np.int32, ("x",)) + v[:] = np.arange(10) + + kwargs = {"decode_vlen_strings": True} + + h5 = h5netcdf.File(tmp_file, mode="r", **kwargs) + store = backends.H5NetCDFStore(h5) + with open_dataset(store) as ds: + copied = ds.copy(deep=True) + expected = Dataset({"y": ("x", np.arange(10))}) + assert_identical(expected, copied) + + +@requires_h5netcdf +class TestH5NetCDFFileObject(TestH5NetCDFData): + engine: T_NetcdfEngine = "h5netcdf" + + def test_open_badbytes(self) -> None: + with pytest.raises(ValueError, match=r"HDF5 as bytes"): + with open_dataset(b"\211HDF\r\n\032\n", engine="h5netcdf"): # type: ignore[arg-type] + pass + with pytest.raises( + ValueError, match=r"match in any of xarray's currently installed IO" + ): + with open_dataset(b"garbage"): # type: ignore[arg-type] + pass + with pytest.raises(ValueError, match=r"can only read bytes"): + with open_dataset(b"garbage", engine="netcdf4"): # type: ignore[arg-type] + pass + with pytest.raises( + ValueError, match=r"not the signature of a valid netCDF4 file" + ): + with open_dataset(BytesIO(b"garbage"), engine="h5netcdf"): + pass + + def test_open_twice(self) -> None: + expected = create_test_data() + expected.attrs["foo"] = "bar" + with create_tmp_file() as tmp_file: + expected.to_netcdf(tmp_file, engine="h5netcdf") + with open(tmp_file, "rb") as f: + with open_dataset(f, engine="h5netcdf"): + with open_dataset(f, engine="h5netcdf"): + pass + + @requires_scipy + def test_open_fileobj(self) -> None: + # open in-memory datasets instead of local file paths + expected = create_test_data().drop_vars("dim3") + expected.attrs["foo"] = "bar" + with create_tmp_file() as tmp_file: + expected.to_netcdf(tmp_file, engine="h5netcdf") + + with open(tmp_file, "rb") as f: + with open_dataset(f, engine="h5netcdf") as actual: + assert_identical(expected, actual) + + f.seek(0) + with open_dataset(f) as actual: + assert_identical(expected, actual) + + f.seek(0) + with BytesIO(f.read()) as bio: + with open_dataset(bio, engine="h5netcdf") as actual: + assert_identical(expected, actual) + + f.seek(0) + with pytest.raises(TypeError, match="not a valid NetCDF 3"): + open_dataset(f, engine="scipy") + + # TODO: this additional open is required since scipy seems to close the file + # when it fails on the TypeError (though didn't when we used + # `raises_regex`?). Ref https://github.com/pydata/xarray/pull/5191 + with open(tmp_file, "rb") as f: + f.seek(8) + open_dataset(f) + + +@requires_h5netcdf +@requires_dask +@pytest.mark.filterwarnings("ignore:deallocating CachingFileManager") +class TestH5NetCDFViaDaskData(TestH5NetCDFData): + @contextlib.contextmanager + def roundtrip( + self, data, save_kwargs=None, open_kwargs=None, allow_cleanup_failure=False + ): + if save_kwargs is None: + save_kwargs = {} + if open_kwargs is None: + open_kwargs = {} + open_kwargs.setdefault("chunks", -1) + with TestH5NetCDFData.roundtrip( + self, data, save_kwargs, open_kwargs, allow_cleanup_failure + ) as ds: + yield ds + + @pytest.mark.skip(reason="caching behavior differs for dask") + def test_dataset_caching(self) -> None: + pass + + def test_write_inconsistent_chunks(self) -> None: + # Construct two variables with the same dimensions, but different + # chunk sizes. + x = da.zeros((100, 100), dtype="f4", chunks=(50, 100)) + x = DataArray(data=x, dims=("lat", "lon"), name="x") + x.encoding["chunksizes"] = (50, 100) + x.encoding["original_shape"] = (100, 100) + y = da.ones((100, 100), dtype="f4", chunks=(100, 50)) + y = DataArray(data=y, dims=("lat", "lon"), name="y") + y.encoding["chunksizes"] = (100, 50) + y.encoding["original_shape"] = (100, 100) + # Put them both into the same dataset + ds = Dataset({"x": x, "y": y}) + with self.roundtrip(ds) as actual: + assert actual["x"].encoding["chunksizes"] == (50, 100) + assert actual["y"].encoding["chunksizes"] == (100, 50) + + +@requires_h5netcdf_ros3 +class TestH5NetCDFDataRos3Driver(TestCommon): + engine: T_NetcdfEngine = "h5netcdf" + test_remote_dataset: str = ( + "https://www.unidata.ucar.edu/software/netcdf/examples/OMI-Aura_L2-example.nc" + ) + + @pytest.mark.filterwarnings("ignore:Duplicate dimension names") + def test_get_variable_list(self) -> None: + with open_dataset( + self.test_remote_dataset, + engine="h5netcdf", + backend_kwargs={"driver": "ros3"}, + ) as actual: + assert "Temperature" in list(actual) + + @pytest.mark.filterwarnings("ignore:Duplicate dimension names") + def test_get_variable_list_empty_driver_kwds(self) -> None: + driver_kwds = { + "secret_id": b"", + "secret_key": b"", + } + backend_kwargs = {"driver": "ros3", "driver_kwds": driver_kwds} + + with open_dataset( + self.test_remote_dataset, engine="h5netcdf", backend_kwargs=backend_kwargs + ) as actual: + assert "Temperature" in list(actual) + + +@pytest.fixture(params=["scipy", "netcdf4", "h5netcdf", "zarr"]) +def readengine(request): + return request.param + + +@pytest.fixture(params=[1, 20]) +def nfiles(request): + return request.param + + +@pytest.fixture(params=[5, None]) +def file_cache_maxsize(request): + maxsize = request.param + if maxsize is not None: + with set_options(file_cache_maxsize=maxsize): + yield maxsize + else: + yield maxsize + + +@pytest.fixture(params=[True, False]) +def parallel(request): + return request.param + + +@pytest.fixture(params=[None, 5]) +def chunks(request): + return request.param + + +@pytest.fixture(params=["tmp_path", "ZipStore", "Dict"]) +def tmp_store(request, tmp_path): + if request.param == "tmp_path": + return tmp_path + elif request.param == "ZipStore": + from zarr.storage import ZipStore + + path = tmp_path / "store.zip" + return ZipStore(path) + elif request.param == "Dict": + return dict() + else: + raise ValueError("not supported") + + +# using pytest.mark.skipif does not work so this a work around +def skip_if_not_engine(engine): + if engine == "netcdf4": + pytest.importorskip("netCDF4") + else: + pytest.importorskip(engine) + + +@requires_dask +@pytest.mark.filterwarnings("ignore:use make_scale(name) instead") +@pytest.mark.skip( + reason="Flaky test which can cause the worker to crash (so don't xfail). Very open to contributions fixing this" +) +def test_open_mfdataset_manyfiles( + readengine, nfiles, parallel, chunks, file_cache_maxsize +): + # skip certain combinations + skip_if_not_engine(readengine) + + randdata = np.random.randn(nfiles) + original = Dataset({"foo": ("x", randdata)}) + # test standard open_mfdataset approach with too many files + with create_tmp_files(nfiles) as tmpfiles: + # split into multiple sets of temp files + for ii in original.x.values: + subds = original.isel(x=slice(ii, ii + 1)) + if readengine != "zarr": + subds.to_netcdf(tmpfiles[ii], engine=readengine) + else: # if writeengine == "zarr": + subds.to_zarr(store=tmpfiles[ii]) + + # check that calculation on opened datasets works properly + with open_mfdataset( + tmpfiles, + combine="nested", + concat_dim="x", + engine=readengine, + parallel=parallel, + chunks=chunks if (not chunks and readengine != "zarr") else "auto", + ) as actual: + # check that using open_mfdataset returns dask arrays for variables + assert isinstance(actual["foo"].data, dask_array_type) + + assert_identical(original, actual) + + +@requires_netCDF4 +@requires_dask +def test_open_mfdataset_can_open_path_objects() -> None: + dataset = os.path.join(os.path.dirname(__file__), "data", "example_1.nc") + with open_mfdataset(Path(dataset)) as actual: + assert isinstance(actual, Dataset) + + +@requires_netCDF4 +@requires_dask +def test_open_mfdataset_list_attr() -> None: + """ + Case when an attribute of type list differs across the multiple files + """ + from netCDF4 import Dataset + + with create_tmp_files(2) as nfiles: + for i in range(2): + with Dataset(nfiles[i], "w") as f: + f.createDimension("x", 3) + vlvar = f.createVariable("test_var", np.int32, ("x")) + # here create an attribute as a list + vlvar.test_attr = [f"string a {i}", f"string b {i}"] + vlvar[:] = np.arange(3) + + with open_dataset(nfiles[0]) as ds1: + with open_dataset(nfiles[1]) as ds2: + original = xr.concat([ds1, ds2], dim="x") + with xr.open_mfdataset( + [nfiles[0], nfiles[1]], combine="nested", concat_dim="x" + ) as actual: + assert_identical(actual, original) + + +@requires_scipy_or_netCDF4 +@requires_dask +class TestOpenMFDatasetWithDataVarsAndCoordsKw: + coord_name = "lon" + var_name = "v1" + + @contextlib.contextmanager + def setup_files_and_datasets(self, fuzz=0): + ds1, ds2 = self.gen_datasets_with_common_coord_and_time() + + # to test join='exact' + ds1["x"] = ds1.x + fuzz + + with create_tmp_file() as tmpfile1: + with create_tmp_file() as tmpfile2: + # save data to the temporary files + ds1.to_netcdf(tmpfile1) + ds2.to_netcdf(tmpfile2) + + yield [tmpfile1, tmpfile2], [ds1, ds2] + + def gen_datasets_with_common_coord_and_time(self): + # create coordinate data + nx = 10 + nt = 10 + x = np.arange(nx) + t1 = np.arange(nt) + t2 = np.arange(nt, 2 * nt, 1) + + v1 = np.random.randn(nt, nx) + v2 = np.random.randn(nt, nx) + + ds1 = Dataset( + data_vars={self.var_name: (["t", "x"], v1), self.coord_name: ("x", 2 * x)}, + coords={"t": (["t"], t1), "x": (["x"], x)}, + ) + + ds2 = Dataset( + data_vars={self.var_name: (["t", "x"], v2), self.coord_name: ("x", 2 * x)}, + coords={"t": (["t"], t2), "x": (["x"], x)}, + ) + + return ds1, ds2 + + @pytest.mark.parametrize( + "combine, concat_dim", [("nested", "t"), ("by_coords", None)] + ) + @pytest.mark.parametrize("opt", ["all", "minimal", "different"]) + @pytest.mark.parametrize("join", ["outer", "inner", "left", "right"]) + def test_open_mfdataset_does_same_as_concat( + self, combine, concat_dim, opt, join + ) -> None: + with self.setup_files_and_datasets() as (files, [ds1, ds2]): + if combine == "by_coords": + files.reverse() + with open_mfdataset( + files, data_vars=opt, combine=combine, concat_dim=concat_dim, join=join + ) as ds: + ds_expect = xr.concat([ds1, ds2], data_vars=opt, dim="t", join=join) + assert_identical(ds, ds_expect) + + @pytest.mark.parametrize( + ["combine_attrs", "attrs", "expected", "expect_error"], + ( + pytest.param("drop", [{"a": 1}, {"a": 2}], {}, False, id="drop"), + pytest.param( + "override", [{"a": 1}, {"a": 2}], {"a": 1}, False, id="override" + ), + pytest.param( + "no_conflicts", [{"a": 1}, {"a": 2}], None, True, id="no_conflicts" + ), + pytest.param( + "identical", + [{"a": 1, "b": 2}, {"a": 1, "c": 3}], + None, + True, + id="identical", + ), + pytest.param( + "drop_conflicts", + [{"a": 1, "b": 2}, {"b": -1, "c": 3}], + {"a": 1, "c": 3}, + False, + id="drop_conflicts", + ), + ), + ) + def test_open_mfdataset_dataset_combine_attrs( + self, combine_attrs, attrs, expected, expect_error + ): + with self.setup_files_and_datasets() as (files, [ds1, ds2]): + # Give the files an inconsistent attribute + for i, f in enumerate(files): + ds = open_dataset(f).load() + ds.attrs = attrs[i] + ds.close() + ds.to_netcdf(f) + + if expect_error: + with pytest.raises(xr.MergeError): + xr.open_mfdataset( + files, + combine="nested", + concat_dim="t", + combine_attrs=combine_attrs, + ) + else: + with xr.open_mfdataset( + files, + combine="nested", + concat_dim="t", + combine_attrs=combine_attrs, + ) as ds: + assert ds.attrs == expected + + def test_open_mfdataset_dataset_attr_by_coords(self) -> None: + """ + Case when an attribute differs across the multiple files + """ + with self.setup_files_and_datasets() as (files, [ds1, ds2]): + # Give the files an inconsistent attribute + for i, f in enumerate(files): + ds = open_dataset(f).load() + ds.attrs["test_dataset_attr"] = 10 + i + ds.close() + ds.to_netcdf(f) + + with xr.open_mfdataset(files, combine="nested", concat_dim="t") as ds: + assert ds.test_dataset_attr == 10 + + def test_open_mfdataset_dataarray_attr_by_coords(self) -> None: + """ + Case when an attribute of a member DataArray differs across the multiple files + """ + with self.setup_files_and_datasets() as (files, [ds1, ds2]): + # Give the files an inconsistent attribute + for i, f in enumerate(files): + ds = open_dataset(f).load() + ds["v1"].attrs["test_dataarray_attr"] = i + ds.close() + ds.to_netcdf(f) + + with xr.open_mfdataset(files, combine="nested", concat_dim="t") as ds: + assert ds["v1"].test_dataarray_attr == 0 + + @pytest.mark.parametrize( + "combine, concat_dim", [("nested", "t"), ("by_coords", None)] + ) + @pytest.mark.parametrize("opt", ["all", "minimal", "different"]) + def test_open_mfdataset_exact_join_raises_error( + self, combine, concat_dim, opt + ) -> None: + with self.setup_files_and_datasets(fuzz=0.1) as (files, [ds1, ds2]): + if combine == "by_coords": + files.reverse() + with pytest.raises( + ValueError, match=r"cannot align objects.*join.*exact.*" + ): + open_mfdataset( + files, + data_vars=opt, + combine=combine, + concat_dim=concat_dim, + join="exact", + ) + + def test_common_coord_when_datavars_all(self) -> None: + opt: Final = "all" + + with self.setup_files_and_datasets() as (files, [ds1, ds2]): + # open the files with the data_var option + with open_mfdataset( + files, data_vars=opt, combine="nested", concat_dim="t" + ) as ds: + coord_shape = ds[self.coord_name].shape + coord_shape1 = ds1[self.coord_name].shape + coord_shape2 = ds2[self.coord_name].shape + + var_shape = ds[self.var_name].shape + + assert var_shape == coord_shape + assert coord_shape1 != coord_shape + assert coord_shape2 != coord_shape + + def test_common_coord_when_datavars_minimal(self) -> None: + opt: Final = "minimal" + + with self.setup_files_and_datasets() as (files, [ds1, ds2]): + # open the files using data_vars option + with open_mfdataset( + files, data_vars=opt, combine="nested", concat_dim="t" + ) as ds: + coord_shape = ds[self.coord_name].shape + coord_shape1 = ds1[self.coord_name].shape + coord_shape2 = ds2[self.coord_name].shape + + var_shape = ds[self.var_name].shape + + assert var_shape != coord_shape + assert coord_shape1 == coord_shape + assert coord_shape2 == coord_shape + + def test_invalid_data_vars_value_should_fail(self) -> None: + with self.setup_files_and_datasets() as (files, _): + with pytest.raises(ValueError): + with open_mfdataset(files, data_vars="minimum", combine="by_coords"): # type: ignore[arg-type] + pass + + # test invalid coord parameter + with pytest.raises(ValueError): + with open_mfdataset(files, coords="minimum", combine="by_coords"): + pass + + +@requires_dask +@requires_scipy +@requires_netCDF4 +class TestDask(DatasetIOBase): + @contextlib.contextmanager + def create_store(self): + yield Dataset() + + @contextlib.contextmanager + def roundtrip( + self, data, save_kwargs=None, open_kwargs=None, allow_cleanup_failure=False + ): + yield data.chunk() + + # Override methods in DatasetIOBase - not applicable to dask + def test_roundtrip_string_encoded_characters(self) -> None: + pass + + def test_roundtrip_coordinates_with_space(self) -> None: + pass + + def test_roundtrip_numpy_datetime_data(self) -> None: + # Override method in DatasetIOBase - remove not applicable + # save_kwargs + times = pd.to_datetime(["2000-01-01", "2000-01-02", "NaT"]) + expected = Dataset({"t": ("t", times), "t0": times[0]}) + with self.roundtrip(expected) as actual: + assert_identical(expected, actual) + + def test_roundtrip_cftime_datetime_data(self) -> None: + # Override method in DatasetIOBase - remove not applicable + # save_kwargs + from xarray.tests.test_coding_times import _all_cftime_date_types + + date_types = _all_cftime_date_types() + for date_type in date_types.values(): + times = [date_type(1, 1, 1), date_type(1, 1, 2)] + expected = Dataset({"t": ("t", times), "t0": times[0]}) + expected_decoded_t = np.array(times) + expected_decoded_t0 = np.array([date_type(1, 1, 1)]) + + with self.roundtrip(expected) as actual: + abs_diff = abs(actual.t.values - expected_decoded_t) + assert (abs_diff <= np.timedelta64(1, "s")).all() + + abs_diff = abs(actual.t0.values - expected_decoded_t0) + assert (abs_diff <= np.timedelta64(1, "s")).all() + + def test_write_store(self) -> None: + # Override method in DatasetIOBase - not applicable to dask + pass + + def test_dataset_caching(self) -> None: + expected = Dataset({"foo": ("x", [5, 6, 7])}) + with self.roundtrip(expected) as actual: + assert not actual.foo.variable._in_memory + actual.foo.values # no caching + assert not actual.foo.variable._in_memory + + def test_open_mfdataset(self) -> None: + original = Dataset({"foo": ("x", np.random.randn(10))}) + with create_tmp_file() as tmp1: + with create_tmp_file() as tmp2: + original.isel(x=slice(5)).to_netcdf(tmp1) + original.isel(x=slice(5, 10)).to_netcdf(tmp2) + with open_mfdataset( + [tmp1, tmp2], concat_dim="x", combine="nested" + ) as actual: + assert isinstance(actual.foo.variable.data, da.Array) + assert actual.foo.variable.data.chunks == ((5, 5),) + assert_identical(original, actual) + with open_mfdataset( + [tmp1, tmp2], concat_dim="x", combine="nested", chunks={"x": 3} + ) as actual: + assert actual.foo.variable.data.chunks == ((3, 2, 3, 2),) + + with pytest.raises(OSError, match=r"no files to open"): + open_mfdataset("foo-bar-baz-*.nc") + with pytest.raises(ValueError, match=r"wild-card"): + open_mfdataset("http://some/remote/uri") + + @requires_fsspec + def test_open_mfdataset_no_files(self) -> None: + pytest.importorskip("aiobotocore") + + # glob is attempted as of #4823, but finds no files + with pytest.raises(OSError, match=r"no files"): + open_mfdataset("http://some/remote/uri", engine="zarr") + + def test_open_mfdataset_2d(self) -> None: + original = Dataset({"foo": (["x", "y"], np.random.randn(10, 8))}) + with create_tmp_file() as tmp1: + with create_tmp_file() as tmp2: + with create_tmp_file() as tmp3: + with create_tmp_file() as tmp4: + original.isel(x=slice(5), y=slice(4)).to_netcdf(tmp1) + original.isel(x=slice(5, 10), y=slice(4)).to_netcdf(tmp2) + original.isel(x=slice(5), y=slice(4, 8)).to_netcdf(tmp3) + original.isel(x=slice(5, 10), y=slice(4, 8)).to_netcdf(tmp4) + with open_mfdataset( + [[tmp1, tmp2], [tmp3, tmp4]], + combine="nested", + concat_dim=["y", "x"], + ) as actual: + assert isinstance(actual.foo.variable.data, da.Array) + assert actual.foo.variable.data.chunks == ((5, 5), (4, 4)) + assert_identical(original, actual) + with open_mfdataset( + [[tmp1, tmp2], [tmp3, tmp4]], + combine="nested", + concat_dim=["y", "x"], + chunks={"x": 3, "y": 2}, + ) as actual: + assert actual.foo.variable.data.chunks == ( + (3, 2, 3, 2), + (2, 2, 2, 2), + ) + + def test_open_mfdataset_pathlib(self) -> None: + original = Dataset({"foo": ("x", np.random.randn(10))}) + with create_tmp_file() as tmps1: + with create_tmp_file() as tmps2: + tmp1 = Path(tmps1) + tmp2 = Path(tmps2) + original.isel(x=slice(5)).to_netcdf(tmp1) + original.isel(x=slice(5, 10)).to_netcdf(tmp2) + with open_mfdataset( + [tmp1, tmp2], concat_dim="x", combine="nested" + ) as actual: + assert_identical(original, actual) + + def test_open_mfdataset_2d_pathlib(self) -> None: + original = Dataset({"foo": (["x", "y"], np.random.randn(10, 8))}) + with create_tmp_file() as tmps1: + with create_tmp_file() as tmps2: + with create_tmp_file() as tmps3: + with create_tmp_file() as tmps4: + tmp1 = Path(tmps1) + tmp2 = Path(tmps2) + tmp3 = Path(tmps3) + tmp4 = Path(tmps4) + original.isel(x=slice(5), y=slice(4)).to_netcdf(tmp1) + original.isel(x=slice(5, 10), y=slice(4)).to_netcdf(tmp2) + original.isel(x=slice(5), y=slice(4, 8)).to_netcdf(tmp3) + original.isel(x=slice(5, 10), y=slice(4, 8)).to_netcdf(tmp4) + with open_mfdataset( + [[tmp1, tmp2], [tmp3, tmp4]], + combine="nested", + concat_dim=["y", "x"], + ) as actual: + assert_identical(original, actual) + + def test_open_mfdataset_2(self) -> None: + original = Dataset({"foo": ("x", np.random.randn(10))}) + with create_tmp_file() as tmp1: + with create_tmp_file() as tmp2: + original.isel(x=slice(5)).to_netcdf(tmp1) + original.isel(x=slice(5, 10)).to_netcdf(tmp2) + + with open_mfdataset( + [tmp1, tmp2], concat_dim="x", combine="nested" + ) as actual: + assert_identical(original, actual) + + def test_attrs_mfdataset(self) -> None: + original = Dataset({"foo": ("x", np.random.randn(10))}) + with create_tmp_file() as tmp1: + with create_tmp_file() as tmp2: + ds1 = original.isel(x=slice(5)) + ds2 = original.isel(x=slice(5, 10)) + ds1.attrs["test1"] = "foo" + ds2.attrs["test2"] = "bar" + ds1.to_netcdf(tmp1) + ds2.to_netcdf(tmp2) + with open_mfdataset( + [tmp1, tmp2], concat_dim="x", combine="nested" + ) as actual: + # presumes that attributes inherited from + # first dataset loaded + assert actual.test1 == ds1.test1 + # attributes from ds2 are not retained, e.g., + with pytest.raises(AttributeError, match=r"no attribute"): + actual.test2 + + def test_open_mfdataset_attrs_file(self) -> None: + original = Dataset({"foo": ("x", np.random.randn(10))}) + with create_tmp_files(2) as (tmp1, tmp2): + ds1 = original.isel(x=slice(5)) + ds2 = original.isel(x=slice(5, 10)) + ds1.attrs["test1"] = "foo" + ds2.attrs["test2"] = "bar" + ds1.to_netcdf(tmp1) + ds2.to_netcdf(tmp2) + with open_mfdataset( + [tmp1, tmp2], concat_dim="x", combine="nested", attrs_file=tmp2 + ) as actual: + # attributes are inherited from the master file + assert actual.attrs["test2"] == ds2.attrs["test2"] + # attributes from ds1 are not retained, e.g., + assert "test1" not in actual.attrs + + def test_open_mfdataset_attrs_file_path(self) -> None: + original = Dataset({"foo": ("x", np.random.randn(10))}) + with create_tmp_files(2) as (tmps1, tmps2): + tmp1 = Path(tmps1) + tmp2 = Path(tmps2) + ds1 = original.isel(x=slice(5)) + ds2 = original.isel(x=slice(5, 10)) + ds1.attrs["test1"] = "foo" + ds2.attrs["test2"] = "bar" + ds1.to_netcdf(tmp1) + ds2.to_netcdf(tmp2) + with open_mfdataset( + [tmp1, tmp2], concat_dim="x", combine="nested", attrs_file=tmp2 + ) as actual: + # attributes are inherited from the master file + assert actual.attrs["test2"] == ds2.attrs["test2"] + # attributes from ds1 are not retained, e.g., + assert "test1" not in actual.attrs + + def test_open_mfdataset_auto_combine(self) -> None: + original = Dataset({"foo": ("x", np.random.randn(10)), "x": np.arange(10)}) + with create_tmp_file() as tmp1: + with create_tmp_file() as tmp2: + original.isel(x=slice(5)).to_netcdf(tmp1) + original.isel(x=slice(5, 10)).to_netcdf(tmp2) + + with open_mfdataset([tmp2, tmp1], combine="by_coords") as actual: + assert_identical(original, actual) + + def test_open_mfdataset_raise_on_bad_combine_args(self) -> None: + # Regression test for unhelpful error shown in #5230 + original = Dataset({"foo": ("x", np.random.randn(10)), "x": np.arange(10)}) + with create_tmp_file() as tmp1: + with create_tmp_file() as tmp2: + original.isel(x=slice(5)).to_netcdf(tmp1) + original.isel(x=slice(5, 10)).to_netcdf(tmp2) + with pytest.raises(ValueError, match="`concat_dim` has no effect"): + open_mfdataset([tmp1, tmp2], concat_dim="x") + + def test_encoding_mfdataset(self) -> None: + original = Dataset( + { + "foo": ("t", np.random.randn(10)), + "t": ("t", pd.date_range(start="2010-01-01", periods=10, freq="1D")), + } + ) + original.t.encoding["units"] = "days since 2010-01-01" + + with create_tmp_file() as tmp1: + with create_tmp_file() as tmp2: + ds1 = original.isel(t=slice(5)) + ds2 = original.isel(t=slice(5, 10)) + ds1.t.encoding["units"] = "days since 2010-01-01" + ds2.t.encoding["units"] = "days since 2000-01-01" + ds1.to_netcdf(tmp1) + ds2.to_netcdf(tmp2) + with open_mfdataset([tmp1, tmp2], combine="nested") as actual: + assert actual.t.encoding["units"] == original.t.encoding["units"] + assert actual.t.encoding["units"] == ds1.t.encoding["units"] + assert actual.t.encoding["units"] != ds2.t.encoding["units"] + + def test_preprocess_mfdataset(self) -> None: + original = Dataset({"foo": ("x", np.random.randn(10))}) + with create_tmp_file() as tmp: + original.to_netcdf(tmp) + + def preprocess(ds): + return ds.assign_coords(z=0) + + expected = preprocess(original) + with open_mfdataset( + tmp, preprocess=preprocess, combine="by_coords" + ) as actual: + assert_identical(expected, actual) + + def test_save_mfdataset_roundtrip(self) -> None: + original = Dataset({"foo": ("x", np.random.randn(10))}) + datasets = [original.isel(x=slice(5)), original.isel(x=slice(5, 10))] + with create_tmp_file() as tmp1: + with create_tmp_file() as tmp2: + save_mfdataset(datasets, [tmp1, tmp2]) + with open_mfdataset( + [tmp1, tmp2], concat_dim="x", combine="nested" + ) as actual: + assert_identical(actual, original) + + def test_save_mfdataset_invalid(self) -> None: + ds = Dataset() + with pytest.raises(ValueError, match=r"cannot use mode"): + save_mfdataset([ds, ds], ["same", "same"]) + with pytest.raises(ValueError, match=r"same length"): + save_mfdataset([ds, ds], ["only one path"]) + + def test_save_mfdataset_invalid_dataarray(self) -> None: + # regression test for GH1555 + da = DataArray([1, 2]) + with pytest.raises(TypeError, match=r"supports writing Dataset"): + save_mfdataset([da], ["dataarray"]) + + def test_save_mfdataset_pathlib_roundtrip(self) -> None: + original = Dataset({"foo": ("x", np.random.randn(10))}) + datasets = [original.isel(x=slice(5)), original.isel(x=slice(5, 10))] + with create_tmp_file() as tmps1: + with create_tmp_file() as tmps2: + tmp1 = Path(tmps1) + tmp2 = Path(tmps2) + save_mfdataset(datasets, [tmp1, tmp2]) + with open_mfdataset( + [tmp1, tmp2], concat_dim="x", combine="nested" + ) as actual: + assert_identical(actual, original) + + def test_save_mfdataset_pass_kwargs(self) -> None: + # create a timeseries to store in a netCDF file + times = [0, 1] + time = xr.DataArray(times, dims=("time",)) + + # create a simple dataset to write using save_mfdataset + test_ds = xr.Dataset() + test_ds["time"] = time + + # make sure the times are written as double and + # turn off fill values + encoding = dict(time=dict(dtype="double")) + unlimited_dims = ["time"] + + # set the output file name + output_path = "test.nc" + + # attempt to write the dataset with the encoding and unlimited args + # passed through + xr.save_mfdataset( + [test_ds], [output_path], encoding=encoding, unlimited_dims=unlimited_dims + ) + + def test_open_and_do_math(self) -> None: + original = Dataset({"foo": ("x", np.random.randn(10))}) + with create_tmp_file() as tmp: + original.to_netcdf(tmp) + with open_mfdataset(tmp, combine="by_coords") as ds: + actual = 1.0 * ds + assert_allclose(original, actual, decode_bytes=False) + + def test_open_mfdataset_concat_dim_none(self) -> None: + with create_tmp_file() as tmp1: + with create_tmp_file() as tmp2: + data = Dataset({"x": 0}) + data.to_netcdf(tmp1) + Dataset({"x": np.nan}).to_netcdf(tmp2) + with open_mfdataset( + [tmp1, tmp2], concat_dim=None, combine="nested" + ) as actual: + assert_identical(data, actual) + + def test_open_mfdataset_concat_dim_default_none(self) -> None: + with create_tmp_file() as tmp1: + with create_tmp_file() as tmp2: + data = Dataset({"x": 0}) + data.to_netcdf(tmp1) + Dataset({"x": np.nan}).to_netcdf(tmp2) + with open_mfdataset([tmp1, tmp2], combine="nested") as actual: + assert_identical(data, actual) + + def test_open_dataset(self) -> None: + original = Dataset({"foo": ("x", np.random.randn(10))}) + with create_tmp_file() as tmp: + original.to_netcdf(tmp) + with open_dataset(tmp, chunks={"x": 5}) as actual: + assert isinstance(actual.foo.variable.data, da.Array) + assert actual.foo.variable.data.chunks == ((5, 5),) + assert_identical(original, actual) + with open_dataset(tmp, chunks=5) as actual: + assert_identical(original, actual) + with open_dataset(tmp) as actual: + assert isinstance(actual.foo.variable.data, np.ndarray) + assert_identical(original, actual) + + def test_open_single_dataset(self) -> None: + # Test for issue GH #1988. This makes sure that the + # concat_dim is utilized when specified in open_mfdataset(). + rnddata = np.random.randn(10) + original = Dataset({"foo": ("x", rnddata)}) + dim = DataArray([100], name="baz", dims="baz") + expected = Dataset( + {"foo": (("baz", "x"), rnddata[np.newaxis, :])}, {"baz": [100]} + ) + with create_tmp_file() as tmp: + original.to_netcdf(tmp) + with open_mfdataset([tmp], concat_dim=dim, combine="nested") as actual: + assert_identical(expected, actual) + + def test_open_multi_dataset(self) -> None: + # Test for issue GH #1988 and #2647. This makes sure that the + # concat_dim is utilized when specified in open_mfdataset(). + # The additional wrinkle is to ensure that a length greater + # than one is tested as well due to numpy's implicit casting + # of 1-length arrays to booleans in tests, which allowed + # #2647 to still pass the test_open_single_dataset(), + # which is itself still needed as-is because the original + # bug caused one-length arrays to not be used correctly + # in concatenation. + rnddata = np.random.randn(10) + original = Dataset({"foo": ("x", rnddata)}) + dim = DataArray([100, 150], name="baz", dims="baz") + expected = Dataset( + {"foo": (("baz", "x"), np.tile(rnddata[np.newaxis, :], (2, 1)))}, + {"baz": [100, 150]}, + ) + with create_tmp_file() as tmp1, create_tmp_file() as tmp2: + original.to_netcdf(tmp1) + original.to_netcdf(tmp2) + with open_mfdataset( + [tmp1, tmp2], concat_dim=dim, combine="nested" + ) as actual: + assert_identical(expected, actual) + + # Flaky test. Very open to contributions on fixing this + @pytest.mark.flaky + def test_dask_roundtrip(self) -> None: + with create_tmp_file() as tmp: + data = create_test_data() + data.to_netcdf(tmp) + chunks = {"dim1": 4, "dim2": 4, "dim3": 4, "time": 10} + with open_dataset(tmp, chunks=chunks) as dask_ds: + assert_identical(data, dask_ds) + with create_tmp_file() as tmp2: + dask_ds.to_netcdf(tmp2) + with open_dataset(tmp2) as on_disk: + assert_identical(data, on_disk) + + def test_deterministic_names(self) -> None: + with create_tmp_file() as tmp: + data = create_test_data() + data.to_netcdf(tmp) + with open_mfdataset(tmp, combine="by_coords") as ds: + original_names = {k: v.data.name for k, v in ds.data_vars.items()} + with open_mfdataset(tmp, combine="by_coords") as ds: + repeat_names = {k: v.data.name for k, v in ds.data_vars.items()} + for var_name, dask_name in original_names.items(): + assert var_name in dask_name + assert dask_name[:13] == "open_dataset-" + assert original_names == repeat_names + + def test_dataarray_compute(self) -> None: + # Test DataArray.compute() on dask backend. + # The test for Dataset.compute() is already in DatasetIOBase; + # however dask is the only tested backend which supports DataArrays + actual = DataArray([1, 2]).chunk() + computed = actual.compute() + assert not actual._in_memory + assert computed._in_memory + assert_allclose(actual, computed, decode_bytes=False) + + def test_save_mfdataset_compute_false_roundtrip(self) -> None: + from dask.delayed import Delayed + + original = Dataset({"foo": ("x", np.random.randn(10))}).chunk() + datasets = [original.isel(x=slice(5)), original.isel(x=slice(5, 10))] + with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as tmp1: + with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as tmp2: + delayed_obj = save_mfdataset( + datasets, [tmp1, tmp2], engine=self.engine, compute=False + ) + assert isinstance(delayed_obj, Delayed) + delayed_obj.compute() + with open_mfdataset( + [tmp1, tmp2], combine="nested", concat_dim="x" + ) as actual: + assert_identical(actual, original) + + def test_load_dataset(self) -> None: + with create_tmp_file() as tmp: + original = Dataset({"foo": ("x", np.random.randn(10))}) + original.to_netcdf(tmp) + ds = load_dataset(tmp) + # this would fail if we used open_dataset instead of load_dataset + ds.to_netcdf(tmp) + + def test_load_dataarray(self) -> None: + with create_tmp_file() as tmp: + original = Dataset({"foo": ("x", np.random.randn(10))}) + original.to_netcdf(tmp) + ds = load_dataarray(tmp) + # this would fail if we used open_dataarray instead of + # load_dataarray + ds.to_netcdf(tmp) + + @pytest.mark.skipif( + ON_WINDOWS, + reason="counting number of tasks in graph fails on windows for some reason", + ) + def test_inline_array(self) -> None: + with create_tmp_file() as tmp: + original = Dataset({"foo": ("x", np.random.randn(10))}) + original.to_netcdf(tmp) + chunks = {"time": 10} + + def num_graph_nodes(obj): + return len(obj.__dask_graph__()) + + with ( + open_dataset(tmp, inline_array=False, chunks=chunks) as not_inlined_ds, + open_dataset(tmp, inline_array=True, chunks=chunks) as inlined_ds, + ): + assert num_graph_nodes(inlined_ds) < num_graph_nodes(not_inlined_ds) + + with ( + open_dataarray( + tmp, inline_array=False, chunks=chunks + ) as not_inlined_da, + open_dataarray(tmp, inline_array=True, chunks=chunks) as inlined_da, + ): + assert num_graph_nodes(inlined_da) < num_graph_nodes(not_inlined_da) + + +@requires_scipy_or_netCDF4 +@requires_pydap +@pytest.mark.filterwarnings("ignore:The binary mode of fromstring is deprecated") +class TestPydap: + def convert_to_pydap_dataset(self, original): + from pydap.model import BaseType, DatasetType, GridType + + ds = DatasetType("bears", **original.attrs) + for key, var in original.data_vars.items(): + v = GridType(key) + v[key] = BaseType(key, var.values, dimensions=var.dims, **var.attrs) + for d in var.dims: + v[d] = BaseType(d, var[d].values) + ds[key] = v + # check all dims are stored in ds + for d in original.coords: + ds[d] = BaseType( + d, original[d].values, dimensions=(d,), **original[d].attrs + ) + return ds + + @contextlib.contextmanager + def create_datasets(self, **kwargs): + with open_example_dataset("bears.nc") as expected: + pydap_ds = self.convert_to_pydap_dataset(expected) + actual = open_dataset(PydapDataStore(pydap_ds)) + # TODO solve this workaround: + # netcdf converts string to byte not unicode + expected["bears"] = expected["bears"].astype(str) + yield actual, expected + + def test_cmp_local_file(self) -> None: + with self.create_datasets() as (actual, expected): + assert_equal(actual, expected) + + # global attributes should be global attributes on the dataset + assert "NC_GLOBAL" not in actual.attrs + assert "history" in actual.attrs + + # we don't check attributes exactly with assertDatasetIdentical() + # because the test DAP server seems to insert some extra + # attributes not found in the netCDF file. + assert actual.attrs.keys() == expected.attrs.keys() + + with self.create_datasets() as (actual, expected): + assert_equal(actual[{"l": 2}], expected[{"l": 2}]) + + with self.create_datasets() as (actual, expected): + assert_equal(actual.isel(i=0, j=-1), expected.isel(i=0, j=-1)) + + with self.create_datasets() as (actual, expected): + assert_equal(actual.isel(j=slice(1, 2)), expected.isel(j=slice(1, 2))) + + with self.create_datasets() as (actual, expected): + indexers = {"i": [1, 0, 0], "j": [1, 2, 0, 1]} + assert_equal(actual.isel(**indexers), expected.isel(**indexers)) + + with self.create_datasets() as (actual, expected): + indexers2 = { + "i": DataArray([0, 1, 0], dims="a"), + "j": DataArray([0, 2, 1], dims="a"), + } + assert_equal(actual.isel(**indexers2), expected.isel(**indexers2)) + + def test_compatible_to_netcdf(self) -> None: + # make sure it can be saved as a netcdf + with self.create_datasets() as (actual, expected): + with create_tmp_file() as tmp_file: + actual.to_netcdf(tmp_file) + with open_dataset(tmp_file) as actual2: + actual2["bears"] = actual2["bears"].astype(str) + assert_equal(actual2, expected) + + @requires_dask + def test_dask(self) -> None: + with self.create_datasets(chunks={"j": 2}) as (actual, expected): + assert_equal(actual, expected) + + +@network +@requires_scipy_or_netCDF4 +@requires_pydap +class TestPydapOnline(TestPydap): + @contextlib.contextmanager + def create_datasets(self, **kwargs): + url = "http://test.opendap.org/opendap/hyrax/data/nc/bears.nc" + actual = open_dataset(url, engine="pydap", **kwargs) + with open_example_dataset("bears.nc") as expected: + # workaround to restore string which is converted to byte + expected["bears"] = expected["bears"].astype(str) + yield actual, expected + + def test_session(self) -> None: + from pydap.cas.urs import setup_session + + session = setup_session("XarrayTestUser", "Xarray2017") + with mock.patch("pydap.client.open_url") as mock_func: + xr.backends.PydapDataStore.open("http://test.url", session=session) + mock_func.assert_called_with( + url="http://test.url", + application=None, + session=session, + output_grid=True, + timeout=120, + ) + + +class TestEncodingInvalid: + def test_extract_nc4_variable_encoding(self) -> None: + var = xr.Variable(("x",), [1, 2, 3], {}, {"foo": "bar"}) + with pytest.raises(ValueError, match=r"unexpected encoding"): + _extract_nc4_variable_encoding(var, raise_on_invalid=True) + + var = xr.Variable(("x",), [1, 2, 3], {}, {"chunking": (2, 1)}) + encoding = _extract_nc4_variable_encoding(var) + assert {} == encoding + + # regression test + var = xr.Variable(("x",), [1, 2, 3], {}, {"shuffle": True}) + encoding = _extract_nc4_variable_encoding(var, raise_on_invalid=True) + assert {"shuffle": True} == encoding + + # Variables with unlim dims must be chunked on output. + var = xr.Variable(("x",), [1, 2, 3], {}, {"contiguous": True}) + encoding = _extract_nc4_variable_encoding(var, unlimited_dims=("x",)) + assert {} == encoding + + @requires_netCDF4 + def test_extract_nc4_variable_encoding_netcdf4(self): + # New netCDF4 1.6.0 compression argument. + var = xr.Variable(("x",), [1, 2, 3], {}, {"compression": "szlib"}) + _extract_nc4_variable_encoding(var, backend="netCDF4", raise_on_invalid=True) + + def test_extract_h5nc_encoding(self) -> None: + # not supported with h5netcdf (yet) + var = xr.Variable(("x",), [1, 2, 3], {}, {"least_sigificant_digit": 2}) + with pytest.raises(ValueError, match=r"unexpected encoding"): + _extract_nc4_variable_encoding(var, raise_on_invalid=True) + + +class MiscObject: + pass + + +@requires_netCDF4 +class TestValidateAttrs: + def test_validating_attrs(self) -> None: + def new_dataset(): + return Dataset({"data": ("y", np.arange(10.0))}, {"y": np.arange(10)}) + + def new_dataset_and_dataset_attrs(): + ds = new_dataset() + return ds, ds.attrs + + def new_dataset_and_data_attrs(): + ds = new_dataset() + return ds, ds.data.attrs + + def new_dataset_and_coord_attrs(): + ds = new_dataset() + return ds, ds.coords["y"].attrs + + for new_dataset_and_attrs in [ + new_dataset_and_dataset_attrs, + new_dataset_and_data_attrs, + new_dataset_and_coord_attrs, + ]: + ds, attrs = new_dataset_and_attrs() + + attrs[123] = "test" + with pytest.raises(TypeError, match=r"Invalid name for attr: 123"): + ds.to_netcdf("test.nc") + + ds, attrs = new_dataset_and_attrs() + attrs[MiscObject()] = "test" + with pytest.raises(TypeError, match=r"Invalid name for attr: "): + ds.to_netcdf("test.nc") + + ds, attrs = new_dataset_and_attrs() + attrs[""] = "test" + with pytest.raises(ValueError, match=r"Invalid name for attr '':"): + ds.to_netcdf("test.nc") + + # This one should work + ds, attrs = new_dataset_and_attrs() + attrs["test"] = "test" + with create_tmp_file() as tmp_file: + ds.to_netcdf(tmp_file) + + ds, attrs = new_dataset_and_attrs() + attrs["test"] = {"a": 5} + with pytest.raises(TypeError, match=r"Invalid value for attr 'test'"): + ds.to_netcdf("test.nc") + + ds, attrs = new_dataset_and_attrs() + attrs["test"] = MiscObject() + with pytest.raises(TypeError, match=r"Invalid value for attr 'test'"): + ds.to_netcdf("test.nc") + + ds, attrs = new_dataset_and_attrs() + attrs["test"] = 5 + with create_tmp_file() as tmp_file: + ds.to_netcdf(tmp_file) + + ds, attrs = new_dataset_and_attrs() + attrs["test"] = 3.14 + with create_tmp_file() as tmp_file: + ds.to_netcdf(tmp_file) + + ds, attrs = new_dataset_and_attrs() + attrs["test"] = [1, 2, 3, 4] + with create_tmp_file() as tmp_file: + ds.to_netcdf(tmp_file) + + ds, attrs = new_dataset_and_attrs() + attrs["test"] = (1.9, 2.5) + with create_tmp_file() as tmp_file: + ds.to_netcdf(tmp_file) + + ds, attrs = new_dataset_and_attrs() + attrs["test"] = np.arange(5) + with create_tmp_file() as tmp_file: + ds.to_netcdf(tmp_file) + + ds, attrs = new_dataset_and_attrs() + attrs["test"] = "This is a string" + with create_tmp_file() as tmp_file: + ds.to_netcdf(tmp_file) + + ds, attrs = new_dataset_and_attrs() + attrs["test"] = "" + with create_tmp_file() as tmp_file: + ds.to_netcdf(tmp_file) + + +@requires_scipy_or_netCDF4 +class TestDataArrayToNetCDF: + def test_dataarray_to_netcdf_no_name(self) -> None: + original_da = DataArray(np.arange(12).reshape((3, 4))) + + with create_tmp_file() as tmp: + original_da.to_netcdf(tmp) + + with open_dataarray(tmp) as loaded_da: + assert_identical(original_da, loaded_da) + + def test_dataarray_to_netcdf_with_name(self) -> None: + original_da = DataArray(np.arange(12).reshape((3, 4)), name="test") + + with create_tmp_file() as tmp: + original_da.to_netcdf(tmp) + + with open_dataarray(tmp) as loaded_da: + assert_identical(original_da, loaded_da) + + def test_dataarray_to_netcdf_coord_name_clash(self) -> None: + original_da = DataArray( + np.arange(12).reshape((3, 4)), dims=["x", "y"], name="x" + ) + + with create_tmp_file() as tmp: + original_da.to_netcdf(tmp) + + with open_dataarray(tmp) as loaded_da: + assert_identical(original_da, loaded_da) + + def test_open_dataarray_options(self) -> None: + data = DataArray(np.arange(5), coords={"y": ("x", range(5))}, dims=["x"]) + + with create_tmp_file() as tmp: + data.to_netcdf(tmp) + + expected = data.drop_vars("y") + with open_dataarray(tmp, drop_variables=["y"]) as loaded: + assert_identical(expected, loaded) + + @requires_scipy + def test_dataarray_to_netcdf_return_bytes(self) -> None: + # regression test for GH1410 + data = xr.DataArray([1, 2, 3]) + output = data.to_netcdf() + assert isinstance(output, bytes) + + def test_dataarray_to_netcdf_no_name_pathlib(self) -> None: + original_da = DataArray(np.arange(12).reshape((3, 4))) + + with create_tmp_file() as tmps: + tmp = Path(tmps) + original_da.to_netcdf(tmp) + + with open_dataarray(tmp) as loaded_da: + assert_identical(original_da, loaded_da) + + +@requires_zarr +class TestDataArrayToZarr: + def test_dataarray_to_zarr_no_name(self, tmp_store) -> None: + original_da = DataArray(np.arange(12).reshape((3, 4))) + + original_da.to_zarr(tmp_store) + + with open_dataarray(tmp_store, engine="zarr") as loaded_da: + assert_identical(original_da, loaded_da) + + def test_dataarray_to_zarr_with_name(self, tmp_store) -> None: + original_da = DataArray(np.arange(12).reshape((3, 4)), name="test") + + original_da.to_zarr(tmp_store) + + with open_dataarray(tmp_store, engine="zarr") as loaded_da: + assert_identical(original_da, loaded_da) + + def test_dataarray_to_zarr_coord_name_clash(self, tmp_store) -> None: + original_da = DataArray( + np.arange(12).reshape((3, 4)), dims=["x", "y"], name="x" + ) + + original_da.to_zarr(tmp_store) + + with open_dataarray(tmp_store, engine="zarr") as loaded_da: + assert_identical(original_da, loaded_da) + + def test_open_dataarray_options(self, tmp_store) -> None: + data = DataArray(np.arange(5), coords={"y": ("x", range(5))}, dims=["x"]) + + data.to_zarr(tmp_store) + + expected = data.drop_vars("y") + with open_dataarray(tmp_store, engine="zarr", drop_variables=["y"]) as loaded: + assert_identical(expected, loaded) + + @requires_dask + def test_dataarray_to_zarr_compute_false(self, tmp_store) -> None: + from dask.delayed import Delayed + + original_da = DataArray(np.arange(12).reshape((3, 4))) + + output = original_da.to_zarr(tmp_store, compute=False) + assert isinstance(output, Delayed) + output.compute() + with open_dataarray(tmp_store, engine="zarr") as loaded_da: + assert_identical(original_da, loaded_da) + + +@requires_scipy_or_netCDF4 +def test_no_warning_from_dask_effective_get() -> None: + with create_tmp_file() as tmpfile: + with assert_no_warnings(): + ds = Dataset() + ds.to_netcdf(tmpfile) + + +@requires_scipy_or_netCDF4 +def test_source_encoding_always_present() -> None: + # Test for GH issue #2550. + rnddata = np.random.randn(10) + original = Dataset({"foo": ("x", rnddata)}) + with create_tmp_file() as tmp: + original.to_netcdf(tmp) + with open_dataset(tmp) as ds: + assert ds.encoding["source"] == tmp + + +@requires_scipy_or_netCDF4 +def test_source_encoding_always_present_with_pathlib() -> None: + # Test for GH issue #5888. + rnddata = np.random.randn(10) + original = Dataset({"foo": ("x", rnddata)}) + with create_tmp_file() as tmp: + original.to_netcdf(tmp) + with open_dataset(Path(tmp)) as ds: + assert ds.encoding["source"] == tmp + + +def _assert_no_dates_out_of_range_warning(record): + undesired_message = "dates out of range" + for warning in record: + assert undesired_message not in str(warning.message) + + +@requires_scipy_or_netCDF4 +@pytest.mark.parametrize("calendar", _STANDARD_CALENDARS) +def test_use_cftime_standard_calendar_default_in_range(calendar) -> None: + x = [0, 1] + time = [0, 720] + units_date = "2000-01-01" + units = "days since 2000-01-01" + original = DataArray(x, [("time", time)], name="x").to_dataset() + for v in ["x", "time"]: + original[v].attrs["units"] = units + original[v].attrs["calendar"] = calendar + + x_timedeltas = np.array(x).astype("timedelta64[D]") + time_timedeltas = np.array(time).astype("timedelta64[D]") + decoded_x = np.datetime64(units_date, "ns") + x_timedeltas + decoded_time = np.datetime64(units_date, "ns") + time_timedeltas + expected_x = DataArray(decoded_x, [("time", decoded_time)], name="x") + expected_time = DataArray(decoded_time, [("time", decoded_time)], name="time") + + with create_tmp_file() as tmp_file: + original.to_netcdf(tmp_file) + with warnings.catch_warnings(record=True) as record: + with open_dataset(tmp_file) as ds: + assert_identical(expected_x, ds.x) + assert_identical(expected_time, ds.time) + _assert_no_dates_out_of_range_warning(record) + + +@requires_cftime +@requires_scipy_or_netCDF4 +@pytest.mark.parametrize("calendar", _STANDARD_CALENDARS) +@pytest.mark.parametrize("units_year", [1500, 2500]) +def test_use_cftime_standard_calendar_default_out_of_range( + calendar, units_year +) -> None: + import cftime + + x = [0, 1] + time = [0, 720] + units = f"days since {units_year}-01-01" + original = DataArray(x, [("time", time)], name="x").to_dataset() + for v in ["x", "time"]: + original[v].attrs["units"] = units + original[v].attrs["calendar"] = calendar + + decoded_x = cftime.num2date(x, units, calendar, only_use_cftime_datetimes=True) + decoded_time = cftime.num2date( + time, units, calendar, only_use_cftime_datetimes=True + ) + expected_x = DataArray(decoded_x, [("time", decoded_time)], name="x") + expected_time = DataArray(decoded_time, [("time", decoded_time)], name="time") + + with create_tmp_file() as tmp_file: + original.to_netcdf(tmp_file) + with pytest.warns(SerializationWarning): + with open_dataset(tmp_file) as ds: + assert_identical(expected_x, ds.x) + assert_identical(expected_time, ds.time) + + +@requires_cftime +@requires_scipy_or_netCDF4 +@pytest.mark.parametrize("calendar", _ALL_CALENDARS) +@pytest.mark.parametrize("units_year", [1500, 2000, 2500]) +def test_use_cftime_true(calendar, units_year) -> None: + import cftime + + x = [0, 1] + time = [0, 720] + units = f"days since {units_year}-01-01" + original = DataArray(x, [("time", time)], name="x").to_dataset() + for v in ["x", "time"]: + original[v].attrs["units"] = units + original[v].attrs["calendar"] = calendar + + decoded_x = cftime.num2date(x, units, calendar, only_use_cftime_datetimes=True) + decoded_time = cftime.num2date( + time, units, calendar, only_use_cftime_datetimes=True + ) + expected_x = DataArray(decoded_x, [("time", decoded_time)], name="x") + expected_time = DataArray(decoded_time, [("time", decoded_time)], name="time") + + with create_tmp_file() as tmp_file: + original.to_netcdf(tmp_file) + with warnings.catch_warnings(record=True) as record: + with open_dataset(tmp_file, use_cftime=True) as ds: + assert_identical(expected_x, ds.x) + assert_identical(expected_time, ds.time) + _assert_no_dates_out_of_range_warning(record) + + +@requires_scipy_or_netCDF4 +@pytest.mark.parametrize("calendar", _STANDARD_CALENDARS) +@pytest.mark.xfail( + has_numpy_2, reason="https://github.com/pandas-dev/pandas/issues/56996" +) +def test_use_cftime_false_standard_calendar_in_range(calendar) -> None: + x = [0, 1] + time = [0, 720] + units_date = "2000-01-01" + units = "days since 2000-01-01" + original = DataArray(x, [("time", time)], name="x").to_dataset() + for v in ["x", "time"]: + original[v].attrs["units"] = units + original[v].attrs["calendar"] = calendar + + x_timedeltas = np.array(x).astype("timedelta64[D]") + time_timedeltas = np.array(time).astype("timedelta64[D]") + decoded_x = np.datetime64(units_date, "ns") + x_timedeltas + decoded_time = np.datetime64(units_date, "ns") + time_timedeltas + expected_x = DataArray(decoded_x, [("time", decoded_time)], name="x") + expected_time = DataArray(decoded_time, [("time", decoded_time)], name="time") + + with create_tmp_file() as tmp_file: + original.to_netcdf(tmp_file) + with warnings.catch_warnings(record=True) as record: + with open_dataset(tmp_file, use_cftime=False) as ds: + assert_identical(expected_x, ds.x) + assert_identical(expected_time, ds.time) + _assert_no_dates_out_of_range_warning(record) + + +@requires_scipy_or_netCDF4 +@pytest.mark.parametrize("calendar", _STANDARD_CALENDARS) +@pytest.mark.parametrize("units_year", [1500, 2500]) +def test_use_cftime_false_standard_calendar_out_of_range(calendar, units_year) -> None: + x = [0, 1] + time = [0, 720] + units = f"days since {units_year}-01-01" + original = DataArray(x, [("time", time)], name="x").to_dataset() + for v in ["x", "time"]: + original[v].attrs["units"] = units + original[v].attrs["calendar"] = calendar + + with create_tmp_file() as tmp_file: + original.to_netcdf(tmp_file) + with pytest.raises((OutOfBoundsDatetime, ValueError)): + open_dataset(tmp_file, use_cftime=False) + + +@requires_scipy_or_netCDF4 +@pytest.mark.parametrize("calendar", _NON_STANDARD_CALENDARS) +@pytest.mark.parametrize("units_year", [1500, 2000, 2500]) +def test_use_cftime_false_nonstandard_calendar(calendar, units_year) -> None: + x = [0, 1] + time = [0, 720] + units = f"days since {units_year}" + original = DataArray(x, [("time", time)], name="x").to_dataset() + for v in ["x", "time"]: + original[v].attrs["units"] = units + original[v].attrs["calendar"] = calendar + + with create_tmp_file() as tmp_file: + original.to_netcdf(tmp_file) + with pytest.raises((OutOfBoundsDatetime, ValueError)): + open_dataset(tmp_file, use_cftime=False) + + +@pytest.mark.parametrize("engine", ["netcdf4", "scipy"]) +def test_invalid_netcdf_raises(engine) -> None: + data = create_test_data() + with pytest.raises(ValueError, match=r"unrecognized option 'invalid_netcdf'"): + data.to_netcdf("foo.nc", engine=engine, invalid_netcdf=True) + + +@requires_zarr +def test_encode_zarr_attr_value() -> None: + # array -> list + arr = np.array([1, 2, 3]) + expected1 = [1, 2, 3] + actual1 = backends.zarr.encode_zarr_attr_value(arr) + assert isinstance(actual1, list) + assert actual1 == expected1 + + # scalar array -> scalar + sarr = np.array(1)[()] + expected2 = 1 + actual2 = backends.zarr.encode_zarr_attr_value(sarr) + assert isinstance(actual2, int) + assert actual2 == expected2 + + # string -> string (no change) + expected3 = "foo" + actual3 = backends.zarr.encode_zarr_attr_value(expected3) + assert isinstance(actual3, str) + assert actual3 == expected3 + + +@requires_zarr +def test_extract_zarr_variable_encoding() -> None: + var = xr.Variable("x", [1, 2]) + actual = backends.zarr.extract_zarr_variable_encoding(var) + assert "chunks" in actual + assert actual["chunks"] is None + + var = xr.Variable("x", [1, 2], encoding={"chunks": (1,)}) + actual = backends.zarr.extract_zarr_variable_encoding(var) + assert actual["chunks"] == (1,) + + # does not raise on invalid + var = xr.Variable("x", [1, 2], encoding={"foo": (1,)}) + actual = backends.zarr.extract_zarr_variable_encoding(var) + + # raises on invalid + var = xr.Variable("x", [1, 2], encoding={"foo": (1,)}) + with pytest.raises(ValueError, match=r"unexpected encoding parameters"): + actual = backends.zarr.extract_zarr_variable_encoding( + var, raise_on_invalid=True + ) + + +@requires_zarr +@requires_fsspec +@pytest.mark.filterwarnings("ignore:deallocating CachingFileManager") +def test_open_fsspec() -> None: + import fsspec + import zarr + + if not hasattr(zarr.storage, "FSStore") or not hasattr( + zarr.storage.FSStore, "getitems" + ): + pytest.skip("zarr too old") + + ds = open_dataset(os.path.join(os.path.dirname(__file__), "data", "example_1.nc")) + + m = fsspec.filesystem("memory") + mm = m.get_mapper("out1.zarr") + ds.to_zarr(mm) # old interface + ds0 = ds.copy() + ds0["time"] = ds.time + pd.to_timedelta("1 day") + mm = m.get_mapper("out2.zarr") + ds0.to_zarr(mm) # old interface + + # single dataset + url = "memory://out2.zarr" + ds2 = open_dataset(url, engine="zarr") + xr.testing.assert_equal(ds0, ds2) + + # single dataset with caching + url = "simplecache::memory://out2.zarr" + ds2 = open_dataset(url, engine="zarr") + xr.testing.assert_equal(ds0, ds2) + + # open_mfdataset requires dask + if has_dask: + # multi dataset + url = "memory://out*.zarr" + ds2 = open_mfdataset(url, engine="zarr") + xr.testing.assert_equal(xr.concat([ds, ds0], dim="time"), ds2) + + # multi dataset with caching + url = "simplecache::memory://out*.zarr" + ds2 = open_mfdataset(url, engine="zarr") + xr.testing.assert_equal(xr.concat([ds, ds0], dim="time"), ds2) + + +@requires_h5netcdf +@requires_netCDF4 +def test_load_single_value_h5netcdf(tmp_path: Path) -> None: + """Test that numeric single-element vector attributes are handled fine. + + At present (h5netcdf v0.8.1), the h5netcdf exposes single-valued numeric variable + attributes as arrays of length 1, as opposed to scalars for the NetCDF4 + backend. This was leading to a ValueError upon loading a single value from + a file, see #4471. Test that loading causes no failure. + """ + ds = xr.Dataset( + { + "test": xr.DataArray( + np.array([0]), dims=("x",), attrs={"scale_factor": 1, "add_offset": 0} + ) + } + ) + ds.to_netcdf(tmp_path / "test.nc") + with xr.open_dataset(tmp_path / "test.nc", engine="h5netcdf") as ds2: + ds2["test"][0].load() + + +@requires_zarr +@requires_dask +@pytest.mark.parametrize( + "chunks", ["auto", -1, {}, {"x": "auto"}, {"x": -1}, {"x": "auto", "y": -1}] +) +def test_open_dataset_chunking_zarr(chunks, tmp_path: Path) -> None: + encoded_chunks = 100 + dask_arr = da.from_array( + np.ones((500, 500), dtype="float64"), chunks=encoded_chunks + ) + ds = xr.Dataset( + { + "test": xr.DataArray( + dask_arr, + dims=("x", "y"), + ) + } + ) + ds["test"].encoding["chunks"] = encoded_chunks + ds.to_zarr(tmp_path / "test.zarr") + + with dask.config.set({"array.chunk-size": "1MiB"}): + expected = ds.chunk(chunks) + with open_dataset( + tmp_path / "test.zarr", engine="zarr", chunks=chunks + ) as actual: + xr.testing.assert_chunks_equal(actual, expected) + + +@requires_zarr +@requires_dask +@pytest.mark.parametrize( + "chunks", ["auto", -1, {}, {"x": "auto"}, {"x": -1}, {"x": "auto", "y": -1}] +) +@pytest.mark.filterwarnings("ignore:The specified chunks separate") +def test_chunking_consintency(chunks, tmp_path: Path) -> None: + encoded_chunks: dict[str, Any] = {} + dask_arr = da.from_array( + np.ones((500, 500), dtype="float64"), chunks=encoded_chunks + ) + ds = xr.Dataset( + { + "test": xr.DataArray( + dask_arr, + dims=("x", "y"), + ) + } + ) + ds["test"].encoding["chunks"] = encoded_chunks + ds.to_zarr(tmp_path / "test.zarr") + ds.to_netcdf(tmp_path / "test.nc") + + with dask.config.set({"array.chunk-size": "1MiB"}): + expected = ds.chunk(chunks) + with xr.open_dataset( + tmp_path / "test.zarr", engine="zarr", chunks=chunks + ) as actual: + xr.testing.assert_chunks_equal(actual, expected) + + with xr.open_dataset(tmp_path / "test.nc", chunks=chunks) as actual: + xr.testing.assert_chunks_equal(actual, expected) + + +def _check_guess_can_open_and_open(entrypoint, obj, engine, expected): + assert entrypoint.guess_can_open(obj) + with open_dataset(obj, engine=engine) as actual: + assert_identical(expected, actual) + + +@requires_netCDF4 +def test_netcdf4_entrypoint(tmp_path: Path) -> None: + entrypoint = NetCDF4BackendEntrypoint() + ds = create_test_data() + + path = tmp_path / "foo" + ds.to_netcdf(path, format="NETCDF3_CLASSIC") + _check_guess_can_open_and_open(entrypoint, path, engine="netcdf4", expected=ds) + _check_guess_can_open_and_open(entrypoint, str(path), engine="netcdf4", expected=ds) + + path = tmp_path / "bar" + ds.to_netcdf(path, format="NETCDF4_CLASSIC") + _check_guess_can_open_and_open(entrypoint, path, engine="netcdf4", expected=ds) + _check_guess_can_open_and_open(entrypoint, str(path), engine="netcdf4", expected=ds) + + assert entrypoint.guess_can_open("http://something/remote") + assert entrypoint.guess_can_open("something-local.nc") + assert entrypoint.guess_can_open("something-local.nc4") + assert entrypoint.guess_can_open("something-local.cdf") + assert not entrypoint.guess_can_open("not-found-and-no-extension") + + path = tmp_path / "baz" + with open(path, "wb") as f: + f.write(b"not-a-netcdf-file") + assert not entrypoint.guess_can_open(path) + + +@requires_scipy +def test_scipy_entrypoint(tmp_path: Path) -> None: + entrypoint = ScipyBackendEntrypoint() + ds = create_test_data() + + path = tmp_path / "foo" + ds.to_netcdf(path, engine="scipy") + _check_guess_can_open_and_open(entrypoint, path, engine="scipy", expected=ds) + _check_guess_can_open_and_open(entrypoint, str(path), engine="scipy", expected=ds) + with open(path, "rb") as f: + _check_guess_can_open_and_open(entrypoint, f, engine="scipy", expected=ds) + + contents = ds.to_netcdf(engine="scipy") + _check_guess_can_open_and_open(entrypoint, contents, engine="scipy", expected=ds) + _check_guess_can_open_and_open( + entrypoint, BytesIO(contents), engine="scipy", expected=ds + ) + + path = tmp_path / "foo.nc.gz" + with gzip.open(path, mode="wb") as f: + f.write(contents) + _check_guess_can_open_and_open(entrypoint, path, engine="scipy", expected=ds) + _check_guess_can_open_and_open(entrypoint, str(path), engine="scipy", expected=ds) + + assert entrypoint.guess_can_open("something-local.nc") + assert entrypoint.guess_can_open("something-local.nc.gz") + assert not entrypoint.guess_can_open("not-found-and-no-extension") + assert not entrypoint.guess_can_open(b"not-a-netcdf-file") # type: ignore[arg-type] + + +@requires_h5netcdf +def test_h5netcdf_entrypoint(tmp_path: Path) -> None: + entrypoint = H5netcdfBackendEntrypoint() + ds = create_test_data() + + path = tmp_path / "foo" + ds.to_netcdf(path, engine="h5netcdf") + _check_guess_can_open_and_open(entrypoint, path, engine="h5netcdf", expected=ds) + _check_guess_can_open_and_open( + entrypoint, str(path), engine="h5netcdf", expected=ds + ) + with open(path, "rb") as f: + _check_guess_can_open_and_open(entrypoint, f, engine="h5netcdf", expected=ds) + + assert entrypoint.guess_can_open("something-local.nc") + assert entrypoint.guess_can_open("something-local.nc4") + assert entrypoint.guess_can_open("something-local.cdf") + assert not entrypoint.guess_can_open("not-found-and-no-extension") + + +@requires_netCDF4 +@pytest.mark.parametrize("str_type", (str, np.str_)) +def test_write_file_from_np_str(str_type, tmpdir) -> None: + # https://github.com/pydata/xarray/pull/5264 + scenarios = [str_type(v) for v in ["scenario_a", "scenario_b", "scenario_c"]] + years = range(2015, 2100 + 1) + tdf = pd.DataFrame( + data=np.random.random((len(scenarios), len(years))), + columns=years, + index=scenarios, + ) + tdf.index.name = "scenario" + tdf.columns.name = "year" + tdf = tdf.stack() + tdf.name = "tas" + + txr = tdf.to_xarray() + + txr.to_netcdf(tmpdir.join("test.nc")) + + +@requires_zarr +@requires_netCDF4 +class TestNCZarr: + @property + def netcdfc_version(self): + return Version(nc4.getlibversion().split()[0].split("-development")[0]) + + def _create_nczarr(self, filename): + if self.netcdfc_version < Version("4.8.1"): + pytest.skip("requires netcdf-c>=4.8.1") + if platform.system() == "Windows" and self.netcdfc_version == Version("4.8.1"): + # Bug in netcdf-c==4.8.1 (typo: Nan instead of NaN) + # https://github.com/Unidata/netcdf-c/issues/2265 + pytest.skip("netcdf-c==4.8.1 has issues on Windows") + + ds = create_test_data() + # Drop dim3: netcdf-c does not support dtype=' None: + with create_tmp_file(suffix=".zarr") as tmp: + expected = self._create_nczarr(tmp) + actual = xr.open_zarr(tmp, consolidated=False) + assert_identical(expected, actual) + + def test_overwriting_nczarr(self) -> None: + with create_tmp_file(suffix=".zarr") as tmp: + ds = self._create_nczarr(tmp) + expected = ds[["var1"]] + expected.to_zarr(tmp, mode="w") + actual = xr.open_zarr(tmp, consolidated=False) + assert_identical(expected, actual) + + @pytest.mark.parametrize("mode", ["a", "r+"]) + @pytest.mark.filterwarnings("ignore:.*non-consolidated metadata.*") + def test_raise_writing_to_nczarr(self, mode) -> None: + if self.netcdfc_version > Version("4.8.1"): + pytest.skip("netcdf-c>4.8.1 adds the _ARRAY_DIMENSIONS attribute") + + with create_tmp_file(suffix=".zarr") as tmp: + ds = self._create_nczarr(tmp) + with pytest.raises( + KeyError, match="missing the attribute `_ARRAY_DIMENSIONS`," + ): + ds.to_zarr(tmp, mode=mode) + + +@requires_netCDF4 +@requires_dask +def test_pickle_open_mfdataset_dataset(): + with open_example_mfdataset(["bears.nc"]) as ds: + assert_identical(ds, pickle.loads(pickle.dumps(ds))) + + +@requires_zarr +def test_zarr_closing_internal_zip_store(): + store_name = "tmp.zarr.zip" + original_da = DataArray(np.arange(12).reshape((3, 4))) + original_da.to_zarr(store_name, mode="w") + + with open_dataarray(store_name, engine="zarr") as loaded_da: + assert_identical(original_da, loaded_da) + + +@requires_zarr +class TestZarrRegionAuto: + def test_zarr_region_auto_all(self, tmp_path): + x = np.arange(0, 50, 10) + y = np.arange(0, 20, 2) + data = np.ones((5, 10)) + ds = xr.Dataset( + { + "test": xr.DataArray( + data, + dims=("x", "y"), + coords={"x": x, "y": y}, + ) + } + ) + ds.to_zarr(tmp_path / "test.zarr") + + ds_region = 1 + ds.isel(x=slice(2, 4), y=slice(6, 8)) + ds_region.to_zarr(tmp_path / "test.zarr", region="auto") + + ds_updated = xr.open_zarr(tmp_path / "test.zarr") + + expected = ds.copy() + expected["test"][2:4, 6:8] += 1 + assert_identical(ds_updated, expected) + + def test_zarr_region_auto_mixed(self, tmp_path): + x = np.arange(0, 50, 10) + y = np.arange(0, 20, 2) + data = np.ones((5, 10)) + ds = xr.Dataset( + { + "test": xr.DataArray( + data, + dims=("x", "y"), + coords={"x": x, "y": y}, + ) + } + ) + ds.to_zarr(tmp_path / "test.zarr") + + ds_region = 1 + ds.isel(x=slice(2, 4), y=slice(6, 8)) + ds_region.to_zarr( + tmp_path / "test.zarr", region={"x": "auto", "y": slice(6, 8)} + ) + + ds_updated = xr.open_zarr(tmp_path / "test.zarr") + + expected = ds.copy() + expected["test"][2:4, 6:8] += 1 + assert_identical(ds_updated, expected) + + def test_zarr_region_auto_noncontiguous(self, tmp_path): + x = np.arange(0, 50, 10) + y = np.arange(0, 20, 2) + data = np.ones((5, 10)) + ds = xr.Dataset( + { + "test": xr.DataArray( + data, + dims=("x", "y"), + coords={"x": x, "y": y}, + ) + } + ) + ds.to_zarr(tmp_path / "test.zarr") + + ds_region = 1 + ds.isel(x=[0, 2, 3], y=[5, 6]) + with pytest.raises(ValueError): + ds_region.to_zarr(tmp_path / "test.zarr", region={"x": "auto", "y": "auto"}) + + def test_zarr_region_auto_new_coord_vals(self, tmp_path): + x = np.arange(0, 50, 10) + y = np.arange(0, 20, 2) + data = np.ones((5, 10)) + ds = xr.Dataset( + { + "test": xr.DataArray( + data, + dims=("x", "y"), + coords={"x": x, "y": y}, + ) + } + ) + ds.to_zarr(tmp_path / "test.zarr") + + x = np.arange(5, 55, 10) + y = np.arange(0, 20, 2) + data = np.ones((5, 10)) + ds = xr.Dataset( + { + "test": xr.DataArray( + data, + dims=("x", "y"), + coords={"x": x, "y": y}, + ) + } + ) + + ds_region = 1 + ds.isel(x=slice(2, 4), y=slice(6, 8)) + with pytest.raises(KeyError): + ds_region.to_zarr(tmp_path / "test.zarr", region={"x": "auto", "y": "auto"}) + + def test_zarr_region_index_write(self, tmp_path): + from xarray.backends.zarr import ZarrStore + + x = np.arange(0, 50, 10) + y = np.arange(0, 20, 2) + data = np.ones((5, 10)) + ds = xr.Dataset( + { + "test": xr.DataArray( + data, + dims=("x", "y"), + coords={"x": x, "y": y}, + ) + } + ) + + region_slice = dict(x=slice(2, 4), y=slice(6, 8)) + ds_region = 1 + ds.isel(region_slice) + + ds.to_zarr(tmp_path / "test.zarr") + + region: Mapping[str, slice] | Literal["auto"] + for region in [region_slice, "auto"]: # type: ignore + with patch.object( + ZarrStore, + "set_variables", + side_effect=ZarrStore.set_variables, + autospec=True, + ) as mock: + ds_region.to_zarr(tmp_path / "test.zarr", region=region, mode="r+") + + # should write the data vars but never the index vars with auto mode + for call in mock.call_args_list: + written_variables = call.args[1].keys() + assert "test" in written_variables + assert "x" not in written_variables + assert "y" not in written_variables + + def test_zarr_region_append(self, tmp_path): + x = np.arange(0, 50, 10) + y = np.arange(0, 20, 2) + data = np.ones((5, 10)) + ds = xr.Dataset( + { + "test": xr.DataArray( + data, + dims=("x", "y"), + coords={"x": x, "y": y}, + ) + } + ) + ds.to_zarr(tmp_path / "test.zarr") + + x_new = np.arange(40, 70, 10) + data_new = np.ones((3, 10)) + ds_new = xr.Dataset( + { + "test": xr.DataArray( + data_new, + dims=("x", "y"), + coords={"x": x_new, "y": y}, + ) + } + ) + + # Don't allow auto region detection in append mode due to complexities in + # implementing the overlap logic and lack of safety with parallel writes + with pytest.raises(ValueError): + ds_new.to_zarr( + tmp_path / "test.zarr", mode="a", append_dim="x", region="auto" + ) + + +@requires_zarr +def test_zarr_region(tmp_path): + x = np.arange(0, 50, 10) + y = np.arange(0, 20, 2) + data = np.ones((5, 10)) + ds = xr.Dataset( + { + "test": xr.DataArray( + data, + dims=("x", "y"), + coords={"x": x, "y": y}, + ) + } + ) + ds.to_zarr(tmp_path / "test.zarr") + + ds_transposed = ds.transpose("y", "x") + + ds_region = 1 + ds_transposed.isel(x=[0], y=[0]) + ds_region.to_zarr( + tmp_path / "test.zarr", region={"x": slice(0, 1), "y": slice(0, 1)} + ) + + # Write without region + ds_transposed.to_zarr(tmp_path / "test.zarr", mode="r+") + + +@requires_zarr +@requires_dask +def test_zarr_region_chunk_partial(tmp_path): + """ + Check that writing to partial chunks with `region` fails, assuming `safe_chunks=False`. + """ + ds = ( + xr.DataArray(np.arange(120).reshape(4, 3, -1), dims=list("abc")) + .rename("var1") + .to_dataset() + ) + + ds.chunk(5).to_zarr(tmp_path / "foo.zarr", compute=False, mode="w") + with pytest.raises(ValueError): + for r in range(ds.sizes["a"]): + ds.chunk(3).isel(a=[r]).to_zarr( + tmp_path / "foo.zarr", region=dict(a=slice(r, r + 1)) + ) + + +@requires_zarr +@requires_dask +def test_zarr_append_chunk_partial(tmp_path): + t_coords = np.array([np.datetime64("2020-01-01").astype("datetime64[ns]")]) + data = np.ones((10, 10)) + + da = xr.DataArray( + data.reshape((-1, 10, 10)), + dims=["time", "x", "y"], + coords={"time": t_coords}, + name="foo", + ) + da.to_zarr(tmp_path / "foo.zarr", mode="w", encoding={"foo": {"chunks": (5, 5, 1)}}) + + new_time = np.array([np.datetime64("2021-01-01").astype("datetime64[ns]")]) + + da2 = xr.DataArray( + data.reshape((-1, 10, 10)), + dims=["time", "x", "y"], + coords={"time": new_time}, + name="foo", + ) + with pytest.raises(ValueError, match="encoding was provided"): + da2.to_zarr( + tmp_path / "foo.zarr", + append_dim="time", + mode="a", + encoding={"foo": {"chunks": (1, 1, 1)}}, + ) + + # chunking with dask sidesteps the encoding check, so we need a different check + with pytest.raises(ValueError, match="Specified zarr chunks"): + da2.chunk({"x": 1, "y": 1, "time": 1}).to_zarr( + tmp_path / "foo.zarr", append_dim="time", mode="a" + ) + + +@requires_zarr +@requires_dask +def test_zarr_region_chunk_partial_offset(tmp_path): + # https://github.com/pydata/xarray/pull/8459#issuecomment-1819417545 + store = tmp_path / "foo.zarr" + data = np.ones((30,)) + da = xr.DataArray(data, dims=["x"], coords={"x": range(30)}, name="foo").chunk(x=10) + da.to_zarr(store, compute=False) + + da.isel(x=slice(10)).chunk(x=(10,)).to_zarr(store, region="auto") + + da.isel(x=slice(5, 25)).chunk(x=(10, 10)).to_zarr( + store, safe_chunks=False, region="auto" + ) + + # This write is unsafe, and should raise an error, but does not. + # with pytest.raises(ValueError): + # da.isel(x=slice(5, 25)).chunk(x=(10, 10)).to_zarr(store, region="auto") diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_backends_api.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_backends_api.py new file mode 100644 index 0000000..d4f8b7e --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_backends_api.py @@ -0,0 +1,201 @@ +from __future__ import annotations + +from numbers import Number + +import numpy as np +import pytest + +import xarray as xr +from xarray.backends.api import _get_default_engine +from xarray.tests import ( + assert_identical, + assert_no_warnings, + requires_dask, + requires_netCDF4, + requires_scipy, +) + + +@requires_netCDF4 +@requires_scipy +def test__get_default_engine() -> None: + engine_remote = _get_default_engine("http://example.org/test.nc", allow_remote=True) + assert engine_remote == "netcdf4" + + engine_gz = _get_default_engine("/example.gz") + assert engine_gz == "scipy" + + engine_default = _get_default_engine("/example") + assert engine_default == "netcdf4" + + +def test_custom_engine() -> None: + expected = xr.Dataset( + dict(a=2 * np.arange(5)), coords=dict(x=("x", np.arange(5), dict(units="s"))) + ) + + class CustomBackend(xr.backends.BackendEntrypoint): + def open_dataset( + self, + filename_or_obj, + drop_variables=None, + **kwargs, + ) -> xr.Dataset: + return expected.copy(deep=True) + + actual = xr.open_dataset("fake_filename", engine=CustomBackend) + assert_identical(expected, actual) + + +def test_multiindex() -> None: + # GH7139 + # Check that we properly handle backends that change index variables + dataset = xr.Dataset(coords={"coord1": ["A", "B"], "coord2": [1, 2]}) + dataset = dataset.stack(z=["coord1", "coord2"]) + + class MultiindexBackend(xr.backends.BackendEntrypoint): + def open_dataset( + self, + filename_or_obj, + drop_variables=None, + **kwargs, + ) -> xr.Dataset: + return dataset.copy(deep=True) + + loaded = xr.open_dataset("fake_filename", engine=MultiindexBackend) + assert_identical(dataset, loaded) + + +class PassThroughBackendEntrypoint(xr.backends.BackendEntrypoint): + """Access an object passed to the `open_dataset` method.""" + + def open_dataset(self, dataset, *, drop_variables=None): + """Return the first argument.""" + return dataset + + +def explicit_chunks(chunks, shape): + """Return explicit chunks, expanding any integer member to a tuple of integers.""" + # Emulate `dask.array.core.normalize_chunks` but for simpler inputs. + return tuple( + ( + ( + (size // chunk) * (chunk,) + + ((size % chunk,) if size % chunk or size == 0 else ()) + ) + if isinstance(chunk, Number) + else chunk + ) + for chunk, size in zip(chunks, shape) + ) + + +@requires_dask +class TestPreferredChunks: + """Test behaviors related to the backend's preferred chunks.""" + + var_name = "data" + + def create_dataset(self, shape, pref_chunks): + """Return a dataset with a variable with the given shape and preferred chunks.""" + dims = tuple(f"dim_{idx}" for idx in range(len(shape))) + return xr.Dataset( + { + self.var_name: xr.Variable( + dims, + np.empty(shape, dtype=np.dtype("V1")), + encoding={"preferred_chunks": dict(zip(dims, pref_chunks))}, + ) + } + ) + + def check_dataset(self, initial, final, expected_chunks): + assert_identical(initial, final) + assert final[self.var_name].chunks == expected_chunks + + @pytest.mark.parametrize( + "shape,pref_chunks", + [ + # Represent preferred chunking with int. + ((5,), (2,)), + # Represent preferred chunking with tuple. + ((5,), ((2, 2, 1),)), + # Represent preferred chunking with int in two dims. + ((5, 6), (4, 2)), + # Represent preferred chunking with tuple in second dim. + ((5, 6), (4, (2, 2, 2))), + ], + ) + @pytest.mark.parametrize("request_with_empty_map", [False, True]) + def test_honor_chunks(self, shape, pref_chunks, request_with_empty_map): + """Honor the backend's preferred chunks when opening a dataset.""" + initial = self.create_dataset(shape, pref_chunks) + # To keep the backend's preferred chunks, the `chunks` argument must be an + # empty mapping or map dimensions to `None`. + chunks = ( + {} + if request_with_empty_map + else dict.fromkeys(initial[self.var_name].dims, None) + ) + final = xr.open_dataset( + initial, engine=PassThroughBackendEntrypoint, chunks=chunks + ) + self.check_dataset(initial, final, explicit_chunks(pref_chunks, shape)) + + @pytest.mark.parametrize( + "shape,pref_chunks,req_chunks", + [ + # Preferred chunking is int; requested chunking is int. + ((5,), (2,), (3,)), + # Preferred chunking is int; requested chunking is tuple. + ((5,), (2,), ((2, 1, 1, 1),)), + # Preferred chunking is tuple; requested chunking is int. + ((5,), ((2, 2, 1),), (3,)), + # Preferred chunking is tuple; requested chunking is tuple. + ((5,), ((2, 2, 1),), ((2, 1, 1, 1),)), + # Split chunks along a dimension other than the first. + ((1, 5), (1, 2), (1, 3)), + ], + ) + def test_split_chunks(self, shape, pref_chunks, req_chunks): + """Warn when the requested chunks separate the backend's preferred chunks.""" + initial = self.create_dataset(shape, pref_chunks) + with pytest.warns(UserWarning): + final = xr.open_dataset( + initial, + engine=PassThroughBackendEntrypoint, + chunks=dict(zip(initial[self.var_name].dims, req_chunks)), + ) + self.check_dataset(initial, final, explicit_chunks(req_chunks, shape)) + + @pytest.mark.parametrize( + "shape,pref_chunks,req_chunks", + [ + # Keep preferred chunks using int representation. + ((5,), (2,), (2,)), + # Keep preferred chunks using tuple representation. + ((5,), (2,), ((2, 2, 1),)), + # Join chunks, leaving a final short chunk. + ((5,), (2,), (4,)), + # Join all chunks with an int larger than the dimension size. + ((5,), (2,), (6,)), + # Join one chunk using tuple representation. + ((5,), (1,), ((1, 1, 2, 1),)), + # Join one chunk using int representation. + ((5,), ((1, 1, 2, 1),), (2,)), + # Join multiple chunks using tuple representation. + ((5,), ((1, 1, 2, 1),), ((2, 3),)), + # Join chunks in multiple dimensions. + ((5, 5), (2, (1, 1, 2, 1)), (4, (2, 3))), + ], + ) + def test_join_chunks(self, shape, pref_chunks, req_chunks): + """Don't warn when the requested chunks join or keep the preferred chunks.""" + initial = self.create_dataset(shape, pref_chunks) + with assert_no_warnings(): + final = xr.open_dataset( + initial, + engine=PassThroughBackendEntrypoint, + chunks=dict(zip(initial[self.var_name].dims, req_chunks)), + ) + self.check_dataset(initial, final, explicit_chunks(req_chunks, shape)) diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_backends_common.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_backends_common.py new file mode 100644 index 0000000..c7dba36 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_backends_common.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import pytest + +from xarray.backends.common import robust_getitem + + +class DummyFailure(Exception): + pass + + +class DummyArray: + def __init__(self, failures): + self.failures = failures + + def __getitem__(self, key): + if self.failures: + self.failures -= 1 + raise DummyFailure + return "success" + + +def test_robust_getitem() -> None: + array = DummyArray(failures=2) + with pytest.raises(DummyFailure): + array[...] + result = robust_getitem(array, ..., catch=DummyFailure, initial_delay=1) + assert result == "success" + + array = DummyArray(failures=3) + with pytest.raises(DummyFailure): + robust_getitem(array, ..., catch=DummyFailure, initial_delay=1, max_retries=2) diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_backends_datatree.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_backends_datatree.py new file mode 100644 index 0000000..4e819ee --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_backends_datatree.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +from xarray.backends.api import open_datatree +from xarray.testing import assert_equal +from xarray.tests import ( + requires_h5netcdf, + requires_netCDF4, + requires_zarr, +) + +if TYPE_CHECKING: + from xarray.backends.api import T_NetcdfEngine + + +class DatatreeIOBase: + engine: T_NetcdfEngine | None = None + + def test_to_netcdf(self, tmpdir, simple_datatree): + filepath = tmpdir / "test.nc" + original_dt = simple_datatree + original_dt.to_netcdf(filepath, engine=self.engine) + + roundtrip_dt = open_datatree(filepath, engine=self.engine) + assert_equal(original_dt, roundtrip_dt) + + def test_netcdf_encoding(self, tmpdir, simple_datatree): + filepath = tmpdir / "test.nc" + original_dt = simple_datatree + + # add compression + comp = dict(zlib=True, complevel=9) + enc = {"/set2": {var: comp for var in original_dt["/set2"].ds.data_vars}} + + original_dt.to_netcdf(filepath, encoding=enc, engine=self.engine) + roundtrip_dt = open_datatree(filepath, engine=self.engine) + + assert roundtrip_dt["/set2/a"].encoding["zlib"] == comp["zlib"] + assert roundtrip_dt["/set2/a"].encoding["complevel"] == comp["complevel"] + + enc["/not/a/group"] = {"foo": "bar"} # type: ignore + with pytest.raises(ValueError, match="unexpected encoding group.*"): + original_dt.to_netcdf(filepath, encoding=enc, engine=self.engine) + + +@requires_netCDF4 +class TestNetCDF4DatatreeIO(DatatreeIOBase): + engine: T_NetcdfEngine | None = "netcdf4" + + +@requires_h5netcdf +class TestH5NetCDFDatatreeIO(DatatreeIOBase): + engine: T_NetcdfEngine | None = "h5netcdf" + + +@requires_zarr +class TestZarrDatatreeIO: + engine = "zarr" + + def test_to_zarr(self, tmpdir, simple_datatree): + filepath = tmpdir / "test.zarr" + original_dt = simple_datatree + original_dt.to_zarr(filepath) + + roundtrip_dt = open_datatree(filepath, engine="zarr") + assert_equal(original_dt, roundtrip_dt) + + def test_zarr_encoding(self, tmpdir, simple_datatree): + import zarr + + filepath = tmpdir / "test.zarr" + original_dt = simple_datatree + + comp = {"compressor": zarr.Blosc(cname="zstd", clevel=3, shuffle=2)} + enc = {"/set2": {var: comp for var in original_dt["/set2"].ds.data_vars}} + original_dt.to_zarr(filepath, encoding=enc) + roundtrip_dt = open_datatree(filepath, engine="zarr") + + print(roundtrip_dt["/set2/a"].encoding) + assert roundtrip_dt["/set2/a"].encoding["compressor"] == comp["compressor"] + + enc["/not/a/group"] = {"foo": "bar"} # type: ignore + with pytest.raises(ValueError, match="unexpected encoding group.*"): + original_dt.to_zarr(filepath, encoding=enc, engine="zarr") + + def test_to_zarr_zip_store(self, tmpdir, simple_datatree): + from zarr.storage import ZipStore + + filepath = tmpdir / "test.zarr.zip" + original_dt = simple_datatree + store = ZipStore(filepath) + original_dt.to_zarr(store) + + roundtrip_dt = open_datatree(store, engine="zarr") + assert_equal(original_dt, roundtrip_dt) + + def test_to_zarr_not_consolidated(self, tmpdir, simple_datatree): + filepath = tmpdir / "test.zarr" + zmetadata = filepath / ".zmetadata" + s1zmetadata = filepath / "set1" / ".zmetadata" + filepath = str(filepath) # casting to str avoids a pathlib bug in xarray + original_dt = simple_datatree + original_dt.to_zarr(filepath, consolidated=False) + assert not zmetadata.exists() + assert not s1zmetadata.exists() + + with pytest.warns(RuntimeWarning, match="consolidated"): + roundtrip_dt = open_datatree(filepath, engine="zarr") + assert_equal(original_dt, roundtrip_dt) + + def test_to_zarr_default_write_mode(self, tmpdir, simple_datatree): + import zarr + + simple_datatree.to_zarr(tmpdir) + + # with default settings, to_zarr should not overwrite an existing dir + with pytest.raises(zarr.errors.ContainsGroupError): + simple_datatree.to_zarr(tmpdir) diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_backends_file_manager.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_backends_file_manager.py new file mode 100644 index 0000000..cede3e6 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_backends_file_manager.py @@ -0,0 +1,264 @@ +from __future__ import annotations + +import gc +import pickle +import threading +from unittest import mock + +import pytest + +from xarray.backends.file_manager import CachingFileManager +from xarray.backends.lru_cache import LRUCache +from xarray.core.options import set_options +from xarray.tests import assert_no_warnings + + +@pytest.fixture(params=[1, 2, 3, None]) +def file_cache(request): + maxsize = request.param + if maxsize is None: + yield {} + else: + yield LRUCache(maxsize) + + +def test_file_manager_mock_write(file_cache) -> None: + mock_file = mock.Mock() + opener = mock.Mock(spec=open, return_value=mock_file) + lock = mock.MagicMock(spec=threading.Lock()) + + manager = CachingFileManager(opener, "filename", lock=lock, cache=file_cache) + f = manager.acquire() + f.write("contents") + manager.close() + + assert not file_cache + opener.assert_called_once_with("filename") + mock_file.write.assert_called_once_with("contents") + mock_file.close.assert_called_once_with() + lock.__enter__.assert_has_calls([mock.call(), mock.call()]) + + +@pytest.mark.parametrize("warn_for_unclosed_files", [True, False]) +def test_file_manager_autoclose(warn_for_unclosed_files) -> None: + mock_file = mock.Mock() + opener = mock.Mock(return_value=mock_file) + cache: dict = {} + + manager = CachingFileManager(opener, "filename", cache=cache) + manager.acquire() + assert cache + + # can no longer use pytest.warns(None) + if warn_for_unclosed_files: + ctx = pytest.warns(RuntimeWarning) + else: + ctx = assert_no_warnings() # type: ignore + + with set_options(warn_for_unclosed_files=warn_for_unclosed_files): + with ctx: + del manager + gc.collect() + + assert not cache + mock_file.close.assert_called_once_with() + + +def test_file_manager_autoclose_while_locked() -> None: + opener = mock.Mock() + lock = threading.Lock() + cache: dict = {} + + manager = CachingFileManager(opener, "filename", lock=lock, cache=cache) + manager.acquire() + assert cache + + lock.acquire() + + with set_options(warn_for_unclosed_files=False): + del manager + gc.collect() + + # can't clear the cache while locked, but also don't block in __del__ + assert cache + + +def test_file_manager_repr() -> None: + opener = mock.Mock() + manager = CachingFileManager(opener, "my-file") + assert "my-file" in repr(manager) + + +def test_file_manager_cache_and_refcounts() -> None: + mock_file = mock.Mock() + opener = mock.Mock(spec=open, return_value=mock_file) + cache: dict = {} + ref_counts: dict = {} + + manager = CachingFileManager(opener, "filename", cache=cache, ref_counts=ref_counts) + assert ref_counts[manager._key] == 1 + + assert not cache + manager.acquire() + assert len(cache) == 1 + + with set_options(warn_for_unclosed_files=False): + del manager + gc.collect() + + assert not ref_counts + assert not cache + + +def test_file_manager_cache_repeated_open() -> None: + mock_file = mock.Mock() + opener = mock.Mock(spec=open, return_value=mock_file) + cache: dict = {} + + manager = CachingFileManager(opener, "filename", cache=cache) + manager.acquire() + assert len(cache) == 1 + + manager2 = CachingFileManager(opener, "filename", cache=cache) + manager2.acquire() + assert len(cache) == 2 + + with set_options(warn_for_unclosed_files=False): + del manager + gc.collect() + + assert len(cache) == 1 + + with set_options(warn_for_unclosed_files=False): + del manager2 + gc.collect() + + assert not cache + + +def test_file_manager_cache_with_pickle(tmpdir) -> None: + path = str(tmpdir.join("testing.txt")) + with open(path, "w") as f: + f.write("data") + cache: dict = {} + + with mock.patch("xarray.backends.file_manager.FILE_CACHE", cache): + assert not cache + + manager = CachingFileManager(open, path, mode="r") + manager.acquire() + assert len(cache) == 1 + + manager2 = pickle.loads(pickle.dumps(manager)) + manager2.acquire() + assert len(cache) == 1 + + with set_options(warn_for_unclosed_files=False): + del manager + gc.collect() + # assert len(cache) == 1 + + with set_options(warn_for_unclosed_files=False): + del manager2 + gc.collect() + assert not cache + + +def test_file_manager_write_consecutive(tmpdir, file_cache) -> None: + path1 = str(tmpdir.join("testing1.txt")) + path2 = str(tmpdir.join("testing2.txt")) + manager1 = CachingFileManager(open, path1, mode="w", cache=file_cache) + manager2 = CachingFileManager(open, path2, mode="w", cache=file_cache) + f1a = manager1.acquire() + f1a.write("foo") + f1a.flush() + f2 = manager2.acquire() + f2.write("bar") + f2.flush() + f1b = manager1.acquire() + f1b.write("baz") + assert (getattr(file_cache, "maxsize", float("inf")) > 1) == (f1a is f1b) + manager1.close() + manager2.close() + + with open(path1) as f: + assert f.read() == "foobaz" + with open(path2) as f: + assert f.read() == "bar" + + +def test_file_manager_write_concurrent(tmpdir, file_cache) -> None: + path = str(tmpdir.join("testing.txt")) + manager = CachingFileManager(open, path, mode="w", cache=file_cache) + f1 = manager.acquire() + f2 = manager.acquire() + f3 = manager.acquire() + assert f1 is f2 + assert f2 is f3 + f1.write("foo") + f1.flush() + f2.write("bar") + f2.flush() + f3.write("baz") + f3.flush() + manager.close() + + with open(path) as f: + assert f.read() == "foobarbaz" + + +def test_file_manager_write_pickle(tmpdir, file_cache) -> None: + path = str(tmpdir.join("testing.txt")) + manager = CachingFileManager(open, path, mode="w", cache=file_cache) + f = manager.acquire() + f.write("foo") + f.flush() + manager2 = pickle.loads(pickle.dumps(manager)) + f2 = manager2.acquire() + f2.write("bar") + manager2.close() + manager.close() + + with open(path) as f: + assert f.read() == "foobar" + + +def test_file_manager_read(tmpdir, file_cache) -> None: + path = str(tmpdir.join("testing.txt")) + + with open(path, "w") as f: + f.write("foobar") + + manager = CachingFileManager(open, path, cache=file_cache) + f = manager.acquire() + assert f.read() == "foobar" + manager.close() + + +def test_file_manager_acquire_context(tmpdir, file_cache) -> None: + path = str(tmpdir.join("testing.txt")) + + with open(path, "w") as f: + f.write("foobar") + + class AcquisitionError(Exception): + pass + + manager = CachingFileManager(open, path, cache=file_cache) + with pytest.raises(AcquisitionError): + with manager.acquire_context() as f: + assert f.read() == "foobar" + raise AcquisitionError + assert not file_cache # file was *not* already open + + with manager.acquire_context() as f: + assert f.read() == "foobar" + + with pytest.raises(AcquisitionError): + with manager.acquire_context() as f: + f.seek(0) + assert f.read() == "foobar" + raise AcquisitionError + assert file_cache # file *was* already open + + manager.close() diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_backends_locks.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_backends_locks.py new file mode 100644 index 0000000..341a9c4 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_backends_locks.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +import threading + +from xarray.backends import locks + + +def test_threaded_lock() -> None: + lock1 = locks._get_threaded_lock("foo") + assert isinstance(lock1, type(threading.Lock())) + lock2 = locks._get_threaded_lock("foo") + assert lock1 is lock2 + + lock3 = locks._get_threaded_lock("bar") + assert lock1 is not lock3 diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_backends_lru_cache.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_backends_lru_cache.py new file mode 100644 index 0000000..5735e03 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_backends_lru_cache.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +from typing import Any +from unittest import mock + +import pytest + +from xarray.backends.lru_cache import LRUCache + + +def test_simple() -> None: + cache: LRUCache[Any, Any] = LRUCache(maxsize=2) + cache["x"] = 1 + cache["y"] = 2 + + assert cache["x"] == 1 + assert cache["y"] == 2 + assert len(cache) == 2 + assert dict(cache) == {"x": 1, "y": 2} + assert list(cache.keys()) == ["x", "y"] + assert list(cache.items()) == [("x", 1), ("y", 2)] + + cache["z"] = 3 + assert len(cache) == 2 + assert list(cache.items()) == [("y", 2), ("z", 3)] + + +def test_trivial() -> None: + cache: LRUCache[Any, Any] = LRUCache(maxsize=0) + cache["x"] = 1 + assert len(cache) == 0 + + +def test_invalid() -> None: + with pytest.raises(TypeError): + LRUCache(maxsize=None) # type: ignore + with pytest.raises(ValueError): + LRUCache(maxsize=-1) + + +def test_update_priority() -> None: + cache: LRUCache[Any, Any] = LRUCache(maxsize=2) + cache["x"] = 1 + cache["y"] = 2 + assert list(cache) == ["x", "y"] + assert "x" in cache # contains + assert list(cache) == ["y", "x"] + assert cache["y"] == 2 # getitem + assert list(cache) == ["x", "y"] + cache["x"] = 3 # setitem + assert list(cache.items()) == [("y", 2), ("x", 3)] + + +def test_del() -> None: + cache: LRUCache[Any, Any] = LRUCache(maxsize=2) + cache["x"] = 1 + cache["y"] = 2 + del cache["x"] + assert dict(cache) == {"y": 2} + + +def test_on_evict() -> None: + on_evict = mock.Mock() + cache = LRUCache(maxsize=1, on_evict=on_evict) + cache["x"] = 1 + cache["y"] = 2 + on_evict.assert_called_once_with("x", 1) + + +def test_on_evict_trivial() -> None: + on_evict = mock.Mock() + cache = LRUCache(maxsize=0, on_evict=on_evict) + cache["x"] = 1 + on_evict.assert_called_once_with("x", 1) + + +def test_resize() -> None: + cache: LRUCache[Any, Any] = LRUCache(maxsize=2) + assert cache.maxsize == 2 + cache["w"] = 0 + cache["x"] = 1 + cache["y"] = 2 + assert list(cache.items()) == [("x", 1), ("y", 2)] + cache.maxsize = 10 + cache["z"] = 3 + assert list(cache.items()) == [("x", 1), ("y", 2), ("z", 3)] + cache.maxsize = 1 + assert list(cache.items()) == [("z", 3)] + + with pytest.raises(ValueError): + cache.maxsize = -1 diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_calendar_ops.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_calendar_ops.py new file mode 100644 index 0000000..7d22937 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_calendar_ops.py @@ -0,0 +1,288 @@ +from __future__ import annotations + +import numpy as np +import pytest + +from xarray import DataArray, infer_freq +from xarray.coding.calendar_ops import convert_calendar, interp_calendar +from xarray.coding.cftime_offsets import date_range +from xarray.testing import assert_identical +from xarray.tests import requires_cftime + +cftime = pytest.importorskip("cftime") + + +@pytest.mark.parametrize( + "source, target, use_cftime, freq", + [ + ("standard", "noleap", None, "D"), + ("noleap", "proleptic_gregorian", True, "D"), + ("noleap", "all_leap", None, "D"), + ("all_leap", "proleptic_gregorian", False, "4h"), + ], +) +def test_convert_calendar(source, target, use_cftime, freq): + src = DataArray( + date_range("2004-01-01", "2004-12-31", freq=freq, calendar=source), + dims=("time",), + name="time", + ) + da_src = DataArray( + np.linspace(0, 1, src.size), dims=("time",), coords={"time": src} + ) + + conv = convert_calendar(da_src, target, use_cftime=use_cftime) + + assert conv.time.dt.calendar == target + + if source != "noleap": + expected_times = date_range( + "2004-01-01", + "2004-12-31", + freq=freq, + use_cftime=use_cftime, + calendar=target, + ) + else: + expected_times_pre_leap = date_range( + "2004-01-01", + "2004-02-28", + freq=freq, + use_cftime=use_cftime, + calendar=target, + ) + expected_times_post_leap = date_range( + "2004-03-01", + "2004-12-31", + freq=freq, + use_cftime=use_cftime, + calendar=target, + ) + expected_times = expected_times_pre_leap.append(expected_times_post_leap) + np.testing.assert_array_equal(conv.time, expected_times) + + +@pytest.mark.parametrize( + "source,target,freq", + [ + ("standard", "360_day", "D"), + ("360_day", "proleptic_gregorian", "D"), + ("proleptic_gregorian", "360_day", "4h"), + ], +) +@pytest.mark.parametrize("align_on", ["date", "year"]) +def test_convert_calendar_360_days(source, target, freq, align_on): + src = DataArray( + date_range("2004-01-01", "2004-12-30", freq=freq, calendar=source), + dims=("time",), + name="time", + ) + da_src = DataArray( + np.linspace(0, 1, src.size), dims=("time",), coords={"time": src} + ) + + conv = convert_calendar(da_src, target, align_on=align_on) + + assert conv.time.dt.calendar == target + + if align_on == "date": + np.testing.assert_array_equal( + conv.time.resample(time="ME").last().dt.day, + [30, 29, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30], + ) + elif target == "360_day": + np.testing.assert_array_equal( + conv.time.resample(time="ME").last().dt.day, + [30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 29], + ) + else: + np.testing.assert_array_equal( + conv.time.resample(time="ME").last().dt.day, + [30, 29, 30, 30, 31, 30, 30, 31, 30, 31, 29, 31], + ) + if source == "360_day" and align_on == "year": + assert conv.size == 360 if freq == "D" else 360 * 4 + else: + assert conv.size == 359 if freq == "D" else 359 * 4 + + +def test_convert_calendar_360_days_random(): + da_std = DataArray( + np.linspace(0, 1, 366), + dims=("time",), + coords={ + "time": date_range( + "2004-01-01", + "2004-12-31", + freq="D", + calendar="standard", + use_cftime=False, + ) + }, + ) + da_360 = DataArray( + np.linspace(0, 1, 360), + dims=("time",), + coords={ + "time": date_range("2004-01-01", "2004-12-30", freq="D", calendar="360_day") + }, + ) + + conv = convert_calendar(da_std, "360_day", align_on="random") + conv2 = convert_calendar(da_std, "360_day", align_on="random") + assert (conv != conv2).any() + + conv = convert_calendar(da_360, "standard", use_cftime=False, align_on="random") + assert np.datetime64("2004-02-29") not in conv.time + conv2 = convert_calendar(da_360, "standard", use_cftime=False, align_on="random") + assert (conv2 != conv).any() + + # Ensure that added days are evenly distributed in the 5 fifths of each year + conv = convert_calendar(da_360, "noleap", align_on="random", missing=np.nan) + conv = conv.where(conv.isnull(), drop=True) + nandoys = conv.time.dt.dayofyear[:366] + assert all(nandoys < np.array([74, 147, 220, 293, 366])) + assert all(nandoys > np.array([0, 73, 146, 219, 292])) + + +@requires_cftime +@pytest.mark.parametrize( + "source,target,freq", + [ + ("standard", "noleap", "D"), + ("noleap", "proleptic_gregorian", "4h"), + ("noleap", "all_leap", "ME"), + ("360_day", "noleap", "D"), + ("noleap", "360_day", "D"), + ], +) +def test_convert_calendar_missing(source, target, freq): + src = DataArray( + date_range( + "2004-01-01", + "2004-12-31" if source != "360_day" else "2004-12-30", + freq=freq, + calendar=source, + ), + dims=("time",), + name="time", + ) + da_src = DataArray( + np.linspace(0, 1, src.size), dims=("time",), coords={"time": src} + ) + out = convert_calendar(da_src, target, missing=np.nan, align_on="date") + + expected_freq = freq + assert infer_freq(out.time) == expected_freq + + expected = date_range( + "2004-01-01", + "2004-12-31" if target != "360_day" else "2004-12-30", + freq=freq, + calendar=target, + ) + np.testing.assert_array_equal(out.time, expected) + + if freq != "ME": + out_without_missing = convert_calendar(da_src, target, align_on="date") + expected_nan = out.isel(time=~out.time.isin(out_without_missing.time)) + assert expected_nan.isnull().all() + + expected_not_nan = out.sel(time=out_without_missing.time) + assert_identical(expected_not_nan, out_without_missing) + + +@requires_cftime +def test_convert_calendar_errors(): + src_nl = DataArray( + date_range("0000-01-01", "0000-12-31", freq="D", calendar="noleap"), + dims=("time",), + name="time", + ) + # no align_on for conversion to 360_day + with pytest.raises(ValueError, match="Argument `align_on` must be specified"): + convert_calendar(src_nl, "360_day") + + # Standard doesn't support year 0 + with pytest.raises( + ValueError, match="Source time coordinate contains dates with year 0" + ): + convert_calendar(src_nl, "standard") + + # no align_on for conversion from 360 day + src_360 = convert_calendar(src_nl, "360_day", align_on="year") + with pytest.raises(ValueError, match="Argument `align_on` must be specified"): + convert_calendar(src_360, "noleap") + + # Datetime objects + da = DataArray([0, 1, 2], dims=("x",), name="x") + with pytest.raises(ValueError, match="Coordinate x must contain datetime objects."): + convert_calendar(da, "standard", dim="x") + + +def test_convert_calendar_same_calendar(): + src = DataArray( + date_range("2000-01-01", periods=12, freq="6h", use_cftime=False), + dims=("time",), + name="time", + ) + out = convert_calendar(src, "proleptic_gregorian") + assert src is out + + +@pytest.mark.parametrize( + "source,target", + [ + ("standard", "noleap"), + ("noleap", "proleptic_gregorian"), + ("standard", "360_day"), + ("360_day", "proleptic_gregorian"), + ("noleap", "all_leap"), + ("360_day", "noleap"), + ], +) +def test_interp_calendar(source, target): + src = DataArray( + date_range("2004-01-01", "2004-07-30", freq="D", calendar=source), + dims=("time",), + name="time", + ) + tgt = DataArray( + date_range("2004-01-01", "2004-07-30", freq="D", calendar=target), + dims=("time",), + name="time", + ) + da_src = DataArray( + np.linspace(0, 1, src.size), dims=("time",), coords={"time": src} + ) + conv = interp_calendar(da_src, tgt) + + assert_identical(tgt.time, conv.time) + + np.testing.assert_almost_equal(conv.max(), 1, 2) + assert conv.min() == 0 + + +@requires_cftime +def test_interp_calendar_errors(): + src_nl = DataArray( + [1] * 100, + dims=("time",), + coords={ + "time": date_range("0000-01-01", periods=100, freq="MS", calendar="noleap") + }, + ) + tgt_360 = date_range("0001-01-01", "0001-12-30", freq="MS", calendar="standard") + + with pytest.raises( + ValueError, match="Source time coordinate contains dates with year 0" + ): + interp_calendar(src_nl, tgt_360) + + da1 = DataArray([0, 1, 2], dims=("x",), name="x") + da2 = da1 + 1 + + with pytest.raises( + ValueError, match="Both 'source.x' and 'target' must contain datetime objects." + ): + interp_calendar(da1, da2, dim="x") diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_cftime_offsets.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_cftime_offsets.py new file mode 100644 index 0000000..eabb7d2 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_cftime_offsets.py @@ -0,0 +1,1789 @@ +from __future__ import annotations + +from itertools import product +from typing import Callable, Literal + +import numpy as np +import pandas as pd +import pytest + +from xarray import CFTimeIndex +from xarray.coding.cftime_offsets import ( + _MONTH_ABBREVIATIONS, + BaseCFTimeOffset, + Day, + Hour, + Microsecond, + Millisecond, + Minute, + MonthBegin, + MonthEnd, + QuarterBegin, + QuarterEnd, + Second, + Tick, + YearBegin, + YearEnd, + _days_in_month, + _legacy_to_new_freq, + _new_to_legacy_freq, + cftime_range, + date_range, + date_range_like, + get_date_type, + to_cftime_datetime, + to_offset, +) +from xarray.coding.frequencies import infer_freq +from xarray.core.dataarray import DataArray +from xarray.tests import ( + _CFTIME_CALENDARS, + assert_no_warnings, + has_cftime, + has_pandas_ge_2_2, + requires_cftime, + requires_pandas_3, +) + +cftime = pytest.importorskip("cftime") + + +def _id_func(param): + """Called on each parameter passed to pytest.mark.parametrize""" + return str(param) + + +@pytest.fixture(params=_CFTIME_CALENDARS) +def calendar(request): + return request.param + + +@pytest.mark.parametrize( + ("offset", "expected_n"), + [ + (BaseCFTimeOffset(), 1), + (YearBegin(), 1), + (YearEnd(), 1), + (QuarterBegin(), 1), + (QuarterEnd(), 1), + (Tick(), 1), + (Day(), 1), + (Hour(), 1), + (Minute(), 1), + (Second(), 1), + (Millisecond(), 1), + (Microsecond(), 1), + (BaseCFTimeOffset(n=2), 2), + (YearBegin(n=2), 2), + (YearEnd(n=2), 2), + (QuarterBegin(n=2), 2), + (QuarterEnd(n=2), 2), + (Tick(n=2), 2), + (Day(n=2), 2), + (Hour(n=2), 2), + (Minute(n=2), 2), + (Second(n=2), 2), + (Millisecond(n=2), 2), + (Microsecond(n=2), 2), + ], + ids=_id_func, +) +def test_cftime_offset_constructor_valid_n(offset, expected_n): + assert offset.n == expected_n + + +@pytest.mark.parametrize( + ("offset", "invalid_n"), + [ + (BaseCFTimeOffset, 1.5), + (YearBegin, 1.5), + (YearEnd, 1.5), + (QuarterBegin, 1.5), + (QuarterEnd, 1.5), + (MonthBegin, 1.5), + (MonthEnd, 1.5), + (Tick, 1.5), + (Day, 1.5), + (Hour, 1.5), + (Minute, 1.5), + (Second, 1.5), + (Millisecond, 1.5), + (Microsecond, 1.5), + ], + ids=_id_func, +) +def test_cftime_offset_constructor_invalid_n(offset, invalid_n): + with pytest.raises(TypeError): + offset(n=invalid_n) + + +@pytest.mark.parametrize( + ("offset", "expected_month"), + [ + (YearBegin(), 1), + (YearEnd(), 12), + (YearBegin(month=5), 5), + (YearEnd(month=5), 5), + (QuarterBegin(), 3), + (QuarterEnd(), 3), + (QuarterBegin(month=5), 5), + (QuarterEnd(month=5), 5), + ], + ids=_id_func, +) +def test_year_offset_constructor_valid_month(offset, expected_month): + assert offset.month == expected_month + + +@pytest.mark.parametrize( + ("offset", "invalid_month", "exception"), + [ + (YearBegin, 0, ValueError), + (YearEnd, 0, ValueError), + (YearBegin, 13, ValueError), + (YearEnd, 13, ValueError), + (YearBegin, 1.5, TypeError), + (YearEnd, 1.5, TypeError), + (QuarterBegin, 0, ValueError), + (QuarterEnd, 0, ValueError), + (QuarterBegin, 1.5, TypeError), + (QuarterEnd, 1.5, TypeError), + (QuarterBegin, 13, ValueError), + (QuarterEnd, 13, ValueError), + ], + ids=_id_func, +) +def test_year_offset_constructor_invalid_month(offset, invalid_month, exception): + with pytest.raises(exception): + offset(month=invalid_month) + + +@pytest.mark.parametrize( + ("offset", "expected"), + [ + (BaseCFTimeOffset(), None), + (MonthBegin(), "MS"), + (MonthEnd(), "ME"), + (YearBegin(), "YS-JAN"), + (YearEnd(), "YE-DEC"), + (QuarterBegin(), "QS-MAR"), + (QuarterEnd(), "QE-MAR"), + (Day(), "D"), + (Hour(), "h"), + (Minute(), "min"), + (Second(), "s"), + (Millisecond(), "ms"), + (Microsecond(), "us"), + ], + ids=_id_func, +) +def test_rule_code(offset, expected): + assert offset.rule_code() == expected + + +@pytest.mark.parametrize( + ("offset", "expected"), + [ + (BaseCFTimeOffset(), ""), + (YearBegin(), ""), + (QuarterBegin(), ""), + ], + ids=_id_func, +) +def test_str_and_repr(offset, expected): + assert str(offset) == expected + assert repr(offset) == expected + + +@pytest.mark.parametrize( + "offset", + [BaseCFTimeOffset(), MonthBegin(), QuarterBegin(), YearBegin()], + ids=_id_func, +) +def test_to_offset_offset_input(offset): + assert to_offset(offset) == offset + + +@pytest.mark.parametrize( + ("freq", "expected"), + [ + ("M", MonthEnd()), + ("2M", MonthEnd(n=2)), + ("ME", MonthEnd()), + ("2ME", MonthEnd(n=2)), + ("MS", MonthBegin()), + ("2MS", MonthBegin(n=2)), + ("D", Day()), + ("2D", Day(n=2)), + ("H", Hour()), + ("2H", Hour(n=2)), + ("h", Hour()), + ("2h", Hour(n=2)), + ("T", Minute()), + ("2T", Minute(n=2)), + ("min", Minute()), + ("2min", Minute(n=2)), + ("S", Second()), + ("2S", Second(n=2)), + ("L", Millisecond(n=1)), + ("2L", Millisecond(n=2)), + ("ms", Millisecond(n=1)), + ("2ms", Millisecond(n=2)), + ("U", Microsecond(n=1)), + ("2U", Microsecond(n=2)), + ("us", Microsecond(n=1)), + ("2us", Microsecond(n=2)), + # negative + ("-2M", MonthEnd(n=-2)), + ("-2ME", MonthEnd(n=-2)), + ("-2MS", MonthBegin(n=-2)), + ("-2D", Day(n=-2)), + ("-2H", Hour(n=-2)), + ("-2h", Hour(n=-2)), + ("-2T", Minute(n=-2)), + ("-2min", Minute(n=-2)), + ("-2S", Second(n=-2)), + ("-2L", Millisecond(n=-2)), + ("-2ms", Millisecond(n=-2)), + ("-2U", Microsecond(n=-2)), + ("-2us", Microsecond(n=-2)), + ], + ids=_id_func, +) +@pytest.mark.filterwarnings("ignore::FutureWarning") # Deprecation of "M" etc. +def test_to_offset_sub_annual(freq, expected): + assert to_offset(freq) == expected + + +_ANNUAL_OFFSET_TYPES = { + "A": YearEnd, + "AS": YearBegin, + "Y": YearEnd, + "YS": YearBegin, + "YE": YearEnd, +} + + +@pytest.mark.parametrize( + ("month_int", "month_label"), list(_MONTH_ABBREVIATIONS.items()) + [(0, "")] +) +@pytest.mark.parametrize("multiple", [None, 2, -1]) +@pytest.mark.parametrize("offset_str", ["AS", "A", "YS", "Y"]) +@pytest.mark.filterwarnings("ignore::FutureWarning") # Deprecation of "A" etc. +def test_to_offset_annual(month_label, month_int, multiple, offset_str): + freq = offset_str + offset_type = _ANNUAL_OFFSET_TYPES[offset_str] + if month_label: + freq = "-".join([freq, month_label]) + if multiple: + freq = f"{multiple}{freq}" + result = to_offset(freq) + + if multiple and month_int: + expected = offset_type(n=multiple, month=month_int) + elif multiple: + expected = offset_type(n=multiple) + elif month_int: + expected = offset_type(month=month_int) + else: + expected = offset_type() + assert result == expected + + +_QUARTER_OFFSET_TYPES = {"Q": QuarterEnd, "QS": QuarterBegin, "QE": QuarterEnd} + + +@pytest.mark.parametrize( + ("month_int", "month_label"), list(_MONTH_ABBREVIATIONS.items()) + [(0, "")] +) +@pytest.mark.parametrize("multiple", [None, 2, -1]) +@pytest.mark.parametrize("offset_str", ["QS", "Q", "QE"]) +@pytest.mark.filterwarnings("ignore::FutureWarning") # Deprecation of "Q" etc. +def test_to_offset_quarter(month_label, month_int, multiple, offset_str): + freq = offset_str + offset_type = _QUARTER_OFFSET_TYPES[offset_str] + if month_label: + freq = "-".join([freq, month_label]) + if multiple: + freq = f"{multiple}{freq}" + result = to_offset(freq) + + if multiple and month_int: + expected = offset_type(n=multiple, month=month_int) + elif multiple: + if month_int: + expected = offset_type(n=multiple) + else: + if offset_type == QuarterBegin: + expected = offset_type(n=multiple, month=1) + elif offset_type == QuarterEnd: + expected = offset_type(n=multiple, month=12) + elif month_int: + expected = offset_type(month=month_int) + else: + if offset_type == QuarterBegin: + expected = offset_type(month=1) + elif offset_type == QuarterEnd: + expected = offset_type(month=12) + assert result == expected + + +@pytest.mark.parametrize("freq", ["Z", "7min2", "AM", "M-", "AS-", "QS-", "1H1min"]) +def test_invalid_to_offset_str(freq): + with pytest.raises(ValueError): + to_offset(freq) + + +@pytest.mark.parametrize( + ("argument", "expected_date_args"), + [("2000-01-01", (2000, 1, 1)), ((2000, 1, 1), (2000, 1, 1))], + ids=_id_func, +) +def test_to_cftime_datetime(calendar, argument, expected_date_args): + date_type = get_date_type(calendar) + expected = date_type(*expected_date_args) + if isinstance(argument, tuple): + argument = date_type(*argument) + result = to_cftime_datetime(argument, calendar=calendar) + assert result == expected + + +def test_to_cftime_datetime_error_no_calendar(): + with pytest.raises(ValueError): + to_cftime_datetime("2000") + + +def test_to_cftime_datetime_error_type_error(): + with pytest.raises(TypeError): + to_cftime_datetime(1) + + +_EQ_TESTS_A = [ + BaseCFTimeOffset(), + YearBegin(), + YearEnd(), + YearBegin(month=2), + YearEnd(month=2), + QuarterBegin(), + QuarterEnd(), + QuarterBegin(month=2), + QuarterEnd(month=2), + MonthBegin(), + MonthEnd(), + Day(), + Hour(), + Minute(), + Second(), + Millisecond(), + Microsecond(), +] +_EQ_TESTS_B = [ + BaseCFTimeOffset(n=2), + YearBegin(n=2), + YearEnd(n=2), + YearBegin(n=2, month=2), + YearEnd(n=2, month=2), + QuarterBegin(n=2), + QuarterEnd(n=2), + QuarterBegin(n=2, month=2), + QuarterEnd(n=2, month=2), + MonthBegin(n=2), + MonthEnd(n=2), + Day(n=2), + Hour(n=2), + Minute(n=2), + Second(n=2), + Millisecond(n=2), + Microsecond(n=2), +] + + +@pytest.mark.parametrize(("a", "b"), product(_EQ_TESTS_A, _EQ_TESTS_B), ids=_id_func) +def test_neq(a, b): + assert a != b + + +_EQ_TESTS_B_COPY = [ + BaseCFTimeOffset(n=2), + YearBegin(n=2), + YearEnd(n=2), + YearBegin(n=2, month=2), + YearEnd(n=2, month=2), + QuarterBegin(n=2), + QuarterEnd(n=2), + QuarterBegin(n=2, month=2), + QuarterEnd(n=2, month=2), + MonthBegin(n=2), + MonthEnd(n=2), + Day(n=2), + Hour(n=2), + Minute(n=2), + Second(n=2), + Millisecond(n=2), + Microsecond(n=2), +] + + +@pytest.mark.parametrize(("a", "b"), zip(_EQ_TESTS_B, _EQ_TESTS_B_COPY), ids=_id_func) +def test_eq(a, b): + assert a == b + + +_MUL_TESTS = [ + (BaseCFTimeOffset(), 3, BaseCFTimeOffset(n=3)), + (BaseCFTimeOffset(), -3, BaseCFTimeOffset(n=-3)), + (YearEnd(), 3, YearEnd(n=3)), + (YearBegin(), 3, YearBegin(n=3)), + (QuarterEnd(), 3, QuarterEnd(n=3)), + (QuarterBegin(), 3, QuarterBegin(n=3)), + (MonthEnd(), 3, MonthEnd(n=3)), + (MonthBegin(), 3, MonthBegin(n=3)), + (Tick(), 3, Tick(n=3)), + (Day(), 3, Day(n=3)), + (Hour(), 3, Hour(n=3)), + (Minute(), 3, Minute(n=3)), + (Second(), 3, Second(n=3)), + (Millisecond(), 3, Millisecond(n=3)), + (Microsecond(), 3, Microsecond(n=3)), + (Day(), 0.5, Hour(n=12)), + (Hour(), 0.5, Minute(n=30)), + (Hour(), -0.5, Minute(n=-30)), + (Minute(), 0.5, Second(n=30)), + (Second(), 0.5, Millisecond(n=500)), + (Millisecond(), 0.5, Microsecond(n=500)), +] + + +@pytest.mark.parametrize(("offset", "multiple", "expected"), _MUL_TESTS, ids=_id_func) +def test_mul(offset, multiple, expected): + assert offset * multiple == expected + + +@pytest.mark.parametrize(("offset", "multiple", "expected"), _MUL_TESTS, ids=_id_func) +def test_rmul(offset, multiple, expected): + assert multiple * offset == expected + + +def test_mul_float_multiple_next_higher_resolution(): + """Test more than one iteration through _next_higher_resolution is required.""" + assert 1e-6 * Second() == Microsecond() + assert 1e-6 / 60 * Minute() == Microsecond() + + +@pytest.mark.parametrize( + "offset", + [YearBegin(), YearEnd(), QuarterBegin(), QuarterEnd(), MonthBegin(), MonthEnd()], + ids=_id_func, +) +def test_nonTick_offset_multiplied_float_error(offset): + """Test that the appropriate error is raised if a non-Tick offset is + multiplied by a float.""" + with pytest.raises(TypeError, match="unsupported operand type"): + offset * 0.5 + + +def test_Microsecond_multiplied_float_error(): + """Test that the appropriate error is raised if a Tick offset is multiplied + by a float which causes it not to be representable by a + microsecond-precision timedelta.""" + with pytest.raises( + ValueError, match="Could not convert to integer offset at any resolution" + ): + Microsecond() * 0.5 + + +@pytest.mark.parametrize( + ("offset", "expected"), + [ + (BaseCFTimeOffset(), BaseCFTimeOffset(n=-1)), + (YearEnd(), YearEnd(n=-1)), + (YearBegin(), YearBegin(n=-1)), + (QuarterEnd(), QuarterEnd(n=-1)), + (QuarterBegin(), QuarterBegin(n=-1)), + (MonthEnd(), MonthEnd(n=-1)), + (MonthBegin(), MonthBegin(n=-1)), + (Day(), Day(n=-1)), + (Hour(), Hour(n=-1)), + (Minute(), Minute(n=-1)), + (Second(), Second(n=-1)), + (Millisecond(), Millisecond(n=-1)), + (Microsecond(), Microsecond(n=-1)), + ], + ids=_id_func, +) +def test_neg(offset, expected): + assert -offset == expected + + +_ADD_TESTS = [ + (Day(n=2), (1, 1, 3)), + (Hour(n=2), (1, 1, 1, 2)), + (Minute(n=2), (1, 1, 1, 0, 2)), + (Second(n=2), (1, 1, 1, 0, 0, 2)), + (Millisecond(n=2), (1, 1, 1, 0, 0, 0, 2000)), + (Microsecond(n=2), (1, 1, 1, 0, 0, 0, 2)), +] + + +@pytest.mark.parametrize(("offset", "expected_date_args"), _ADD_TESTS, ids=_id_func) +def test_add_sub_monthly(offset, expected_date_args, calendar): + date_type = get_date_type(calendar) + initial = date_type(1, 1, 1) + expected = date_type(*expected_date_args) + result = offset + initial + assert result == expected + + +@pytest.mark.parametrize(("offset", "expected_date_args"), _ADD_TESTS, ids=_id_func) +def test_radd_sub_monthly(offset, expected_date_args, calendar): + date_type = get_date_type(calendar) + initial = date_type(1, 1, 1) + expected = date_type(*expected_date_args) + result = initial + offset + assert result == expected + + +@pytest.mark.parametrize( + ("offset", "expected_date_args"), + [ + (Day(n=2), (1, 1, 1)), + (Hour(n=2), (1, 1, 2, 22)), + (Minute(n=2), (1, 1, 2, 23, 58)), + (Second(n=2), (1, 1, 2, 23, 59, 58)), + (Millisecond(n=2), (1, 1, 2, 23, 59, 59, 998000)), + (Microsecond(n=2), (1, 1, 2, 23, 59, 59, 999998)), + ], + ids=_id_func, +) +def test_rsub_sub_monthly(offset, expected_date_args, calendar): + date_type = get_date_type(calendar) + initial = date_type(1, 1, 3) + expected = date_type(*expected_date_args) + result = initial - offset + assert result == expected + + +@pytest.mark.parametrize("offset", _EQ_TESTS_A, ids=_id_func) +def test_sub_error(offset, calendar): + date_type = get_date_type(calendar) + initial = date_type(1, 1, 1) + with pytest.raises(TypeError): + offset - initial + + +@pytest.mark.parametrize(("a", "b"), zip(_EQ_TESTS_A, _EQ_TESTS_B), ids=_id_func) +def test_minus_offset(a, b): + result = b - a + expected = a + assert result == expected + + +@pytest.mark.parametrize( + ("a", "b"), + list(zip(np.roll(_EQ_TESTS_A, 1), _EQ_TESTS_B)) # type: ignore[arg-type] + + [(YearEnd(month=1), YearEnd(month=2))], + ids=_id_func, +) +def test_minus_offset_error(a, b): + with pytest.raises(TypeError): + b - a + + +def test_days_in_month_non_december(calendar): + date_type = get_date_type(calendar) + reference = date_type(1, 4, 1) + assert _days_in_month(reference) == 30 + + +def test_days_in_month_december(calendar): + if calendar == "360_day": + expected = 30 + else: + expected = 31 + date_type = get_date_type(calendar) + reference = date_type(1, 12, 5) + assert _days_in_month(reference) == expected + + +@pytest.mark.parametrize( + ("initial_date_args", "offset", "expected_date_args"), + [ + ((1, 1, 1), MonthBegin(), (1, 2, 1)), + ((1, 1, 1), MonthBegin(n=2), (1, 3, 1)), + ((1, 1, 7), MonthBegin(), (1, 2, 1)), + ((1, 1, 7), MonthBegin(n=2), (1, 3, 1)), + ((1, 3, 1), MonthBegin(n=-1), (1, 2, 1)), + ((1, 3, 1), MonthBegin(n=-2), (1, 1, 1)), + ((1, 3, 3), MonthBegin(n=-1), (1, 3, 1)), + ((1, 3, 3), MonthBegin(n=-2), (1, 2, 1)), + ((1, 2, 1), MonthBegin(n=14), (2, 4, 1)), + ((2, 4, 1), MonthBegin(n=-14), (1, 2, 1)), + ((1, 1, 1, 5, 5, 5, 5), MonthBegin(), (1, 2, 1, 5, 5, 5, 5)), + ((1, 1, 3, 5, 5, 5, 5), MonthBegin(), (1, 2, 1, 5, 5, 5, 5)), + ((1, 1, 3, 5, 5, 5, 5), MonthBegin(n=-1), (1, 1, 1, 5, 5, 5, 5)), + ], + ids=_id_func, +) +def test_add_month_begin(calendar, initial_date_args, offset, expected_date_args): + date_type = get_date_type(calendar) + initial = date_type(*initial_date_args) + result = initial + offset + expected = date_type(*expected_date_args) + assert result == expected + + +@pytest.mark.parametrize( + ("initial_date_args", "offset", "expected_year_month", "expected_sub_day"), + [ + ((1, 1, 1), MonthEnd(), (1, 1), ()), + ((1, 1, 1), MonthEnd(n=2), (1, 2), ()), + ((1, 3, 1), MonthEnd(n=-1), (1, 2), ()), + ((1, 3, 1), MonthEnd(n=-2), (1, 1), ()), + ((1, 2, 1), MonthEnd(n=14), (2, 3), ()), + ((2, 4, 1), MonthEnd(n=-14), (1, 2), ()), + ((1, 1, 1, 5, 5, 5, 5), MonthEnd(), (1, 1), (5, 5, 5, 5)), + ((1, 2, 1, 5, 5, 5, 5), MonthEnd(n=-1), (1, 1), (5, 5, 5, 5)), + ], + ids=_id_func, +) +def test_add_month_end( + calendar, initial_date_args, offset, expected_year_month, expected_sub_day +): + date_type = get_date_type(calendar) + initial = date_type(*initial_date_args) + result = initial + offset + reference_args = expected_year_month + (1,) + reference = date_type(*reference_args) + + # Here the days at the end of each month varies based on the calendar used + expected_date_args = ( + expected_year_month + (_days_in_month(reference),) + expected_sub_day + ) + expected = date_type(*expected_date_args) + assert result == expected + + +@pytest.mark.parametrize( + ( + "initial_year_month", + "initial_sub_day", + "offset", + "expected_year_month", + "expected_sub_day", + ), + [ + ((1, 1), (), MonthEnd(), (1, 2), ()), + ((1, 1), (), MonthEnd(n=2), (1, 3), ()), + ((1, 3), (), MonthEnd(n=-1), (1, 2), ()), + ((1, 3), (), MonthEnd(n=-2), (1, 1), ()), + ((1, 2), (), MonthEnd(n=14), (2, 4), ()), + ((2, 4), (), MonthEnd(n=-14), (1, 2), ()), + ((1, 1), (5, 5, 5, 5), MonthEnd(), (1, 2), (5, 5, 5, 5)), + ((1, 2), (5, 5, 5, 5), MonthEnd(n=-1), (1, 1), (5, 5, 5, 5)), + ], + ids=_id_func, +) +def test_add_month_end_onOffset( + calendar, + initial_year_month, + initial_sub_day, + offset, + expected_year_month, + expected_sub_day, +): + date_type = get_date_type(calendar) + reference_args = initial_year_month + (1,) + reference = date_type(*reference_args) + initial_date_args = ( + initial_year_month + (_days_in_month(reference),) + initial_sub_day + ) + initial = date_type(*initial_date_args) + result = initial + offset + reference_args = expected_year_month + (1,) + reference = date_type(*reference_args) + + # Here the days at the end of each month varies based on the calendar used + expected_date_args = ( + expected_year_month + (_days_in_month(reference),) + expected_sub_day + ) + expected = date_type(*expected_date_args) + assert result == expected + + +@pytest.mark.parametrize( + ("initial_date_args", "offset", "expected_date_args"), + [ + ((1, 1, 1), YearBegin(), (2, 1, 1)), + ((1, 1, 1), YearBegin(n=2), (3, 1, 1)), + ((1, 1, 1), YearBegin(month=2), (1, 2, 1)), + ((1, 1, 7), YearBegin(n=2), (3, 1, 1)), + ((2, 2, 1), YearBegin(n=-1), (2, 1, 1)), + ((1, 1, 2), YearBegin(n=-1), (1, 1, 1)), + ((1, 1, 1, 5, 5, 5, 5), YearBegin(), (2, 1, 1, 5, 5, 5, 5)), + ((2, 1, 1, 5, 5, 5, 5), YearBegin(n=-1), (1, 1, 1, 5, 5, 5, 5)), + ], + ids=_id_func, +) +def test_add_year_begin(calendar, initial_date_args, offset, expected_date_args): + date_type = get_date_type(calendar) + initial = date_type(*initial_date_args) + result = initial + offset + expected = date_type(*expected_date_args) + assert result == expected + + +@pytest.mark.parametrize( + ("initial_date_args", "offset", "expected_year_month", "expected_sub_day"), + [ + ((1, 1, 1), YearEnd(), (1, 12), ()), + ((1, 1, 1), YearEnd(n=2), (2, 12), ()), + ((1, 1, 1), YearEnd(month=1), (1, 1), ()), + ((2, 3, 1), YearEnd(n=-1), (1, 12), ()), + ((1, 3, 1), YearEnd(n=-1, month=2), (1, 2), ()), + ((1, 1, 1, 5, 5, 5, 5), YearEnd(), (1, 12), (5, 5, 5, 5)), + ((1, 1, 1, 5, 5, 5, 5), YearEnd(n=2), (2, 12), (5, 5, 5, 5)), + ], + ids=_id_func, +) +def test_add_year_end( + calendar, initial_date_args, offset, expected_year_month, expected_sub_day +): + date_type = get_date_type(calendar) + initial = date_type(*initial_date_args) + result = initial + offset + reference_args = expected_year_month + (1,) + reference = date_type(*reference_args) + + # Here the days at the end of each month varies based on the calendar used + expected_date_args = ( + expected_year_month + (_days_in_month(reference),) + expected_sub_day + ) + expected = date_type(*expected_date_args) + assert result == expected + + +@pytest.mark.parametrize( + ( + "initial_year_month", + "initial_sub_day", + "offset", + "expected_year_month", + "expected_sub_day", + ), + [ + ((1, 12), (), YearEnd(), (2, 12), ()), + ((1, 12), (), YearEnd(n=2), (3, 12), ()), + ((2, 12), (), YearEnd(n=-1), (1, 12), ()), + ((3, 12), (), YearEnd(n=-2), (1, 12), ()), + ((1, 1), (), YearEnd(month=2), (1, 2), ()), + ((1, 12), (5, 5, 5, 5), YearEnd(), (2, 12), (5, 5, 5, 5)), + ((2, 12), (5, 5, 5, 5), YearEnd(n=-1), (1, 12), (5, 5, 5, 5)), + ], + ids=_id_func, +) +def test_add_year_end_onOffset( + calendar, + initial_year_month, + initial_sub_day, + offset, + expected_year_month, + expected_sub_day, +): + date_type = get_date_type(calendar) + reference_args = initial_year_month + (1,) + reference = date_type(*reference_args) + initial_date_args = ( + initial_year_month + (_days_in_month(reference),) + initial_sub_day + ) + initial = date_type(*initial_date_args) + result = initial + offset + reference_args = expected_year_month + (1,) + reference = date_type(*reference_args) + + # Here the days at the end of each month varies based on the calendar used + expected_date_args = ( + expected_year_month + (_days_in_month(reference),) + expected_sub_day + ) + expected = date_type(*expected_date_args) + assert result == expected + + +@pytest.mark.parametrize( + ("initial_date_args", "offset", "expected_date_args"), + [ + ((1, 1, 1), QuarterBegin(), (1, 3, 1)), + ((1, 1, 1), QuarterBegin(n=2), (1, 6, 1)), + ((1, 1, 1), QuarterBegin(month=2), (1, 2, 1)), + ((1, 1, 7), QuarterBegin(n=2), (1, 6, 1)), + ((2, 2, 1), QuarterBegin(n=-1), (1, 12, 1)), + ((1, 3, 2), QuarterBegin(n=-1), (1, 3, 1)), + ((1, 1, 1, 5, 5, 5, 5), QuarterBegin(), (1, 3, 1, 5, 5, 5, 5)), + ((2, 1, 1, 5, 5, 5, 5), QuarterBegin(n=-1), (1, 12, 1, 5, 5, 5, 5)), + ], + ids=_id_func, +) +def test_add_quarter_begin(calendar, initial_date_args, offset, expected_date_args): + date_type = get_date_type(calendar) + initial = date_type(*initial_date_args) + result = initial + offset + expected = date_type(*expected_date_args) + assert result == expected + + +@pytest.mark.parametrize( + ("initial_date_args", "offset", "expected_year_month", "expected_sub_day"), + [ + ((1, 1, 1), QuarterEnd(), (1, 3), ()), + ((1, 1, 1), QuarterEnd(n=2), (1, 6), ()), + ((1, 1, 1), QuarterEnd(month=1), (1, 1), ()), + ((2, 3, 1), QuarterEnd(n=-1), (1, 12), ()), + ((1, 3, 1), QuarterEnd(n=-1, month=2), (1, 2), ()), + ((1, 1, 1, 5, 5, 5, 5), QuarterEnd(), (1, 3), (5, 5, 5, 5)), + ((1, 1, 1, 5, 5, 5, 5), QuarterEnd(n=2), (1, 6), (5, 5, 5, 5)), + ], + ids=_id_func, +) +def test_add_quarter_end( + calendar, initial_date_args, offset, expected_year_month, expected_sub_day +): + date_type = get_date_type(calendar) + initial = date_type(*initial_date_args) + result = initial + offset + reference_args = expected_year_month + (1,) + reference = date_type(*reference_args) + + # Here the days at the end of each month varies based on the calendar used + expected_date_args = ( + expected_year_month + (_days_in_month(reference),) + expected_sub_day + ) + expected = date_type(*expected_date_args) + assert result == expected + + +@pytest.mark.parametrize( + ( + "initial_year_month", + "initial_sub_day", + "offset", + "expected_year_month", + "expected_sub_day", + ), + [ + ((1, 12), (), QuarterEnd(), (2, 3), ()), + ((1, 12), (), QuarterEnd(n=2), (2, 6), ()), + ((1, 12), (), QuarterEnd(n=-1), (1, 9), ()), + ((1, 12), (), QuarterEnd(n=-2), (1, 6), ()), + ((1, 1), (), QuarterEnd(month=2), (1, 2), ()), + ((1, 12), (5, 5, 5, 5), QuarterEnd(), (2, 3), (5, 5, 5, 5)), + ((1, 12), (5, 5, 5, 5), QuarterEnd(n=-1), (1, 9), (5, 5, 5, 5)), + ], + ids=_id_func, +) +def test_add_quarter_end_onOffset( + calendar, + initial_year_month, + initial_sub_day, + offset, + expected_year_month, + expected_sub_day, +): + date_type = get_date_type(calendar) + reference_args = initial_year_month + (1,) + reference = date_type(*reference_args) + initial_date_args = ( + initial_year_month + (_days_in_month(reference),) + initial_sub_day + ) + initial = date_type(*initial_date_args) + result = initial + offset + reference_args = expected_year_month + (1,) + reference = date_type(*reference_args) + + # Here the days at the end of each month varies based on the calendar used + expected_date_args = ( + expected_year_month + (_days_in_month(reference),) + expected_sub_day + ) + expected = date_type(*expected_date_args) + assert result == expected + + +# Note for all sub-monthly offsets, pandas always returns True for onOffset +@pytest.mark.parametrize( + ("date_args", "offset", "expected"), + [ + ((1, 1, 1), MonthBegin(), True), + ((1, 1, 1, 1), MonthBegin(), True), + ((1, 1, 5), MonthBegin(), False), + ((1, 1, 5), MonthEnd(), False), + ((1, 3, 1), QuarterBegin(), True), + ((1, 3, 1, 1), QuarterBegin(), True), + ((1, 3, 5), QuarterBegin(), False), + ((1, 12, 1), QuarterEnd(), False), + ((1, 1, 1), YearBegin(), True), + ((1, 1, 1, 1), YearBegin(), True), + ((1, 1, 5), YearBegin(), False), + ((1, 12, 1), YearEnd(), False), + ((1, 1, 1), Day(), True), + ((1, 1, 1, 1), Day(), True), + ((1, 1, 1), Hour(), True), + ((1, 1, 1), Minute(), True), + ((1, 1, 1), Second(), True), + ((1, 1, 1), Millisecond(), True), + ((1, 1, 1), Microsecond(), True), + ], + ids=_id_func, +) +def test_onOffset(calendar, date_args, offset, expected): + date_type = get_date_type(calendar) + date = date_type(*date_args) + result = offset.onOffset(date) + assert result == expected + + +@pytest.mark.parametrize( + ("year_month_args", "sub_day_args", "offset"), + [ + ((1, 1), (), MonthEnd()), + ((1, 1), (1,), MonthEnd()), + ((1, 12), (), QuarterEnd()), + ((1, 1), (), QuarterEnd(month=1)), + ((1, 12), (), YearEnd()), + ((1, 1), (), YearEnd(month=1)), + ], + ids=_id_func, +) +def test_onOffset_month_or_quarter_or_year_end( + calendar, year_month_args, sub_day_args, offset +): + date_type = get_date_type(calendar) + reference_args = year_month_args + (1,) + reference = date_type(*reference_args) + date_args = year_month_args + (_days_in_month(reference),) + sub_day_args + date = date_type(*date_args) + result = offset.onOffset(date) + assert result + + +@pytest.mark.parametrize( + ("offset", "initial_date_args", "partial_expected_date_args"), + [ + (YearBegin(), (1, 3, 1), (2, 1)), + (YearBegin(), (1, 1, 1), (1, 1)), + (YearBegin(n=2), (1, 3, 1), (2, 1)), + (YearBegin(n=2, month=2), (1, 3, 1), (2, 2)), + (YearEnd(), (1, 3, 1), (1, 12)), + (YearEnd(n=2), (1, 3, 1), (1, 12)), + (YearEnd(n=2, month=2), (1, 3, 1), (2, 2)), + (YearEnd(n=2, month=4), (1, 4, 30), (1, 4)), + (QuarterBegin(), (1, 3, 2), (1, 6)), + (QuarterBegin(), (1, 4, 1), (1, 6)), + (QuarterBegin(n=2), (1, 4, 1), (1, 6)), + (QuarterBegin(n=2, month=2), (1, 4, 1), (1, 5)), + (QuarterEnd(), (1, 3, 1), (1, 3)), + (QuarterEnd(n=2), (1, 3, 1), (1, 3)), + (QuarterEnd(n=2, month=2), (1, 3, 1), (1, 5)), + (QuarterEnd(n=2, month=4), (1, 4, 30), (1, 4)), + (MonthBegin(), (1, 3, 2), (1, 4)), + (MonthBegin(), (1, 3, 1), (1, 3)), + (MonthBegin(n=2), (1, 3, 2), (1, 4)), + (MonthEnd(), (1, 3, 2), (1, 3)), + (MonthEnd(), (1, 4, 30), (1, 4)), + (MonthEnd(n=2), (1, 3, 2), (1, 3)), + (Day(), (1, 3, 2, 1), (1, 3, 2, 1)), + (Hour(), (1, 3, 2, 1, 1), (1, 3, 2, 1, 1)), + (Minute(), (1, 3, 2, 1, 1, 1), (1, 3, 2, 1, 1, 1)), + (Second(), (1, 3, 2, 1, 1, 1, 1), (1, 3, 2, 1, 1, 1, 1)), + (Millisecond(), (1, 3, 2, 1, 1, 1, 1000), (1, 3, 2, 1, 1, 1, 1000)), + (Microsecond(), (1, 3, 2, 1, 1, 1, 1), (1, 3, 2, 1, 1, 1, 1)), + ], + ids=_id_func, +) +def test_rollforward(calendar, offset, initial_date_args, partial_expected_date_args): + date_type = get_date_type(calendar) + initial = date_type(*initial_date_args) + if isinstance(offset, (MonthBegin, QuarterBegin, YearBegin)): + expected_date_args = partial_expected_date_args + (1,) + elif isinstance(offset, (MonthEnd, QuarterEnd, YearEnd)): + reference_args = partial_expected_date_args + (1,) + reference = date_type(*reference_args) + expected_date_args = partial_expected_date_args + (_days_in_month(reference),) + else: + expected_date_args = partial_expected_date_args + expected = date_type(*expected_date_args) + result = offset.rollforward(initial) + assert result == expected + + +@pytest.mark.parametrize( + ("offset", "initial_date_args", "partial_expected_date_args"), + [ + (YearBegin(), (1, 3, 1), (1, 1)), + (YearBegin(n=2), (1, 3, 1), (1, 1)), + (YearBegin(n=2, month=2), (1, 3, 1), (1, 2)), + (YearBegin(), (1, 1, 1), (1, 1)), + (YearBegin(n=2, month=2), (1, 2, 1), (1, 2)), + (YearEnd(), (2, 3, 1), (1, 12)), + (YearEnd(n=2), (2, 3, 1), (1, 12)), + (YearEnd(n=2, month=2), (2, 3, 1), (2, 2)), + (YearEnd(month=4), (1, 4, 30), (1, 4)), + (QuarterBegin(), (1, 3, 2), (1, 3)), + (QuarterBegin(), (1, 4, 1), (1, 3)), + (QuarterBegin(n=2), (1, 4, 1), (1, 3)), + (QuarterBegin(n=2, month=2), (1, 4, 1), (1, 2)), + (QuarterEnd(), (2, 3, 1), (1, 12)), + (QuarterEnd(n=2), (2, 3, 1), (1, 12)), + (QuarterEnd(n=2, month=2), (2, 3, 1), (2, 2)), + (QuarterEnd(n=2, month=4), (1, 4, 30), (1, 4)), + (MonthBegin(), (1, 3, 2), (1, 3)), + (MonthBegin(n=2), (1, 3, 2), (1, 3)), + (MonthBegin(), (1, 3, 1), (1, 3)), + (MonthEnd(), (1, 3, 2), (1, 2)), + (MonthEnd(n=2), (1, 3, 2), (1, 2)), + (MonthEnd(), (1, 4, 30), (1, 4)), + (Day(), (1, 3, 2, 1), (1, 3, 2, 1)), + (Hour(), (1, 3, 2, 1, 1), (1, 3, 2, 1, 1)), + (Minute(), (1, 3, 2, 1, 1, 1), (1, 3, 2, 1, 1, 1)), + (Second(), (1, 3, 2, 1, 1, 1, 1), (1, 3, 2, 1, 1, 1, 1)), + (Millisecond(), (1, 3, 2, 1, 1, 1, 1000), (1, 3, 2, 1, 1, 1, 1000)), + (Microsecond(), (1, 3, 2, 1, 1, 1, 1), (1, 3, 2, 1, 1, 1, 1)), + ], + ids=_id_func, +) +def test_rollback(calendar, offset, initial_date_args, partial_expected_date_args): + date_type = get_date_type(calendar) + initial = date_type(*initial_date_args) + if isinstance(offset, (MonthBegin, QuarterBegin, YearBegin)): + expected_date_args = partial_expected_date_args + (1,) + elif isinstance(offset, (MonthEnd, QuarterEnd, YearEnd)): + reference_args = partial_expected_date_args + (1,) + reference = date_type(*reference_args) + expected_date_args = partial_expected_date_args + (_days_in_month(reference),) + else: + expected_date_args = partial_expected_date_args + expected = date_type(*expected_date_args) + result = offset.rollback(initial) + assert result == expected + + +_CFTIME_RANGE_TESTS = [ + ( + "0001-01-01", + "0001-01-04", + None, + "D", + "neither", + False, + [(1, 1, 2), (1, 1, 3)], + ), + ( + "0001-01-01", + "0001-01-04", + None, + "D", + None, + False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)], + ), + ( + "0001-01-01", + "0001-01-04", + None, + "D", + "both", + False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)], + ), + ( + "0001-01-01", + "0001-01-04", + None, + "D", + "left", + False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3)], + ), + ( + "0001-01-01", + "0001-01-04", + None, + "D", + "right", + False, + [(1, 1, 2), (1, 1, 3), (1, 1, 4)], + ), + ( + "0001-01-01T01:00:00", + "0001-01-04", + None, + "D", + "both", + False, + [(1, 1, 1, 1), (1, 1, 2, 1), (1, 1, 3, 1)], + ), + ( + "0001-01-01 01:00:00", + "0001-01-04", + None, + "D", + "both", + False, + [(1, 1, 1, 1), (1, 1, 2, 1), (1, 1, 3, 1)], + ), + ( + "0001-01-01T01:00:00", + "0001-01-04", + None, + "D", + "both", + True, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)], + ), + ( + "0001-01-01", + None, + 4, + "D", + "both", + False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)], + ), + ( + None, + "0001-01-04", + 4, + "D", + "both", + False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)], + ), + ( + (1, 1, 1), + "0001-01-04", + None, + "D", + "both", + False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)], + ), + ( + (1, 1, 1), + (1, 1, 4), + None, + "D", + "both", + False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)], + ), + ( + "0001-01-30", + "0011-02-01", + None, + "3YS-JUN", + "both", + False, + [(1, 6, 1), (4, 6, 1), (7, 6, 1), (10, 6, 1)], + ), + ("0001-01-04", "0001-01-01", None, "D", "both", False, []), + ( + "0010", + None, + 4, + YearBegin(n=-2), + "both", + False, + [(10, 1, 1), (8, 1, 1), (6, 1, 1), (4, 1, 1)], + ), + ( + "0010", + None, + 4, + "-2YS", + "both", + False, + [(10, 1, 1), (8, 1, 1), (6, 1, 1), (4, 1, 1)], + ), + ( + "0001-01-01", + "0001-01-04", + 4, + None, + "both", + False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)], + ), + ( + "0001-06-01", + None, + 4, + "3QS-JUN", + "both", + False, + [(1, 6, 1), (2, 3, 1), (2, 12, 1), (3, 9, 1)], + ), + ( + "0001-06-01", + None, + 4, + "-1MS", + "both", + False, + [(1, 6, 1), (1, 5, 1), (1, 4, 1), (1, 3, 1)], + ), + ( + "0001-01-30", + None, + 4, + "-1D", + "both", + False, + [(1, 1, 30), (1, 1, 29), (1, 1, 28), (1, 1, 27)], + ), +] + + +@pytest.mark.parametrize( + ("start", "end", "periods", "freq", "inclusive", "normalize", "expected_date_args"), + _CFTIME_RANGE_TESTS, + ids=_id_func, +) +def test_cftime_range( + start, end, periods, freq, inclusive, normalize, calendar, expected_date_args +): + date_type = get_date_type(calendar) + expected_dates = [date_type(*args) for args in expected_date_args] + + if isinstance(start, tuple): + start = date_type(*start) + if isinstance(end, tuple): + end = date_type(*end) + + result = cftime_range( + start=start, + end=end, + periods=periods, + freq=freq, + inclusive=inclusive, + normalize=normalize, + calendar=calendar, + ) + resulting_dates = result.values + + assert isinstance(result, CFTimeIndex) + + if freq is not None: + np.testing.assert_equal(resulting_dates, expected_dates) + else: + # If we create a linear range of dates using cftime.num2date + # we will not get exact round number dates. This is because + # datetime arithmetic in cftime is accurate approximately to + # 1 millisecond (see https://unidata.github.io/cftime/api.html). + deltas = resulting_dates - expected_dates + deltas = np.array([delta.total_seconds() for delta in deltas]) + assert np.max(np.abs(deltas)) < 0.001 + + +def test_cftime_range_name(): + result = cftime_range(start="2000", periods=4, name="foo") + assert result.name == "foo" + + result = cftime_range(start="2000", periods=4) + assert result.name is None + + +@pytest.mark.parametrize( + ("start", "end", "periods", "freq", "inclusive"), + [ + (None, None, 5, "YE", None), + ("2000", None, None, "YE", None), + (None, "2000", None, "YE", None), + (None, None, None, None, None), + ("2000", "2001", None, "YE", "up"), + ("2000", "2001", 5, "YE", None), + ], +) +def test_invalid_cftime_range_inputs( + start: str | None, + end: str | None, + periods: int | None, + freq: str | None, + inclusive: Literal["up", None], +) -> None: + with pytest.raises(ValueError): + cftime_range(start, end, periods, freq, inclusive=inclusive) # type: ignore[arg-type] + + +def test_invalid_cftime_arg() -> None: + with pytest.warns( + FutureWarning, match="Following pandas, the `closed` parameter is deprecated" + ): + cftime_range("2000", "2001", None, "YE", closed="left") + + +_CALENDAR_SPECIFIC_MONTH_END_TESTS = [ + ("noleap", [(2, 28), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), + ("all_leap", [(2, 29), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), + ("360_day", [(2, 30), (4, 30), (6, 30), (8, 30), (10, 30), (12, 30)]), + ("standard", [(2, 29), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), + ("gregorian", [(2, 29), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), + ("julian", [(2, 29), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), +] + + +@pytest.mark.parametrize( + ("calendar", "expected_month_day"), + _CALENDAR_SPECIFIC_MONTH_END_TESTS, + ids=_id_func, +) +def test_calendar_specific_month_end( + calendar: str, expected_month_day: list[tuple[int, int]] +) -> None: + year = 2000 # Use a leap-year to highlight calendar differences + result = cftime_range( + start="2000-02", end="2001", freq="2ME", calendar=calendar + ).values + date_type = get_date_type(calendar) + expected = [date_type(year, *args) for args in expected_month_day] + np.testing.assert_equal(result, expected) + + +@pytest.mark.parametrize( + ("calendar", "expected_month_day"), + _CALENDAR_SPECIFIC_MONTH_END_TESTS, + ids=_id_func, +) +def test_calendar_specific_month_end_negative_freq( + calendar: str, expected_month_day: list[tuple[int, int]] +) -> None: + year = 2000 # Use a leap-year to highlight calendar differences + result = cftime_range( + start="2001", + end="2000", + freq="-2ME", + calendar=calendar, + ).values + date_type = get_date_type(calendar) + expected = [date_type(year, *args) for args in expected_month_day[::-1]] + np.testing.assert_equal(result, expected) + + +@pytest.mark.parametrize( + ("calendar", "start", "end", "expected_number_of_days"), + [ + ("noleap", "2000", "2001", 365), + ("all_leap", "2000", "2001", 366), + ("360_day", "2000", "2001", 360), + ("standard", "2000", "2001", 366), + ("gregorian", "2000", "2001", 366), + ("julian", "2000", "2001", 366), + ("noleap", "2001", "2002", 365), + ("all_leap", "2001", "2002", 366), + ("360_day", "2001", "2002", 360), + ("standard", "2001", "2002", 365), + ("gregorian", "2001", "2002", 365), + ("julian", "2001", "2002", 365), + ], +) +def test_calendar_year_length( + calendar: str, start: str, end: str, expected_number_of_days: int +) -> None: + result = cftime_range(start, end, freq="D", inclusive="left", calendar=calendar) + assert len(result) == expected_number_of_days + + +@pytest.mark.parametrize("freq", ["YE", "ME", "D"]) +def test_dayofweek_after_cftime_range(freq: str) -> None: + result = cftime_range("2000-02-01", periods=3, freq=freq).dayofweek + # TODO: remove once requiring pandas 2.2+ + freq = _new_to_legacy_freq(freq) + expected = pd.date_range("2000-02-01", periods=3, freq=freq).dayofweek + np.testing.assert_array_equal(result, expected) + + +@pytest.mark.parametrize("freq", ["YE", "ME", "D"]) +def test_dayofyear_after_cftime_range(freq: str) -> None: + result = cftime_range("2000-02-01", periods=3, freq=freq).dayofyear + # TODO: remove once requiring pandas 2.2+ + freq = _new_to_legacy_freq(freq) + expected = pd.date_range("2000-02-01", periods=3, freq=freq).dayofyear + np.testing.assert_array_equal(result, expected) + + +def test_cftime_range_standard_calendar_refers_to_gregorian() -> None: + from cftime import DatetimeGregorian + + (result,) = cftime_range("2000", periods=1) + assert isinstance(result, DatetimeGregorian) + + +@pytest.mark.parametrize( + "start,calendar,use_cftime,expected_type", + [ + ("1990-01-01", "standard", None, pd.DatetimeIndex), + ("1990-01-01", "proleptic_gregorian", True, CFTimeIndex), + ("1990-01-01", "noleap", None, CFTimeIndex), + ("1990-01-01", "gregorian", False, pd.DatetimeIndex), + ("1400-01-01", "standard", None, CFTimeIndex), + ("3400-01-01", "standard", None, CFTimeIndex), + ], +) +def test_date_range( + start: str, calendar: str, use_cftime: bool | None, expected_type +) -> None: + dr = date_range( + start, periods=14, freq="D", calendar=calendar, use_cftime=use_cftime + ) + + assert isinstance(dr, expected_type) + + +def test_date_range_errors() -> None: + with pytest.raises(ValueError, match="Date range is invalid"): + date_range( + "1400-01-01", periods=1, freq="D", calendar="standard", use_cftime=False + ) + + with pytest.raises(ValueError, match="Date range is invalid"): + date_range( + "2480-01-01", + periods=1, + freq="D", + calendar="proleptic_gregorian", + use_cftime=False, + ) + + with pytest.raises(ValueError, match="Invalid calendar "): + date_range( + "1900-01-01", periods=1, freq="D", calendar="noleap", use_cftime=False + ) + + +@requires_cftime +@pytest.mark.parametrize( + "start,freq,cal_src,cal_tgt,use_cftime,exp0,exp_pd", + [ + ("2020-02-01", "4ME", "standard", "noleap", None, "2020-02-28", False), + ("2020-02-01", "ME", "noleap", "gregorian", True, "2020-02-29", True), + ("2020-02-01", "QE-DEC", "noleap", "gregorian", True, "2020-03-31", True), + ("2020-02-01", "YS-FEB", "noleap", "gregorian", True, "2020-02-01", True), + ("2020-02-01", "YE-FEB", "noleap", "gregorian", True, "2020-02-29", True), + ("2020-02-01", "-1YE-FEB", "noleap", "gregorian", True, "2019-02-28", True), + ("2020-02-28", "3h", "all_leap", "gregorian", False, "2020-02-28", True), + ("2020-03-30", "ME", "360_day", "gregorian", False, "2020-03-31", True), + ("2020-03-31", "ME", "gregorian", "360_day", None, "2020-03-30", False), + ("2020-03-31", "-1ME", "gregorian", "360_day", None, "2020-03-30", False), + ], +) +def test_date_range_like(start, freq, cal_src, cal_tgt, use_cftime, exp0, exp_pd): + expected_freq = freq + + source = date_range(start, periods=12, freq=freq, calendar=cal_src) + + out = date_range_like(source, cal_tgt, use_cftime=use_cftime) + + assert len(out) == 12 + + assert infer_freq(out) == expected_freq + + assert out[0].isoformat().startswith(exp0) + + if exp_pd: + assert isinstance(out, pd.DatetimeIndex) + else: + assert isinstance(out, CFTimeIndex) + assert out.calendar == cal_tgt + + +@requires_cftime +@pytest.mark.parametrize( + "freq", ("YE", "YS", "YE-MAY", "MS", "ME", "QS", "h", "min", "s") +) +@pytest.mark.parametrize("use_cftime", (True, False)) +def test_date_range_like_no_deprecation(freq, use_cftime): + # ensure no internal warnings + # TODO: remove once freq string deprecation is finished + + source = date_range("2000", periods=3, freq=freq, use_cftime=False) + + with assert_no_warnings(): + date_range_like(source, "standard", use_cftime=use_cftime) + + +def test_date_range_like_same_calendar(): + src = date_range("2000-01-01", periods=12, freq="6h", use_cftime=False) + out = date_range_like(src, "standard", use_cftime=False) + assert src is out + + +@pytest.mark.filterwarnings("ignore:Converting non-nanosecond") +def test_date_range_like_errors(): + src = date_range("1899-02-03", periods=20, freq="D", use_cftime=False) + src = src[np.arange(20) != 10] # Remove 1 day so the frequency is not inferable. + + with pytest.raises( + ValueError, + match="`date_range_like` was unable to generate a range as the source frequency was not inferable.", + ): + date_range_like(src, "gregorian") + + src = DataArray( + np.array( + [["1999-01-01", "1999-01-02"], ["1999-01-03", "1999-01-04"]], + dtype=np.datetime64, + ), + dims=("x", "y"), + ) + with pytest.raises( + ValueError, + match="'source' must be a 1D array of datetime objects for inferring its range.", + ): + date_range_like(src, "noleap") + + da = DataArray([1, 2, 3, 4], dims=("time",)) + with pytest.raises( + ValueError, + match="'source' must be a 1D array of datetime objects for inferring its range.", + ): + date_range_like(da, "noleap") + + +def as_timedelta_not_implemented_error(): + tick = Tick() + with pytest.raises(NotImplementedError): + tick.as_timedelta() + + +@pytest.mark.parametrize("function", [cftime_range, date_range]) +def test_cftime_or_date_range_closed_and_inclusive_error(function: Callable) -> None: + if function == cftime_range and not has_cftime: + pytest.skip("requires cftime") + + with pytest.raises(ValueError, match="Following pandas, deprecated"): + function("2000", periods=3, closed=None, inclusive="right") + + +@pytest.mark.parametrize("function", [cftime_range, date_range]) +def test_cftime_or_date_range_invalid_inclusive_value(function: Callable) -> None: + if function == cftime_range and not has_cftime: + pytest.skip("requires cftime") + + with pytest.raises(ValueError, match="nclusive"): + function("2000", periods=3, inclusive="foo") + + +@pytest.mark.parametrize( + "function", + [ + pytest.param(cftime_range, id="cftime", marks=requires_cftime), + pytest.param(date_range, id="date"), + ], +) +@pytest.mark.parametrize( + ("closed", "inclusive"), [(None, "both"), ("left", "left"), ("right", "right")] +) +def test_cftime_or_date_range_closed( + function: Callable, + closed: Literal["left", "right", None], + inclusive: Literal["left", "right", "both"], +) -> None: + with pytest.warns(FutureWarning, match="Following pandas"): + result_closed = function("2000-01-01", "2000-01-04", freq="D", closed=closed) + result_inclusive = function( + "2000-01-01", "2000-01-04", freq="D", inclusive=inclusive + ) + np.testing.assert_equal(result_closed.values, result_inclusive.values) + + +@pytest.mark.parametrize("function", [cftime_range, date_range]) +def test_cftime_or_date_range_inclusive_None(function) -> None: + if function == cftime_range and not has_cftime: + pytest.skip("requires cftime") + + result_None = function("2000-01-01", "2000-01-04") + result_both = function("2000-01-01", "2000-01-04", inclusive="both") + np.testing.assert_equal(result_None.values, result_both.values) + + +@pytest.mark.parametrize( + "freq", ["A", "AS", "Q", "M", "H", "T", "S", "L", "U", "Y", "A-MAY"] +) +def test_to_offset_deprecation_warning(freq): + # Test for deprecations outlined in GitHub issue #8394 + with pytest.warns(FutureWarning, match="is deprecated"): + to_offset(freq) + + +@pytest.mark.skipif(has_pandas_ge_2_2, reason="only relevant for pandas lt 2.2") +@pytest.mark.parametrize( + "freq, expected", + ( + ["Y", "YE"], + ["A", "YE"], + ["Q", "QE"], + ["M", "ME"], + ["AS", "YS"], + ["YE", "YE"], + ["QE", "QE"], + ["ME", "ME"], + ["YS", "YS"], + ), +) +@pytest.mark.parametrize("n", ("", "2")) +def test_legacy_to_new_freq(freq, expected, n): + freq = f"{n}{freq}" + result = _legacy_to_new_freq(freq) + + expected = f"{n}{expected}" + + assert result == expected + + +@pytest.mark.skipif(has_pandas_ge_2_2, reason="only relevant for pandas lt 2.2") +@pytest.mark.parametrize("year_alias", ("YE", "Y", "A")) +@pytest.mark.parametrize("n", ("", "2")) +def test_legacy_to_new_freq_anchored(year_alias, n): + for month in _MONTH_ABBREVIATIONS.values(): + freq = f"{n}{year_alias}-{month}" + result = _legacy_to_new_freq(freq) + + expected = f"{n}YE-{month}" + + assert result == expected + + +@pytest.mark.skipif(has_pandas_ge_2_2, reason="only relevant for pandas lt 2.2") +@pytest.mark.filterwarnings("ignore:'[AY]' is deprecated") +@pytest.mark.parametrize( + "freq, expected", + (["A", "A"], ["YE", "A"], ["Y", "A"], ["QE", "Q"], ["ME", "M"], ["YS", "AS"]), +) +@pytest.mark.parametrize("n", ("", "2")) +def test_new_to_legacy_freq(freq, expected, n): + freq = f"{n}{freq}" + result = _new_to_legacy_freq(freq) + + expected = f"{n}{expected}" + + assert result == expected + + +@pytest.mark.skipif(has_pandas_ge_2_2, reason="only relevant for pandas lt 2.2") +@pytest.mark.filterwarnings("ignore:'[AY]-.{3}' is deprecated") +@pytest.mark.parametrize("year_alias", ("A", "Y", "YE")) +@pytest.mark.parametrize("n", ("", "2")) +def test_new_to_legacy_freq_anchored(year_alias, n): + for month in _MONTH_ABBREVIATIONS.values(): + freq = f"{n}{year_alias}-{month}" + result = _new_to_legacy_freq(freq) + + expected = f"{n}A-{month}" + + assert result == expected + + +@pytest.mark.skipif(has_pandas_ge_2_2, reason="only for pandas lt 2.2") +@pytest.mark.parametrize( + "freq, expected", + ( + # pandas-only freq strings are passed through + ("BH", "BH"), + ("CBH", "CBH"), + ("N", "N"), + ), +) +def test_legacy_to_new_freq_pd_freq_passthrough(freq, expected): + + result = _legacy_to_new_freq(freq) + assert result == expected + + +@pytest.mark.filterwarnings("ignore:'.' is deprecated ") +@pytest.mark.skipif(has_pandas_ge_2_2, reason="only for pandas lt 2.2") +@pytest.mark.parametrize( + "freq, expected", + ( + # these are each valid in pandas lt 2.2 + ("T", "T"), + ("min", "min"), + ("S", "S"), + ("s", "s"), + ("L", "L"), + ("ms", "ms"), + ("U", "U"), + ("us", "us"), + # pandas-only freq strings are passed through + ("bh", "bh"), + ("cbh", "cbh"), + ("ns", "ns"), + ), +) +def test_new_to_legacy_freq_pd_freq_passthrough(freq, expected): + + result = _new_to_legacy_freq(freq) + assert result == expected + + +@pytest.mark.filterwarnings("ignore:Converting a CFTimeIndex with:") +@pytest.mark.parametrize("start", ("2000", "2001")) +@pytest.mark.parametrize("end", ("2000", "2001")) +@pytest.mark.parametrize( + "freq", + ( + "MS", + pytest.param("-1MS", marks=requires_pandas_3), + "YS", + pytest.param("-1YS", marks=requires_pandas_3), + "ME", + pytest.param("-1ME", marks=requires_pandas_3), + "YE", + pytest.param("-1YE", marks=requires_pandas_3), + ), +) +def test_cftime_range_same_as_pandas(start, end, freq): + result = date_range(start, end, freq=freq, calendar="standard", use_cftime=True) + result = result.to_datetimeindex() + expected = date_range(start, end, freq=freq, use_cftime=False) + + np.testing.assert_array_equal(result, expected) + + +@pytest.mark.filterwarnings("ignore:Converting a CFTimeIndex with:") +@pytest.mark.parametrize( + "start, end, periods", + [ + ("2022-01-01", "2022-01-10", 2), + ("2022-03-01", "2022-03-31", 2), + ("2022-01-01", "2022-01-10", None), + ("2022-03-01", "2022-03-31", None), + ], +) +def test_cftime_range_no_freq(start, end, periods): + """ + Test whether cftime_range produces the same result as Pandas + when freq is not provided, but start, end and periods are. + """ + # Generate date ranges using cftime_range + result = cftime_range(start=start, end=end, periods=periods) + result = result.to_datetimeindex() + expected = pd.date_range(start=start, end=end, periods=periods) + + np.testing.assert_array_equal(result, expected) + + +@pytest.mark.parametrize( + "start, end, periods", + [ + ("2022-01-01", "2022-01-10", 2), + ("2022-03-01", "2022-03-31", 2), + ("2022-01-01", "2022-01-10", None), + ("2022-03-01", "2022-03-31", None), + ], +) +def test_date_range_no_freq(start, end, periods): + """ + Test whether date_range produces the same result as Pandas + when freq is not provided, but start, end and periods are. + """ + # Generate date ranges using date_range + result = date_range(start=start, end=end, periods=periods) + expected = pd.date_range(start=start, end=end, periods=periods) + + np.testing.assert_array_equal(result, expected) diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_cftimeindex.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_cftimeindex.py new file mode 100644 index 0000000..f6eb15f --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_cftimeindex.py @@ -0,0 +1,1382 @@ +from __future__ import annotations + +import pickle +from datetime import timedelta +from textwrap import dedent + +import numpy as np +import pandas as pd +import pytest + +import xarray as xr +from xarray.coding.cftimeindex import ( + CFTimeIndex, + _parse_array_of_cftime_strings, + _parse_iso8601_with_reso, + _parsed_string_to_bounds, + assert_all_valid_date_type, + parse_iso8601_like, +) +from xarray.tests import ( + assert_array_equal, + assert_identical, + has_cftime, + requires_cftime, +) +from xarray.tests.test_coding_times import ( + _ALL_CALENDARS, + _NON_STANDARD_CALENDARS, + _all_cftime_date_types, +) + +# cftime 1.5.2 renames "gregorian" to "standard" +standard_or_gregorian = "" +if has_cftime: + standard_or_gregorian = "standard" + + +def date_dict(year=None, month=None, day=None, hour=None, minute=None, second=None): + return dict( + year=year, month=month, day=day, hour=hour, minute=minute, second=second + ) + + +ISO8601_LIKE_STRING_TESTS = { + "year": ("1999", date_dict(year="1999")), + "month": ("199901", date_dict(year="1999", month="01")), + "month-dash": ("1999-01", date_dict(year="1999", month="01")), + "day": ("19990101", date_dict(year="1999", month="01", day="01")), + "day-dash": ("1999-01-01", date_dict(year="1999", month="01", day="01")), + "hour": ("19990101T12", date_dict(year="1999", month="01", day="01", hour="12")), + "hour-dash": ( + "1999-01-01T12", + date_dict(year="1999", month="01", day="01", hour="12"), + ), + "hour-space-separator": ( + "1999-01-01 12", + date_dict(year="1999", month="01", day="01", hour="12"), + ), + "minute": ( + "19990101T1234", + date_dict(year="1999", month="01", day="01", hour="12", minute="34"), + ), + "minute-dash": ( + "1999-01-01T12:34", + date_dict(year="1999", month="01", day="01", hour="12", minute="34"), + ), + "minute-space-separator": ( + "1999-01-01 12:34", + date_dict(year="1999", month="01", day="01", hour="12", minute="34"), + ), + "second": ( + "19990101T123456", + date_dict( + year="1999", month="01", day="01", hour="12", minute="34", second="56" + ), + ), + "second-dash": ( + "1999-01-01T12:34:56", + date_dict( + year="1999", month="01", day="01", hour="12", minute="34", second="56" + ), + ), + "second-space-separator": ( + "1999-01-01 12:34:56", + date_dict( + year="1999", month="01", day="01", hour="12", minute="34", second="56" + ), + ), +} + + +@pytest.mark.parametrize( + ("string", "expected"), + list(ISO8601_LIKE_STRING_TESTS.values()), + ids=list(ISO8601_LIKE_STRING_TESTS.keys()), +) +def test_parse_iso8601_like(string, expected): + result = parse_iso8601_like(string) + assert result == expected + + with pytest.raises(ValueError): + parse_iso8601_like(string + "3") + parse_iso8601_like(string + ".3") + + +_CFTIME_CALENDARS = [ + "365_day", + "360_day", + "julian", + "all_leap", + "366_day", + "gregorian", + "proleptic_gregorian", +] + + +@pytest.fixture(params=_CFTIME_CALENDARS) +def date_type(request): + return _all_cftime_date_types()[request.param] + + +@pytest.fixture +def index(date_type): + dates = [ + date_type(1, 1, 1), + date_type(1, 2, 1), + date_type(2, 1, 1), + date_type(2, 2, 1), + ] + return CFTimeIndex(dates) + + +@pytest.fixture +def monotonic_decreasing_index(date_type): + dates = [ + date_type(2, 2, 1), + date_type(2, 1, 1), + date_type(1, 2, 1), + date_type(1, 1, 1), + ] + return CFTimeIndex(dates) + + +@pytest.fixture +def length_one_index(date_type): + dates = [date_type(1, 1, 1)] + return CFTimeIndex(dates) + + +@pytest.fixture +def da(index): + return xr.DataArray([1, 2, 3, 4], coords=[index], dims=["time"]) + + +@pytest.fixture +def series(index): + return pd.Series([1, 2, 3, 4], index=index) + + +@pytest.fixture +def df(index): + return pd.DataFrame([1, 2, 3, 4], index=index) + + +@pytest.fixture +def feb_days(date_type): + import cftime + + if date_type is cftime.DatetimeAllLeap: + return 29 + elif date_type is cftime.Datetime360Day: + return 30 + else: + return 28 + + +@pytest.fixture +def dec_days(date_type): + import cftime + + if date_type is cftime.Datetime360Day: + return 30 + else: + return 31 + + +@pytest.fixture +def index_with_name(date_type): + dates = [ + date_type(1, 1, 1), + date_type(1, 2, 1), + date_type(2, 1, 1), + date_type(2, 2, 1), + ] + return CFTimeIndex(dates, name="foo") + + +@requires_cftime +@pytest.mark.parametrize(("name", "expected_name"), [("bar", "bar"), (None, "foo")]) +def test_constructor_with_name(index_with_name, name, expected_name): + result = CFTimeIndex(index_with_name, name=name).name + assert result == expected_name + + +@requires_cftime +def test_assert_all_valid_date_type(date_type, index): + import cftime + + if date_type is cftime.DatetimeNoLeap: + mixed_date_types = np.array( + [date_type(1, 1, 1), cftime.DatetimeAllLeap(1, 2, 1)] + ) + else: + mixed_date_types = np.array( + [date_type(1, 1, 1), cftime.DatetimeNoLeap(1, 2, 1)] + ) + with pytest.raises(TypeError): + assert_all_valid_date_type(mixed_date_types) + + with pytest.raises(TypeError): + assert_all_valid_date_type(np.array([1, date_type(1, 1, 1)])) + + assert_all_valid_date_type(np.array([date_type(1, 1, 1), date_type(1, 2, 1)])) + + +@requires_cftime +@pytest.mark.parametrize( + ("field", "expected"), + [ + ("year", [1, 1, 2, 2]), + ("month", [1, 2, 1, 2]), + ("day", [1, 1, 1, 1]), + ("hour", [0, 0, 0, 0]), + ("minute", [0, 0, 0, 0]), + ("second", [0, 0, 0, 0]), + ("microsecond", [0, 0, 0, 0]), + ], +) +def test_cftimeindex_field_accessors(index, field, expected): + result = getattr(index, field) + expected = np.array(expected, dtype=np.int64) + assert_array_equal(result, expected) + assert result.dtype == expected.dtype + + +@requires_cftime +@pytest.mark.parametrize( + ("field"), + [ + "year", + "month", + "day", + "hour", + "minute", + "second", + "microsecond", + "dayofyear", + "dayofweek", + "days_in_month", + ], +) +def test_empty_cftimeindex_field_accessors(field): + index = CFTimeIndex([]) + result = getattr(index, field) + expected = np.array([], dtype=np.int64) + assert_array_equal(result, expected) + assert result.dtype == expected.dtype + + +@requires_cftime +def test_cftimeindex_dayofyear_accessor(index): + result = index.dayofyear + expected = np.array([date.dayofyr for date in index], dtype=np.int64) + assert_array_equal(result, expected) + assert result.dtype == expected.dtype + + +@requires_cftime +def test_cftimeindex_dayofweek_accessor(index): + result = index.dayofweek + expected = np.array([date.dayofwk for date in index], dtype=np.int64) + assert_array_equal(result, expected) + assert result.dtype == expected.dtype + + +@requires_cftime +def test_cftimeindex_days_in_month_accessor(index): + result = index.days_in_month + expected = np.array([date.daysinmonth for date in index], dtype=np.int64) + assert_array_equal(result, expected) + assert result.dtype == expected.dtype + + +@requires_cftime +@pytest.mark.parametrize( + ("string", "date_args", "reso"), + [ + ("1999", (1999, 1, 1), "year"), + ("199902", (1999, 2, 1), "month"), + ("19990202", (1999, 2, 2), "day"), + ("19990202T01", (1999, 2, 2, 1), "hour"), + ("19990202T0101", (1999, 2, 2, 1, 1), "minute"), + ("19990202T010156", (1999, 2, 2, 1, 1, 56), "second"), + ], +) +def test_parse_iso8601_with_reso(date_type, string, date_args, reso): + expected_date = date_type(*date_args) + expected_reso = reso + result_date, result_reso = _parse_iso8601_with_reso(date_type, string) + assert result_date == expected_date + assert result_reso == expected_reso + + +@requires_cftime +def test_parse_string_to_bounds_year(date_type, dec_days): + parsed = date_type(2, 2, 10, 6, 2, 8, 1) + expected_start = date_type(2, 1, 1) + expected_end = date_type(2, 12, dec_days, 23, 59, 59, 999999) + result_start, result_end = _parsed_string_to_bounds(date_type, "year", parsed) + assert result_start == expected_start + assert result_end == expected_end + + +@requires_cftime +def test_parse_string_to_bounds_month_feb(date_type, feb_days): + parsed = date_type(2, 2, 10, 6, 2, 8, 1) + expected_start = date_type(2, 2, 1) + expected_end = date_type(2, 2, feb_days, 23, 59, 59, 999999) + result_start, result_end = _parsed_string_to_bounds(date_type, "month", parsed) + assert result_start == expected_start + assert result_end == expected_end + + +@requires_cftime +def test_parse_string_to_bounds_month_dec(date_type, dec_days): + parsed = date_type(2, 12, 1) + expected_start = date_type(2, 12, 1) + expected_end = date_type(2, 12, dec_days, 23, 59, 59, 999999) + result_start, result_end = _parsed_string_to_bounds(date_type, "month", parsed) + assert result_start == expected_start + assert result_end == expected_end + + +@requires_cftime +@pytest.mark.parametrize( + ("reso", "ex_start_args", "ex_end_args"), + [ + ("day", (2, 2, 10), (2, 2, 10, 23, 59, 59, 999999)), + ("hour", (2, 2, 10, 6), (2, 2, 10, 6, 59, 59, 999999)), + ("minute", (2, 2, 10, 6, 2), (2, 2, 10, 6, 2, 59, 999999)), + ("second", (2, 2, 10, 6, 2, 8), (2, 2, 10, 6, 2, 8, 999999)), + ], +) +def test_parsed_string_to_bounds_sub_monthly( + date_type, reso, ex_start_args, ex_end_args +): + parsed = date_type(2, 2, 10, 6, 2, 8, 123456) + expected_start = date_type(*ex_start_args) + expected_end = date_type(*ex_end_args) + + result_start, result_end = _parsed_string_to_bounds(date_type, reso, parsed) + assert result_start == expected_start + assert result_end == expected_end + + +@requires_cftime +def test_parsed_string_to_bounds_raises(date_type): + with pytest.raises(KeyError): + _parsed_string_to_bounds(date_type, "a", date_type(1, 1, 1)) + + +@requires_cftime +def test_get_loc(date_type, index): + result = index.get_loc("0001") + assert result == slice(0, 2) + + result = index.get_loc(date_type(1, 2, 1)) + assert result == 1 + + result = index.get_loc("0001-02-01") + assert result == slice(1, 2) + + with pytest.raises(KeyError, match=r"1234"): + index.get_loc("1234") + + +@requires_cftime +def test_get_slice_bound(date_type, index): + result = index.get_slice_bound("0001", "left") + expected = 0 + assert result == expected + + result = index.get_slice_bound("0001", "right") + expected = 2 + assert result == expected + + result = index.get_slice_bound(date_type(1, 3, 1), "left") + expected = 2 + assert result == expected + + result = index.get_slice_bound(date_type(1, 3, 1), "right") + expected = 2 + assert result == expected + + +@requires_cftime +def test_get_slice_bound_decreasing_index(date_type, monotonic_decreasing_index): + result = monotonic_decreasing_index.get_slice_bound("0001", "left") + expected = 2 + assert result == expected + + result = monotonic_decreasing_index.get_slice_bound("0001", "right") + expected = 4 + assert result == expected + + result = monotonic_decreasing_index.get_slice_bound(date_type(1, 3, 1), "left") + expected = 2 + assert result == expected + + result = monotonic_decreasing_index.get_slice_bound(date_type(1, 3, 1), "right") + expected = 2 + assert result == expected + + +@requires_cftime +def test_get_slice_bound_length_one_index(date_type, length_one_index): + result = length_one_index.get_slice_bound("0001", "left") + expected = 0 + assert result == expected + + result = length_one_index.get_slice_bound("0001", "right") + expected = 1 + assert result == expected + + result = length_one_index.get_slice_bound(date_type(1, 3, 1), "left") + expected = 1 + assert result == expected + + result = length_one_index.get_slice_bound(date_type(1, 3, 1), "right") + expected = 1 + assert result == expected + + +@requires_cftime +def test_string_slice_length_one_index(length_one_index): + da = xr.DataArray([1], coords=[length_one_index], dims=["time"]) + result = da.sel(time=slice("0001", "0001")) + assert_identical(result, da) + + +@requires_cftime +def test_date_type_property(date_type, index): + assert index.date_type is date_type + + +@requires_cftime +def test_contains(date_type, index): + assert "0001-01-01" in index + assert "0001" in index + assert "0003" not in index + assert date_type(1, 1, 1) in index + assert date_type(3, 1, 1) not in index + + +@requires_cftime +def test_groupby(da): + result = da.groupby("time.month").sum("time") + expected = xr.DataArray([4, 6], coords=[[1, 2]], dims=["month"]) + assert_identical(result, expected) + + +SEL_STRING_OR_LIST_TESTS = { + "string": "0001", + "string-slice": slice("0001-01-01", "0001-12-30"), + "bool-list": [True, True, False, False], +} + + +@requires_cftime +@pytest.mark.parametrize( + "sel_arg", + list(SEL_STRING_OR_LIST_TESTS.values()), + ids=list(SEL_STRING_OR_LIST_TESTS.keys()), +) +def test_sel_string_or_list(da, index, sel_arg): + expected = xr.DataArray([1, 2], coords=[index[:2]], dims=["time"]) + result = da.sel(time=sel_arg) + assert_identical(result, expected) + + +@requires_cftime +def test_sel_date_slice_or_list(da, index, date_type): + expected = xr.DataArray([1, 2], coords=[index[:2]], dims=["time"]) + result = da.sel(time=slice(date_type(1, 1, 1), date_type(1, 12, 30))) + assert_identical(result, expected) + + result = da.sel(time=[date_type(1, 1, 1), date_type(1, 2, 1)]) + assert_identical(result, expected) + + +@requires_cftime +def test_sel_date_scalar(da, date_type, index): + expected = xr.DataArray(1).assign_coords(time=index[0]) + result = da.sel(time=date_type(1, 1, 1)) + assert_identical(result, expected) + + +@requires_cftime +def test_sel_date_distant_date(da, date_type, index): + expected = xr.DataArray(4).assign_coords(time=index[3]) + result = da.sel(time=date_type(2000, 1, 1), method="nearest") + assert_identical(result, expected) + + +@requires_cftime +@pytest.mark.parametrize( + "sel_kwargs", + [ + {"method": "nearest"}, + {"method": "nearest", "tolerance": timedelta(days=70)}, + {"method": "nearest", "tolerance": timedelta(days=1800000)}, + ], +) +def test_sel_date_scalar_nearest(da, date_type, index, sel_kwargs): + expected = xr.DataArray(2).assign_coords(time=index[1]) + result = da.sel(time=date_type(1, 4, 1), **sel_kwargs) + assert_identical(result, expected) + + expected = xr.DataArray(3).assign_coords(time=index[2]) + result = da.sel(time=date_type(1, 11, 1), **sel_kwargs) + assert_identical(result, expected) + + +@requires_cftime +@pytest.mark.parametrize( + "sel_kwargs", + [{"method": "pad"}, {"method": "pad", "tolerance": timedelta(days=365)}], +) +def test_sel_date_scalar_pad(da, date_type, index, sel_kwargs): + expected = xr.DataArray(2).assign_coords(time=index[1]) + result = da.sel(time=date_type(1, 4, 1), **sel_kwargs) + assert_identical(result, expected) + + expected = xr.DataArray(2).assign_coords(time=index[1]) + result = da.sel(time=date_type(1, 11, 1), **sel_kwargs) + assert_identical(result, expected) + + +@requires_cftime +@pytest.mark.parametrize( + "sel_kwargs", + [{"method": "backfill"}, {"method": "backfill", "tolerance": timedelta(days=365)}], +) +def test_sel_date_scalar_backfill(da, date_type, index, sel_kwargs): + expected = xr.DataArray(3).assign_coords(time=index[2]) + result = da.sel(time=date_type(1, 4, 1), **sel_kwargs) + assert_identical(result, expected) + + expected = xr.DataArray(3).assign_coords(time=index[2]) + result = da.sel(time=date_type(1, 11, 1), **sel_kwargs) + assert_identical(result, expected) + + +@requires_cftime +@pytest.mark.parametrize( + "sel_kwargs", + [ + {"method": "pad", "tolerance": timedelta(days=20)}, + {"method": "backfill", "tolerance": timedelta(days=20)}, + {"method": "nearest", "tolerance": timedelta(days=20)}, + ], +) +def test_sel_date_scalar_tolerance_raises(da, date_type, sel_kwargs): + with pytest.raises(KeyError): + da.sel(time=date_type(1, 5, 1), **sel_kwargs) + + +@requires_cftime +@pytest.mark.parametrize( + "sel_kwargs", + [{"method": "nearest"}, {"method": "nearest", "tolerance": timedelta(days=70)}], +) +def test_sel_date_list_nearest(da, date_type, index, sel_kwargs): + expected = xr.DataArray([2, 2], coords=[[index[1], index[1]]], dims=["time"]) + result = da.sel(time=[date_type(1, 3, 1), date_type(1, 4, 1)], **sel_kwargs) + assert_identical(result, expected) + + expected = xr.DataArray([2, 3], coords=[[index[1], index[2]]], dims=["time"]) + result = da.sel(time=[date_type(1, 3, 1), date_type(1, 12, 1)], **sel_kwargs) + assert_identical(result, expected) + + expected = xr.DataArray([3, 3], coords=[[index[2], index[2]]], dims=["time"]) + result = da.sel(time=[date_type(1, 11, 1), date_type(1, 12, 1)], **sel_kwargs) + assert_identical(result, expected) + + +@requires_cftime +@pytest.mark.parametrize( + "sel_kwargs", + [{"method": "pad"}, {"method": "pad", "tolerance": timedelta(days=365)}], +) +def test_sel_date_list_pad(da, date_type, index, sel_kwargs): + expected = xr.DataArray([2, 2], coords=[[index[1], index[1]]], dims=["time"]) + result = da.sel(time=[date_type(1, 3, 1), date_type(1, 4, 1)], **sel_kwargs) + assert_identical(result, expected) + + +@requires_cftime +@pytest.mark.parametrize( + "sel_kwargs", + [{"method": "backfill"}, {"method": "backfill", "tolerance": timedelta(days=365)}], +) +def test_sel_date_list_backfill(da, date_type, index, sel_kwargs): + expected = xr.DataArray([3, 3], coords=[[index[2], index[2]]], dims=["time"]) + result = da.sel(time=[date_type(1, 3, 1), date_type(1, 4, 1)], **sel_kwargs) + assert_identical(result, expected) + + +@requires_cftime +@pytest.mark.parametrize( + "sel_kwargs", + [ + {"method": "pad", "tolerance": timedelta(days=20)}, + {"method": "backfill", "tolerance": timedelta(days=20)}, + {"method": "nearest", "tolerance": timedelta(days=20)}, + ], +) +def test_sel_date_list_tolerance_raises(da, date_type, sel_kwargs): + with pytest.raises(KeyError): + da.sel(time=[date_type(1, 2, 1), date_type(1, 5, 1)], **sel_kwargs) + + +@requires_cftime +def test_isel(da, index): + expected = xr.DataArray(1).assign_coords(time=index[0]) + result = da.isel(time=0) + assert_identical(result, expected) + + expected = xr.DataArray([1, 2], coords=[index[:2]], dims=["time"]) + result = da.isel(time=[0, 1]) + assert_identical(result, expected) + + +@pytest.fixture +def scalar_args(date_type): + return [date_type(1, 1, 1)] + + +@pytest.fixture +def range_args(date_type): + return [ + "0001", + slice("0001-01-01", "0001-12-30"), + slice(None, "0001-12-30"), + slice(date_type(1, 1, 1), date_type(1, 12, 30)), + slice(None, date_type(1, 12, 30)), + ] + + +@requires_cftime +def test_indexing_in_series_getitem(series, index, scalar_args, range_args): + for arg in scalar_args: + assert series[arg] == 1 + + expected = pd.Series([1, 2], index=index[:2]) + for arg in range_args: + assert series[arg].equals(expected) + + +@requires_cftime +def test_indexing_in_series_loc(series, index, scalar_args, range_args): + for arg in scalar_args: + assert series.loc[arg] == 1 + + expected = pd.Series([1, 2], index=index[:2]) + for arg in range_args: + assert series.loc[arg].equals(expected) + + +@requires_cftime +def test_indexing_in_series_iloc(series, index): + expected = 1 + assert series.iloc[0] == expected + + expected = pd.Series([1, 2], index=index[:2]) + assert series.iloc[:2].equals(expected) + + +@requires_cftime +def test_series_dropna(index): + series = pd.Series([0.0, 1.0, np.nan, np.nan], index=index) + expected = series.iloc[:2] + result = series.dropna() + assert result.equals(expected) + + +@requires_cftime +def test_indexing_in_dataframe_loc(df, index, scalar_args, range_args): + expected = pd.Series([1], name=index[0]) + for arg in scalar_args: + result = df.loc[arg] + assert result.equals(expected) + + expected = pd.DataFrame([1, 2], index=index[:2]) + for arg in range_args: + result = df.loc[arg] + assert result.equals(expected) + + +@requires_cftime +def test_indexing_in_dataframe_iloc(df, index): + expected = pd.Series([1], name=index[0]) + result = df.iloc[0] + assert result.equals(expected) + assert result.equals(expected) + + expected = pd.DataFrame([1, 2], index=index[:2]) + result = df.iloc[:2] + assert result.equals(expected) + + +@requires_cftime +def test_concat_cftimeindex(date_type): + da1 = xr.DataArray( + [1.0, 2.0], coords=[[date_type(1, 1, 1), date_type(1, 2, 1)]], dims=["time"] + ) + da2 = xr.DataArray( + [3.0, 4.0], coords=[[date_type(1, 3, 1), date_type(1, 4, 1)]], dims=["time"] + ) + da = xr.concat([da1, da2], dim="time") + + assert isinstance(da.xindexes["time"].to_pandas_index(), CFTimeIndex) + + +@requires_cftime +def test_empty_cftimeindex(): + index = CFTimeIndex([]) + assert index.date_type is None + + +@requires_cftime +def test_cftimeindex_add(index): + date_type = index.date_type + expected_dates = [ + date_type(1, 1, 2), + date_type(1, 2, 2), + date_type(2, 1, 2), + date_type(2, 2, 2), + ] + expected = CFTimeIndex(expected_dates) + result = index + timedelta(days=1) + assert result.equals(expected) + assert isinstance(result, CFTimeIndex) + + +@requires_cftime +@pytest.mark.parametrize("calendar", _CFTIME_CALENDARS) +def test_cftimeindex_add_timedeltaindex(calendar) -> None: + a = xr.cftime_range("2000", periods=5, calendar=calendar) + deltas = pd.TimedeltaIndex([timedelta(days=2) for _ in range(5)]) + result = a + deltas + expected = a.shift(2, "D") + assert result.equals(expected) + assert isinstance(result, CFTimeIndex) + + +@requires_cftime +@pytest.mark.parametrize("n", [2.0, 1.5]) +@pytest.mark.parametrize( + "freq,units", + [ + ("D", "D"), + ("h", "h"), + ("min", "min"), + ("s", "s"), + ("ms", "ms"), + ], +) +@pytest.mark.parametrize("calendar", _CFTIME_CALENDARS) +def test_cftimeindex_shift_float(n, freq, units, calendar) -> None: + a = xr.cftime_range("2000", periods=3, calendar=calendar, freq="D") + result = a + pd.Timedelta(n, units) + expected = a.shift(n, freq) + assert result.equals(expected) + assert isinstance(result, CFTimeIndex) + + +@requires_cftime +def test_cftimeindex_shift_float_us() -> None: + a = xr.cftime_range("2000", periods=3, freq="D") + with pytest.raises( + ValueError, match="Could not convert to integer offset at any resolution" + ): + a.shift(2.5, "us") + + +@requires_cftime +@pytest.mark.parametrize("freq", ["YS", "YE", "QS", "QE", "MS", "ME"]) +def test_cftimeindex_shift_float_fails_for_non_tick_freqs(freq) -> None: + a = xr.cftime_range("2000", periods=3, freq="D") + with pytest.raises(TypeError, match="unsupported operand type"): + a.shift(2.5, freq) + + +@requires_cftime +def test_cftimeindex_radd(index): + date_type = index.date_type + expected_dates = [ + date_type(1, 1, 2), + date_type(1, 2, 2), + date_type(2, 1, 2), + date_type(2, 2, 2), + ] + expected = CFTimeIndex(expected_dates) + result = timedelta(days=1) + index + assert result.equals(expected) + assert isinstance(result, CFTimeIndex) + + +@requires_cftime +@pytest.mark.parametrize("calendar", _CFTIME_CALENDARS) +def test_timedeltaindex_add_cftimeindex(calendar) -> None: + a = xr.cftime_range("2000", periods=5, calendar=calendar) + deltas = pd.TimedeltaIndex([timedelta(days=2) for _ in range(5)]) + result = deltas + a + expected = a.shift(2, "D") + assert result.equals(expected) + assert isinstance(result, CFTimeIndex) + + +@requires_cftime +def test_cftimeindex_sub_timedelta(index): + date_type = index.date_type + expected_dates = [ + date_type(1, 1, 2), + date_type(1, 2, 2), + date_type(2, 1, 2), + date_type(2, 2, 2), + ] + expected = CFTimeIndex(expected_dates) + result = index + timedelta(days=2) + result = result - timedelta(days=1) + assert result.equals(expected) + assert isinstance(result, CFTimeIndex) + + +@requires_cftime +@pytest.mark.parametrize( + "other", + [np.array(4 * [timedelta(days=1)]), np.array(timedelta(days=1))], + ids=["1d-array", "scalar-array"], +) +def test_cftimeindex_sub_timedelta_array(index, other): + date_type = index.date_type + expected_dates = [ + date_type(1, 1, 2), + date_type(1, 2, 2), + date_type(2, 1, 2), + date_type(2, 2, 2), + ] + expected = CFTimeIndex(expected_dates) + result = index + timedelta(days=2) + result = result - other + assert result.equals(expected) + assert isinstance(result, CFTimeIndex) + + +@requires_cftime +@pytest.mark.parametrize("calendar", _CFTIME_CALENDARS) +def test_cftimeindex_sub_cftimeindex(calendar) -> None: + a = xr.cftime_range("2000", periods=5, calendar=calendar) + b = a.shift(2, "D") + result = b - a + expected = pd.TimedeltaIndex([timedelta(days=2) for _ in range(5)]) + assert result.equals(expected) + assert isinstance(result, pd.TimedeltaIndex) + + +@requires_cftime +@pytest.mark.parametrize("calendar", _CFTIME_CALENDARS) +def test_cftimeindex_sub_cftime_datetime(calendar): + a = xr.cftime_range("2000", periods=5, calendar=calendar) + result = a - a[0] + expected = pd.TimedeltaIndex([timedelta(days=i) for i in range(5)]) + assert result.equals(expected) + assert isinstance(result, pd.TimedeltaIndex) + + +@requires_cftime +@pytest.mark.parametrize("calendar", _CFTIME_CALENDARS) +def test_cftime_datetime_sub_cftimeindex(calendar): + a = xr.cftime_range("2000", periods=5, calendar=calendar) + result = a[0] - a + expected = pd.TimedeltaIndex([timedelta(days=-i) for i in range(5)]) + assert result.equals(expected) + assert isinstance(result, pd.TimedeltaIndex) + + +@requires_cftime +@pytest.mark.parametrize("calendar", _CFTIME_CALENDARS) +def test_distant_cftime_datetime_sub_cftimeindex(calendar): + a = xr.cftime_range("2000", periods=5, calendar=calendar) + with pytest.raises(ValueError, match="difference exceeds"): + a.date_type(1, 1, 1) - a + + +@requires_cftime +@pytest.mark.parametrize("calendar", _CFTIME_CALENDARS) +def test_cftimeindex_sub_timedeltaindex(calendar) -> None: + a = xr.cftime_range("2000", periods=5, calendar=calendar) + deltas = pd.TimedeltaIndex([timedelta(days=2) for _ in range(5)]) + result = a - deltas + expected = a.shift(-2, "D") + assert result.equals(expected) + assert isinstance(result, CFTimeIndex) + + +@requires_cftime +@pytest.mark.parametrize("calendar", _CFTIME_CALENDARS) +def test_cftimeindex_sub_index_of_cftime_datetimes(calendar): + a = xr.cftime_range("2000", periods=5, calendar=calendar) + b = pd.Index(a.values) + expected = a - a + result = a - b + assert result.equals(expected) + assert isinstance(result, pd.TimedeltaIndex) + + +@requires_cftime +@pytest.mark.parametrize("calendar", _CFTIME_CALENDARS) +def test_cftimeindex_sub_not_implemented(calendar): + a = xr.cftime_range("2000", periods=5, calendar=calendar) + with pytest.raises(TypeError, match="unsupported operand"): + a - 1 + + +@requires_cftime +def test_cftimeindex_rsub(index): + with pytest.raises(TypeError): + timedelta(days=1) - index + + +@requires_cftime +@pytest.mark.parametrize("freq", ["D", timedelta(days=1)]) +def test_cftimeindex_shift(index, freq) -> None: + date_type = index.date_type + expected_dates = [ + date_type(1, 1, 3), + date_type(1, 2, 3), + date_type(2, 1, 3), + date_type(2, 2, 3), + ] + expected = CFTimeIndex(expected_dates) + result = index.shift(2, freq) + assert result.equals(expected) + assert isinstance(result, CFTimeIndex) + + +@requires_cftime +def test_cftimeindex_shift_invalid_n() -> None: + index = xr.cftime_range("2000", periods=3) + with pytest.raises(TypeError): + index.shift("a", "D") + + +@requires_cftime +def test_cftimeindex_shift_invalid_freq() -> None: + index = xr.cftime_range("2000", periods=3) + with pytest.raises(TypeError): + index.shift(1, 1) + + +@requires_cftime +@pytest.mark.parametrize( + ("calendar", "expected"), + [ + ("noleap", "noleap"), + ("365_day", "noleap"), + ("360_day", "360_day"), + ("julian", "julian"), + ("gregorian", standard_or_gregorian), + ("standard", standard_or_gregorian), + ("proleptic_gregorian", "proleptic_gregorian"), + ], +) +def test_cftimeindex_calendar_property(calendar, expected): + index = xr.cftime_range(start="2000", periods=3, calendar=calendar) + assert index.calendar == expected + + +@requires_cftime +def test_empty_cftimeindex_calendar_property(): + index = CFTimeIndex([]) + assert index.calendar is None + + +@requires_cftime +@pytest.mark.parametrize( + "calendar", + [ + "noleap", + "365_day", + "360_day", + "julian", + "gregorian", + "standard", + "proleptic_gregorian", + ], +) +def test_cftimeindex_freq_property_none_size_lt_3(calendar): + for periods in range(3): + index = xr.cftime_range(start="2000", periods=periods, calendar=calendar) + assert index.freq is None + + +@requires_cftime +@pytest.mark.parametrize( + ("calendar", "expected"), + [ + ("noleap", "noleap"), + ("365_day", "noleap"), + ("360_day", "360_day"), + ("julian", "julian"), + ("gregorian", standard_or_gregorian), + ("standard", standard_or_gregorian), + ("proleptic_gregorian", "proleptic_gregorian"), + ], +) +def test_cftimeindex_calendar_repr(calendar, expected): + """Test that cftimeindex has calendar property in repr.""" + index = xr.cftime_range(start="2000", periods=3, calendar=calendar) + repr_str = index.__repr__() + assert f" calendar='{expected}'" in repr_str + assert "2000-01-01 00:00:00, 2000-01-02 00:00:00" in repr_str + + +@requires_cftime +@pytest.mark.parametrize("periods", [2, 40]) +def test_cftimeindex_periods_repr(periods): + """Test that cftimeindex has periods property in repr.""" + index = xr.cftime_range(start="2000", periods=periods) + repr_str = index.__repr__() + assert f" length={periods}" in repr_str + + +@requires_cftime +@pytest.mark.parametrize("calendar", ["noleap", "360_day", "standard"]) +@pytest.mark.parametrize("freq", ["D", "h"]) +def test_cftimeindex_freq_in_repr(freq, calendar): + """Test that cftimeindex has frequency property in repr.""" + index = xr.cftime_range(start="2000", periods=3, freq=freq, calendar=calendar) + repr_str = index.__repr__() + assert f", freq='{freq}'" in repr_str + + +@requires_cftime +@pytest.mark.parametrize( + "periods,expected", + [ + ( + 2, + f"""\ +CFTimeIndex([2000-01-01 00:00:00, 2000-01-02 00:00:00], + dtype='object', length=2, calendar='{standard_or_gregorian}', freq=None)""", + ), + ( + 4, + f"""\ +CFTimeIndex([2000-01-01 00:00:00, 2000-01-02 00:00:00, 2000-01-03 00:00:00, + 2000-01-04 00:00:00], + dtype='object', length=4, calendar='{standard_or_gregorian}', freq='D')""", + ), + ( + 101, + f"""\ +CFTimeIndex([2000-01-01 00:00:00, 2000-01-02 00:00:00, 2000-01-03 00:00:00, + 2000-01-04 00:00:00, 2000-01-05 00:00:00, 2000-01-06 00:00:00, + 2000-01-07 00:00:00, 2000-01-08 00:00:00, 2000-01-09 00:00:00, + 2000-01-10 00:00:00, + ... + 2000-04-01 00:00:00, 2000-04-02 00:00:00, 2000-04-03 00:00:00, + 2000-04-04 00:00:00, 2000-04-05 00:00:00, 2000-04-06 00:00:00, + 2000-04-07 00:00:00, 2000-04-08 00:00:00, 2000-04-09 00:00:00, + 2000-04-10 00:00:00], + dtype='object', length=101, calendar='{standard_or_gregorian}', freq='D')""", + ), + ], +) +def test_cftimeindex_repr_formatting(periods, expected): + """Test that cftimeindex.__repr__ is formatted similar to pd.Index.__repr__.""" + index = xr.cftime_range(start="2000", periods=periods, freq="D") + expected = dedent(expected) + assert expected == repr(index) + + +@requires_cftime +@pytest.mark.parametrize("display_width", [40, 80, 100]) +@pytest.mark.parametrize("periods", [2, 3, 4, 100, 101]) +def test_cftimeindex_repr_formatting_width(periods, display_width): + """Test that cftimeindex is sensitive to OPTIONS['display_width'].""" + index = xr.cftime_range(start="2000", periods=periods) + len_intro_str = len("CFTimeIndex(") + with xr.set_options(display_width=display_width): + repr_str = index.__repr__() + splitted = repr_str.split("\n") + for i, s in enumerate(splitted): + # check that lines not longer than OPTIONS['display_width'] + assert len(s) <= display_width, f"{len(s)} {s} {display_width}" + if i > 0: + # check for initial spaces + assert s[:len_intro_str] == " " * len_intro_str + + +@requires_cftime +@pytest.mark.parametrize("periods", [22, 50, 100]) +def test_cftimeindex_repr_101_shorter(periods): + index_101 = xr.cftime_range(start="2000", periods=101) + index_periods = xr.cftime_range(start="2000", periods=periods) + index_101_repr_str = index_101.__repr__() + index_periods_repr_str = index_periods.__repr__() + assert len(index_101_repr_str) < len(index_periods_repr_str) + + +@requires_cftime +def test_parse_array_of_cftime_strings(): + from cftime import DatetimeNoLeap + + strings = np.array([["2000-01-01", "2000-01-02"], ["2000-01-03", "2000-01-04"]]) + expected = np.array( + [ + [DatetimeNoLeap(2000, 1, 1), DatetimeNoLeap(2000, 1, 2)], + [DatetimeNoLeap(2000, 1, 3), DatetimeNoLeap(2000, 1, 4)], + ] + ) + + result = _parse_array_of_cftime_strings(strings, DatetimeNoLeap) + np.testing.assert_array_equal(result, expected) + + # Test scalar array case + strings = np.array("2000-01-01") + expected = np.array(DatetimeNoLeap(2000, 1, 1)) + result = _parse_array_of_cftime_strings(strings, DatetimeNoLeap) + np.testing.assert_array_equal(result, expected) + + +@requires_cftime +@pytest.mark.parametrize("calendar", _ALL_CALENDARS) +def test_strftime_of_cftime_array(calendar): + date_format = "%Y%m%d%H%M" + cf_values = xr.cftime_range("2000", periods=5, calendar=calendar) + dt_values = pd.date_range("2000", periods=5) + expected = pd.Index(dt_values.strftime(date_format)) + result = cf_values.strftime(date_format) + assert result.equals(expected) + + +@requires_cftime +@pytest.mark.parametrize("calendar", _ALL_CALENDARS) +@pytest.mark.parametrize("unsafe", [False, True]) +def test_to_datetimeindex(calendar, unsafe): + index = xr.cftime_range("2000", periods=5, calendar=calendar) + expected = pd.date_range("2000", periods=5) + + if calendar in _NON_STANDARD_CALENDARS and not unsafe: + with pytest.warns(RuntimeWarning, match="non-standard"): + result = index.to_datetimeindex() + else: + result = index.to_datetimeindex(unsafe=unsafe) + + assert result.equals(expected) + np.testing.assert_array_equal(result, expected) + assert isinstance(result, pd.DatetimeIndex) + + +@requires_cftime +@pytest.mark.parametrize("calendar", _ALL_CALENDARS) +def test_to_datetimeindex_out_of_range(calendar): + index = xr.cftime_range("0001", periods=5, calendar=calendar) + with pytest.raises(ValueError, match="0001"): + index.to_datetimeindex() + + +@requires_cftime +@pytest.mark.parametrize("calendar", ["all_leap", "360_day"]) +def test_to_datetimeindex_feb_29(calendar): + index = xr.cftime_range("2001-02-28", periods=2, calendar=calendar) + with pytest.raises(ValueError, match="29"): + index.to_datetimeindex() + + +@requires_cftime +def test_multiindex(): + index = xr.cftime_range("2001-01-01", periods=100, calendar="360_day") + mindex = pd.MultiIndex.from_arrays([index]) + assert mindex.get_loc("2001-01") == slice(0, 30) + + +@requires_cftime +@pytest.mark.parametrize("freq", ["3663s", "33min", "2h"]) +@pytest.mark.parametrize("method", ["floor", "ceil", "round"]) +def test_rounding_methods_against_datetimeindex(freq, method): + expected = pd.date_range("2000-01-02T01:03:51", periods=10, freq="1777s") + expected = getattr(expected, method)(freq) + result = xr.cftime_range("2000-01-02T01:03:51", periods=10, freq="1777s") + result = getattr(result, method)(freq).to_datetimeindex() + assert result.equals(expected) + + +@requires_cftime +@pytest.mark.parametrize("method", ["floor", "ceil", "round"]) +def test_rounding_methods_empty_cftimindex(method): + index = CFTimeIndex([]) + result = getattr(index, method)("2s") + + expected = CFTimeIndex([]) + + assert result.equals(expected) + assert result is not index + + +@requires_cftime +@pytest.mark.parametrize("method", ["floor", "ceil", "round"]) +def test_rounding_methods_invalid_freq(method): + index = xr.cftime_range("2000-01-02T01:03:51", periods=10, freq="1777s") + with pytest.raises(ValueError, match="fixed"): + getattr(index, method)("MS") + + +@pytest.fixture +def rounding_index(date_type): + return xr.CFTimeIndex( + [ + date_type(1, 1, 1, 1, 59, 59, 999512), + date_type(1, 1, 1, 3, 0, 1, 500001), + date_type(1, 1, 1, 7, 0, 6, 499999), + ] + ) + + +@requires_cftime +def test_ceil(rounding_index, date_type): + result = rounding_index.ceil("s") + expected = xr.CFTimeIndex( + [ + date_type(1, 1, 1, 2, 0, 0, 0), + date_type(1, 1, 1, 3, 0, 2, 0), + date_type(1, 1, 1, 7, 0, 7, 0), + ] + ) + assert result.equals(expected) + + +@requires_cftime +def test_floor(rounding_index, date_type): + result = rounding_index.floor("s") + expected = xr.CFTimeIndex( + [ + date_type(1, 1, 1, 1, 59, 59, 0), + date_type(1, 1, 1, 3, 0, 1, 0), + date_type(1, 1, 1, 7, 0, 6, 0), + ] + ) + assert result.equals(expected) + + +@requires_cftime +def test_round(rounding_index, date_type): + result = rounding_index.round("s") + expected = xr.CFTimeIndex( + [ + date_type(1, 1, 1, 2, 0, 0, 0), + date_type(1, 1, 1, 3, 0, 2, 0), + date_type(1, 1, 1, 7, 0, 6, 0), + ] + ) + assert result.equals(expected) + + +@requires_cftime +def test_asi8(date_type): + index = xr.CFTimeIndex([date_type(1970, 1, 1), date_type(1970, 1, 2)]) + result = index.asi8 + expected = 1000000 * 86400 * np.array([0, 1]) + np.testing.assert_array_equal(result, expected) + + +@requires_cftime +def test_asi8_distant_date(): + """Test that asi8 conversion is truly exact.""" + import cftime + + date_type = cftime.DatetimeProlepticGregorian + index = xr.CFTimeIndex([date_type(10731, 4, 22, 3, 25, 45, 123456)]) + result = index.asi8 + expected = np.array([1000000 * 86400 * 400 * 8000 + 12345 * 1000000 + 123456]) + np.testing.assert_array_equal(result, expected) + + +@requires_cftime +def test_asi8_empty_cftimeindex(): + index = xr.CFTimeIndex([]) + result = index.asi8 + expected = np.array([], dtype=np.int64) + np.testing.assert_array_equal(result, expected) + + +@requires_cftime +def test_infer_freq_valid_types(): + cf_indx = xr.cftime_range("2000-01-01", periods=3, freq="D") + assert xr.infer_freq(cf_indx) == "D" + assert xr.infer_freq(xr.DataArray(cf_indx)) == "D" + + pd_indx = pd.date_range("2000-01-01", periods=3, freq="D") + assert xr.infer_freq(pd_indx) == "D" + assert xr.infer_freq(xr.DataArray(pd_indx)) == "D" + + pd_td_indx = pd.timedelta_range(start="1D", periods=3, freq="D") + assert xr.infer_freq(pd_td_indx) == "D" + assert xr.infer_freq(xr.DataArray(pd_td_indx)) == "D" + + +@requires_cftime +def test_infer_freq_invalid_inputs(): + # Non-datetime DataArray + with pytest.raises(ValueError, match="must contain datetime-like objects"): + xr.infer_freq(xr.DataArray([0, 1, 2])) + + indx = xr.cftime_range("1990-02-03", periods=4, freq="MS") + # 2D DataArray + with pytest.raises(ValueError, match="must be 1D"): + xr.infer_freq(xr.DataArray([indx, indx])) + + # CFTimeIndex too short + with pytest.raises(ValueError, match="Need at least 3 dates to infer frequency"): + xr.infer_freq(indx[:2]) + + # Non-monotonic input + assert xr.infer_freq(indx[np.array([0, 2, 1, 3])]) is None + + # Non-unique input + assert xr.infer_freq(indx[np.array([0, 1, 1, 2])]) is None + + # No unique frequency (here 1st step is MS, second is 2MS) + assert xr.infer_freq(indx[np.array([0, 1, 3])]) is None + + # Same, but for QS + indx = xr.cftime_range("1990-02-03", periods=4, freq="QS") + assert xr.infer_freq(indx[np.array([0, 1, 3])]) is None + + +@requires_cftime +@pytest.mark.parametrize( + "freq", + [ + "300YS-JAN", + "YE-DEC", + "YS-JUL", + "2YS-FEB", + "QE-NOV", + "3QS-DEC", + "MS", + "4ME", + "7D", + "D", + "30h", + "5min", + "40s", + ], +) +@pytest.mark.parametrize("calendar", _CFTIME_CALENDARS) +def test_infer_freq(freq, calendar): + indx = xr.cftime_range("2000-01-01", periods=3, freq=freq, calendar=calendar) + out = xr.infer_freq(indx) + assert out == freq + + +@requires_cftime +@pytest.mark.parametrize("calendar", _CFTIME_CALENDARS) +def test_pickle_cftimeindex(calendar): + idx = xr.cftime_range("2000-01-01", periods=3, freq="D", calendar=calendar) + idx_pkl = pickle.loads(pickle.dumps(idx)) + assert (idx == idx_pkl).all() diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_cftimeindex_resample.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_cftimeindex_resample.py new file mode 100644 index 0000000..98d4377 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_cftimeindex_resample.py @@ -0,0 +1,308 @@ +from __future__ import annotations + +import datetime +from typing import TypedDict + +import numpy as np +import pandas as pd +import pytest +from packaging.version import Version + +import xarray as xr +from xarray.coding.cftime_offsets import _new_to_legacy_freq +from xarray.core.pdcompat import _convert_base_to_offset +from xarray.core.resample_cftime import CFTimeGrouper + +cftime = pytest.importorskip("cftime") + + +# Create a list of pairs of similar-length initial and resample frequencies +# that cover: +# - Resampling from shorter to longer frequencies +# - Resampling from longer to shorter frequencies +# - Resampling from one initial frequency to another. +# These are used to test the cftime version of resample against pandas +# with a standard calendar. +FREQS = [ + ("8003D", "4001D"), + ("8003D", "16006D"), + ("8003D", "21YS"), + ("6h", "3h"), + ("6h", "12h"), + ("6h", "400min"), + ("3D", "D"), + ("3D", "6D"), + ("11D", "MS"), + ("3MS", "MS"), + ("3MS", "6MS"), + ("3MS", "85D"), + ("7ME", "3ME"), + ("7ME", "14ME"), + ("7ME", "2QS-APR"), + ("43QS-AUG", "21QS-AUG"), + ("43QS-AUG", "86QS-AUG"), + ("43QS-AUG", "11YE-JUN"), + ("11QE-JUN", "5QE-JUN"), + ("11QE-JUN", "22QE-JUN"), + ("11QE-JUN", "51MS"), + ("3YS-MAR", "YS-MAR"), + ("3YS-MAR", "6YS-MAR"), + ("3YS-MAR", "14QE-FEB"), + ("7YE-MAY", "3YE-MAY"), + ("7YE-MAY", "14YE-MAY"), + ("7YE-MAY", "85ME"), +] + + +def compare_against_pandas( + da_datetimeindex, + da_cftimeindex, + freq, + closed=None, + label=None, + base=None, + offset=None, + origin=None, + loffset=None, +) -> None: + if isinstance(origin, tuple): + origin_pandas = pd.Timestamp(datetime.datetime(*origin)) + origin_cftime = cftime.DatetimeGregorian(*origin) + else: + origin_pandas = origin + origin_cftime = origin + + try: + result_datetimeindex = da_datetimeindex.resample( + time=freq, + closed=closed, + label=label, + base=base, + loffset=loffset, + offset=offset, + origin=origin_pandas, + ).mean() + except ValueError: + with pytest.raises(ValueError): + da_cftimeindex.resample( + time=freq, + closed=closed, + label=label, + base=base, + loffset=loffset, + origin=origin_cftime, + offset=offset, + ).mean() + else: + result_cftimeindex = da_cftimeindex.resample( + time=freq, + closed=closed, + label=label, + base=base, + loffset=loffset, + origin=origin_cftime, + offset=offset, + ).mean() + # TODO (benbovy - flexible indexes): update when CFTimeIndex is a xarray Index subclass + result_cftimeindex["time"] = ( + result_cftimeindex.xindexes["time"].to_pandas_index().to_datetimeindex() + ) + xr.testing.assert_identical(result_cftimeindex, result_datetimeindex) + + +def da(index) -> xr.DataArray: + return xr.DataArray( + np.arange(100.0, 100.0 + index.size), coords=[index], dims=["time"] + ) + + +@pytest.mark.filterwarnings("ignore:.*the `(base|loffset)` parameter to resample") +@pytest.mark.parametrize("freqs", FREQS, ids=lambda x: "{}->{}".format(*x)) +@pytest.mark.parametrize("closed", [None, "left", "right"]) +@pytest.mark.parametrize("label", [None, "left", "right"]) +@pytest.mark.parametrize( + ("base", "offset"), [(24, None), (31, None), (None, "5s")], ids=lambda x: f"{x}" +) +def test_resample(freqs, closed, label, base, offset) -> None: + initial_freq, resample_freq = freqs + if ( + resample_freq == "4001D" + and closed == "right" + and Version(pd.__version__) < Version("2.2") + ): + pytest.skip( + "Pandas fixed a bug in this test case in version 2.2, which we " + "ported to xarray, so this test no longer produces the same " + "result as pandas for earlier pandas versions." + ) + start = "2000-01-01T12:07:01" + loffset = "12h" + origin = "start" + + datetime_index = pd.date_range( + start=start, periods=5, freq=_new_to_legacy_freq(initial_freq) + ) + cftime_index = xr.cftime_range(start=start, periods=5, freq=initial_freq) + da_datetimeindex = da(datetime_index) + da_cftimeindex = da(cftime_index) + + with pytest.warns(FutureWarning, match="`loffset` parameter"): + compare_against_pandas( + da_datetimeindex, + da_cftimeindex, + resample_freq, + closed=closed, + label=label, + base=base, + offset=offset, + origin=origin, + loffset=loffset, + ) + + +@pytest.mark.parametrize( + ("freq", "expected"), + [ + ("s", "left"), + ("min", "left"), + ("h", "left"), + ("D", "left"), + ("ME", "right"), + ("MS", "left"), + ("QE", "right"), + ("QS", "left"), + ("YE", "right"), + ("YS", "left"), + ], +) +def test_closed_label_defaults(freq, expected) -> None: + assert CFTimeGrouper(freq=freq).closed == expected + assert CFTimeGrouper(freq=freq).label == expected + + +@pytest.mark.filterwarnings("ignore:Converting a CFTimeIndex") +@pytest.mark.filterwarnings("ignore:.*the `(base|loffset)` parameter to resample") +@pytest.mark.parametrize( + "calendar", ["gregorian", "noleap", "all_leap", "360_day", "julian"] +) +def test_calendars(calendar: str) -> None: + # Limited testing for non-standard calendars + freq, closed, label, base = "8001min", None, None, 17 + loffset = datetime.timedelta(hours=12) + xr_index = xr.cftime_range( + start="2004-01-01T12:07:01", periods=7, freq="3D", calendar=calendar + ) + pd_index = pd.date_range(start="2004-01-01T12:07:01", periods=7, freq="3D") + da_cftime = ( + da(xr_index) + .resample(time=freq, closed=closed, label=label, base=base, loffset=loffset) + .mean() + ) + da_datetime = ( + da(pd_index) + .resample(time=freq, closed=closed, label=label, base=base, loffset=loffset) + .mean() + ) + # TODO (benbovy - flexible indexes): update when CFTimeIndex is a xarray Index subclass + da_cftime["time"] = da_cftime.xindexes["time"].to_pandas_index().to_datetimeindex() + xr.testing.assert_identical(da_cftime, da_datetime) + + +class DateRangeKwargs(TypedDict): + start: str + periods: int + freq: str + + +@pytest.mark.filterwarnings("ignore:.*the `(base|loffset)` parameter to resample") +@pytest.mark.parametrize("closed", ["left", "right"]) +@pytest.mark.parametrize( + "origin", + ["start_day", "start", "end", "end_day", "epoch", (1970, 1, 1, 3, 2)], + ids=lambda x: f"{x}", +) +def test_origin(closed, origin) -> None: + initial_freq, resample_freq = ("3h", "9h") + start = "1969-12-31T12:07:01" + index_kwargs: DateRangeKwargs = dict(start=start, periods=12, freq=initial_freq) + datetime_index = pd.date_range(**index_kwargs) + cftime_index = xr.cftime_range(**index_kwargs) + da_datetimeindex = da(datetime_index) + da_cftimeindex = da(cftime_index) + + compare_against_pandas( + da_datetimeindex, + da_cftimeindex, + resample_freq, + closed=closed, + origin=origin, + ) + + +@pytest.mark.filterwarnings("ignore:.*the `(base|loffset)` parameter to resample") +def test_base_and_offset_error(): + cftime_index = xr.cftime_range("2000", periods=5) + da_cftime = da(cftime_index) + with pytest.raises(ValueError, match="base and offset cannot"): + da_cftime.resample(time="2D", base=3, offset="5s") + + +@pytest.mark.parametrize("offset", ["foo", "5MS", 10]) +def test_invalid_offset_error(offset) -> None: + cftime_index = xr.cftime_range("2000", periods=5) + da_cftime = da(cftime_index) + with pytest.raises(ValueError, match="offset must be"): + da_cftime.resample(time="2D", offset=offset) + + +def test_timedelta_offset() -> None: + timedelta = datetime.timedelta(seconds=5) + string = "5s" + + cftime_index = xr.cftime_range("2000", periods=5) + da_cftime = da(cftime_index) + + timedelta_result = da_cftime.resample(time="2D", offset=timedelta).mean() + string_result = da_cftime.resample(time="2D", offset=string).mean() + xr.testing.assert_identical(timedelta_result, string_result) + + +@pytest.mark.parametrize("loffset", ["MS", "12h", datetime.timedelta(hours=-12)]) +def test_resample_loffset_cftimeindex(loffset) -> None: + datetimeindex = pd.date_range("2000-01-01", freq="6h", periods=10) + da_datetimeindex = xr.DataArray(np.arange(10), [("time", datetimeindex)]) + + cftimeindex = xr.cftime_range("2000-01-01", freq="6h", periods=10) + da_cftimeindex = xr.DataArray(np.arange(10), [("time", cftimeindex)]) + + with pytest.warns(FutureWarning, match="`loffset` parameter"): + result = da_cftimeindex.resample(time="24h", loffset=loffset).mean() + expected = da_datetimeindex.resample(time="24h", loffset=loffset).mean() + + result["time"] = result.xindexes["time"].to_pandas_index().to_datetimeindex() + xr.testing.assert_identical(result, expected) + + +@pytest.mark.filterwarnings("ignore:.*the `(base|loffset)` parameter to resample") +def test_resample_invalid_loffset_cftimeindex() -> None: + times = xr.cftime_range("2000-01-01", freq="6h", periods=10) + da = xr.DataArray(np.arange(10), [("time", times)]) + + with pytest.raises(ValueError): + da.resample(time="24h", loffset=1) # type: ignore + + +@pytest.mark.parametrize(("base", "freq"), [(1, "10s"), (17, "3h"), (15, "5us")]) +def test__convert_base_to_offset(base, freq): + # Verify that the cftime_offset adapted version of _convert_base_to_offset + # produces the same result as the pandas version. + datetimeindex = pd.date_range("2000", periods=2) + cftimeindex = xr.cftime_range("2000", periods=2) + pandas_result = _convert_base_to_offset(base, freq, datetimeindex) + cftime_result = _convert_base_to_offset(base, freq, cftimeindex) + assert pandas_result.to_pytimedelta() == cftime_result + + +def test__convert_base_to_offset_invalid_index(): + with pytest.raises(ValueError, match="Can only resample"): + _convert_base_to_offset(1, "12h", pd.Index([0])) diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_coarsen.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_coarsen.py new file mode 100644 index 0000000..01d5393 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_coarsen.py @@ -0,0 +1,345 @@ +from __future__ import annotations + +import numpy as np +import pandas as pd +import pytest + +import xarray as xr +from xarray import DataArray, Dataset, set_options +from xarray.core import duck_array_ops +from xarray.tests import ( + assert_allclose, + assert_equal, + assert_identical, + has_dask, + raise_if_dask_computes, + requires_cftime, +) + + +def test_coarsen_absent_dims_error(ds: Dataset) -> None: + with pytest.raises( + ValueError, + match=r"Window dimensions \('foo',\) not found in Dataset dimensions", + ): + ds.coarsen(foo=2) + + +@pytest.mark.parametrize("dask", [True, False]) +@pytest.mark.parametrize(("boundary", "side"), [("trim", "left"), ("pad", "right")]) +def test_coarsen_dataset(ds, dask, boundary, side): + if dask and has_dask: + ds = ds.chunk({"x": 4}) + + actual = ds.coarsen(time=2, x=3, boundary=boundary, side=side).max() + assert_equal( + actual["z1"], ds["z1"].coarsen(x=3, boundary=boundary, side=side).max() + ) + # coordinate should be mean by default + assert_equal( + actual["time"], ds["time"].coarsen(time=2, boundary=boundary, side=side).mean() + ) + + +@pytest.mark.parametrize("dask", [True, False]) +def test_coarsen_coords(ds, dask): + if dask and has_dask: + ds = ds.chunk({"x": 4}) + + # check if coord_func works + actual = ds.coarsen(time=2, x=3, boundary="trim", coord_func={"time": "max"}).max() + assert_equal(actual["z1"], ds["z1"].coarsen(x=3, boundary="trim").max()) + assert_equal(actual["time"], ds["time"].coarsen(time=2, boundary="trim").max()) + + # raise if exact + with pytest.raises(ValueError): + ds.coarsen(x=3).mean() + # should be no error + ds.isel(x=slice(0, 3 * (len(ds["x"]) // 3))).coarsen(x=3).mean() + + # working test with pd.time + da = xr.DataArray( + np.linspace(0, 365, num=364), + dims="time", + coords={"time": pd.date_range("1999-12-15", periods=364)}, + ) + actual = da.coarsen(time=2).mean() + + +@requires_cftime +def test_coarsen_coords_cftime(): + times = xr.cftime_range("2000", periods=6) + da = xr.DataArray(range(6), [("time", times)]) + actual = da.coarsen(time=3).mean() + expected_times = xr.cftime_range("2000-01-02", freq="3D", periods=2) + np.testing.assert_array_equal(actual.time, expected_times) + + +@pytest.mark.parametrize( + "funcname, argument", + [ + ("reduce", (np.mean,)), + ("mean", ()), + ], +) +def test_coarsen_keep_attrs(funcname, argument) -> None: + global_attrs = {"units": "test", "long_name": "testing"} + da_attrs = {"da_attr": "test"} + attrs_coords = {"attrs_coords": "test"} + da_not_coarsend_attrs = {"da_not_coarsend_attr": "test"} + + data = np.linspace(10, 15, 100) + coords = np.linspace(1, 10, 100) + + ds = Dataset( + data_vars={ + "da": ("coord", data, da_attrs), + "da_not_coarsend": ("no_coord", data, da_not_coarsend_attrs), + }, + coords={"coord": ("coord", coords, attrs_coords)}, + attrs=global_attrs, + ) + + # attrs are now kept per default + func = getattr(ds.coarsen(dim={"coord": 5}), funcname) + result = func(*argument) + assert result.attrs == global_attrs + assert result.da.attrs == da_attrs + assert result.da_not_coarsend.attrs == da_not_coarsend_attrs + assert result.coord.attrs == attrs_coords + assert result.da.name == "da" + assert result.da_not_coarsend.name == "da_not_coarsend" + + # discard attrs + func = getattr(ds.coarsen(dim={"coord": 5}), funcname) + result = func(*argument, keep_attrs=False) + assert result.attrs == {} + assert result.da.attrs == {} + assert result.da_not_coarsend.attrs == {} + assert result.coord.attrs == {} + assert result.da.name == "da" + assert result.da_not_coarsend.name == "da_not_coarsend" + + # test discard attrs using global option + func = getattr(ds.coarsen(dim={"coord": 5}), funcname) + with set_options(keep_attrs=False): + result = func(*argument) + + assert result.attrs == {} + assert result.da.attrs == {} + assert result.da_not_coarsend.attrs == {} + assert result.coord.attrs == {} + assert result.da.name == "da" + assert result.da_not_coarsend.name == "da_not_coarsend" + + # keyword takes precedence over global option + func = getattr(ds.coarsen(dim={"coord": 5}), funcname) + with set_options(keep_attrs=False): + result = func(*argument, keep_attrs=True) + + assert result.attrs == global_attrs + assert result.da.attrs == da_attrs + assert result.da_not_coarsend.attrs == da_not_coarsend_attrs + assert result.coord.attrs == attrs_coords + assert result.da.name == "da" + assert result.da_not_coarsend.name == "da_not_coarsend" + + func = getattr(ds.coarsen(dim={"coord": 5}), funcname) + with set_options(keep_attrs=True): + result = func(*argument, keep_attrs=False) + + assert result.attrs == {} + assert result.da.attrs == {} + assert result.da_not_coarsend.attrs == {} + assert result.coord.attrs == {} + assert result.da.name == "da" + assert result.da_not_coarsend.name == "da_not_coarsend" + + +@pytest.mark.slow +@pytest.mark.parametrize("ds", (1, 2), indirect=True) +@pytest.mark.parametrize("window", (1, 2, 3, 4)) +@pytest.mark.parametrize("name", ("sum", "mean", "std", "var", "min", "max", "median")) +def test_coarsen_reduce(ds: Dataset, window, name) -> None: + # Use boundary="trim" to accommodate all window sizes used in tests + coarsen_obj = ds.coarsen(time=window, boundary="trim") + + # add nan prefix to numpy methods to get similar behavior as bottleneck + actual = coarsen_obj.reduce(getattr(np, f"nan{name}")) + expected = getattr(coarsen_obj, name)() + assert_allclose(actual, expected) + + # make sure the order of data_var are not changed. + assert list(ds.data_vars.keys()) == list(actual.data_vars.keys()) + + # Make sure the dimension order is restored + for key, src_var in ds.data_vars.items(): + assert src_var.dims == actual[key].dims + + +@pytest.mark.parametrize( + "funcname, argument", + [ + ("reduce", (np.mean,)), + ("mean", ()), + ], +) +def test_coarsen_da_keep_attrs(funcname, argument) -> None: + attrs_da = {"da_attr": "test"} + attrs_coords = {"attrs_coords": "test"} + + data = np.linspace(10, 15, 100) + coords = np.linspace(1, 10, 100) + + da = DataArray( + data, + dims=("coord"), + coords={"coord": ("coord", coords, attrs_coords)}, + attrs=attrs_da, + name="name", + ) + + # attrs are now kept per default + func = getattr(da.coarsen(dim={"coord": 5}), funcname) + result = func(*argument) + assert result.attrs == attrs_da + da.coord.attrs == attrs_coords + assert result.name == "name" + + # discard attrs + func = getattr(da.coarsen(dim={"coord": 5}), funcname) + result = func(*argument, keep_attrs=False) + assert result.attrs == {} + da.coord.attrs == {} + assert result.name == "name" + + # test discard attrs using global option + func = getattr(da.coarsen(dim={"coord": 5}), funcname) + with set_options(keep_attrs=False): + result = func(*argument) + assert result.attrs == {} + da.coord.attrs == {} + assert result.name == "name" + + # keyword takes precedence over global option + func = getattr(da.coarsen(dim={"coord": 5}), funcname) + with set_options(keep_attrs=False): + result = func(*argument, keep_attrs=True) + assert result.attrs == attrs_da + da.coord.attrs == {} + assert result.name == "name" + + func = getattr(da.coarsen(dim={"coord": 5}), funcname) + with set_options(keep_attrs=True): + result = func(*argument, keep_attrs=False) + assert result.attrs == {} + da.coord.attrs == {} + assert result.name == "name" + + +@pytest.mark.parametrize("da", (1, 2), indirect=True) +@pytest.mark.parametrize("window", (1, 2, 3, 4)) +@pytest.mark.parametrize("name", ("sum", "mean", "std", "max")) +def test_coarsen_da_reduce(da, window, name) -> None: + if da.isnull().sum() > 1 and window == 1: + pytest.skip("These parameters lead to all-NaN slices") + + # Use boundary="trim" to accommodate all window sizes used in tests + coarsen_obj = da.coarsen(time=window, boundary="trim") + + # add nan prefix to numpy methods to get similar # behavior as bottleneck + actual = coarsen_obj.reduce(getattr(np, f"nan{name}")) + expected = getattr(coarsen_obj, name)() + assert_allclose(actual, expected) + + +class TestCoarsenConstruct: + @pytest.mark.parametrize("dask", [True, False]) + def test_coarsen_construct(self, dask: bool) -> None: + ds = Dataset( + { + "vart": ("time", np.arange(48), {"a": "b"}), + "varx": ("x", np.arange(10), {"a": "b"}), + "vartx": (("x", "time"), np.arange(480).reshape(10, 48), {"a": "b"}), + "vary": ("y", np.arange(12)), + }, + coords={"time": np.arange(48), "y": np.arange(12)}, + attrs={"foo": "bar"}, + ) + + if dask and has_dask: + ds = ds.chunk({"x": 4, "time": 10}) + + expected = xr.Dataset(attrs={"foo": "bar"}) + expected["vart"] = ( + ("year", "month"), + duck_array_ops.reshape(ds.vart.data, (-1, 12)), + {"a": "b"}, + ) + expected["varx"] = ( + ("x", "x_reshaped"), + duck_array_ops.reshape(ds.varx.data, (-1, 5)), + {"a": "b"}, + ) + expected["vartx"] = ( + ("x", "x_reshaped", "year", "month"), + duck_array_ops.reshape(ds.vartx.data, (2, 5, 4, 12)), + {"a": "b"}, + ) + expected["vary"] = ds.vary + expected.coords["time"] = ( + ("year", "month"), + duck_array_ops.reshape(ds.time.data, (-1, 12)), + ) + + with raise_if_dask_computes(): + actual = ds.coarsen(time=12, x=5).construct( + {"time": ("year", "month"), "x": ("x", "x_reshaped")} + ) + assert_identical(actual, expected) + + with raise_if_dask_computes(): + actual = ds.coarsen(time=12, x=5).construct( + time=("year", "month"), x=("x", "x_reshaped") + ) + assert_identical(actual, expected) + + with raise_if_dask_computes(): + actual = ds.coarsen(time=12, x=5).construct( + {"time": ("year", "month"), "x": ("x", "x_reshaped")}, keep_attrs=False + ) + for var in actual: + assert actual[var].attrs == {} + assert actual.attrs == {} + + with raise_if_dask_computes(): + actual = ds.vartx.coarsen(time=12, x=5).construct( + {"time": ("year", "month"), "x": ("x", "x_reshaped")} + ) + assert_identical(actual, expected["vartx"]) + + with pytest.raises(ValueError): + ds.coarsen(time=12).construct(foo="bar") + + with pytest.raises(ValueError): + ds.coarsen(time=12, x=2).construct(time=("year", "month")) + + with pytest.raises(ValueError): + ds.coarsen(time=12).construct() + + with pytest.raises(ValueError): + ds.coarsen(time=12).construct(time="bar") + + with pytest.raises(ValueError): + ds.coarsen(time=12).construct(time=("bar",)) + + def test_coarsen_construct_keeps_all_coords(self): + da = xr.DataArray(np.arange(24), dims=["time"]) + da = da.assign_coords(day=365 * da) + + result = da.coarsen(time=12).construct(time=("year", "month")) + assert list(da.coords) == list(result.coords) + + ds = da.to_dataset(name="T") + result = ds.coarsen(time=12).construct(time=("year", "month")) + assert list(da.coords) == list(result.coords) diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_coding.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_coding.py new file mode 100644 index 0000000..6d81d6f --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_coding.py @@ -0,0 +1,149 @@ +from __future__ import annotations + +from contextlib import suppress + +import numpy as np +import pandas as pd +import pytest + +import xarray as xr +from xarray.coding import variables +from xarray.conventions import decode_cf_variable, encode_cf_variable +from xarray.tests import assert_allclose, assert_equal, assert_identical, requires_dask + +with suppress(ImportError): + import dask.array as da + + +def test_CFMaskCoder_decode() -> None: + original = xr.Variable(("x",), [0, -1, 1], {"_FillValue": -1}) + expected = xr.Variable(("x",), [0, np.nan, 1]) + coder = variables.CFMaskCoder() + encoded = coder.decode(original) + assert_identical(expected, encoded) + + +encoding_with_dtype = { + "dtype": np.dtype("float64"), + "_FillValue": np.float32(1e20), + "missing_value": np.float64(1e20), +} +encoding_without_dtype = { + "_FillValue": np.float32(1e20), + "missing_value": np.float64(1e20), +} +CFMASKCODER_ENCODE_DTYPE_CONFLICT_TESTS = { + "numeric-with-dtype": ([0.0, -1.0, 1.0], encoding_with_dtype), + "numeric-without-dtype": ([0.0, -1.0, 1.0], encoding_without_dtype), + "times-with-dtype": (pd.date_range("2000", periods=3), encoding_with_dtype), +} + + +@pytest.mark.parametrize( + ("data", "encoding"), + CFMASKCODER_ENCODE_DTYPE_CONFLICT_TESTS.values(), + ids=list(CFMASKCODER_ENCODE_DTYPE_CONFLICT_TESTS.keys()), +) +def test_CFMaskCoder_encode_missing_fill_values_conflict(data, encoding) -> None: + original = xr.Variable(("x",), data, encoding=encoding) + encoded = encode_cf_variable(original) + + assert encoded.dtype == encoded.attrs["missing_value"].dtype + assert encoded.dtype == encoded.attrs["_FillValue"].dtype + + roundtripped = decode_cf_variable("foo", encoded) + assert_identical(roundtripped, original) + + +def test_CFMaskCoder_missing_value() -> None: + expected = xr.DataArray( + np.array([[26915, 27755, -9999, 27705], [25595, -9999, 28315, -9999]]), + dims=["npts", "ntimes"], + name="tmpk", + ) + expected.attrs["missing_value"] = -9999 + + decoded = xr.decode_cf(expected.to_dataset()) + encoded, _ = xr.conventions.cf_encoder(decoded.variables, decoded.attrs) + + assert_equal(encoded["tmpk"], expected.variable) + + decoded.tmpk.encoding["_FillValue"] = -9940 + with pytest.raises(ValueError): + encoded, _ = xr.conventions.cf_encoder(decoded.variables, decoded.attrs) + + +@requires_dask +def test_CFMaskCoder_decode_dask() -> None: + original = xr.Variable(("x",), [0, -1, 1], {"_FillValue": -1}).chunk() + expected = xr.Variable(("x",), [0, np.nan, 1]) + coder = variables.CFMaskCoder() + encoded = coder.decode(original) + assert isinstance(encoded.data, da.Array) + assert_identical(expected, encoded) + + +# TODO(shoyer): port other fill-value tests + + +# TODO(shoyer): parameterize when we have more coders +def test_coder_roundtrip() -> None: + original = xr.Variable(("x",), [0.0, np.nan, 1.0]) + coder = variables.CFMaskCoder() + roundtripped = coder.decode(coder.encode(original)) + assert_identical(original, roundtripped) + + +@pytest.mark.parametrize("dtype", "u1 u2 i1 i2 f2 f4".split()) +@pytest.mark.parametrize("dtype2", "f4 f8".split()) +def test_scaling_converts_to_float(dtype: str, dtype2: str) -> None: + dt = np.dtype(dtype2) + original = xr.Variable( + ("x",), np.arange(10, dtype=dtype), encoding=dict(scale_factor=dt.type(10)) + ) + coder = variables.CFScaleOffsetCoder() + encoded = coder.encode(original) + assert encoded.dtype == dt + roundtripped = coder.decode(encoded) + assert_identical(original, roundtripped) + assert roundtripped.dtype == dt + + +@pytest.mark.parametrize("scale_factor", (10, [10])) +@pytest.mark.parametrize("add_offset", (0.1, [0.1])) +def test_scaling_offset_as_list(scale_factor, add_offset) -> None: + # test for #4631 + encoding = dict(scale_factor=scale_factor, add_offset=add_offset) + original = xr.Variable(("x",), np.arange(10.0), encoding=encoding) + coder = variables.CFScaleOffsetCoder() + encoded = coder.encode(original) + roundtripped = coder.decode(encoded) + assert_allclose(original, roundtripped) + + +@pytest.mark.parametrize("bits", [1, 2, 4, 8]) +def test_decode_unsigned_from_signed(bits) -> None: + unsigned_dtype = np.dtype(f"u{bits}") + signed_dtype = np.dtype(f"i{bits}") + original_values = np.array([np.iinfo(unsigned_dtype).max], dtype=unsigned_dtype) + encoded = xr.Variable( + ("x",), original_values.astype(signed_dtype), attrs={"_Unsigned": "true"} + ) + coder = variables.UnsignedIntegerCoder() + decoded = coder.decode(encoded) + assert decoded.dtype == unsigned_dtype + assert decoded.values == original_values + + +@pytest.mark.parametrize("bits", [1, 2, 4, 8]) +def test_decode_signed_from_unsigned(bits) -> None: + unsigned_dtype = np.dtype(f"u{bits}") + signed_dtype = np.dtype(f"i{bits}") + original_values = np.array([-1], dtype=signed_dtype) + encoded = xr.Variable( + ("x",), original_values.astype(unsigned_dtype), attrs={"_Unsigned": "false"} + ) + coder = variables.UnsignedIntegerCoder() + decoded = coder.decode(encoded) + assert decoded.dtype == signed_dtype + assert decoded.values == original_values diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_coding_strings.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_coding_strings.py new file mode 100644 index 0000000..51f63ea --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_coding_strings.py @@ -0,0 +1,247 @@ +from __future__ import annotations + +from contextlib import suppress + +import numpy as np +import pytest + +from xarray import Variable +from xarray.coding import strings +from xarray.core import indexing +from xarray.tests import ( + IndexerMaker, + assert_array_equal, + assert_identical, + requires_dask, +) + +with suppress(ImportError): + import dask.array as da + + +def test_vlen_dtype() -> None: + dtype = strings.create_vlen_dtype(str) + assert dtype.metadata["element_type"] == str + assert strings.is_unicode_dtype(dtype) + assert not strings.is_bytes_dtype(dtype) + assert strings.check_vlen_dtype(dtype) is str + + dtype = strings.create_vlen_dtype(bytes) + assert dtype.metadata["element_type"] == bytes + assert not strings.is_unicode_dtype(dtype) + assert strings.is_bytes_dtype(dtype) + assert strings.check_vlen_dtype(dtype) is bytes + + # check h5py variant ("vlen") + dtype = np.dtype("O", metadata={"vlen": str}) # type: ignore[call-overload,unused-ignore] + assert strings.check_vlen_dtype(dtype) is str + + assert strings.check_vlen_dtype(np.dtype(object)) is None + + +@pytest.mark.parametrize("numpy_str_type", (np.str_, np.bytes_)) +def test_numpy_subclass_handling(numpy_str_type) -> None: + with pytest.raises(TypeError, match="unsupported type for vlen_dtype"): + strings.create_vlen_dtype(numpy_str_type) + + +def test_EncodedStringCoder_decode() -> None: + coder = strings.EncodedStringCoder() + + raw_data = np.array([b"abc", "ß∂µ∆".encode()]) + raw = Variable(("x",), raw_data, {"_Encoding": "utf-8"}) + actual = coder.decode(raw) + + expected = Variable(("x",), np.array(["abc", "ß∂µ∆"], dtype=object)) + assert_identical(actual, expected) + + assert_identical(coder.decode(actual[0]), expected[0]) + + +@requires_dask +def test_EncodedStringCoder_decode_dask() -> None: + coder = strings.EncodedStringCoder() + + raw_data = np.array([b"abc", "ß∂µ∆".encode()]) + raw = Variable(("x",), raw_data, {"_Encoding": "utf-8"}).chunk() + actual = coder.decode(raw) + assert isinstance(actual.data, da.Array) + + expected = Variable(("x",), np.array(["abc", "ß∂µ∆"], dtype=object)) + assert_identical(actual, expected) + + actual_indexed = coder.decode(actual[0]) + assert isinstance(actual_indexed.data, da.Array) + assert_identical(actual_indexed, expected[0]) + + +def test_EncodedStringCoder_encode() -> None: + dtype = strings.create_vlen_dtype(str) + raw_data = np.array(["abc", "ß∂µ∆"], dtype=dtype) + expected_data = np.array([r.encode("utf-8") for r in raw_data], dtype=object) + + coder = strings.EncodedStringCoder(allows_unicode=True) + raw = Variable(("x",), raw_data, encoding={"dtype": "S1"}) + actual = coder.encode(raw) + expected = Variable(("x",), expected_data, attrs={"_Encoding": "utf-8"}) + assert_identical(actual, expected) + + raw = Variable(("x",), raw_data) + assert_identical(coder.encode(raw), raw) + + coder = strings.EncodedStringCoder(allows_unicode=False) + assert_identical(coder.encode(raw), expected) + + +@pytest.mark.parametrize( + "original", + [ + Variable(("x",), [b"ab", b"cdef"]), + Variable((), b"ab"), + Variable(("x",), [b"a", b"b"]), + Variable((), b"a"), + ], +) +def test_CharacterArrayCoder_roundtrip(original) -> None: + coder = strings.CharacterArrayCoder() + roundtripped = coder.decode(coder.encode(original)) + assert_identical(original, roundtripped) + + +@pytest.mark.parametrize( + "data", + [ + np.array([b"a", b"bc"]), + np.array([b"a", b"bc"], dtype=strings.create_vlen_dtype(bytes)), + ], +) +def test_CharacterArrayCoder_encode(data) -> None: + coder = strings.CharacterArrayCoder() + raw = Variable(("x",), data) + actual = coder.encode(raw) + expected = Variable(("x", "string2"), np.array([[b"a", b""], [b"b", b"c"]])) + assert_identical(actual, expected) + + +@pytest.mark.parametrize( + ["original", "expected_char_dim_name"], + [ + (Variable(("x",), [b"ab", b"cdef"]), "string4"), + (Variable(("x",), [b"ab", b"cdef"], encoding={"char_dim_name": "foo"}), "foo"), + ], +) +def test_CharacterArrayCoder_char_dim_name(original, expected_char_dim_name) -> None: + coder = strings.CharacterArrayCoder() + encoded = coder.encode(original) + roundtripped = coder.decode(encoded) + assert encoded.dims[-1] == expected_char_dim_name + assert roundtripped.encoding["char_dim_name"] == expected_char_dim_name + assert roundtripped.dims[-1] == original.dims[-1] + + +def test_StackedBytesArray() -> None: + array = np.array([[b"a", b"b", b"c"], [b"d", b"e", b"f"]], dtype="S") + actual = strings.StackedBytesArray(array) + expected = np.array([b"abc", b"def"], dtype="S") + assert actual.dtype == expected.dtype + assert actual.shape == expected.shape + assert actual.size == expected.size + assert actual.ndim == expected.ndim + assert len(actual) == len(expected) + assert_array_equal(expected, actual) + + B = IndexerMaker(indexing.BasicIndexer) + assert_array_equal(expected[:1], actual[B[:1]]) + with pytest.raises(IndexError): + actual[B[:, :2]] + + +def test_StackedBytesArray_scalar() -> None: + array = np.array([b"a", b"b", b"c"], dtype="S") + actual = strings.StackedBytesArray(array) + + expected = np.array(b"abc") + assert actual.dtype == expected.dtype + assert actual.shape == expected.shape + assert actual.size == expected.size + assert actual.ndim == expected.ndim + with pytest.raises(TypeError): + len(actual) + np.testing.assert_array_equal(expected, actual) + + B = IndexerMaker(indexing.BasicIndexer) + with pytest.raises(IndexError): + actual[B[:2]] + + +def test_StackedBytesArray_vectorized_indexing() -> None: + array = np.array([[b"a", b"b", b"c"], [b"d", b"e", b"f"]], dtype="S") + stacked = strings.StackedBytesArray(array) + expected = np.array([[b"abc", b"def"], [b"def", b"abc"]]) + + V = IndexerMaker(indexing.VectorizedIndexer) + indexer = V[np.array([[0, 1], [1, 0]])] + actual = stacked.vindex[indexer] + assert_array_equal(actual, expected) + + +def test_char_to_bytes() -> None: + array = np.array([[b"a", b"b", b"c"], [b"d", b"e", b"f"]]) + expected = np.array([b"abc", b"def"]) + actual = strings.char_to_bytes(array) + assert_array_equal(actual, expected) + + expected = np.array([b"ad", b"be", b"cf"]) + actual = strings.char_to_bytes(array.T) # non-contiguous + assert_array_equal(actual, expected) + + +def test_char_to_bytes_ndim_zero() -> None: + expected = np.array(b"a") + actual = strings.char_to_bytes(expected) + assert_array_equal(actual, expected) + + +def test_char_to_bytes_size_zero() -> None: + array = np.zeros((3, 0), dtype="S1") + expected = np.array([b"", b"", b""]) + actual = strings.char_to_bytes(array) + assert_array_equal(actual, expected) + + +@requires_dask +def test_char_to_bytes_dask() -> None: + numpy_array = np.array([[b"a", b"b", b"c"], [b"d", b"e", b"f"]]) + array = da.from_array(numpy_array, ((2,), (3,))) + expected = np.array([b"abc", b"def"]) + actual = strings.char_to_bytes(array) + assert isinstance(actual, da.Array) + assert actual.chunks == ((2,),) + assert actual.dtype == "S3" + assert_array_equal(np.array(actual), expected) + + with pytest.raises(ValueError, match=r"stacked dask character array"): + strings.char_to_bytes(array.rechunk(1)) + + +def test_bytes_to_char() -> None: + array = np.array([[b"ab", b"cd"], [b"ef", b"gh"]]) + expected = np.array([[[b"a", b"b"], [b"c", b"d"]], [[b"e", b"f"], [b"g", b"h"]]]) + actual = strings.bytes_to_char(array) + assert_array_equal(actual, expected) + + expected = np.array([[[b"a", b"b"], [b"e", b"f"]], [[b"c", b"d"], [b"g", b"h"]]]) + actual = strings.bytes_to_char(array.T) # non-contiguous + assert_array_equal(actual, expected) + + +@requires_dask +def test_bytes_to_char_dask() -> None: + numpy_array = np.array([b"ab", b"cd"]) + array = da.from_array(numpy_array, ((1, 1),)) + expected = np.array([[b"a", b"b"], [b"c", b"d"]]) + actual = strings.bytes_to_char(array) + assert isinstance(actual, da.Array) + assert actual.chunks == ((1, 1), ((2,))) + assert actual.dtype == "S1" + assert_array_equal(np.array(actual), expected) diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_coding_times.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_coding_times.py new file mode 100644 index 0000000..09221d6 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_coding_times.py @@ -0,0 +1,1603 @@ +from __future__ import annotations + +import warnings +from datetime import timedelta +from itertools import product + +import numpy as np +import pandas as pd +import pytest +from pandas.errors import OutOfBoundsDatetime + +from xarray import ( + DataArray, + Dataset, + Variable, + cftime_range, + coding, + conventions, + date_range, + decode_cf, +) +from xarray.coding.times import ( + _encode_datetime_with_cftime, + _numpy_to_netcdf_timeunit, + _should_cftime_be_used, + cftime_to_nptime, + decode_cf_datetime, + decode_cf_timedelta, + encode_cf_datetime, + encode_cf_timedelta, + to_timedelta_unboxed, +) +from xarray.coding.variables import SerializationWarning +from xarray.conventions import _update_bounds_attributes, cf_encoder +from xarray.core.common import contains_cftime_datetimes +from xarray.core.utils import is_duck_dask_array +from xarray.testing import assert_equal, assert_identical +from xarray.tests import ( + FirstElementAccessibleArray, + arm_xfail, + assert_array_equal, + assert_no_warnings, + has_cftime, + requires_cftime, + requires_dask, +) + +_NON_STANDARD_CALENDARS_SET = { + "noleap", + "365_day", + "360_day", + "julian", + "all_leap", + "366_day", +} +_ALL_CALENDARS = sorted( + _NON_STANDARD_CALENDARS_SET.union(coding.times._STANDARD_CALENDARS) +) +_NON_STANDARD_CALENDARS = sorted(_NON_STANDARD_CALENDARS_SET) +_STANDARD_CALENDARS = sorted(coding.times._STANDARD_CALENDARS) +_CF_DATETIME_NUM_DATES_UNITS = [ + (np.arange(10), "days since 2000-01-01"), + (np.arange(10).astype("float64"), "days since 2000-01-01"), + (np.arange(10).astype("float32"), "days since 2000-01-01"), + (np.arange(10).reshape(2, 5), "days since 2000-01-01"), + (12300 + np.arange(5), "hours since 1680-01-01 00:00:00"), + # here we add a couple minor formatting errors to test + # the robustness of the parsing algorithm. + (12300 + np.arange(5), "hour since 1680-01-01 00:00:00"), + (12300 + np.arange(5), "Hour since 1680-01-01 00:00:00"), + (12300 + np.arange(5), " Hour since 1680-01-01 00:00:00 "), + (10, "days since 2000-01-01"), + ([10], "daYs since 2000-01-01"), + ([[10]], "days since 2000-01-01"), + ([10, 10], "days since 2000-01-01"), + (np.array(10), "days since 2000-01-01"), + (0, "days since 1000-01-01"), + ([0], "days since 1000-01-01"), + ([[0]], "days since 1000-01-01"), + (np.arange(2), "days since 1000-01-01"), + (np.arange(0, 100000, 20000), "days since 1900-01-01"), + (np.arange(0, 100000, 20000), "days since 1-01-01"), + (17093352.0, "hours since 1-1-1 00:00:0.0"), + ([0.5, 1.5], "hours since 1900-01-01T00:00:00"), + (0, "milliseconds since 2000-01-01T00:00:00"), + (0, "microseconds since 2000-01-01T00:00:00"), + (np.int32(788961600), "seconds since 1981-01-01"), # GH2002 + (12300 + np.arange(5), "hour since 1680-01-01 00:00:00.500000"), + (164375, "days since 1850-01-01 00:00:00"), + (164374.5, "days since 1850-01-01 00:00:00"), + ([164374.5, 168360.5], "days since 1850-01-01 00:00:00"), +] +_CF_DATETIME_TESTS = [ + num_dates_units + (calendar,) + for num_dates_units, calendar in product( + _CF_DATETIME_NUM_DATES_UNITS, _STANDARD_CALENDARS + ) +] + + +def _all_cftime_date_types(): + import cftime + + return { + "noleap": cftime.DatetimeNoLeap, + "365_day": cftime.DatetimeNoLeap, + "360_day": cftime.Datetime360Day, + "julian": cftime.DatetimeJulian, + "all_leap": cftime.DatetimeAllLeap, + "366_day": cftime.DatetimeAllLeap, + "gregorian": cftime.DatetimeGregorian, + "proleptic_gregorian": cftime.DatetimeProlepticGregorian, + } + + +@requires_cftime +@pytest.mark.filterwarnings("ignore:Ambiguous reference date string") +@pytest.mark.filterwarnings("ignore:Times can't be serialized faithfully") +@pytest.mark.parametrize(["num_dates", "units", "calendar"], _CF_DATETIME_TESTS) +def test_cf_datetime(num_dates, units, calendar) -> None: + import cftime + + expected = cftime.num2date( + num_dates, units, calendar, only_use_cftime_datetimes=True + ) + min_y = np.ravel(np.atleast_1d(expected))[np.nanargmin(num_dates)].year + max_y = np.ravel(np.atleast_1d(expected))[np.nanargmax(num_dates)].year + if min_y >= 1678 and max_y < 2262: + expected = cftime_to_nptime(expected) + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "Unable to decode time axis") + actual = coding.times.decode_cf_datetime(num_dates, units, calendar) + + abs_diff = np.asarray(abs(actual - expected)).ravel() + abs_diff = pd.to_timedelta(abs_diff.tolist()).to_numpy() + + # once we no longer support versions of netCDF4 older than 1.1.5, + # we could do this check with near microsecond accuracy: + # https://github.com/Unidata/netcdf4-python/issues/355 + assert (abs_diff <= np.timedelta64(1, "s")).all() + encoded, _, _ = coding.times.encode_cf_datetime(actual, units, calendar) + + assert_array_equal(num_dates, np.around(encoded, 1)) + if hasattr(num_dates, "ndim") and num_dates.ndim == 1 and "1000" not in units: + # verify that wrapping with a pandas.Index works + # note that it *does not* currently work to put + # non-datetime64 compatible dates into a pandas.Index + encoded, _, _ = coding.times.encode_cf_datetime( + pd.Index(actual), units, calendar + ) + assert_array_equal(num_dates, np.around(encoded, 1)) + + +@requires_cftime +def test_decode_cf_datetime_overflow() -> None: + # checks for + # https://github.com/pydata/pandas/issues/14068 + # https://github.com/pydata/xarray/issues/975 + from cftime import DatetimeGregorian + + datetime = DatetimeGregorian + units = "days since 2000-01-01 00:00:00" + + # date after 2262 and before 1678 + days = (-117608, 95795) + expected = (datetime(1677, 12, 31), datetime(2262, 4, 12)) + + for i, day in enumerate(days): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "Unable to decode time axis") + result = coding.times.decode_cf_datetime(day, units) + assert result == expected[i] + + +def test_decode_cf_datetime_non_standard_units() -> None: + expected = pd.date_range(periods=100, start="1970-01-01", freq="h") + # netCDFs from madis.noaa.gov use this format for their time units + # they cannot be parsed by cftime, but pd.Timestamp works + units = "hours since 1-1-1970" + actual = coding.times.decode_cf_datetime(np.arange(100), units) + assert_array_equal(actual, expected) + + +@requires_cftime +def test_decode_cf_datetime_non_iso_strings() -> None: + # datetime strings that are _almost_ ISO compliant but not quite, + # but which cftime.num2date can still parse correctly + expected = pd.date_range(periods=100, start="2000-01-01", freq="h") + cases = [ + (np.arange(100), "hours since 2000-01-01 0"), + (np.arange(100), "hours since 2000-1-1 0"), + (np.arange(100), "hours since 2000-01-01 0:00"), + ] + for num_dates, units in cases: + actual = coding.times.decode_cf_datetime(num_dates, units) + abs_diff = abs(actual - expected.values) + # once we no longer support versions of netCDF4 older than 1.1.5, + # we could do this check with near microsecond accuracy: + # https://github.com/Unidata/netcdf4-python/issues/355 + assert (abs_diff <= np.timedelta64(1, "s")).all() + + +@requires_cftime +@pytest.mark.parametrize("calendar", _STANDARD_CALENDARS) +def test_decode_standard_calendar_inside_timestamp_range(calendar) -> None: + import cftime + + units = "days since 0001-01-01" + times = pd.date_range("2001-04-01-00", end="2001-04-30-23", freq="h") + time = cftime.date2num(times.to_pydatetime(), units, calendar=calendar) + expected = times.values + expected_dtype = np.dtype("M8[ns]") + + actual = coding.times.decode_cf_datetime(time, units, calendar=calendar) + assert actual.dtype == expected_dtype + abs_diff = abs(actual - expected) + # once we no longer support versions of netCDF4 older than 1.1.5, + # we could do this check with near microsecond accuracy: + # https://github.com/Unidata/netcdf4-python/issues/355 + assert (abs_diff <= np.timedelta64(1, "s")).all() + + +@requires_cftime +@pytest.mark.parametrize("calendar", _NON_STANDARD_CALENDARS) +def test_decode_non_standard_calendar_inside_timestamp_range(calendar) -> None: + import cftime + + units = "days since 0001-01-01" + times = pd.date_range("2001-04-01-00", end="2001-04-30-23", freq="h") + non_standard_time = cftime.date2num(times.to_pydatetime(), units, calendar=calendar) + + expected = cftime.num2date( + non_standard_time, units, calendar=calendar, only_use_cftime_datetimes=True + ) + expected_dtype = np.dtype("O") + + actual = coding.times.decode_cf_datetime( + non_standard_time, units, calendar=calendar + ) + assert actual.dtype == expected_dtype + abs_diff = abs(actual - expected) + # once we no longer support versions of netCDF4 older than 1.1.5, + # we could do this check with near microsecond accuracy: + # https://github.com/Unidata/netcdf4-python/issues/355 + assert (abs_diff <= np.timedelta64(1, "s")).all() + + +@requires_cftime +@pytest.mark.parametrize("calendar", _ALL_CALENDARS) +def test_decode_dates_outside_timestamp_range(calendar) -> None: + from datetime import datetime + + import cftime + + units = "days since 0001-01-01" + times = [datetime(1, 4, 1, h) for h in range(1, 5)] + time = cftime.date2num(times, units, calendar=calendar) + + expected = cftime.num2date( + time, units, calendar=calendar, only_use_cftime_datetimes=True + ) + expected_date_type = type(expected[0]) + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "Unable to decode time axis") + actual = coding.times.decode_cf_datetime(time, units, calendar=calendar) + assert all(isinstance(value, expected_date_type) for value in actual) + abs_diff = abs(actual - expected) + # once we no longer support versions of netCDF4 older than 1.1.5, + # we could do this check with near microsecond accuracy: + # https://github.com/Unidata/netcdf4-python/issues/355 + assert (abs_diff <= np.timedelta64(1, "s")).all() + + +@requires_cftime +@pytest.mark.parametrize("calendar", _STANDARD_CALENDARS) +def test_decode_standard_calendar_single_element_inside_timestamp_range( + calendar, +) -> None: + units = "days since 0001-01-01" + for num_time in [735368, [735368], [[735368]]]: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "Unable to decode time axis") + actual = coding.times.decode_cf_datetime(num_time, units, calendar=calendar) + assert actual.dtype == np.dtype("M8[ns]") + + +@requires_cftime +@pytest.mark.parametrize("calendar", _NON_STANDARD_CALENDARS) +def test_decode_non_standard_calendar_single_element_inside_timestamp_range( + calendar, +) -> None: + units = "days since 0001-01-01" + for num_time in [735368, [735368], [[735368]]]: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "Unable to decode time axis") + actual = coding.times.decode_cf_datetime(num_time, units, calendar=calendar) + assert actual.dtype == np.dtype("O") + + +@requires_cftime +@pytest.mark.parametrize("calendar", _NON_STANDARD_CALENDARS) +def test_decode_single_element_outside_timestamp_range(calendar) -> None: + import cftime + + units = "days since 0001-01-01" + for days in [1, 1470376]: + for num_time in [days, [days], [[days]]]: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "Unable to decode time axis") + actual = coding.times.decode_cf_datetime( + num_time, units, calendar=calendar + ) + + expected = cftime.num2date( + days, units, calendar, only_use_cftime_datetimes=True + ) + assert isinstance(actual.item(), type(expected)) + + +@requires_cftime +@pytest.mark.parametrize("calendar", _STANDARD_CALENDARS) +def test_decode_standard_calendar_multidim_time_inside_timestamp_range( + calendar, +) -> None: + import cftime + + units = "days since 0001-01-01" + times1 = pd.date_range("2001-04-01", end="2001-04-05", freq="D") + times2 = pd.date_range("2001-05-01", end="2001-05-05", freq="D") + time1 = cftime.date2num(times1.to_pydatetime(), units, calendar=calendar) + time2 = cftime.date2num(times2.to_pydatetime(), units, calendar=calendar) + mdim_time = np.empty((len(time1), 2)) + mdim_time[:, 0] = time1 + mdim_time[:, 1] = time2 + + expected1 = times1.values + expected2 = times2.values + + actual = coding.times.decode_cf_datetime(mdim_time, units, calendar=calendar) + assert actual.dtype == np.dtype("M8[ns]") + + abs_diff1 = abs(actual[:, 0] - expected1) + abs_diff2 = abs(actual[:, 1] - expected2) + # once we no longer support versions of netCDF4 older than 1.1.5, + # we could do this check with near microsecond accuracy: + # https://github.com/Unidata/netcdf4-python/issues/355 + assert (abs_diff1 <= np.timedelta64(1, "s")).all() + assert (abs_diff2 <= np.timedelta64(1, "s")).all() + + +@requires_cftime +@pytest.mark.parametrize("calendar", _NON_STANDARD_CALENDARS) +def test_decode_nonstandard_calendar_multidim_time_inside_timestamp_range( + calendar, +) -> None: + import cftime + + units = "days since 0001-01-01" + times1 = pd.date_range("2001-04-01", end="2001-04-05", freq="D") + times2 = pd.date_range("2001-05-01", end="2001-05-05", freq="D") + time1 = cftime.date2num(times1.to_pydatetime(), units, calendar=calendar) + time2 = cftime.date2num(times2.to_pydatetime(), units, calendar=calendar) + mdim_time = np.empty((len(time1), 2)) + mdim_time[:, 0] = time1 + mdim_time[:, 1] = time2 + + if cftime.__name__ == "cftime": + expected1 = cftime.num2date( + time1, units, calendar, only_use_cftime_datetimes=True + ) + expected2 = cftime.num2date( + time2, units, calendar, only_use_cftime_datetimes=True + ) + else: + expected1 = cftime.num2date(time1, units, calendar) + expected2 = cftime.num2date(time2, units, calendar) + + expected_dtype = np.dtype("O") + + actual = coding.times.decode_cf_datetime(mdim_time, units, calendar=calendar) + + assert actual.dtype == expected_dtype + abs_diff1 = abs(actual[:, 0] - expected1) + abs_diff2 = abs(actual[:, 1] - expected2) + # once we no longer support versions of netCDF4 older than 1.1.5, + # we could do this check with near microsecond accuracy: + # https://github.com/Unidata/netcdf4-python/issues/355 + assert (abs_diff1 <= np.timedelta64(1, "s")).all() + assert (abs_diff2 <= np.timedelta64(1, "s")).all() + + +@requires_cftime +@pytest.mark.parametrize("calendar", _ALL_CALENDARS) +def test_decode_multidim_time_outside_timestamp_range(calendar) -> None: + from datetime import datetime + + import cftime + + units = "days since 0001-01-01" + times1 = [datetime(1, 4, day) for day in range(1, 6)] + times2 = [datetime(1, 5, day) for day in range(1, 6)] + time1 = cftime.date2num(times1, units, calendar=calendar) + time2 = cftime.date2num(times2, units, calendar=calendar) + mdim_time = np.empty((len(time1), 2)) + mdim_time[:, 0] = time1 + mdim_time[:, 1] = time2 + + expected1 = cftime.num2date(time1, units, calendar, only_use_cftime_datetimes=True) + expected2 = cftime.num2date(time2, units, calendar, only_use_cftime_datetimes=True) + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "Unable to decode time axis") + actual = coding.times.decode_cf_datetime(mdim_time, units, calendar=calendar) + + assert actual.dtype == np.dtype("O") + + abs_diff1 = abs(actual[:, 0] - expected1) + abs_diff2 = abs(actual[:, 1] - expected2) + # once we no longer support versions of netCDF4 older than 1.1.5, + # we could do this check with near microsecond accuracy: + # https://github.com/Unidata/netcdf4-python/issues/355 + assert (abs_diff1 <= np.timedelta64(1, "s")).all() + assert (abs_diff2 <= np.timedelta64(1, "s")).all() + + +@requires_cftime +@pytest.mark.parametrize( + ("calendar", "num_time"), + [("360_day", 720058.0), ("all_leap", 732059.0), ("366_day", 732059.0)], +) +def test_decode_non_standard_calendar_single_element(calendar, num_time) -> None: + import cftime + + units = "days since 0001-01-01" + + actual = coding.times.decode_cf_datetime(num_time, units, calendar=calendar) + + expected = np.asarray( + cftime.num2date(num_time, units, calendar, only_use_cftime_datetimes=True) + ) + assert actual.dtype == np.dtype("O") + assert expected == actual + + +@requires_cftime +def test_decode_360_day_calendar() -> None: + import cftime + + calendar = "360_day" + # ensure leap year doesn't matter + for year in [2010, 2011, 2012, 2013, 2014]: + units = f"days since {year}-01-01" + num_times = np.arange(100) + + expected = cftime.num2date( + num_times, units, calendar, only_use_cftime_datetimes=True + ) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + actual = coding.times.decode_cf_datetime( + num_times, units, calendar=calendar + ) + assert len(w) == 0 + + assert actual.dtype == np.dtype("O") + assert_array_equal(actual, expected) + + +@requires_cftime +def test_decode_abbreviation() -> None: + """Test making sure we properly fall back to cftime on abbreviated units.""" + import cftime + + val = np.array([1586628000000.0]) + units = "msecs since 1970-01-01T00:00:00Z" + actual = coding.times.decode_cf_datetime(val, units) + expected = coding.times.cftime_to_nptime(cftime.num2date(val, units)) + assert_array_equal(actual, expected) + + +@arm_xfail +@requires_cftime +@pytest.mark.parametrize( + ["num_dates", "units", "expected_list"], + [ + ([np.nan], "days since 2000-01-01", ["NaT"]), + ([np.nan, 0], "days since 2000-01-01", ["NaT", "2000-01-01T00:00:00Z"]), + ( + [np.nan, 0, 1], + "days since 2000-01-01", + ["NaT", "2000-01-01T00:00:00Z", "2000-01-02T00:00:00Z"], + ), + ], +) +def test_cf_datetime_nan(num_dates, units, expected_list) -> None: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "All-NaN") + actual = coding.times.decode_cf_datetime(num_dates, units) + # use pandas because numpy will deprecate timezone-aware conversions + expected = pd.to_datetime(expected_list).to_numpy(dtype="datetime64[ns]") + assert_array_equal(expected, actual) + + +@requires_cftime +def test_decoded_cf_datetime_array_2d() -> None: + # regression test for GH1229 + variable = Variable( + ("x", "y"), np.array([[0, 1], [2, 3]]), {"units": "days since 2000-01-01"} + ) + result = coding.times.CFDatetimeCoder().decode(variable) + assert result.dtype == "datetime64[ns]" + expected = pd.date_range("2000-01-01", periods=4).values.reshape(2, 2) + assert_array_equal(np.asarray(result), expected) + + +FREQUENCIES_TO_ENCODING_UNITS = { + "ns": "nanoseconds", + "us": "microseconds", + "ms": "milliseconds", + "s": "seconds", + "min": "minutes", + "h": "hours", + "D": "days", +} + + +@pytest.mark.parametrize(("freq", "units"), FREQUENCIES_TO_ENCODING_UNITS.items()) +def test_infer_datetime_units(freq, units) -> None: + dates = pd.date_range("2000", periods=2, freq=freq) + expected = f"{units} since 2000-01-01 00:00:00" + assert expected == coding.times.infer_datetime_units(dates) + + +@pytest.mark.parametrize( + ["dates", "expected"], + [ + ( + pd.to_datetime(["1900-01-01", "1900-01-02", "NaT"], unit="ns"), + "days since 1900-01-01 00:00:00", + ), + ( + pd.to_datetime(["NaT", "1900-01-01"], unit="ns"), + "days since 1900-01-01 00:00:00", + ), + (pd.to_datetime(["NaT"], unit="ns"), "days since 1970-01-01 00:00:00"), + ], +) +def test_infer_datetime_units_with_NaT(dates, expected) -> None: + assert expected == coding.times.infer_datetime_units(dates) + + +_CFTIME_DATETIME_UNITS_TESTS = [ + ([(1900, 1, 1), (1900, 1, 1)], "days since 1900-01-01 00:00:00.000000"), + ( + [(1900, 1, 1), (1900, 1, 2), (1900, 1, 2, 0, 0, 1)], + "seconds since 1900-01-01 00:00:00.000000", + ), + ( + [(1900, 1, 1), (1900, 1, 8), (1900, 1, 16)], + "days since 1900-01-01 00:00:00.000000", + ), +] + + +@requires_cftime +@pytest.mark.parametrize( + "calendar", _NON_STANDARD_CALENDARS + ["gregorian", "proleptic_gregorian"] +) +@pytest.mark.parametrize(("date_args", "expected"), _CFTIME_DATETIME_UNITS_TESTS) +def test_infer_cftime_datetime_units(calendar, date_args, expected) -> None: + date_type = _all_cftime_date_types()[calendar] + dates = [date_type(*args) for args in date_args] + assert expected == coding.times.infer_datetime_units(dates) + + +@pytest.mark.filterwarnings("ignore:Timedeltas can't be serialized faithfully") +@pytest.mark.parametrize( + ["timedeltas", "units", "numbers"], + [ + ("1D", "days", np.int64(1)), + (["1D", "2D", "3D"], "days", np.array([1, 2, 3], "int64")), + ("1h", "hours", np.int64(1)), + ("1ms", "milliseconds", np.int64(1)), + ("1us", "microseconds", np.int64(1)), + ("1ns", "nanoseconds", np.int64(1)), + (["NaT", "0s", "1s"], None, [np.iinfo(np.int64).min, 0, 1]), + (["30m", "60m"], "hours", [0.5, 1.0]), + ("NaT", "days", np.iinfo(np.int64).min), + (["NaT", "NaT"], "days", [np.iinfo(np.int64).min, np.iinfo(np.int64).min]), + ], +) +def test_cf_timedelta(timedeltas, units, numbers) -> None: + if timedeltas == "NaT": + timedeltas = np.timedelta64("NaT", "ns") + else: + timedeltas = to_timedelta_unboxed(timedeltas) + numbers = np.array(numbers) + + expected = numbers + actual, _ = coding.times.encode_cf_timedelta(timedeltas, units) + assert_array_equal(expected, actual) + assert expected.dtype == actual.dtype + + if units is not None: + expected = timedeltas + actual = coding.times.decode_cf_timedelta(numbers, units) + assert_array_equal(expected, actual) + assert expected.dtype == actual.dtype + + expected = np.timedelta64("NaT", "ns") + actual = coding.times.decode_cf_timedelta(np.array(np.nan), "days") + assert_array_equal(expected, actual) + + +def test_cf_timedelta_2d() -> None: + units = "days" + numbers = np.atleast_2d([1, 2, 3]) + + timedeltas = np.atleast_2d(to_timedelta_unboxed(["1D", "2D", "3D"])) + expected = timedeltas + + actual = coding.times.decode_cf_timedelta(numbers, units) + assert_array_equal(expected, actual) + assert expected.dtype == actual.dtype + + +@pytest.mark.parametrize( + ["deltas", "expected"], + [ + (pd.to_timedelta(["1 day", "2 days"]), "days"), + (pd.to_timedelta(["1h", "1 day 1 hour"]), "hours"), + (pd.to_timedelta(["1m", "2m", np.nan]), "minutes"), + (pd.to_timedelta(["1m3s", "1m4s"]), "seconds"), + ], +) +def test_infer_timedelta_units(deltas, expected) -> None: + assert expected == coding.times.infer_timedelta_units(deltas) + + +@requires_cftime +@pytest.mark.parametrize( + ["date_args", "expected"], + [ + ((1, 2, 3, 4, 5, 6), "0001-02-03 04:05:06.000000"), + ((10, 2, 3, 4, 5, 6), "0010-02-03 04:05:06.000000"), + ((100, 2, 3, 4, 5, 6), "0100-02-03 04:05:06.000000"), + ((1000, 2, 3, 4, 5, 6), "1000-02-03 04:05:06.000000"), + ], +) +def test_format_cftime_datetime(date_args, expected) -> None: + date_types = _all_cftime_date_types() + for date_type in date_types.values(): + result = coding.times.format_cftime_datetime(date_type(*date_args)) + assert result == expected + + +@pytest.mark.parametrize("calendar", _ALL_CALENDARS) +def test_decode_cf(calendar) -> None: + days = [1.0, 2.0, 3.0] + # TODO: GH5690 — do we want to allow this type for `coords`? + da = DataArray(days, coords=[days], dims=["time"], name="test") + ds = da.to_dataset() + + for v in ["test", "time"]: + ds[v].attrs["units"] = "days since 2001-01-01" + ds[v].attrs["calendar"] = calendar + + if not has_cftime and calendar not in _STANDARD_CALENDARS: + with pytest.raises(ValueError): + ds = decode_cf(ds) + else: + ds = decode_cf(ds) + + if calendar not in _STANDARD_CALENDARS: + assert ds.test.dtype == np.dtype("O") + else: + assert ds.test.dtype == np.dtype("M8[ns]") + + +def test_decode_cf_time_bounds() -> None: + da = DataArray( + np.arange(6, dtype="int64").reshape((3, 2)), + coords={"time": [1, 2, 3]}, + dims=("time", "nbnd"), + name="time_bnds", + ) + + attrs = { + "units": "days since 2001-01", + "calendar": "standard", + "bounds": "time_bnds", + } + + ds = da.to_dataset() + ds["time"].attrs.update(attrs) + _update_bounds_attributes(ds.variables) + assert ds.variables["time_bnds"].attrs == { + "units": "days since 2001-01", + "calendar": "standard", + } + dsc = decode_cf(ds) + assert dsc.time_bnds.dtype == np.dtype("M8[ns]") + dsc = decode_cf(ds, decode_times=False) + assert dsc.time_bnds.dtype == np.dtype("int64") + + # Do not overwrite existing attrs + ds = da.to_dataset() + ds["time"].attrs.update(attrs) + bnd_attr = {"units": "hours since 2001-01", "calendar": "noleap"} + ds["time_bnds"].attrs.update(bnd_attr) + _update_bounds_attributes(ds.variables) + assert ds.variables["time_bnds"].attrs == bnd_attr + + # If bounds variable not available do not complain + ds = da.to_dataset() + ds["time"].attrs.update(attrs) + ds["time"].attrs["bounds"] = "fake_var" + _update_bounds_attributes(ds.variables) + + +@requires_cftime +def test_encode_time_bounds() -> None: + time = pd.date_range("2000-01-16", periods=1) + time_bounds = pd.date_range("2000-01-01", periods=2, freq="MS") + ds = Dataset(dict(time=time, time_bounds=time_bounds)) + ds.time.attrs = {"bounds": "time_bounds"} + ds.time.encoding = {"calendar": "noleap", "units": "days since 2000-01-01"} + + expected = {} + # expected['time'] = Variable(data=np.array([15]), dims=['time']) + expected["time_bounds"] = Variable(data=np.array([0, 31]), dims=["time_bounds"]) + + encoded, _ = cf_encoder(ds.variables, ds.attrs) + assert_equal(encoded["time_bounds"], expected["time_bounds"]) + assert "calendar" not in encoded["time_bounds"].attrs + assert "units" not in encoded["time_bounds"].attrs + + # if time_bounds attrs are same as time attrs, it doesn't matter + ds.time_bounds.encoding = {"calendar": "noleap", "units": "days since 2000-01-01"} + encoded, _ = cf_encoder({k: v for k, v in ds.variables.items()}, ds.attrs) + assert_equal(encoded["time_bounds"], expected["time_bounds"]) + assert "calendar" not in encoded["time_bounds"].attrs + assert "units" not in encoded["time_bounds"].attrs + + # for CF-noncompliant case of time_bounds attrs being different from + # time attrs; preserve them for faithful roundtrip + ds.time_bounds.encoding = {"calendar": "noleap", "units": "days since 1849-01-01"} + encoded, _ = cf_encoder({k: v for k, v in ds.variables.items()}, ds.attrs) + with pytest.raises(AssertionError): + assert_equal(encoded["time_bounds"], expected["time_bounds"]) + assert "calendar" not in encoded["time_bounds"].attrs + assert encoded["time_bounds"].attrs["units"] == ds.time_bounds.encoding["units"] + + ds.time.encoding = {} + with pytest.warns(UserWarning): + cf_encoder(ds.variables, ds.attrs) + + +@pytest.fixture(params=_ALL_CALENDARS) +def calendar(request): + return request.param + + +@pytest.fixture() +def times(calendar): + import cftime + + return cftime.num2date( + np.arange(4), + units="hours since 2000-01-01", + calendar=calendar, + only_use_cftime_datetimes=True, + ) + + +@pytest.fixture() +def data(times): + data = np.random.rand(2, 2, 4) + lons = np.linspace(0, 11, 2) + lats = np.linspace(0, 20, 2) + return DataArray( + data, coords=[lons, lats, times], dims=["lon", "lat", "time"], name="data" + ) + + +@pytest.fixture() +def times_3d(times): + lons = np.linspace(0, 11, 2) + lats = np.linspace(0, 20, 2) + times_arr = np.random.choice(times, size=(2, 2, 4)) + return DataArray( + times_arr, coords=[lons, lats, times], dims=["lon", "lat", "time"], name="data" + ) + + +@requires_cftime +def test_contains_cftime_datetimes_1d(data) -> None: + assert contains_cftime_datetimes(data.time.variable) + + +@requires_cftime +@requires_dask +def test_contains_cftime_datetimes_dask_1d(data) -> None: + assert contains_cftime_datetimes(data.time.variable.chunk()) + + +@requires_cftime +def test_contains_cftime_datetimes_3d(times_3d) -> None: + assert contains_cftime_datetimes(times_3d.variable) + + +@requires_cftime +@requires_dask +def test_contains_cftime_datetimes_dask_3d(times_3d) -> None: + assert contains_cftime_datetimes(times_3d.variable.chunk()) + + +@pytest.mark.parametrize("non_cftime_data", [DataArray([]), DataArray([1, 2])]) +def test_contains_cftime_datetimes_non_cftimes(non_cftime_data) -> None: + assert not contains_cftime_datetimes(non_cftime_data.variable) + + +@requires_dask +@pytest.mark.parametrize("non_cftime_data", [DataArray([]), DataArray([1, 2])]) +def test_contains_cftime_datetimes_non_cftimes_dask(non_cftime_data) -> None: + assert not contains_cftime_datetimes(non_cftime_data.variable.chunk()) + + +@requires_cftime +@pytest.mark.parametrize("shape", [(24,), (8, 3), (2, 4, 3)]) +def test_encode_cf_datetime_overflow(shape) -> None: + # Test for fix to GH 2272 + dates = pd.date_range("2100", periods=24).values.reshape(shape) + units = "days since 1800-01-01" + calendar = "standard" + + num, _, _ = encode_cf_datetime(dates, units, calendar) + roundtrip = decode_cf_datetime(num, units, calendar) + np.testing.assert_array_equal(dates, roundtrip) + + +def test_encode_expected_failures() -> None: + dates = pd.date_range("2000", periods=3) + with pytest.raises(ValueError, match="invalid time units"): + encode_cf_datetime(dates, units="days after 2000-01-01") + with pytest.raises(ValueError, match="invalid reference date"): + encode_cf_datetime(dates, units="days since NO_YEAR") + + +def test_encode_cf_datetime_pandas_min() -> None: + # GH 2623 + dates = pd.date_range("2000", periods=3) + num, units, calendar = encode_cf_datetime(dates) + expected_num = np.array([0.0, 1.0, 2.0]) + expected_units = "days since 2000-01-01 00:00:00" + expected_calendar = "proleptic_gregorian" + np.testing.assert_array_equal(num, expected_num) + assert units == expected_units + assert calendar == expected_calendar + + +@requires_cftime +def test_encode_cf_datetime_invalid_pandas_valid_cftime() -> None: + num, units, calendar = encode_cf_datetime( + pd.date_range("2000", periods=3), + # Pandas fails to parse this unit, but cftime is quite happy with it + "days since 1970-01-01 00:00:00 00", + "standard", + ) + + expected_num = [10957, 10958, 10959] + expected_units = "days since 1970-01-01 00:00:00 00" + expected_calendar = "standard" + assert_array_equal(num, expected_num) + assert units == expected_units + assert calendar == expected_calendar + + +@requires_cftime +def test_time_units_with_timezone_roundtrip(calendar) -> None: + # Regression test for GH 2649 + expected_units = "days since 2000-01-01T00:00:00-05:00" + expected_num_dates = np.array([1, 2, 3]) + dates = decode_cf_datetime(expected_num_dates, expected_units, calendar) + + # Check that dates were decoded to UTC; here the hours should all + # equal 5. + result_hours = DataArray(dates).dt.hour + expected_hours = DataArray([5, 5, 5]) + assert_equal(result_hours, expected_hours) + + # Check that the encoded values are accurately roundtripped. + result_num_dates, result_units, result_calendar = encode_cf_datetime( + dates, expected_units, calendar + ) + + if calendar in _STANDARD_CALENDARS: + np.testing.assert_array_equal(result_num_dates, expected_num_dates) + else: + # cftime datetime arithmetic is not quite exact. + np.testing.assert_allclose(result_num_dates, expected_num_dates) + + assert result_units == expected_units + assert result_calendar == calendar + + +@pytest.mark.parametrize("calendar", _STANDARD_CALENDARS) +def test_use_cftime_default_standard_calendar_in_range(calendar) -> None: + numerical_dates = [0, 1] + units = "days since 2000-01-01" + expected = pd.date_range("2000", periods=2) + + with assert_no_warnings(): + result = decode_cf_datetime(numerical_dates, units, calendar) + np.testing.assert_array_equal(result, expected) + + +@requires_cftime +@pytest.mark.parametrize("calendar", _STANDARD_CALENDARS) +@pytest.mark.parametrize("units_year", [1500, 2500]) +def test_use_cftime_default_standard_calendar_out_of_range( + calendar, units_year +) -> None: + from cftime import num2date + + numerical_dates = [0, 1] + units = f"days since {units_year}-01-01" + expected = num2date( + numerical_dates, units, calendar, only_use_cftime_datetimes=True + ) + + with pytest.warns(SerializationWarning): + result = decode_cf_datetime(numerical_dates, units, calendar) + np.testing.assert_array_equal(result, expected) + + +@requires_cftime +@pytest.mark.parametrize("calendar", _NON_STANDARD_CALENDARS) +@pytest.mark.parametrize("units_year", [1500, 2000, 2500]) +def test_use_cftime_default_non_standard_calendar(calendar, units_year) -> None: + from cftime import num2date + + numerical_dates = [0, 1] + units = f"days since {units_year}-01-01" + expected = num2date( + numerical_dates, units, calendar, only_use_cftime_datetimes=True + ) + + with assert_no_warnings(): + result = decode_cf_datetime(numerical_dates, units, calendar) + np.testing.assert_array_equal(result, expected) + + +@requires_cftime +@pytest.mark.parametrize("calendar", _ALL_CALENDARS) +@pytest.mark.parametrize("units_year", [1500, 2000, 2500]) +def test_use_cftime_true(calendar, units_year) -> None: + from cftime import num2date + + numerical_dates = [0, 1] + units = f"days since {units_year}-01-01" + expected = num2date( + numerical_dates, units, calendar, only_use_cftime_datetimes=True + ) + + with assert_no_warnings(): + result = decode_cf_datetime(numerical_dates, units, calendar, use_cftime=True) + np.testing.assert_array_equal(result, expected) + + +@pytest.mark.parametrize("calendar", _STANDARD_CALENDARS) +def test_use_cftime_false_standard_calendar_in_range(calendar) -> None: + numerical_dates = [0, 1] + units = "days since 2000-01-01" + expected = pd.date_range("2000", periods=2) + + with assert_no_warnings(): + result = decode_cf_datetime(numerical_dates, units, calendar, use_cftime=False) + np.testing.assert_array_equal(result, expected) + + +@pytest.mark.parametrize("calendar", _STANDARD_CALENDARS) +@pytest.mark.parametrize("units_year", [1500, 2500]) +def test_use_cftime_false_standard_calendar_out_of_range(calendar, units_year) -> None: + numerical_dates = [0, 1] + units = f"days since {units_year}-01-01" + with pytest.raises(OutOfBoundsDatetime): + decode_cf_datetime(numerical_dates, units, calendar, use_cftime=False) + + +@pytest.mark.parametrize("calendar", _NON_STANDARD_CALENDARS) +@pytest.mark.parametrize("units_year", [1500, 2000, 2500]) +def test_use_cftime_false_non_standard_calendar(calendar, units_year) -> None: + numerical_dates = [0, 1] + units = f"days since {units_year}-01-01" + with pytest.raises(OutOfBoundsDatetime): + decode_cf_datetime(numerical_dates, units, calendar, use_cftime=False) + + +@requires_cftime +@pytest.mark.parametrize("calendar", _ALL_CALENDARS) +def test_decode_ambiguous_time_warns(calendar) -> None: + # GH 4422, 4506 + from cftime import num2date + + # we don't decode non-standard calendards with + # pandas so expect no warning will be emitted + is_standard_calendar = calendar in coding.times._STANDARD_CALENDARS + + dates = [1, 2, 3] + units = "days since 1-1-1" + expected = num2date(dates, units, calendar=calendar, only_use_cftime_datetimes=True) + + if is_standard_calendar: + with pytest.warns(SerializationWarning) as record: + result = decode_cf_datetime(dates, units, calendar=calendar) + relevant_warnings = [ + r + for r in record.list + if str(r.message).startswith("Ambiguous reference date string: 1-1-1") + ] + assert len(relevant_warnings) == 1 + else: + with assert_no_warnings(): + result = decode_cf_datetime(dates, units, calendar=calendar) + + np.testing.assert_array_equal(result, expected) + + +@pytest.mark.filterwarnings("ignore:Times can't be serialized faithfully") +@pytest.mark.parametrize("encoding_units", FREQUENCIES_TO_ENCODING_UNITS.values()) +@pytest.mark.parametrize("freq", FREQUENCIES_TO_ENCODING_UNITS.keys()) +@pytest.mark.parametrize("date_range", [pd.date_range, cftime_range]) +def test_encode_cf_datetime_defaults_to_correct_dtype( + encoding_units, freq, date_range +) -> None: + if not has_cftime and date_range == cftime_range: + pytest.skip("Test requires cftime") + if (freq == "ns" or encoding_units == "nanoseconds") and date_range == cftime_range: + pytest.skip("Nanosecond frequency is not valid for cftime dates.") + times = date_range("2000", periods=3, freq=freq) + units = f"{encoding_units} since 2000-01-01" + encoded, _units, _ = coding.times.encode_cf_datetime(times, units) + + numpy_timeunit = coding.times._netcdf_to_numpy_timeunit(encoding_units) + encoding_units_as_timedelta = np.timedelta64(1, numpy_timeunit) + if pd.to_timedelta(1, freq) >= encoding_units_as_timedelta: + assert encoded.dtype == np.int64 + else: + assert encoded.dtype == np.float64 + + +@pytest.mark.parametrize("freq", FREQUENCIES_TO_ENCODING_UNITS.keys()) +def test_encode_decode_roundtrip_datetime64(freq) -> None: + # See GH 4045. Prior to GH 4684 this test would fail for frequencies of + # "s", "ms", "us", and "ns". + initial_time = pd.date_range("1678-01-01", periods=1) + times = initial_time.append(pd.date_range("1968", periods=2, freq=freq)) + variable = Variable(["time"], times) + encoded = conventions.encode_cf_variable(variable) + decoded = conventions.decode_cf_variable("time", encoded) + assert_equal(variable, decoded) + + +@requires_cftime +@pytest.mark.parametrize("freq", ["us", "ms", "s", "min", "h", "D"]) +def test_encode_decode_roundtrip_cftime(freq) -> None: + initial_time = cftime_range("0001", periods=1) + times = initial_time.append( + cftime_range("0001", periods=2, freq=freq) + timedelta(days=291000 * 365) + ) + variable = Variable(["time"], times) + encoded = conventions.encode_cf_variable(variable) + decoded = conventions.decode_cf_variable("time", encoded, use_cftime=True) + assert_equal(variable, decoded) + + +@requires_cftime +def test__encode_datetime_with_cftime() -> None: + # See GH 4870. cftime versions > 1.4.0 required us to adapt the + # way _encode_datetime_with_cftime was written. + import cftime + + calendar = "gregorian" + times = cftime.num2date([0, 1], "hours since 2000-01-01", calendar) + + encoding_units = "days since 2000-01-01" + # Since netCDF files do not support storing float128 values, we ensure that + # float64 values are used by setting longdouble=False in num2date. This try + # except logic can be removed when xarray's minimum version of cftime is at + # least 1.6.2. + try: + expected = cftime.date2num(times, encoding_units, calendar, longdouble=False) + except TypeError: + expected = cftime.date2num(times, encoding_units, calendar) + result = _encode_datetime_with_cftime(times, encoding_units, calendar) + np.testing.assert_equal(result, expected) + + +@pytest.mark.parametrize("calendar", ["gregorian", "Gregorian", "GREGORIAN"]) +def test_decode_encode_roundtrip_with_non_lowercase_letters(calendar) -> None: + # See GH 5093. + times = [0, 1] + units = "days since 2000-01-01" + attrs = {"calendar": calendar, "units": units} + variable = Variable(["time"], times, attrs) + decoded = conventions.decode_cf_variable("time", variable) + encoded = conventions.encode_cf_variable(decoded) + + # Previously this would erroneously be an array of cftime.datetime + # objects. We check here that it is decoded properly to np.datetime64. + assert np.issubdtype(decoded.dtype, np.datetime64) + + # Use assert_identical to ensure that the calendar attribute maintained its + # original form throughout the roundtripping process, uppercase letters and + # all. + assert_identical(variable, encoded) + + +@requires_cftime +def test_should_cftime_be_used_source_outside_range(): + src = cftime_range("1000-01-01", periods=100, freq="MS", calendar="noleap") + with pytest.raises( + ValueError, match="Source time range is not valid for numpy datetimes." + ): + _should_cftime_be_used(src, "standard", False) + + +@requires_cftime +def test_should_cftime_be_used_target_not_npable(): + src = cftime_range("2000-01-01", periods=100, freq="MS", calendar="noleap") + with pytest.raises( + ValueError, match="Calendar 'noleap' is only valid with cftime." + ): + _should_cftime_be_used(src, "noleap", False) + + +@pytest.mark.parametrize("dtype", [np.uint8, np.uint16, np.uint32, np.uint64]) +def test_decode_cf_datetime_uint(dtype): + units = "seconds since 2018-08-22T03:23:03Z" + num_dates = dtype(50) + result = decode_cf_datetime(num_dates, units) + expected = np.asarray(np.datetime64("2018-08-22T03:23:53", "ns")) + np.testing.assert_equal(result, expected) + + +@requires_cftime +def test_decode_cf_datetime_uint64_with_cftime(): + units = "days since 1700-01-01" + num_dates = np.uint64(182621) + result = decode_cf_datetime(num_dates, units) + expected = np.asarray(np.datetime64("2200-01-01", "ns")) + np.testing.assert_equal(result, expected) + + +@requires_cftime +def test_decode_cf_datetime_uint64_with_cftime_overflow_error(): + units = "microseconds since 1700-01-01" + calendar = "360_day" + num_dates = np.uint64(1_000_000 * 86_400 * 360 * 500_000) + with pytest.raises(OverflowError): + decode_cf_datetime(num_dates, units, calendar) + + +@pytest.mark.parametrize("use_cftime", [True, False]) +def test_decode_0size_datetime(use_cftime): + # GH1329 + if use_cftime and not has_cftime: + pytest.skip() + + dtype = object if use_cftime else "M8[ns]" + expected = np.array([], dtype=dtype) + actual = decode_cf_datetime( + np.zeros(shape=0, dtype=np.int64), + units="days since 1970-01-01 00:00:00", + calendar="proleptic_gregorian", + use_cftime=use_cftime, + ) + np.testing.assert_equal(expected, actual) + + +@requires_cftime +def test_scalar_unit() -> None: + # test that a scalar units (often NaN when using to_netcdf) does not raise an error + variable = Variable(("x", "y"), np.array([[0, 1], [2, 3]]), {"units": np.nan}) + result = coding.times.CFDatetimeCoder().decode(variable) + assert np.isnan(result.attrs["units"]) + + +@requires_cftime +def test_contains_cftime_lazy() -> None: + import cftime + + from xarray.core.common import _contains_cftime_datetimes + + times = np.array( + [cftime.DatetimeGregorian(1, 1, 2, 0), cftime.DatetimeGregorian(1, 1, 2, 0)], + dtype=object, + ) + array = FirstElementAccessibleArray(times) + assert _contains_cftime_datetimes(array) + + +@pytest.mark.parametrize( + "timestr, timeunit, dtype, fill_value, use_encoding", + [ + ("1677-09-21T00:12:43.145224193", "ns", np.int64, 20, True), + ("1970-09-21T00:12:44.145224808", "ns", np.float64, 1e30, True), + ( + "1677-09-21T00:12:43.145225216", + "ns", + np.float64, + -9.223372036854776e18, + True, + ), + ("1677-09-21T00:12:43.145224193", "ns", np.int64, None, False), + ("1677-09-21T00:12:43.145225", "us", np.int64, None, False), + ("1970-01-01T00:00:01.000001", "us", np.int64, None, False), + ("1677-09-21T00:21:52.901038080", "ns", np.float32, 20.0, True), + ], +) +def test_roundtrip_datetime64_nanosecond_precision( + timestr: str, + timeunit: str, + dtype: np.typing.DTypeLike, + fill_value: int | float | None, + use_encoding: bool, +) -> None: + # test for GH7817 + time = np.datetime64(timestr, timeunit) + times = [np.datetime64("1970-01-01T00:00:00", timeunit), np.datetime64("NaT"), time] + + if use_encoding: + encoding = dict(dtype=dtype, _FillValue=fill_value) + else: + encoding = {} + + var = Variable(["time"], times, encoding=encoding) + assert var.dtype == np.dtype(" None: + # test warning if times can't be serialized faithfully + times = [ + np.datetime64("1970-01-01T00:01:00", "ns"), + np.datetime64("NaT"), + np.datetime64("1970-01-02T00:01:00", "ns"), + ] + units = "days since 1970-01-10T01:01:00" + needed_units = "hours" + new_units = f"{needed_units} since 1970-01-10T01:01:00" + + encoding = dict(dtype=None, _FillValue=20, units=units) + var = Variable(["time"], times, encoding=encoding) + with pytest.warns(UserWarning, match=f"Resolution of {needed_units!r} needed."): + encoded_var = conventions.encode_cf_variable(var) + assert encoded_var.dtype == np.float64 + assert encoded_var.attrs["units"] == units + assert encoded_var.attrs["_FillValue"] == 20.0 + + decoded_var = conventions.decode_cf_variable("foo", encoded_var) + assert_identical(var, decoded_var) + + encoding = dict(dtype="int64", _FillValue=20, units=units) + var = Variable(["time"], times, encoding=encoding) + with pytest.warns( + UserWarning, match=f"Serializing with units {new_units!r} instead." + ): + encoded_var = conventions.encode_cf_variable(var) + assert encoded_var.dtype == np.int64 + assert encoded_var.attrs["units"] == new_units + assert encoded_var.attrs["_FillValue"] == 20 + + decoded_var = conventions.decode_cf_variable("foo", encoded_var) + assert_identical(var, decoded_var) + + encoding = dict(dtype="float64", _FillValue=20, units=units) + var = Variable(["time"], times, encoding=encoding) + with warnings.catch_warnings(): + warnings.simplefilter("error") + encoded_var = conventions.encode_cf_variable(var) + assert encoded_var.dtype == np.float64 + assert encoded_var.attrs["units"] == units + assert encoded_var.attrs["_FillValue"] == 20.0 + + decoded_var = conventions.decode_cf_variable("foo", encoded_var) + assert_identical(var, decoded_var) + + encoding = dict(dtype="int64", _FillValue=20, units=new_units) + var = Variable(["time"], times, encoding=encoding) + with warnings.catch_warnings(): + warnings.simplefilter("error") + encoded_var = conventions.encode_cf_variable(var) + assert encoded_var.dtype == np.int64 + assert encoded_var.attrs["units"] == new_units + assert encoded_var.attrs["_FillValue"] == 20 + + decoded_var = conventions.decode_cf_variable("foo", encoded_var) + assert_identical(var, decoded_var) + + +@pytest.mark.parametrize( + "dtype, fill_value", + [(np.int64, 20), (np.int64, np.iinfo(np.int64).min), (np.float64, 1e30)], +) +def test_roundtrip_timedelta64_nanosecond_precision( + dtype: np.typing.DTypeLike, fill_value: int | float +) -> None: + # test for GH7942 + one_day = np.timedelta64(1, "ns") + nat = np.timedelta64("nat", "ns") + timedelta_values = (np.arange(5) * one_day).astype("timedelta64[ns]") + timedelta_values[2] = nat + timedelta_values[4] = nat + + encoding = dict(dtype=dtype, _FillValue=fill_value) + var = Variable(["time"], timedelta_values, encoding=encoding) + + encoded_var = conventions.encode_cf_variable(var) + decoded_var = conventions.decode_cf_variable("foo", encoded_var) + + assert_identical(var, decoded_var) + + +def test_roundtrip_timedelta64_nanosecond_precision_warning() -> None: + # test warning if timedeltas can't be serialized faithfully + one_day = np.timedelta64(1, "D") + nat = np.timedelta64("nat", "ns") + timedelta_values = (np.arange(5) * one_day).astype("timedelta64[ns]") + timedelta_values[2] = nat + timedelta_values[4] = np.timedelta64(12, "h").astype("timedelta64[ns]") + + units = "days" + needed_units = "hours" + wmsg = ( + f"Timedeltas can't be serialized faithfully with requested units {units!r}. " + f"Serializing with units {needed_units!r} instead." + ) + encoding = dict(dtype=np.int64, _FillValue=20, units=units) + var = Variable(["time"], timedelta_values, encoding=encoding) + with pytest.warns(UserWarning, match=wmsg): + encoded_var = conventions.encode_cf_variable(var) + assert encoded_var.dtype == np.int64 + assert encoded_var.attrs["units"] == needed_units + assert encoded_var.attrs["_FillValue"] == 20 + decoded_var = conventions.decode_cf_variable("foo", encoded_var) + assert_identical(var, decoded_var) + assert decoded_var.encoding["dtype"] == np.int64 + + +def test_roundtrip_float_times() -> None: + # Regression test for GitHub issue #8271 + fill_value = 20.0 + times = [ + np.datetime64("1970-01-01 00:00:00", "ns"), + np.datetime64("1970-01-01 06:00:00", "ns"), + np.datetime64("NaT", "ns"), + ] + + units = "days since 1960-01-01" + var = Variable( + ["time"], + times, + encoding=dict(dtype=np.float64, _FillValue=fill_value, units=units), + ) + + encoded_var = conventions.encode_cf_variable(var) + np.testing.assert_array_equal(encoded_var, np.array([3653, 3653.25, 20.0])) + assert encoded_var.attrs["units"] == units + assert encoded_var.attrs["_FillValue"] == fill_value + + decoded_var = conventions.decode_cf_variable("foo", encoded_var) + assert_identical(var, decoded_var) + assert decoded_var.encoding["units"] == units + assert decoded_var.encoding["_FillValue"] == fill_value + + +_ENCODE_DATETIME64_VIA_DASK_TESTS = { + "pandas-encoding-with-prescribed-units-and-dtype": ( + "D", + "days since 1700-01-01", + np.dtype("int32"), + ), + "mixed-cftime-pandas-encoding-with-prescribed-units-and-dtype": ( + "250YS", + "days since 1700-01-01", + np.dtype("int32"), + ), + "pandas-encoding-with-default-units-and-dtype": ("250YS", None, None), +} + + +@requires_dask +@pytest.mark.parametrize( + ("freq", "units", "dtype"), + _ENCODE_DATETIME64_VIA_DASK_TESTS.values(), + ids=_ENCODE_DATETIME64_VIA_DASK_TESTS.keys(), +) +def test_encode_cf_datetime_datetime64_via_dask(freq, units, dtype) -> None: + import dask.array + + times = pd.date_range(start="1700", freq=freq, periods=3) + times = dask.array.from_array(times, chunks=1) + encoded_times, encoding_units, encoding_calendar = encode_cf_datetime( + times, units, None, dtype + ) + + assert is_duck_dask_array(encoded_times) + assert encoded_times.chunks == times.chunks + + if units is not None and dtype is not None: + assert encoding_units == units + assert encoded_times.dtype == dtype + else: + assert encoding_units == "nanoseconds since 1970-01-01" + assert encoded_times.dtype == np.dtype("int64") + + assert encoding_calendar == "proleptic_gregorian" + + decoded_times = decode_cf_datetime(encoded_times, encoding_units, encoding_calendar) + np.testing.assert_equal(decoded_times, times) + + +@requires_dask +@pytest.mark.parametrize( + ("range_function", "start", "units", "dtype"), + [ + (pd.date_range, "2000", None, np.dtype("int32")), + (pd.date_range, "2000", "days since 2000-01-01", None), + (pd.timedelta_range, "0D", None, np.dtype("int32")), + (pd.timedelta_range, "0D", "days", None), + ], +) +def test_encode_via_dask_cannot_infer_error( + range_function, start, units, dtype +) -> None: + values = range_function(start=start, freq="D", periods=3) + encoding = dict(units=units, dtype=dtype) + variable = Variable(["time"], values, encoding=encoding).chunk({"time": 1}) + with pytest.raises(ValueError, match="When encoding chunked arrays"): + conventions.encode_cf_variable(variable) + + +@requires_cftime +@requires_dask +@pytest.mark.parametrize( + ("units", "dtype"), [("days since 1700-01-01", np.dtype("int32")), (None, None)] +) +def test_encode_cf_datetime_cftime_datetime_via_dask(units, dtype) -> None: + import dask.array + + calendar = "standard" + times = cftime_range(start="1700", freq="D", periods=3, calendar=calendar) + times = dask.array.from_array(times, chunks=1) + encoded_times, encoding_units, encoding_calendar = encode_cf_datetime( + times, units, None, dtype + ) + + assert is_duck_dask_array(encoded_times) + assert encoded_times.chunks == times.chunks + + if units is not None and dtype is not None: + assert encoding_units == units + assert encoded_times.dtype == dtype + else: + assert encoding_units == "microseconds since 1970-01-01" + assert encoded_times.dtype == np.int64 + + assert encoding_calendar == calendar + + decoded_times = decode_cf_datetime( + encoded_times, encoding_units, encoding_calendar, use_cftime=True + ) + np.testing.assert_equal(decoded_times, times) + + +@pytest.mark.parametrize( + "use_cftime", [False, pytest.param(True, marks=requires_cftime)] +) +@pytest.mark.parametrize("use_dask", [False, pytest.param(True, marks=requires_dask)]) +def test_encode_cf_datetime_casting_value_error(use_cftime, use_dask) -> None: + times = date_range(start="2000", freq="12h", periods=3, use_cftime=use_cftime) + encoding = dict(units="days since 2000-01-01", dtype=np.dtype("int64")) + variable = Variable(["time"], times, encoding=encoding) + + if use_dask: + variable = variable.chunk({"time": 1}) + + if not use_cftime and not use_dask: + # In this particular case we automatically modify the encoding units to + # continue encoding with integer values. For all other cases we raise. + with pytest.warns(UserWarning, match="Times can't be serialized"): + encoded = conventions.encode_cf_variable(variable) + assert encoded.attrs["units"] == "hours since 2000-01-01" + decoded = conventions.decode_cf_variable("name", encoded) + assert_equal(variable, decoded) + else: + with pytest.raises(ValueError, match="Not possible"): + encoded = conventions.encode_cf_variable(variable) + encoded.compute() + + +@pytest.mark.parametrize( + "use_cftime", [False, pytest.param(True, marks=requires_cftime)] +) +@pytest.mark.parametrize("use_dask", [False, pytest.param(True, marks=requires_dask)]) +@pytest.mark.parametrize("dtype", [np.dtype("int16"), np.dtype("float16")]) +def test_encode_cf_datetime_casting_overflow_error(use_cftime, use_dask, dtype) -> None: + # Regression test for GitHub issue #8542 + times = date_range(start="2018", freq="5h", periods=3, use_cftime=use_cftime) + encoding = dict(units="microseconds since 2018-01-01", dtype=dtype) + variable = Variable(["time"], times, encoding=encoding) + + if use_dask: + variable = variable.chunk({"time": 1}) + + with pytest.raises(OverflowError, match="Not possible"): + encoded = conventions.encode_cf_variable(variable) + encoded.compute() + + +@requires_dask +@pytest.mark.parametrize( + ("units", "dtype"), [("days", np.dtype("int32")), (None, None)] +) +def test_encode_cf_timedelta_via_dask(units, dtype) -> None: + import dask.array + + times = pd.timedelta_range(start="0D", freq="D", periods=3) + times = dask.array.from_array(times, chunks=1) + encoded_times, encoding_units = encode_cf_timedelta(times, units, dtype) + + assert is_duck_dask_array(encoded_times) + assert encoded_times.chunks == times.chunks + + if units is not None and dtype is not None: + assert encoding_units == units + assert encoded_times.dtype == dtype + else: + assert encoding_units == "nanoseconds" + assert encoded_times.dtype == np.dtype("int64") + + decoded_times = decode_cf_timedelta(encoded_times, encoding_units) + np.testing.assert_equal(decoded_times, times) + + +@pytest.mark.parametrize("use_dask", [False, pytest.param(True, marks=requires_dask)]) +def test_encode_cf_timedelta_casting_value_error(use_dask) -> None: + timedeltas = pd.timedelta_range(start="0h", freq="12h", periods=3) + encoding = dict(units="days", dtype=np.dtype("int64")) + variable = Variable(["time"], timedeltas, encoding=encoding) + + if use_dask: + variable = variable.chunk({"time": 1}) + + if not use_dask: + # In this particular case we automatically modify the encoding units to + # continue encoding with integer values. + with pytest.warns(UserWarning, match="Timedeltas can't be serialized"): + encoded = conventions.encode_cf_variable(variable) + assert encoded.attrs["units"] == "hours" + decoded = conventions.decode_cf_variable("name", encoded) + assert_equal(variable, decoded) + else: + with pytest.raises(ValueError, match="Not possible"): + encoded = conventions.encode_cf_variable(variable) + encoded.compute() + + +@pytest.mark.parametrize("use_dask", [False, pytest.param(True, marks=requires_dask)]) +@pytest.mark.parametrize("dtype", [np.dtype("int16"), np.dtype("float16")]) +def test_encode_cf_timedelta_casting_overflow_error(use_dask, dtype) -> None: + timedeltas = pd.timedelta_range(start="0h", freq="5h", periods=3) + encoding = dict(units="microseconds", dtype=dtype) + variable = Variable(["time"], timedeltas, encoding=encoding) + + if use_dask: + variable = variable.chunk({"time": 1}) + + with pytest.raises(OverflowError, match="Not possible"): + encoded = conventions.encode_cf_variable(variable) + encoded.compute() diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_combine.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_combine.py new file mode 100644 index 0000000..aad7103 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_combine.py @@ -0,0 +1,1165 @@ +from __future__ import annotations + +from itertools import product + +import numpy as np +import pytest + +from xarray import ( + DataArray, + Dataset, + MergeError, + combine_by_coords, + combine_nested, + concat, + merge, +) +from xarray.core import dtypes +from xarray.core.combine import ( + _check_shape_tile_ids, + _combine_all_along_first_dim, + _combine_nd, + _infer_concat_order_from_coords, + _infer_concat_order_from_positions, + _new_tile_id, +) +from xarray.tests import assert_equal, assert_identical, requires_cftime +from xarray.tests.test_dataset import create_test_data + + +def assert_combined_tile_ids_equal(dict1, dict2): + assert len(dict1) == len(dict2) + for k, v in dict1.items(): + assert k in dict2.keys() + assert_equal(dict1[k], dict2[k]) + + +class TestTileIDsFromNestedList: + def test_1d(self): + ds = create_test_data + input = [ds(0), ds(1)] + + expected = {(0,): ds(0), (1,): ds(1)} + actual = _infer_concat_order_from_positions(input) + assert_combined_tile_ids_equal(expected, actual) + + def test_2d(self): + ds = create_test_data + input = [[ds(0), ds(1)], [ds(2), ds(3)], [ds(4), ds(5)]] + + expected = { + (0, 0): ds(0), + (0, 1): ds(1), + (1, 0): ds(2), + (1, 1): ds(3), + (2, 0): ds(4), + (2, 1): ds(5), + } + actual = _infer_concat_order_from_positions(input) + assert_combined_tile_ids_equal(expected, actual) + + def test_3d(self): + ds = create_test_data + input = [ + [[ds(0), ds(1)], [ds(2), ds(3)], [ds(4), ds(5)]], + [[ds(6), ds(7)], [ds(8), ds(9)], [ds(10), ds(11)]], + ] + + expected = { + (0, 0, 0): ds(0), + (0, 0, 1): ds(1), + (0, 1, 0): ds(2), + (0, 1, 1): ds(3), + (0, 2, 0): ds(4), + (0, 2, 1): ds(5), + (1, 0, 0): ds(6), + (1, 0, 1): ds(7), + (1, 1, 0): ds(8), + (1, 1, 1): ds(9), + (1, 2, 0): ds(10), + (1, 2, 1): ds(11), + } + actual = _infer_concat_order_from_positions(input) + assert_combined_tile_ids_equal(expected, actual) + + def test_single_dataset(self): + ds = create_test_data(0) + input = [ds] + + expected = {(0,): ds} + actual = _infer_concat_order_from_positions(input) + assert_combined_tile_ids_equal(expected, actual) + + def test_redundant_nesting(self): + ds = create_test_data + input = [[ds(0)], [ds(1)]] + + expected = {(0, 0): ds(0), (1, 0): ds(1)} + actual = _infer_concat_order_from_positions(input) + assert_combined_tile_ids_equal(expected, actual) + + def test_ignore_empty_list(self): + ds = create_test_data(0) + input = [ds, []] + expected = {(0,): ds} + actual = _infer_concat_order_from_positions(input) + assert_combined_tile_ids_equal(expected, actual) + + def test_uneven_depth_input(self): + # Auto_combine won't work on ragged input + # but this is just to increase test coverage + ds = create_test_data + input = [ds(0), [ds(1), ds(2)]] + + expected = {(0,): ds(0), (1, 0): ds(1), (1, 1): ds(2)} + actual = _infer_concat_order_from_positions(input) + assert_combined_tile_ids_equal(expected, actual) + + def test_uneven_length_input(self): + # Auto_combine won't work on ragged input + # but this is just to increase test coverage + ds = create_test_data + input = [[ds(0)], [ds(1), ds(2)]] + + expected = {(0, 0): ds(0), (1, 0): ds(1), (1, 1): ds(2)} + actual = _infer_concat_order_from_positions(input) + assert_combined_tile_ids_equal(expected, actual) + + def test_infer_from_datasets(self): + ds = create_test_data + input = [ds(0), ds(1)] + + expected = {(0,): ds(0), (1,): ds(1)} + actual = _infer_concat_order_from_positions(input) + assert_combined_tile_ids_equal(expected, actual) + + +class TestTileIDsFromCoords: + def test_1d(self): + ds0 = Dataset({"x": [0, 1]}) + ds1 = Dataset({"x": [2, 3]}) + + expected = {(0,): ds0, (1,): ds1} + actual, concat_dims = _infer_concat_order_from_coords([ds1, ds0]) + assert_combined_tile_ids_equal(expected, actual) + assert concat_dims == ["x"] + + def test_2d(self): + ds0 = Dataset({"x": [0, 1], "y": [10, 20, 30]}) + ds1 = Dataset({"x": [2, 3], "y": [10, 20, 30]}) + ds2 = Dataset({"x": [0, 1], "y": [40, 50, 60]}) + ds3 = Dataset({"x": [2, 3], "y": [40, 50, 60]}) + ds4 = Dataset({"x": [0, 1], "y": [70, 80, 90]}) + ds5 = Dataset({"x": [2, 3], "y": [70, 80, 90]}) + + expected = { + (0, 0): ds0, + (1, 0): ds1, + (0, 1): ds2, + (1, 1): ds3, + (0, 2): ds4, + (1, 2): ds5, + } + actual, concat_dims = _infer_concat_order_from_coords( + [ds1, ds0, ds3, ds5, ds2, ds4] + ) + assert_combined_tile_ids_equal(expected, actual) + assert concat_dims == ["x", "y"] + + def test_no_dimension_coords(self): + ds0 = Dataset({"foo": ("x", [0, 1])}) + ds1 = Dataset({"foo": ("x", [2, 3])}) + with pytest.raises(ValueError, match=r"Could not find any dimension"): + _infer_concat_order_from_coords([ds1, ds0]) + + def test_coord_not_monotonic(self): + ds0 = Dataset({"x": [0, 1]}) + ds1 = Dataset({"x": [3, 2]}) + with pytest.raises( + ValueError, + match=r"Coordinate variable x is neither monotonically increasing nor", + ): + _infer_concat_order_from_coords([ds1, ds0]) + + def test_coord_monotonically_decreasing(self): + ds0 = Dataset({"x": [3, 2]}) + ds1 = Dataset({"x": [1, 0]}) + + expected = {(0,): ds0, (1,): ds1} + actual, concat_dims = _infer_concat_order_from_coords([ds1, ds0]) + assert_combined_tile_ids_equal(expected, actual) + assert concat_dims == ["x"] + + def test_no_concatenation_needed(self): + ds = Dataset({"foo": ("x", [0, 1])}) + expected = {(): ds} + actual, concat_dims = _infer_concat_order_from_coords([ds]) + assert_combined_tile_ids_equal(expected, actual) + assert concat_dims == [] + + def test_2d_plus_bystander_dim(self): + ds0 = Dataset({"x": [0, 1], "y": [10, 20, 30], "t": [0.1, 0.2]}) + ds1 = Dataset({"x": [2, 3], "y": [10, 20, 30], "t": [0.1, 0.2]}) + ds2 = Dataset({"x": [0, 1], "y": [40, 50, 60], "t": [0.1, 0.2]}) + ds3 = Dataset({"x": [2, 3], "y": [40, 50, 60], "t": [0.1, 0.2]}) + + expected = {(0, 0): ds0, (1, 0): ds1, (0, 1): ds2, (1, 1): ds3} + actual, concat_dims = _infer_concat_order_from_coords([ds1, ds0, ds3, ds2]) + assert_combined_tile_ids_equal(expected, actual) + assert concat_dims == ["x", "y"] + + def test_string_coords(self): + ds0 = Dataset({"person": ["Alice", "Bob"]}) + ds1 = Dataset({"person": ["Caroline", "Daniel"]}) + + expected = {(0,): ds0, (1,): ds1} + actual, concat_dims = _infer_concat_order_from_coords([ds1, ds0]) + assert_combined_tile_ids_equal(expected, actual) + assert concat_dims == ["person"] + + # Decided against natural sorting of string coords GH #2616 + def test_lexicographic_sort_string_coords(self): + ds0 = Dataset({"simulation": ["run8", "run9"]}) + ds1 = Dataset({"simulation": ["run10", "run11"]}) + + expected = {(0,): ds1, (1,): ds0} + actual, concat_dims = _infer_concat_order_from_coords([ds1, ds0]) + assert_combined_tile_ids_equal(expected, actual) + assert concat_dims == ["simulation"] + + def test_datetime_coords(self): + ds0 = Dataset( + {"time": np.array(["2000-03-06", "2000-03-07"], dtype="datetime64[ns]")} + ) + ds1 = Dataset( + {"time": np.array(["1999-01-01", "1999-02-04"], dtype="datetime64[ns]")} + ) + + expected = {(0,): ds1, (1,): ds0} + actual, concat_dims = _infer_concat_order_from_coords([ds0, ds1]) + assert_combined_tile_ids_equal(expected, actual) + assert concat_dims == ["time"] + + +@pytest.fixture(scope="module") +def create_combined_ids(): + return _create_combined_ids + + +def _create_combined_ids(shape): + tile_ids = _create_tile_ids(shape) + nums = range(len(tile_ids)) + return {tile_id: create_test_data(num) for tile_id, num in zip(tile_ids, nums)} + + +def _create_tile_ids(shape): + tile_ids = product(*(range(i) for i in shape)) + return list(tile_ids) + + +class TestNewTileIDs: + @pytest.mark.parametrize( + "old_id, new_id", + [((3, 0, 1), (0, 1)), ((0, 0), (0,)), ((1,), ()), ((0,), ()), ((1, 0), (0,))], + ) + def test_new_tile_id(self, old_id, new_id): + ds = create_test_data + assert _new_tile_id((old_id, ds)) == new_id + + def test_get_new_tile_ids(self, create_combined_ids): + shape = (1, 2, 3) + combined_ids = create_combined_ids(shape) + + expected_tile_ids = sorted(combined_ids.keys()) + actual_tile_ids = _create_tile_ids(shape) + assert expected_tile_ids == actual_tile_ids + + +class TestCombineND: + @pytest.mark.parametrize("concat_dim", ["dim1", "new_dim"]) + def test_concat_once(self, create_combined_ids, concat_dim): + shape = (2,) + combined_ids = create_combined_ids(shape) + ds = create_test_data + result = _combine_all_along_first_dim( + combined_ids, + dim=concat_dim, + data_vars="all", + coords="different", + compat="no_conflicts", + ) + + expected_ds = concat([ds(0), ds(1)], dim=concat_dim) + assert_combined_tile_ids_equal(result, {(): expected_ds}) + + def test_concat_only_first_dim(self, create_combined_ids): + shape = (2, 3) + combined_ids = create_combined_ids(shape) + result = _combine_all_along_first_dim( + combined_ids, + dim="dim1", + data_vars="all", + coords="different", + compat="no_conflicts", + ) + + ds = create_test_data + partway1 = concat([ds(0), ds(3)], dim="dim1") + partway2 = concat([ds(1), ds(4)], dim="dim1") + partway3 = concat([ds(2), ds(5)], dim="dim1") + expected_datasets = [partway1, partway2, partway3] + expected = {(i,): ds for i, ds in enumerate(expected_datasets)} + + assert_combined_tile_ids_equal(result, expected) + + @pytest.mark.parametrize("concat_dim", ["dim1", "new_dim"]) + def test_concat_twice(self, create_combined_ids, concat_dim): + shape = (2, 3) + combined_ids = create_combined_ids(shape) + result = _combine_nd(combined_ids, concat_dims=["dim1", concat_dim]) + + ds = create_test_data + partway1 = concat([ds(0), ds(3)], dim="dim1") + partway2 = concat([ds(1), ds(4)], dim="dim1") + partway3 = concat([ds(2), ds(5)], dim="dim1") + expected = concat([partway1, partway2, partway3], dim=concat_dim) + + assert_equal(result, expected) + + +class TestCheckShapeTileIDs: + def test_check_depths(self): + ds = create_test_data(0) + combined_tile_ids = {(0,): ds, (0, 1): ds} + with pytest.raises( + ValueError, match=r"sub-lists do not have consistent depths" + ): + _check_shape_tile_ids(combined_tile_ids) + + def test_check_lengths(self): + ds = create_test_data(0) + combined_tile_ids = {(0, 0): ds, (0, 1): ds, (0, 2): ds, (1, 0): ds, (1, 1): ds} + with pytest.raises( + ValueError, match=r"sub-lists do not have consistent lengths" + ): + _check_shape_tile_ids(combined_tile_ids) + + +class TestNestedCombine: + def test_nested_concat(self): + objs = [Dataset({"x": [0]}), Dataset({"x": [1]})] + expected = Dataset({"x": [0, 1]}) + actual = combine_nested(objs, concat_dim="x") + assert_identical(expected, actual) + actual = combine_nested(objs, concat_dim=["x"]) + assert_identical(expected, actual) + + actual = combine_nested([actual], concat_dim=None) + assert_identical(expected, actual) + + actual = combine_nested([actual], concat_dim="x") + assert_identical(expected, actual) + + objs = [Dataset({"x": [0, 1]}), Dataset({"x": [2]})] + actual = combine_nested(objs, concat_dim="x") + expected = Dataset({"x": [0, 1, 2]}) + assert_identical(expected, actual) + + # ensure combine_nested handles non-sorted variables + objs = [ + Dataset({"x": ("a", [0]), "y": ("a", [0])}), + Dataset({"y": ("a", [1]), "x": ("a", [1])}), + ] + actual = combine_nested(objs, concat_dim="a") + expected = Dataset({"x": ("a", [0, 1]), "y": ("a", [0, 1])}) + assert_identical(expected, actual) + + objs = [Dataset({"x": [0], "y": [0]}), Dataset({"x": [1]})] + actual = combine_nested(objs, concat_dim="x") + expected = Dataset({"x": [0, 1], "y": [0]}) + assert_identical(expected, actual) + + @pytest.mark.parametrize( + "join, expected", + [ + ("outer", Dataset({"x": [0, 1], "y": [0, 1]})), + ("inner", Dataset({"x": [0, 1], "y": []})), + ("left", Dataset({"x": [0, 1], "y": [0]})), + ("right", Dataset({"x": [0, 1], "y": [1]})), + ], + ) + def test_combine_nested_join(self, join, expected): + objs = [Dataset({"x": [0], "y": [0]}), Dataset({"x": [1], "y": [1]})] + actual = combine_nested(objs, concat_dim="x", join=join) + assert_identical(expected, actual) + + def test_combine_nested_join_exact(self): + objs = [Dataset({"x": [0], "y": [0]}), Dataset({"x": [1], "y": [1]})] + with pytest.raises(ValueError, match=r"cannot align.*join.*exact"): + combine_nested(objs, concat_dim="x", join="exact") + + def test_empty_input(self): + assert_identical(Dataset(), combine_nested([], concat_dim="x")) + + # Fails because of concat's weird treatment of dimension coords, see #2975 + @pytest.mark.xfail + def test_nested_concat_too_many_dims_at_once(self): + objs = [Dataset({"x": [0], "y": [1]}), Dataset({"y": [0], "x": [1]})] + with pytest.raises(ValueError, match="not equal across datasets"): + combine_nested(objs, concat_dim="x", coords="minimal") + + def test_nested_concat_along_new_dim(self): + objs = [ + Dataset({"a": ("x", [10]), "x": [0]}), + Dataset({"a": ("x", [20]), "x": [0]}), + ] + expected = Dataset({"a": (("t", "x"), [[10], [20]]), "x": [0]}) + actual = combine_nested(objs, concat_dim="t") + assert_identical(expected, actual) + + # Same but with a DataArray as new dim, see GH #1988 and #2647 + dim = DataArray([100, 150], name="baz", dims="baz") + expected = Dataset( + {"a": (("baz", "x"), [[10], [20]]), "x": [0], "baz": [100, 150]} + ) + actual = combine_nested(objs, concat_dim=dim) + assert_identical(expected, actual) + + def test_nested_merge(self): + data = Dataset({"x": 0}) + actual = combine_nested([data, data, data], concat_dim=None) + assert_identical(data, actual) + + ds1 = Dataset({"a": ("x", [1, 2]), "x": [0, 1]}) + ds2 = Dataset({"a": ("x", [2, 3]), "x": [1, 2]}) + expected = Dataset({"a": ("x", [1, 2, 3]), "x": [0, 1, 2]}) + actual = combine_nested([ds1, ds2], concat_dim=None) + assert_identical(expected, actual) + actual = combine_nested([ds1, ds2], concat_dim=[None]) + assert_identical(expected, actual) + + tmp1 = Dataset({"x": 0}) + tmp2 = Dataset({"x": np.nan}) + actual = combine_nested([tmp1, tmp2], concat_dim=None) + assert_identical(tmp1, actual) + actual = combine_nested([tmp1, tmp2], concat_dim=[None]) + assert_identical(tmp1, actual) + + # Single object, with a concat_dim explicitly provided + # Test the issue reported in GH #1988 + objs = [Dataset({"x": 0, "y": 1})] + dim = DataArray([100], name="baz", dims="baz") + actual = combine_nested(objs, concat_dim=[dim]) + expected = Dataset({"x": ("baz", [0]), "y": ("baz", [1])}, {"baz": [100]}) + assert_identical(expected, actual) + + # Just making sure that auto_combine is doing what is + # expected for non-scalar values, too. + objs = [Dataset({"x": ("z", [0, 1]), "y": ("z", [1, 2])})] + dim = DataArray([100], name="baz", dims="baz") + actual = combine_nested(objs, concat_dim=[dim]) + expected = Dataset( + {"x": (("baz", "z"), [[0, 1]]), "y": (("baz", "z"), [[1, 2]])}, + {"baz": [100]}, + ) + assert_identical(expected, actual) + + def test_concat_multiple_dims(self): + objs = [ + [Dataset({"a": (("x", "y"), [[0]])}), Dataset({"a": (("x", "y"), [[1]])})], + [Dataset({"a": (("x", "y"), [[2]])}), Dataset({"a": (("x", "y"), [[3]])})], + ] + actual = combine_nested(objs, concat_dim=["x", "y"]) + expected = Dataset({"a": (("x", "y"), [[0, 1], [2, 3]])}) + assert_identical(expected, actual) + + def test_concat_name_symmetry(self): + """Inspired by the discussion on GH issue #2777""" + + da1 = DataArray(name="a", data=[[0]], dims=["x", "y"]) + da2 = DataArray(name="b", data=[[1]], dims=["x", "y"]) + da3 = DataArray(name="a", data=[[2]], dims=["x", "y"]) + da4 = DataArray(name="b", data=[[3]], dims=["x", "y"]) + + x_first = combine_nested([[da1, da2], [da3, da4]], concat_dim=["x", "y"]) + y_first = combine_nested([[da1, da3], [da2, da4]], concat_dim=["y", "x"]) + + assert_identical(x_first, y_first) + + def test_concat_one_dim_merge_another(self): + data = create_test_data(add_attrs=False) + + data1 = data.copy(deep=True) + data2 = data.copy(deep=True) + + objs = [ + [data1.var1.isel(dim2=slice(4)), data2.var1.isel(dim2=slice(4, 9))], + [data1.var2.isel(dim2=slice(4)), data2.var2.isel(dim2=slice(4, 9))], + ] + + expected = data[["var1", "var2"]] + actual = combine_nested(objs, concat_dim=[None, "dim2"]) + assert_identical(expected, actual) + + def test_auto_combine_2d(self): + ds = create_test_data + + partway1 = concat([ds(0), ds(3)], dim="dim1") + partway2 = concat([ds(1), ds(4)], dim="dim1") + partway3 = concat([ds(2), ds(5)], dim="dim1") + expected = concat([partway1, partway2, partway3], dim="dim2") + + datasets = [[ds(0), ds(1), ds(2)], [ds(3), ds(4), ds(5)]] + result = combine_nested(datasets, concat_dim=["dim1", "dim2"]) + assert_equal(result, expected) + + def test_auto_combine_2d_combine_attrs_kwarg(self): + ds = lambda x: create_test_data(x, add_attrs=False) + + partway1 = concat([ds(0), ds(3)], dim="dim1") + partway2 = concat([ds(1), ds(4)], dim="dim1") + partway3 = concat([ds(2), ds(5)], dim="dim1") + expected = concat([partway1, partway2, partway3], dim="dim2") + + expected_dict = {} + expected_dict["drop"] = expected.copy(deep=True) + expected_dict["drop"].attrs = {} + expected_dict["no_conflicts"] = expected.copy(deep=True) + expected_dict["no_conflicts"].attrs = { + "a": 1, + "b": 2, + "c": 3, + "d": 4, + "e": 5, + "f": 6, + } + expected_dict["override"] = expected.copy(deep=True) + expected_dict["override"].attrs = {"a": 1} + f = lambda attrs, context: attrs[0] + expected_dict[f] = expected.copy(deep=True) + expected_dict[f].attrs = f([{"a": 1}], None) + + datasets = [[ds(0), ds(1), ds(2)], [ds(3), ds(4), ds(5)]] + + datasets[0][0].attrs = {"a": 1} + datasets[0][1].attrs = {"a": 1, "b": 2} + datasets[0][2].attrs = {"a": 1, "c": 3} + datasets[1][0].attrs = {"a": 1, "d": 4} + datasets[1][1].attrs = {"a": 1, "e": 5} + datasets[1][2].attrs = {"a": 1, "f": 6} + + with pytest.raises(ValueError, match=r"combine_attrs='identical'"): + result = combine_nested( + datasets, concat_dim=["dim1", "dim2"], combine_attrs="identical" + ) + + for combine_attrs in expected_dict: + result = combine_nested( + datasets, concat_dim=["dim1", "dim2"], combine_attrs=combine_attrs + ) + assert_identical(result, expected_dict[combine_attrs]) + + def test_combine_nested_missing_data_new_dim(self): + # Your data includes "time" and "station" dimensions, and each year's + # data has a different set of stations. + datasets = [ + Dataset({"a": ("x", [2, 3]), "x": [1, 2]}), + Dataset({"a": ("x", [1, 2]), "x": [0, 1]}), + ] + expected = Dataset( + {"a": (("t", "x"), [[np.nan, 2, 3], [1, 2, np.nan]])}, {"x": [0, 1, 2]} + ) + actual = combine_nested(datasets, concat_dim="t") + assert_identical(expected, actual) + + def test_invalid_hypercube_input(self): + ds = create_test_data + + datasets = [[ds(0), ds(1), ds(2)], [ds(3), ds(4)]] + with pytest.raises( + ValueError, match=r"sub-lists do not have consistent lengths" + ): + combine_nested(datasets, concat_dim=["dim1", "dim2"]) + + datasets = [[ds(0), ds(1)], [[ds(3), ds(4)]]] + with pytest.raises( + ValueError, match=r"sub-lists do not have consistent depths" + ): + combine_nested(datasets, concat_dim=["dim1", "dim2"]) + + datasets = [[ds(0), ds(1)], [ds(3), ds(4)]] + with pytest.raises(ValueError, match=r"concat_dims has length"): + combine_nested(datasets, concat_dim=["dim1"]) + + def test_merge_one_dim_concat_another(self): + objs = [ + [Dataset({"foo": ("x", [0, 1])}), Dataset({"bar": ("x", [10, 20])})], + [Dataset({"foo": ("x", [2, 3])}), Dataset({"bar": ("x", [30, 40])})], + ] + expected = Dataset({"foo": ("x", [0, 1, 2, 3]), "bar": ("x", [10, 20, 30, 40])}) + + actual = combine_nested(objs, concat_dim=["x", None], compat="equals") + assert_identical(expected, actual) + + # Proving it works symmetrically + objs = [ + [Dataset({"foo": ("x", [0, 1])}), Dataset({"foo": ("x", [2, 3])})], + [Dataset({"bar": ("x", [10, 20])}), Dataset({"bar": ("x", [30, 40])})], + ] + actual = combine_nested(objs, concat_dim=[None, "x"], compat="equals") + assert_identical(expected, actual) + + def test_combine_concat_over_redundant_nesting(self): + objs = [[Dataset({"x": [0]}), Dataset({"x": [1]})]] + actual = combine_nested(objs, concat_dim=[None, "x"]) + expected = Dataset({"x": [0, 1]}) + assert_identical(expected, actual) + + objs = [[Dataset({"x": [0]})], [Dataset({"x": [1]})]] + actual = combine_nested(objs, concat_dim=["x", None]) + expected = Dataset({"x": [0, 1]}) + assert_identical(expected, actual) + + objs = [[Dataset({"x": [0]})]] + actual = combine_nested(objs, concat_dim=[None, None]) + expected = Dataset({"x": [0]}) + assert_identical(expected, actual) + + @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0, {"a": 2, "b": 1}]) + def test_combine_nested_fill_value(self, fill_value): + datasets = [ + Dataset({"a": ("x", [2, 3]), "b": ("x", [-2, 1]), "x": [1, 2]}), + Dataset({"a": ("x", [1, 2]), "b": ("x", [3, -1]), "x": [0, 1]}), + ] + if fill_value == dtypes.NA: + # if we supply the default, we expect the missing value for a + # float array + fill_value_a = fill_value_b = np.nan + elif isinstance(fill_value, dict): + fill_value_a = fill_value["a"] + fill_value_b = fill_value["b"] + else: + fill_value_a = fill_value_b = fill_value + expected = Dataset( + { + "a": (("t", "x"), [[fill_value_a, 2, 3], [1, 2, fill_value_a]]), + "b": (("t", "x"), [[fill_value_b, -2, 1], [3, -1, fill_value_b]]), + }, + {"x": [0, 1, 2]}, + ) + actual = combine_nested(datasets, concat_dim="t", fill_value=fill_value) + assert_identical(expected, actual) + + def test_combine_nested_unnamed_data_arrays(self): + unnamed_array = DataArray(data=[1.0, 2.0], coords={"x": [0, 1]}, dims="x") + + actual = combine_nested([unnamed_array], concat_dim="x") + expected = unnamed_array + assert_identical(expected, actual) + + unnamed_array1 = DataArray(data=[1.0, 2.0], coords={"x": [0, 1]}, dims="x") + unnamed_array2 = DataArray(data=[3.0, 4.0], coords={"x": [2, 3]}, dims="x") + + actual = combine_nested([unnamed_array1, unnamed_array2], concat_dim="x") + expected = DataArray( + data=[1.0, 2.0, 3.0, 4.0], coords={"x": [0, 1, 2, 3]}, dims="x" + ) + assert_identical(expected, actual) + + da1 = DataArray(data=[[0.0]], coords={"x": [0], "y": [0]}, dims=["x", "y"]) + da2 = DataArray(data=[[1.0]], coords={"x": [0], "y": [1]}, dims=["x", "y"]) + da3 = DataArray(data=[[2.0]], coords={"x": [1], "y": [0]}, dims=["x", "y"]) + da4 = DataArray(data=[[3.0]], coords={"x": [1], "y": [1]}, dims=["x", "y"]) + objs = [[da1, da2], [da3, da4]] + + expected = DataArray( + data=[[0.0, 1.0], [2.0, 3.0]], + coords={"x": [0, 1], "y": [0, 1]}, + dims=["x", "y"], + ) + actual = combine_nested(objs, concat_dim=["x", "y"]) + assert_identical(expected, actual) + + # TODO aijams - Determine if this test is appropriate. + def test_nested_combine_mixed_datasets_arrays(self): + objs = [ + DataArray([0, 1], dims=("x"), coords=({"x": [0, 1]})), + Dataset({"x": [2, 3]}), + ] + with pytest.raises( + ValueError, match=r"Can't combine datasets with unnamed arrays." + ): + combine_nested(objs, "x") + + +class TestCombineDatasetsbyCoords: + def test_combine_by_coords(self): + objs = [Dataset({"x": [0]}), Dataset({"x": [1]})] + actual = combine_by_coords(objs) + expected = Dataset({"x": [0, 1]}) + assert_identical(expected, actual) + + actual = combine_by_coords([actual]) + assert_identical(expected, actual) + + objs = [Dataset({"x": [0, 1]}), Dataset({"x": [2]})] + actual = combine_by_coords(objs) + expected = Dataset({"x": [0, 1, 2]}) + assert_identical(expected, actual) + + # ensure auto_combine handles non-sorted variables + objs = [ + Dataset({"x": ("a", [0]), "y": ("a", [0]), "a": [0]}), + Dataset({"x": ("a", [1]), "y": ("a", [1]), "a": [1]}), + ] + actual = combine_by_coords(objs) + expected = Dataset({"x": ("a", [0, 1]), "y": ("a", [0, 1]), "a": [0, 1]}) + assert_identical(expected, actual) + + objs = [Dataset({"x": [0], "y": [0]}), Dataset({"y": [1], "x": [1]})] + actual = combine_by_coords(objs) + expected = Dataset({"x": [0, 1], "y": [0, 1]}) + assert_equal(actual, expected) + + objs = [Dataset({"x": 0}), Dataset({"x": 1})] + with pytest.raises( + ValueError, match=r"Could not find any dimension coordinates" + ): + combine_by_coords(objs) + + objs = [Dataset({"x": [0], "y": [0]}), Dataset({"x": [0]})] + with pytest.raises(ValueError, match=r"Every dimension needs a coordinate"): + combine_by_coords(objs) + + def test_empty_input(self): + assert_identical(Dataset(), combine_by_coords([])) + + @pytest.mark.parametrize( + "join, expected", + [ + ("outer", Dataset({"x": [0, 1], "y": [0, 1]})), + ("inner", Dataset({"x": [0, 1], "y": []})), + ("left", Dataset({"x": [0, 1], "y": [0]})), + ("right", Dataset({"x": [0, 1], "y": [1]})), + ], + ) + def test_combine_coords_join(self, join, expected): + objs = [Dataset({"x": [0], "y": [0]}), Dataset({"x": [1], "y": [1]})] + actual = combine_nested(objs, concat_dim="x", join=join) + assert_identical(expected, actual) + + def test_combine_coords_join_exact(self): + objs = [Dataset({"x": [0], "y": [0]}), Dataset({"x": [1], "y": [1]})] + with pytest.raises(ValueError, match=r"cannot align.*join.*exact.*"): + combine_nested(objs, concat_dim="x", join="exact") + + @pytest.mark.parametrize( + "combine_attrs, expected", + [ + ("drop", Dataset({"x": [0, 1], "y": [0, 1]}, attrs={})), + ( + "no_conflicts", + Dataset({"x": [0, 1], "y": [0, 1]}, attrs={"a": 1, "b": 2}), + ), + ("override", Dataset({"x": [0, 1], "y": [0, 1]}, attrs={"a": 1})), + ( + lambda attrs, context: attrs[1], + Dataset({"x": [0, 1], "y": [0, 1]}, attrs={"a": 1, "b": 2}), + ), + ], + ) + def test_combine_coords_combine_attrs(self, combine_attrs, expected): + objs = [ + Dataset({"x": [0], "y": [0]}, attrs={"a": 1}), + Dataset({"x": [1], "y": [1]}, attrs={"a": 1, "b": 2}), + ] + actual = combine_nested( + objs, concat_dim="x", join="outer", combine_attrs=combine_attrs + ) + assert_identical(expected, actual) + + if combine_attrs == "no_conflicts": + objs[1].attrs["a"] = 2 + with pytest.raises(ValueError, match=r"combine_attrs='no_conflicts'"): + actual = combine_nested( + objs, concat_dim="x", join="outer", combine_attrs=combine_attrs + ) + + def test_combine_coords_combine_attrs_identical(self): + objs = [ + Dataset({"x": [0], "y": [0]}, attrs={"a": 1}), + Dataset({"x": [1], "y": [1]}, attrs={"a": 1}), + ] + expected = Dataset({"x": [0, 1], "y": [0, 1]}, attrs={"a": 1}) + actual = combine_nested( + objs, concat_dim="x", join="outer", combine_attrs="identical" + ) + assert_identical(expected, actual) + + objs[1].attrs["b"] = 2 + + with pytest.raises(ValueError, match=r"combine_attrs='identical'"): + actual = combine_nested( + objs, concat_dim="x", join="outer", combine_attrs="identical" + ) + + def test_combine_nested_combine_attrs_drop_conflicts(self): + objs = [ + Dataset({"x": [0], "y": [0]}, attrs={"a": 1, "b": 2, "c": 3}), + Dataset({"x": [1], "y": [1]}, attrs={"a": 1, "b": 0, "d": 3}), + ] + expected = Dataset({"x": [0, 1], "y": [0, 1]}, attrs={"a": 1, "c": 3, "d": 3}) + actual = combine_nested( + objs, concat_dim="x", join="outer", combine_attrs="drop_conflicts" + ) + assert_identical(expected, actual) + + @pytest.mark.parametrize( + "combine_attrs, attrs1, attrs2, expected_attrs, expect_exception", + [ + ( + "no_conflicts", + {"a": 1, "b": 2}, + {"a": 1, "c": 3}, + {"a": 1, "b": 2, "c": 3}, + False, + ), + ("no_conflicts", {"a": 1, "b": 2}, {}, {"a": 1, "b": 2}, False), + ("no_conflicts", {}, {"a": 1, "c": 3}, {"a": 1, "c": 3}, False), + ( + "no_conflicts", + {"a": 1, "b": 2}, + {"a": 4, "c": 3}, + {"a": 1, "b": 2, "c": 3}, + True, + ), + ("drop", {"a": 1, "b": 2}, {"a": 1, "c": 3}, {}, False), + ("identical", {"a": 1, "b": 2}, {"a": 1, "b": 2}, {"a": 1, "b": 2}, False), + ("identical", {"a": 1, "b": 2}, {"a": 1, "c": 3}, {"a": 1, "b": 2}, True), + ( + "override", + {"a": 1, "b": 2}, + {"a": 4, "b": 5, "c": 3}, + {"a": 1, "b": 2}, + False, + ), + ( + "drop_conflicts", + {"a": 1, "b": 2, "c": 3}, + {"b": 1, "c": 3, "d": 4}, + {"a": 1, "c": 3, "d": 4}, + False, + ), + ], + ) + def test_combine_nested_combine_attrs_variables( + self, combine_attrs, attrs1, attrs2, expected_attrs, expect_exception + ): + """check that combine_attrs is used on data variables and coords""" + data1 = Dataset( + { + "a": ("x", [1, 2], attrs1), + "b": ("x", [3, -1], attrs1), + "x": ("x", [0, 1], attrs1), + } + ) + data2 = Dataset( + { + "a": ("x", [2, 3], attrs2), + "b": ("x", [-2, 1], attrs2), + "x": ("x", [2, 3], attrs2), + } + ) + + if expect_exception: + with pytest.raises(MergeError, match="combine_attrs"): + combine_by_coords([data1, data2], combine_attrs=combine_attrs) + else: + actual = combine_by_coords([data1, data2], combine_attrs=combine_attrs) + expected = Dataset( + { + "a": ("x", [1, 2, 2, 3], expected_attrs), + "b": ("x", [3, -1, -2, 1], expected_attrs), + }, + {"x": ("x", [0, 1, 2, 3], expected_attrs)}, + ) + + assert_identical(actual, expected) + + @pytest.mark.parametrize( + "combine_attrs, attrs1, attrs2, expected_attrs, expect_exception", + [ + ( + "no_conflicts", + {"a": 1, "b": 2}, + {"a": 1, "c": 3}, + {"a": 1, "b": 2, "c": 3}, + False, + ), + ("no_conflicts", {"a": 1, "b": 2}, {}, {"a": 1, "b": 2}, False), + ("no_conflicts", {}, {"a": 1, "c": 3}, {"a": 1, "c": 3}, False), + ( + "no_conflicts", + {"a": 1, "b": 2}, + {"a": 4, "c": 3}, + {"a": 1, "b": 2, "c": 3}, + True, + ), + ("drop", {"a": 1, "b": 2}, {"a": 1, "c": 3}, {}, False), + ("identical", {"a": 1, "b": 2}, {"a": 1, "b": 2}, {"a": 1, "b": 2}, False), + ("identical", {"a": 1, "b": 2}, {"a": 1, "c": 3}, {"a": 1, "b": 2}, True), + ( + "override", + {"a": 1, "b": 2}, + {"a": 4, "b": 5, "c": 3}, + {"a": 1, "b": 2}, + False, + ), + ( + "drop_conflicts", + {"a": 1, "b": 2, "c": 3}, + {"b": 1, "c": 3, "d": 4}, + {"a": 1, "c": 3, "d": 4}, + False, + ), + ], + ) + def test_combine_by_coords_combine_attrs_variables( + self, combine_attrs, attrs1, attrs2, expected_attrs, expect_exception + ): + """check that combine_attrs is used on data variables and coords""" + data1 = Dataset( + {"x": ("a", [0], attrs1), "y": ("a", [0], attrs1), "a": ("a", [0], attrs1)} + ) + data2 = Dataset( + {"x": ("a", [1], attrs2), "y": ("a", [1], attrs2), "a": ("a", [1], attrs2)} + ) + + if expect_exception: + with pytest.raises(MergeError, match="combine_attrs"): + combine_by_coords([data1, data2], combine_attrs=combine_attrs) + else: + actual = combine_by_coords([data1, data2], combine_attrs=combine_attrs) + expected = Dataset( + { + "x": ("a", [0, 1], expected_attrs), + "y": ("a", [0, 1], expected_attrs), + "a": ("a", [0, 1], expected_attrs), + } + ) + + assert_identical(actual, expected) + + def test_infer_order_from_coords(self): + data = create_test_data() + objs = [data.isel(dim2=slice(4, 9)), data.isel(dim2=slice(4))] + actual = combine_by_coords(objs) + expected = data + assert expected.broadcast_equals(actual) + + def test_combine_leaving_bystander_dimensions(self): + # Check non-monotonic bystander dimension coord doesn't raise + # ValueError on combine (https://github.com/pydata/xarray/issues/3150) + ycoord = ["a", "c", "b"] + + data = np.random.rand(7, 3) + + ds1 = Dataset( + data_vars=dict(data=(["x", "y"], data[:3, :])), + coords=dict(x=[1, 2, 3], y=ycoord), + ) + + ds2 = Dataset( + data_vars=dict(data=(["x", "y"], data[3:, :])), + coords=dict(x=[4, 5, 6, 7], y=ycoord), + ) + + expected = Dataset( + data_vars=dict(data=(["x", "y"], data)), + coords=dict(x=[1, 2, 3, 4, 5, 6, 7], y=ycoord), + ) + + actual = combine_by_coords((ds1, ds2)) + assert_identical(expected, actual) + + def test_combine_by_coords_previously_failed(self): + # In the above scenario, one file is missing, containing the data for + # one year's data for one variable. + datasets = [ + Dataset({"a": ("x", [0]), "x": [0]}), + Dataset({"b": ("x", [0]), "x": [0]}), + Dataset({"a": ("x", [1]), "x": [1]}), + ] + expected = Dataset({"a": ("x", [0, 1]), "b": ("x", [0, np.nan])}, {"x": [0, 1]}) + actual = combine_by_coords(datasets) + assert_identical(expected, actual) + + def test_combine_by_coords_still_fails(self): + # concat can't handle new variables (yet): + # https://github.com/pydata/xarray/issues/508 + datasets = [Dataset({"x": 0}, {"y": 0}), Dataset({"x": 1}, {"y": 1, "z": 1})] + with pytest.raises(ValueError): + combine_by_coords(datasets, "y") + + def test_combine_by_coords_no_concat(self): + objs = [Dataset({"x": 0}), Dataset({"y": 1})] + actual = combine_by_coords(objs) + expected = Dataset({"x": 0, "y": 1}) + assert_identical(expected, actual) + + objs = [Dataset({"x": 0, "y": 1}), Dataset({"y": np.nan, "z": 2})] + actual = combine_by_coords(objs) + expected = Dataset({"x": 0, "y": 1, "z": 2}) + assert_identical(expected, actual) + + def test_check_for_impossible_ordering(self): + ds0 = Dataset({"x": [0, 1, 5]}) + ds1 = Dataset({"x": [2, 3]}) + with pytest.raises( + ValueError, + match=r"does not have monotonic global indexes along dimension x", + ): + combine_by_coords([ds1, ds0]) + + def test_combine_by_coords_incomplete_hypercube(self): + # test that this succeeds with default fill_value + x1 = Dataset({"a": (("y", "x"), [[1]])}, coords={"y": [0], "x": [0]}) + x2 = Dataset({"a": (("y", "x"), [[1]])}, coords={"y": [1], "x": [0]}) + x3 = Dataset({"a": (("y", "x"), [[1]])}, coords={"y": [0], "x": [1]}) + actual = combine_by_coords([x1, x2, x3]) + expected = Dataset( + {"a": (("y", "x"), [[1, 1], [1, np.nan]])}, + coords={"y": [0, 1], "x": [0, 1]}, + ) + assert_identical(expected, actual) + + # test that this fails if fill_value is None + with pytest.raises(ValueError): + combine_by_coords([x1, x2, x3], fill_value=None) + + +class TestCombineMixedObjectsbyCoords: + def test_combine_by_coords_mixed_unnamed_dataarrays(self): + named_da = DataArray(name="a", data=[1.0, 2.0], coords={"x": [0, 1]}, dims="x") + unnamed_da = DataArray(data=[3.0, 4.0], coords={"x": [2, 3]}, dims="x") + + with pytest.raises( + ValueError, match="Can't automatically combine unnamed DataArrays with" + ): + combine_by_coords([named_da, unnamed_da]) + + da = DataArray([0, 1], dims="x", coords=({"x": [0, 1]})) + ds = Dataset({"x": [2, 3]}) + with pytest.raises( + ValueError, + match="Can't automatically combine unnamed DataArrays with", + ): + combine_by_coords([da, ds]) + + def test_combine_coords_mixed_datasets_named_dataarrays(self): + da = DataArray(name="a", data=[4, 5], dims="x", coords=({"x": [0, 1]})) + ds = Dataset({"b": ("x", [2, 3])}) + actual = combine_by_coords([da, ds]) + expected = Dataset( + {"a": ("x", [4, 5]), "b": ("x", [2, 3])}, coords={"x": ("x", [0, 1])} + ) + assert_identical(expected, actual) + + def test_combine_by_coords_all_unnamed_dataarrays(self): + unnamed_array = DataArray(data=[1.0, 2.0], coords={"x": [0, 1]}, dims="x") + + actual = combine_by_coords([unnamed_array]) + expected = unnamed_array + assert_identical(expected, actual) + + unnamed_array1 = DataArray(data=[1.0, 2.0], coords={"x": [0, 1]}, dims="x") + unnamed_array2 = DataArray(data=[3.0, 4.0], coords={"x": [2, 3]}, dims="x") + + actual = combine_by_coords([unnamed_array1, unnamed_array2]) + expected = DataArray( + data=[1.0, 2.0, 3.0, 4.0], coords={"x": [0, 1, 2, 3]}, dims="x" + ) + assert_identical(expected, actual) + + def test_combine_by_coords_all_named_dataarrays(self): + named_da = DataArray(name="a", data=[1.0, 2.0], coords={"x": [0, 1]}, dims="x") + + actual = combine_by_coords([named_da]) + expected = named_da.to_dataset() + assert_identical(expected, actual) + + named_da1 = DataArray(name="a", data=[1.0, 2.0], coords={"x": [0, 1]}, dims="x") + named_da2 = DataArray(name="b", data=[3.0, 4.0], coords={"x": [2, 3]}, dims="x") + + actual = combine_by_coords([named_da1, named_da2]) + expected = Dataset( + { + "a": DataArray(data=[1.0, 2.0], coords={"x": [0, 1]}, dims="x"), + "b": DataArray(data=[3.0, 4.0], coords={"x": [2, 3]}, dims="x"), + } + ) + assert_identical(expected, actual) + + def test_combine_by_coords_all_dataarrays_with_the_same_name(self): + named_da1 = DataArray(name="a", data=[1.0, 2.0], coords={"x": [0, 1]}, dims="x") + named_da2 = DataArray(name="a", data=[3.0, 4.0], coords={"x": [2, 3]}, dims="x") + + actual = combine_by_coords([named_da1, named_da2]) + expected = merge([named_da1, named_da2]) + assert_identical(expected, actual) + + +@requires_cftime +def test_combine_by_coords_distant_cftime_dates(): + # Regression test for https://github.com/pydata/xarray/issues/3535 + import cftime + + time_1 = [cftime.DatetimeGregorian(4500, 12, 31)] + time_2 = [cftime.DatetimeGregorian(4600, 12, 31)] + time_3 = [cftime.DatetimeGregorian(5100, 12, 31)] + + da_1 = DataArray([0], dims=["time"], coords=[time_1], name="a").to_dataset() + da_2 = DataArray([1], dims=["time"], coords=[time_2], name="a").to_dataset() + da_3 = DataArray([2], dims=["time"], coords=[time_3], name="a").to_dataset() + + result = combine_by_coords([da_1, da_2, da_3]) + + expected_time = np.concatenate([time_1, time_2, time_3]) + expected = DataArray( + [0, 1, 2], dims=["time"], coords=[expected_time], name="a" + ).to_dataset() + assert_identical(result, expected) + + +@requires_cftime +def test_combine_by_coords_raises_for_differing_calendars(): + # previously failed with uninformative StopIteration instead of TypeError + # https://github.com/pydata/xarray/issues/4495 + + import cftime + + time_1 = [cftime.DatetimeGregorian(2000, 1, 1)] + time_2 = [cftime.DatetimeProlepticGregorian(2001, 1, 1)] + + da_1 = DataArray([0], dims=["time"], coords=[time_1], name="a").to_dataset() + da_2 = DataArray([1], dims=["time"], coords=[time_2], name="a").to_dataset() + + error_msg = ( + "Cannot combine along dimension 'time' with mixed types." + " Found:.*" + " If importing data directly from a file then setting" + " `use_cftime=True` may fix this issue." + ) + with pytest.raises(TypeError, match=error_msg): + combine_by_coords([da_1, da_2]) + + +def test_combine_by_coords_raises_for_differing_types(): + # str and byte cannot be compared + da_1 = DataArray([0], dims=["time"], coords=[["a"]], name="a").to_dataset() + da_2 = DataArray([1], dims=["time"], coords=[[b"b"]], name="a").to_dataset() + + with pytest.raises( + TypeError, match=r"Cannot combine along dimension 'time' with mixed types." + ): + combine_by_coords([da_1, da_2]) diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_computation.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_computation.py new file mode 100644 index 0000000..080447c --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_computation.py @@ -0,0 +1,2600 @@ +from __future__ import annotations + +import functools +import operator +import pickle + +import numpy as np +import pandas as pd +import pytest +from numpy.testing import assert_allclose, assert_array_equal + +import xarray as xr +from xarray.core.alignment import broadcast +from xarray.core.computation import ( + _UFuncSignature, + apply_ufunc, + broadcast_compat_data, + collect_dict_values, + join_dict_keys, + ordered_set_intersection, + ordered_set_union, + result_name, + unified_dim_sizes, +) +from xarray.tests import ( + has_dask, + raise_if_dask_computes, + requires_cftime, + requires_dask, +) + + +def assert_identical(a, b): + """A version of this function which accepts numpy arrays""" + __tracebackhide__ = True + from xarray.testing import assert_identical as assert_identical_ + + if hasattr(a, "identical"): + assert_identical_(a, b) + else: + assert_array_equal(a, b) + + +def test_signature_properties() -> None: + sig = _UFuncSignature([["x"], ["x", "y"]], [["z"]]) + assert sig.input_core_dims == (("x",), ("x", "y")) + assert sig.output_core_dims == (("z",),) + assert sig.all_input_core_dims == frozenset(["x", "y"]) + assert sig.all_output_core_dims == frozenset(["z"]) + assert sig.num_inputs == 2 + assert sig.num_outputs == 1 + assert str(sig) == "(x),(x,y)->(z)" + assert sig.to_gufunc_string() == "(dim0),(dim0,dim1)->(dim2)" + assert ( + sig.to_gufunc_string(exclude_dims=set("x")) == "(dim0_0),(dim0_1,dim1)->(dim2)" + ) + # dimension names matter + assert _UFuncSignature([["x"]]) != _UFuncSignature([["y"]]) + + +def test_result_name() -> None: + class Named: + def __init__(self, name=None): + self.name = name + + assert result_name([1, 2]) is None + assert result_name([Named()]) is None + assert result_name([Named("foo"), 2]) == "foo" + assert result_name([Named("foo"), Named("bar")]) is None + assert result_name([Named("foo"), Named()]) is None + + +def test_ordered_set_union() -> None: + assert list(ordered_set_union([[1, 2]])) == [1, 2] + assert list(ordered_set_union([[1, 2], [2, 1]])) == [1, 2] + assert list(ordered_set_union([[0], [1, 2], [1, 3]])) == [0, 1, 2, 3] + + +def test_ordered_set_intersection() -> None: + assert list(ordered_set_intersection([[1, 2]])) == [1, 2] + assert list(ordered_set_intersection([[1, 2], [2, 1]])) == [1, 2] + assert list(ordered_set_intersection([[1, 2], [1, 3]])) == [1] + assert list(ordered_set_intersection([[1, 2], [2]])) == [2] + + +def test_join_dict_keys() -> None: + dicts = [dict.fromkeys(keys) for keys in [["x", "y"], ["y", "z"]]] + assert list(join_dict_keys(dicts, "left")) == ["x", "y"] + assert list(join_dict_keys(dicts, "right")) == ["y", "z"] + assert list(join_dict_keys(dicts, "inner")) == ["y"] + assert list(join_dict_keys(dicts, "outer")) == ["x", "y", "z"] + with pytest.raises(ValueError): + join_dict_keys(dicts, "exact") + with pytest.raises(KeyError): + join_dict_keys(dicts, "foobar") + + +def test_collect_dict_values() -> None: + dicts = [{"x": 1, "y": 2, "z": 3}, {"z": 4}, 5] + expected = [[1, 0, 5], [2, 0, 5], [3, 4, 5]] + collected = collect_dict_values(dicts, ["x", "y", "z"], fill_value=0) + assert collected == expected + + +def identity(x): + return x + + +def test_apply_identity() -> None: + array = np.arange(10) + variable = xr.Variable("x", array) + data_array = xr.DataArray(variable, [("x", -array)]) + dataset = xr.Dataset({"y": variable}, {"x": -array}) + + apply_identity = functools.partial(apply_ufunc, identity) + + assert_identical(array, apply_identity(array)) + assert_identical(variable, apply_identity(variable)) + assert_identical(data_array, apply_identity(data_array)) + assert_identical(data_array, apply_identity(data_array.groupby("x"))) + assert_identical(data_array, apply_identity(data_array.groupby("x", squeeze=False))) + assert_identical(dataset, apply_identity(dataset)) + assert_identical(dataset, apply_identity(dataset.groupby("x"))) + assert_identical(dataset, apply_identity(dataset.groupby("x", squeeze=False))) + + +def add(a, b): + return apply_ufunc(operator.add, a, b) + + +def test_apply_two_inputs() -> None: + array = np.array([1, 2, 3]) + variable = xr.Variable("x", array) + data_array = xr.DataArray(variable, [("x", -array)]) + dataset = xr.Dataset({"y": variable}, {"x": -array}) + + zero_array = np.zeros_like(array) + zero_variable = xr.Variable("x", zero_array) + zero_data_array = xr.DataArray(zero_variable, [("x", -array)]) + zero_dataset = xr.Dataset({"y": zero_variable}, {"x": -array}) + + assert_identical(array, add(array, zero_array)) + assert_identical(array, add(zero_array, array)) + + assert_identical(variable, add(variable, zero_array)) + assert_identical(variable, add(variable, zero_variable)) + assert_identical(variable, add(zero_array, variable)) + assert_identical(variable, add(zero_variable, variable)) + + assert_identical(data_array, add(data_array, zero_array)) + assert_identical(data_array, add(data_array, zero_variable)) + assert_identical(data_array, add(data_array, zero_data_array)) + assert_identical(data_array, add(zero_array, data_array)) + assert_identical(data_array, add(zero_variable, data_array)) + assert_identical(data_array, add(zero_data_array, data_array)) + + assert_identical(dataset, add(dataset, zero_array)) + assert_identical(dataset, add(dataset, zero_variable)) + assert_identical(dataset, add(dataset, zero_data_array)) + assert_identical(dataset, add(dataset, zero_dataset)) + assert_identical(dataset, add(zero_array, dataset)) + assert_identical(dataset, add(zero_variable, dataset)) + assert_identical(dataset, add(zero_data_array, dataset)) + assert_identical(dataset, add(zero_dataset, dataset)) + + assert_identical(data_array, add(data_array.groupby("x"), zero_data_array)) + assert_identical(data_array, add(zero_data_array, data_array.groupby("x"))) + + assert_identical(dataset, add(data_array.groupby("x"), zero_dataset)) + assert_identical(dataset, add(zero_dataset, data_array.groupby("x"))) + + assert_identical(dataset, add(dataset.groupby("x"), zero_data_array)) + assert_identical(dataset, add(dataset.groupby("x"), zero_dataset)) + assert_identical(dataset, add(zero_data_array, dataset.groupby("x"))) + assert_identical(dataset, add(zero_dataset, dataset.groupby("x"))) + + +def test_apply_1d_and_0d() -> None: + array = np.array([1, 2, 3]) + variable = xr.Variable("x", array) + data_array = xr.DataArray(variable, [("x", -array)]) + dataset = xr.Dataset({"y": variable}, {"x": -array}) + + zero_array = 0 + zero_variable = xr.Variable((), zero_array) + zero_data_array = xr.DataArray(zero_variable) + zero_dataset = xr.Dataset({"y": zero_variable}) + + assert_identical(array, add(array, zero_array)) + assert_identical(array, add(zero_array, array)) + + assert_identical(variable, add(variable, zero_array)) + assert_identical(variable, add(variable, zero_variable)) + assert_identical(variable, add(zero_array, variable)) + assert_identical(variable, add(zero_variable, variable)) + + assert_identical(data_array, add(data_array, zero_array)) + assert_identical(data_array, add(data_array, zero_variable)) + assert_identical(data_array, add(data_array, zero_data_array)) + assert_identical(data_array, add(zero_array, data_array)) + assert_identical(data_array, add(zero_variable, data_array)) + assert_identical(data_array, add(zero_data_array, data_array)) + + assert_identical(dataset, add(dataset, zero_array)) + assert_identical(dataset, add(dataset, zero_variable)) + assert_identical(dataset, add(dataset, zero_data_array)) + assert_identical(dataset, add(dataset, zero_dataset)) + assert_identical(dataset, add(zero_array, dataset)) + assert_identical(dataset, add(zero_variable, dataset)) + assert_identical(dataset, add(zero_data_array, dataset)) + assert_identical(dataset, add(zero_dataset, dataset)) + + assert_identical(data_array, add(data_array.groupby("x"), zero_data_array)) + assert_identical(data_array, add(zero_data_array, data_array.groupby("x"))) + + assert_identical(dataset, add(data_array.groupby("x"), zero_dataset)) + assert_identical(dataset, add(zero_dataset, data_array.groupby("x"))) + + assert_identical(dataset, add(dataset.groupby("x"), zero_data_array)) + assert_identical(dataset, add(dataset.groupby("x"), zero_dataset)) + assert_identical(dataset, add(zero_data_array, dataset.groupby("x"))) + assert_identical(dataset, add(zero_dataset, dataset.groupby("x"))) + + +def test_apply_two_outputs() -> None: + array = np.arange(5) + variable = xr.Variable("x", array) + data_array = xr.DataArray(variable, [("x", -array)]) + dataset = xr.Dataset({"y": variable}, {"x": -array}) + + def twice(obj): + def func(x): + return (x, x) + + return apply_ufunc(func, obj, output_core_dims=[[], []]) + + out0, out1 = twice(array) + assert_identical(out0, array) + assert_identical(out1, array) + + out0, out1 = twice(variable) + assert_identical(out0, variable) + assert_identical(out1, variable) + + out0, out1 = twice(data_array) + assert_identical(out0, data_array) + assert_identical(out1, data_array) + + out0, out1 = twice(dataset) + assert_identical(out0, dataset) + assert_identical(out1, dataset) + + out0, out1 = twice(data_array.groupby("x")) + assert_identical(out0, data_array) + assert_identical(out1, data_array) + + out0, out1 = twice(dataset.groupby("x")) + assert_identical(out0, dataset) + assert_identical(out1, dataset) + + +def test_apply_missing_dims() -> None: + ## Single arg + + def add_one(a, core_dims, on_missing_core_dim): + return apply_ufunc( + lambda x: x + 1, + a, + input_core_dims=core_dims, + output_core_dims=core_dims, + on_missing_core_dim=on_missing_core_dim, + ) + + array = np.arange(6).reshape(2, 3) + variable = xr.Variable(["x", "y"], array) + variable_no_y = xr.Variable(["x", "z"], array) + + ds = xr.Dataset({"x_y": variable, "x_z": variable_no_y}) + + # Check the standard stuff works OK + assert_identical( + add_one(ds[["x_y"]], core_dims=[["y"]], on_missing_core_dim="raise"), + ds[["x_y"]] + 1, + ) + + # `raise` — should raise on a missing dim + with pytest.raises(ValueError): + add_one(ds, core_dims=[["y"]], on_missing_core_dim="raise") + + # `drop` — should drop the var with the missing dim + assert_identical( + add_one(ds, core_dims=[["y"]], on_missing_core_dim="drop"), + (ds + 1).drop_vars("x_z"), + ) + + # `copy` — should not add one to the missing with `copy` + copy_result = add_one(ds, core_dims=[["y"]], on_missing_core_dim="copy") + assert_identical(copy_result["x_y"], (ds + 1)["x_y"]) + assert_identical(copy_result["x_z"], ds["x_z"]) + + ## Multiple args + + def sum_add(a, b, core_dims, on_missing_core_dim): + return apply_ufunc( + lambda a, b, axis=None: a.sum(axis) + b.sum(axis), + a, + b, + input_core_dims=core_dims, + on_missing_core_dim=on_missing_core_dim, + ) + + # Check the standard stuff works OK + assert_identical( + sum_add( + ds[["x_y"]], + ds[["x_y"]], + core_dims=[["x", "y"], ["x", "y"]], + on_missing_core_dim="raise", + ), + ds[["x_y"]].sum() * 2, + ) + + # `raise` — should raise on a missing dim + with pytest.raises( + ValueError, + match=r".*Missing core dims \{'y'\} from arg number 1 on a variable named `x_z`:\n.* None: + data_array = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=("x", "y")) + + def twice(obj): + def func(x): + return (x, x) + + return apply_ufunc(func, obj, output_core_dims=[[], []], dask="parallelized") + + out0, out1 = twice(data_array.chunk({"x": 1})) + assert_identical(data_array, out0) + assert_identical(data_array, out1) + + +def test_apply_input_core_dimension() -> None: + def first_element(obj, dim): + def func(x): + return x[..., 0] + + return apply_ufunc(func, obj, input_core_dims=[[dim]]) + + array = np.array([[1, 2], [3, 4]]) + variable = xr.Variable(["x", "y"], array) + data_array = xr.DataArray(variable, {"x": ["a", "b"], "y": [-1, -2]}) + dataset = xr.Dataset({"data": data_array}) + + expected_variable_x = xr.Variable(["y"], [1, 2]) + expected_data_array_x = xr.DataArray(expected_variable_x, {"y": [-1, -2]}) + expected_dataset_x = xr.Dataset({"data": expected_data_array_x}) + + expected_variable_y = xr.Variable(["x"], [1, 3]) + expected_data_array_y = xr.DataArray(expected_variable_y, {"x": ["a", "b"]}) + expected_dataset_y = xr.Dataset({"data": expected_data_array_y}) + + assert_identical(expected_variable_x, first_element(variable, "x")) + assert_identical(expected_variable_y, first_element(variable, "y")) + + assert_identical(expected_data_array_x, first_element(data_array, "x")) + assert_identical(expected_data_array_y, first_element(data_array, "y")) + + assert_identical(expected_dataset_x, first_element(dataset, "x")) + assert_identical(expected_dataset_y, first_element(dataset, "y")) + + assert_identical(expected_data_array_x, first_element(data_array.groupby("y"), "x")) + assert_identical(expected_dataset_x, first_element(dataset.groupby("y"), "x")) + + def multiply(*args): + val = args[0] + for arg in args[1:]: + val = val * arg + return val + + # regression test for GH:2341 + with pytest.raises(ValueError): + apply_ufunc( + multiply, + data_array, + data_array["y"].values, + input_core_dims=[["y"]], + output_core_dims=[["y"]], + ) + expected = xr.DataArray( + multiply(data_array, data_array["y"]), dims=["x", "y"], coords=data_array.coords + ) + actual = apply_ufunc( + multiply, + data_array, + data_array["y"].values, + input_core_dims=[["y"], []], + output_core_dims=[["y"]], + ) + assert_identical(expected, actual) + + +def test_apply_output_core_dimension() -> None: + def stack_negative(obj): + def func(x): + return np.stack([x, -x], axis=-1) + + result = apply_ufunc(func, obj, output_core_dims=[["sign"]]) + if isinstance(result, (xr.Dataset, xr.DataArray)): + result.coords["sign"] = [1, -1] + return result + + array = np.array([[1, 2], [3, 4]]) + variable = xr.Variable(["x", "y"], array) + data_array = xr.DataArray(variable, {"x": ["a", "b"], "y": [-1, -2]}) + dataset = xr.Dataset({"data": data_array}) + + stacked_array = np.array([[[1, -1], [2, -2]], [[3, -3], [4, -4]]]) + stacked_variable = xr.Variable(["x", "y", "sign"], stacked_array) + stacked_coords = {"x": ["a", "b"], "y": [-1, -2], "sign": [1, -1]} + stacked_data_array = xr.DataArray(stacked_variable, stacked_coords) + stacked_dataset = xr.Dataset({"data": stacked_data_array}) + + assert_identical(stacked_array, stack_negative(array)) + assert_identical(stacked_variable, stack_negative(variable)) + assert_identical(stacked_data_array, stack_negative(data_array)) + assert_identical(stacked_dataset, stack_negative(dataset)) + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + assert_identical(stacked_data_array, stack_negative(data_array.groupby("x"))) + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + assert_identical(stacked_dataset, stack_negative(dataset.groupby("x"))) + + def original_and_stack_negative(obj): + def func(x): + return (x, np.stack([x, -x], axis=-1)) + + result = apply_ufunc(func, obj, output_core_dims=[[], ["sign"]]) + if isinstance(result[1], (xr.Dataset, xr.DataArray)): + result[1].coords["sign"] = [1, -1] + return result + + out0, out1 = original_and_stack_negative(array) + assert_identical(array, out0) + assert_identical(stacked_array, out1) + + out0, out1 = original_and_stack_negative(variable) + assert_identical(variable, out0) + assert_identical(stacked_variable, out1) + + out0, out1 = original_and_stack_negative(data_array) + assert_identical(data_array, out0) + assert_identical(stacked_data_array, out1) + + out0, out1 = original_and_stack_negative(dataset) + assert_identical(dataset, out0) + assert_identical(stacked_dataset, out1) + + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + out0, out1 = original_and_stack_negative(data_array.groupby("x")) + assert_identical(data_array, out0) + assert_identical(stacked_data_array, out1) + + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + out0, out1 = original_and_stack_negative(dataset.groupby("x")) + assert_identical(dataset, out0) + assert_identical(stacked_dataset, out1) + + +def test_apply_exclude() -> None: + def concatenate(objects, dim="x"): + def func(*x): + return np.concatenate(x, axis=-1) + + result = apply_ufunc( + func, + *objects, + input_core_dims=[[dim]] * len(objects), + output_core_dims=[[dim]], + exclude_dims={dim}, + ) + if isinstance(result, (xr.Dataset, xr.DataArray)): + # note: this will fail if dim is not a coordinate on any input + new_coord = np.concatenate([obj.coords[dim] for obj in objects]) + result.coords[dim] = new_coord + return result + + arrays = [np.array([1]), np.array([2, 3])] + variables = [xr.Variable("x", a) for a in arrays] + data_arrays = [ + xr.DataArray(v, {"x": c, "y": ("x", range(len(c)))}) + for v, c in zip(variables, [["a"], ["b", "c"]]) + ] + datasets = [xr.Dataset({"data": data_array}) for data_array in data_arrays] + + expected_array = np.array([1, 2, 3]) + expected_variable = xr.Variable("x", expected_array) + expected_data_array = xr.DataArray(expected_variable, [("x", list("abc"))]) + expected_dataset = xr.Dataset({"data": expected_data_array}) + + assert_identical(expected_array, concatenate(arrays)) + assert_identical(expected_variable, concatenate(variables)) + assert_identical(expected_data_array, concatenate(data_arrays)) + assert_identical(expected_dataset, concatenate(datasets)) + + # must also be a core dimension + with pytest.raises(ValueError): + apply_ufunc(identity, variables[0], exclude_dims={"x"}) + + +def test_apply_groupby_add() -> None: + array = np.arange(5) + variable = xr.Variable("x", array) + coords = {"x": -array, "y": ("x", [0, 0, 1, 1, 2])} + data_array = xr.DataArray(variable, coords, dims="x") + dataset = xr.Dataset({"z": variable}, coords) + + other_variable = xr.Variable("y", [0, 10]) + other_data_array = xr.DataArray(other_variable, dims="y") + other_dataset = xr.Dataset({"z": other_variable}) + + expected_variable = xr.Variable("x", [0, 1, 12, 13, np.nan]) + expected_data_array = xr.DataArray(expected_variable, coords, dims="x") + expected_dataset = xr.Dataset({"z": expected_variable}, coords) + + assert_identical( + expected_data_array, add(data_array.groupby("y"), other_data_array) + ) + assert_identical(expected_dataset, add(data_array.groupby("y"), other_dataset)) + assert_identical(expected_dataset, add(dataset.groupby("y"), other_data_array)) + assert_identical(expected_dataset, add(dataset.groupby("y"), other_dataset)) + + # cannot be performed with xarray.Variable objects that share a dimension + with pytest.raises(ValueError): + add(data_array.groupby("y"), other_variable) + + # if they are all grouped the same way + with pytest.raises(ValueError): + add(data_array.groupby("y"), data_array[:4].groupby("y")) + with pytest.raises(ValueError): + add(data_array.groupby("y"), data_array[1:].groupby("y")) + with pytest.raises(ValueError): + add(data_array.groupby("y"), other_data_array.groupby("y")) + with pytest.raises(ValueError): + add(data_array.groupby("y"), data_array.groupby("x")) + + +def test_unified_dim_sizes() -> None: + assert unified_dim_sizes([xr.Variable((), 0)]) == {} + assert unified_dim_sizes([xr.Variable("x", [1]), xr.Variable("x", [1])]) == {"x": 1} + assert unified_dim_sizes([xr.Variable("x", [1]), xr.Variable("y", [1, 2])]) == { + "x": 1, + "y": 2, + } + assert unified_dim_sizes( + [xr.Variable(("x", "z"), [[1]]), xr.Variable(("y", "z"), [[1, 2], [3, 4]])], + exclude_dims={"z"}, + ) == {"x": 1, "y": 2} + + # duplicate dimensions + with pytest.raises(ValueError): + unified_dim_sizes([xr.Variable(("x", "x"), [[1]])]) + + # mismatched lengths + with pytest.raises(ValueError): + unified_dim_sizes([xr.Variable("x", [1]), xr.Variable("x", [1, 2])]) + + +def test_broadcast_compat_data_1d() -> None: + data = np.arange(5) + var = xr.Variable("x", data) + + assert_identical(data, broadcast_compat_data(var, ("x",), ())) + assert_identical(data, broadcast_compat_data(var, (), ("x",))) + assert_identical(data[:], broadcast_compat_data(var, ("w",), ("x",))) + assert_identical(data[:, None], broadcast_compat_data(var, ("w", "x", "y"), ())) + + with pytest.raises(ValueError): + broadcast_compat_data(var, ("x",), ("w",)) + + with pytest.raises(ValueError): + broadcast_compat_data(var, (), ()) + + +def test_broadcast_compat_data_2d() -> None: + data = np.arange(12).reshape(3, 4) + var = xr.Variable(["x", "y"], data) + + assert_identical(data, broadcast_compat_data(var, ("x", "y"), ())) + assert_identical(data, broadcast_compat_data(var, ("x",), ("y",))) + assert_identical(data, broadcast_compat_data(var, (), ("x", "y"))) + assert_identical(data.T, broadcast_compat_data(var, ("y", "x"), ())) + assert_identical(data.T, broadcast_compat_data(var, ("y",), ("x",))) + assert_identical(data, broadcast_compat_data(var, ("w", "x"), ("y",))) + assert_identical(data, broadcast_compat_data(var, ("w",), ("x", "y"))) + assert_identical(data.T, broadcast_compat_data(var, ("w",), ("y", "x"))) + assert_identical( + data[:, :, None], broadcast_compat_data(var, ("w", "x", "y", "z"), ()) + ) + assert_identical( + data[None, :, :].T, broadcast_compat_data(var, ("w", "y", "x", "z"), ()) + ) + + +def test_keep_attrs() -> None: + def add(a, b, keep_attrs): + if keep_attrs: + return apply_ufunc(operator.add, a, b, keep_attrs=keep_attrs) + else: + return apply_ufunc(operator.add, a, b) + + a = xr.DataArray([0, 1], [("x", [0, 1])]) + a.attrs["attr"] = "da" + a["x"].attrs["attr"] = "da_coord" + b = xr.DataArray([1, 2], [("x", [0, 1])]) + + actual = add(a, b, keep_attrs=False) + assert not actual.attrs + actual = add(a, b, keep_attrs=True) + assert_identical(actual.attrs, a.attrs) + assert_identical(actual["x"].attrs, a["x"].attrs) + + actual = add(a.variable, b.variable, keep_attrs=False) + assert not actual.attrs + actual = add(a.variable, b.variable, keep_attrs=True) + assert_identical(actual.attrs, a.attrs) + + ds_a = xr.Dataset({"x": [0, 1]}) + ds_a.attrs["attr"] = "ds" + ds_a.x.attrs["attr"] = "da" + ds_b = xr.Dataset({"x": [0, 1]}) + + actual = add(ds_a, ds_b, keep_attrs=False) + assert not actual.attrs + actual = add(ds_a, ds_b, keep_attrs=True) + assert_identical(actual.attrs, ds_a.attrs) + assert_identical(actual.x.attrs, ds_a.x.attrs) + + +@pytest.mark.parametrize( + ["strategy", "attrs", "expected", "error"], + ( + pytest.param( + None, + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="default", + ), + pytest.param( + False, + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="False", + ), + pytest.param( + True, + [{"a": 1}, {"a": 2}, {"a": 3}], + {"a": 1}, + False, + id="True", + ), + pytest.param( + "override", + [{"a": 1}, {"a": 2}, {"a": 3}], + {"a": 1}, + False, + id="override", + ), + pytest.param( + "drop", + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="drop", + ), + pytest.param( + "drop_conflicts", + [{"a": 1, "b": 2}, {"b": 1, "c": 3}, {"c": 3, "d": 4}], + {"a": 1, "c": 3, "d": 4}, + False, + id="drop_conflicts", + ), + pytest.param( + "no_conflicts", + [{"a": 1}, {"b": 2}, {"b": 3}], + None, + True, + id="no_conflicts", + ), + ), +) +def test_keep_attrs_strategies_variable(strategy, attrs, expected, error) -> None: + a = xr.Variable("x", [0, 1], attrs=attrs[0]) + b = xr.Variable("x", [0, 1], attrs=attrs[1]) + c = xr.Variable("x", [0, 1], attrs=attrs[2]) + + if error: + with pytest.raises(xr.MergeError): + apply_ufunc(lambda *args: sum(args), a, b, c, keep_attrs=strategy) + else: + expected = xr.Variable("x", [0, 3], attrs=expected) + actual = apply_ufunc(lambda *args: sum(args), a, b, c, keep_attrs=strategy) + + assert_identical(actual, expected) + + +@pytest.mark.parametrize( + ["strategy", "attrs", "expected", "error"], + ( + pytest.param( + None, + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="default", + ), + pytest.param( + False, + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="False", + ), + pytest.param( + True, + [{"a": 1}, {"a": 2}, {"a": 3}], + {"a": 1}, + False, + id="True", + ), + pytest.param( + "override", + [{"a": 1}, {"a": 2}, {"a": 3}], + {"a": 1}, + False, + id="override", + ), + pytest.param( + "drop", + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="drop", + ), + pytest.param( + "drop_conflicts", + [{"a": 1, "b": 2}, {"b": 1, "c": 3}, {"c": 3, "d": 4}], + {"a": 1, "c": 3, "d": 4}, + False, + id="drop_conflicts", + ), + pytest.param( + "no_conflicts", + [{"a": 1}, {"b": 2}, {"b": 3}], + None, + True, + id="no_conflicts", + ), + ), +) +def test_keep_attrs_strategies_dataarray(strategy, attrs, expected, error) -> None: + a = xr.DataArray(dims="x", data=[0, 1], attrs=attrs[0]) + b = xr.DataArray(dims="x", data=[0, 1], attrs=attrs[1]) + c = xr.DataArray(dims="x", data=[0, 1], attrs=attrs[2]) + + if error: + with pytest.raises(xr.MergeError): + apply_ufunc(lambda *args: sum(args), a, b, c, keep_attrs=strategy) + else: + expected = xr.DataArray(dims="x", data=[0, 3], attrs=expected) + actual = apply_ufunc(lambda *args: sum(args), a, b, c, keep_attrs=strategy) + + assert_identical(actual, expected) + + +@pytest.mark.parametrize("variant", ("dim", "coord")) +@pytest.mark.parametrize( + ["strategy", "attrs", "expected", "error"], + ( + pytest.param( + None, + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="default", + ), + pytest.param( + False, + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="False", + ), + pytest.param( + True, + [{"a": 1}, {"a": 2}, {"a": 3}], + {"a": 1}, + False, + id="True", + ), + pytest.param( + "override", + [{"a": 1}, {"a": 2}, {"a": 3}], + {"a": 1}, + False, + id="override", + ), + pytest.param( + "drop", + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="drop", + ), + pytest.param( + "drop_conflicts", + [{"a": 1, "b": 2}, {"b": 1, "c": 3}, {"c": 3, "d": 4}], + {"a": 1, "c": 3, "d": 4}, + False, + id="drop_conflicts", + ), + pytest.param( + "no_conflicts", + [{"a": 1}, {"b": 2}, {"b": 3}], + None, + True, + id="no_conflicts", + ), + ), +) +def test_keep_attrs_strategies_dataarray_variables( + variant, strategy, attrs, expected, error +): + compute_attrs = { + "dim": lambda attrs, default: (attrs, default), + "coord": lambda attrs, default: (default, attrs), + }.get(variant) + + dim_attrs, coord_attrs = compute_attrs(attrs, [{}, {}, {}]) + + a = xr.DataArray( + dims="x", + data=[0, 1], + coords={"x": ("x", [0, 1], dim_attrs[0]), "u": ("x", [0, 1], coord_attrs[0])}, + ) + b = xr.DataArray( + dims="x", + data=[0, 1], + coords={"x": ("x", [0, 1], dim_attrs[1]), "u": ("x", [0, 1], coord_attrs[1])}, + ) + c = xr.DataArray( + dims="x", + data=[0, 1], + coords={"x": ("x", [0, 1], dim_attrs[2]), "u": ("x", [0, 1], coord_attrs[2])}, + ) + + if error: + with pytest.raises(xr.MergeError): + apply_ufunc(lambda *args: sum(args), a, b, c, keep_attrs=strategy) + else: + dim_attrs, coord_attrs = compute_attrs(expected, {}) + expected = xr.DataArray( + dims="x", + data=[0, 3], + coords={"x": ("x", [0, 1], dim_attrs), "u": ("x", [0, 1], coord_attrs)}, + ) + actual = apply_ufunc(lambda *args: sum(args), a, b, c, keep_attrs=strategy) + + assert_identical(actual, expected) + + +@pytest.mark.parametrize( + ["strategy", "attrs", "expected", "error"], + ( + pytest.param( + None, + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="default", + ), + pytest.param( + False, + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="False", + ), + pytest.param( + True, + [{"a": 1}, {"a": 2}, {"a": 3}], + {"a": 1}, + False, + id="True", + ), + pytest.param( + "override", + [{"a": 1}, {"a": 2}, {"a": 3}], + {"a": 1}, + False, + id="override", + ), + pytest.param( + "drop", + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="drop", + ), + pytest.param( + "drop_conflicts", + [{"a": 1, "b": 2}, {"b": 1, "c": 3}, {"c": 3, "d": 4}], + {"a": 1, "c": 3, "d": 4}, + False, + id="drop_conflicts", + ), + pytest.param( + "no_conflicts", + [{"a": 1}, {"b": 2}, {"b": 3}], + None, + True, + id="no_conflicts", + ), + ), +) +def test_keep_attrs_strategies_dataset(strategy, attrs, expected, error) -> None: + a = xr.Dataset({"a": ("x", [0, 1])}, attrs=attrs[0]) + b = xr.Dataset({"a": ("x", [0, 1])}, attrs=attrs[1]) + c = xr.Dataset({"a": ("x", [0, 1])}, attrs=attrs[2]) + + if error: + with pytest.raises(xr.MergeError): + apply_ufunc(lambda *args: sum(args), a, b, c, keep_attrs=strategy) + else: + expected = xr.Dataset({"a": ("x", [0, 3])}, attrs=expected) + actual = apply_ufunc(lambda *args: sum(args), a, b, c, keep_attrs=strategy) + + assert_identical(actual, expected) + + +@pytest.mark.parametrize("variant", ("data", "dim", "coord")) +@pytest.mark.parametrize( + ["strategy", "attrs", "expected", "error"], + ( + pytest.param( + None, + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="default", + ), + pytest.param( + False, + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="False", + ), + pytest.param( + True, + [{"a": 1}, {"a": 2}, {"a": 3}], + {"a": 1}, + False, + id="True", + ), + pytest.param( + "override", + [{"a": 1}, {"a": 2}, {"a": 3}], + {"a": 1}, + False, + id="override", + ), + pytest.param( + "drop", + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="drop", + ), + pytest.param( + "drop_conflicts", + [{"a": 1, "b": 2}, {"b": 1, "c": 3}, {"c": 3, "d": 4}], + {"a": 1, "c": 3, "d": 4}, + False, + id="drop_conflicts", + ), + pytest.param( + "no_conflicts", + [{"a": 1}, {"b": 2}, {"b": 3}], + None, + True, + id="no_conflicts", + ), + ), +) +def test_keep_attrs_strategies_dataset_variables( + variant, strategy, attrs, expected, error +): + compute_attrs = { + "data": lambda attrs, default: (attrs, default, default), + "dim": lambda attrs, default: (default, attrs, default), + "coord": lambda attrs, default: (default, default, attrs), + }.get(variant) + data_attrs, dim_attrs, coord_attrs = compute_attrs(attrs, [{}, {}, {}]) + + a = xr.Dataset( + {"a": ("x", [], data_attrs[0])}, + coords={"x": ("x", [], dim_attrs[0]), "u": ("x", [], coord_attrs[0])}, + ) + b = xr.Dataset( + {"a": ("x", [], data_attrs[1])}, + coords={"x": ("x", [], dim_attrs[1]), "u": ("x", [], coord_attrs[1])}, + ) + c = xr.Dataset( + {"a": ("x", [], data_attrs[2])}, + coords={"x": ("x", [], dim_attrs[2]), "u": ("x", [], coord_attrs[2])}, + ) + + if error: + with pytest.raises(xr.MergeError): + apply_ufunc(lambda *args: sum(args), a, b, c, keep_attrs=strategy) + else: + data_attrs, dim_attrs, coord_attrs = compute_attrs(expected, {}) + expected = xr.Dataset( + {"a": ("x", [], data_attrs)}, + coords={"x": ("x", [], dim_attrs), "u": ("x", [], coord_attrs)}, + ) + actual = apply_ufunc(lambda *args: sum(args), a, b, c, keep_attrs=strategy) + + assert_identical(actual, expected) + + +def test_dataset_join() -> None: + ds0 = xr.Dataset({"a": ("x", [1, 2]), "x": [0, 1]}) + ds1 = xr.Dataset({"a": ("x", [99, 3]), "x": [1, 2]}) + + # by default, cannot have different labels + with pytest.raises(ValueError, match=r"cannot align.*join.*exact.*"): + apply_ufunc(operator.add, ds0, ds1) + with pytest.raises(TypeError, match=r"must supply"): + apply_ufunc(operator.add, ds0, ds1, dataset_join="outer") + + def add(a, b, join, dataset_join): + return apply_ufunc( + operator.add, + a, + b, + join=join, + dataset_join=dataset_join, + dataset_fill_value=np.nan, + ) + + actual = add(ds0, ds1, "outer", "inner") + expected = xr.Dataset({"a": ("x", [np.nan, 101, np.nan]), "x": [0, 1, 2]}) + assert_identical(actual, expected) + + actual = add(ds0, ds1, "outer", "outer") + assert_identical(actual, expected) + + with pytest.raises(ValueError, match=r"data variable names"): + apply_ufunc(operator.add, ds0, xr.Dataset({"b": 1})) + + ds2 = xr.Dataset({"b": ("x", [99, 3]), "x": [1, 2]}) + actual = add(ds0, ds2, "outer", "inner") + expected = xr.Dataset({"x": [0, 1, 2]}) + assert_identical(actual, expected) + + # we used np.nan as the fill_value in add() above + actual = add(ds0, ds2, "outer", "outer") + expected = xr.Dataset( + { + "a": ("x", [np.nan, np.nan, np.nan]), + "b": ("x", [np.nan, np.nan, np.nan]), + "x": [0, 1, 2], + } + ) + assert_identical(actual, expected) + + +@requires_dask +def test_apply_dask() -> None: + import dask.array as da + + array = da.ones((2,), chunks=2) + variable = xr.Variable("x", array) + coords = xr.DataArray(variable).coords.variables + data_array = xr.DataArray(variable, dims=["x"], coords=coords) + dataset = xr.Dataset({"y": variable}) + + # encountered dask array, but did not set dask='allowed' + with pytest.raises(ValueError): + apply_ufunc(identity, array) + with pytest.raises(ValueError): + apply_ufunc(identity, variable) + with pytest.raises(ValueError): + apply_ufunc(identity, data_array) + with pytest.raises(ValueError): + apply_ufunc(identity, dataset) + + # unknown setting for dask array handling + with pytest.raises(ValueError): + apply_ufunc(identity, array, dask="unknown") # type: ignore + + def dask_safe_identity(x): + return apply_ufunc(identity, x, dask="allowed") + + assert array is dask_safe_identity(array) + + actual = dask_safe_identity(variable) + assert isinstance(actual.data, da.Array) + assert_identical(variable, actual) + + actual = dask_safe_identity(data_array) + assert isinstance(actual.data, da.Array) + assert_identical(data_array, actual) + + actual = dask_safe_identity(dataset) + assert isinstance(actual["y"].data, da.Array) + assert_identical(dataset, actual) + + +@requires_dask +def test_apply_dask_parallelized_one_arg() -> None: + import dask.array as da + + array = da.ones((2, 2), chunks=(1, 1)) + data_array = xr.DataArray(array, dims=("x", "y")) + + def parallel_identity(x): + return apply_ufunc(identity, x, dask="parallelized", output_dtypes=[x.dtype]) + + actual = parallel_identity(data_array) + assert isinstance(actual.data, da.Array) + assert actual.data.chunks == array.chunks + assert_identical(data_array, actual) + + computed = data_array.compute() + actual = parallel_identity(computed) + assert_identical(computed, actual) + + +@requires_dask +def test_apply_dask_parallelized_two_args() -> None: + import dask.array as da + + array = da.ones((2, 2), chunks=(1, 1), dtype=np.int64) + data_array = xr.DataArray(array, dims=("x", "y")) + data_array.name = None + + def parallel_add(x, y): + return apply_ufunc( + operator.add, x, y, dask="parallelized", output_dtypes=[np.int64] + ) + + def check(x, y): + actual = parallel_add(x, y) + assert isinstance(actual.data, da.Array) + assert actual.data.chunks == array.chunks + assert_identical(data_array, actual) + + check(data_array, 0) + check(0, data_array) + check(data_array, xr.DataArray(0)) + check(data_array, 0 * data_array) + check(data_array, 0 * data_array[0]) + check(data_array[:, 0], 0 * data_array[0]) + check(data_array, 0 * data_array.compute()) + + +@requires_dask +def test_apply_dask_parallelized_errors() -> None: + import dask.array as da + + array = da.ones((2, 2), chunks=(1, 1)) + data_array = xr.DataArray(array, dims=("x", "y")) + + # from apply_array_ufunc + with pytest.raises(ValueError, match=r"at least one input is an xarray object"): + apply_ufunc(identity, array, dask="parallelized") + + # formerly from _apply_blockwise, now from apply_variable_ufunc + with pytest.raises(ValueError, match=r"consists of multiple chunks"): + apply_ufunc( + identity, + data_array, + dask="parallelized", + output_dtypes=[float], + input_core_dims=[("y",)], + output_core_dims=[("y",)], + ) + + +# it's currently impossible to silence these warnings from inside dask.array: +# https://github.com/dask/dask/issues/3245 +@requires_dask +@pytest.mark.filterwarnings("ignore:Mean of empty slice") +def test_apply_dask_multiple_inputs() -> None: + import dask.array as da + + def covariance(x, y): + return ( + (x - x.mean(axis=-1, keepdims=True)) * (y - y.mean(axis=-1, keepdims=True)) + ).mean(axis=-1) + + rs = np.random.RandomState(42) + array1 = da.from_array(rs.randn(4, 4), chunks=(2, 4)) + array2 = da.from_array(rs.randn(4, 4), chunks=(2, 4)) + data_array_1 = xr.DataArray(array1, dims=("x", "z")) + data_array_2 = xr.DataArray(array2, dims=("y", "z")) + + expected = apply_ufunc( + covariance, + data_array_1.compute(), + data_array_2.compute(), + input_core_dims=[["z"], ["z"]], + ) + allowed = apply_ufunc( + covariance, + data_array_1, + data_array_2, + input_core_dims=[["z"], ["z"]], + dask="allowed", + ) + assert isinstance(allowed.data, da.Array) + xr.testing.assert_allclose(expected, allowed.compute()) + + parallelized = apply_ufunc( + covariance, + data_array_1, + data_array_2, + input_core_dims=[["z"], ["z"]], + dask="parallelized", + output_dtypes=[float], + ) + assert isinstance(parallelized.data, da.Array) + xr.testing.assert_allclose(expected, parallelized.compute()) + + +@requires_dask +def test_apply_dask_new_output_dimension() -> None: + import dask.array as da + + array = da.ones((2, 2), chunks=(1, 1)) + data_array = xr.DataArray(array, dims=("x", "y")) + + def stack_negative(obj): + def func(x): + return np.stack([x, -x], axis=-1) + + return apply_ufunc( + func, + obj, + output_core_dims=[["sign"]], + dask="parallelized", + output_dtypes=[obj.dtype], + dask_gufunc_kwargs=dict(output_sizes={"sign": 2}), + ) + + expected = stack_negative(data_array.compute()) + + actual = stack_negative(data_array) + assert actual.dims == ("x", "y", "sign") + assert actual.shape == (2, 2, 2) + assert isinstance(actual.data, da.Array) + assert_identical(expected, actual) + + +@requires_dask +def test_apply_dask_new_output_sizes() -> None: + ds = xr.Dataset({"foo": (["lon", "lat"], np.arange(10 * 10).reshape((10, 10)))}) + ds["bar"] = ds["foo"] + newdims = {"lon_new": 3, "lat_new": 6} + + def extract(obj): + def func(da): + return da[1:4, 1:7] + + return apply_ufunc( + func, + obj, + dask="parallelized", + input_core_dims=[["lon", "lat"]], + output_core_dims=[["lon_new", "lat_new"]], + dask_gufunc_kwargs=dict(output_sizes=newdims), + ) + + expected = extract(ds) + + actual = extract(ds.chunk()) + assert actual.sizes == {"lon_new": 3, "lat_new": 6} + assert_identical(expected.chunk(), actual) + + +@requires_dask +def test_apply_dask_new_output_sizes_not_supplied_same_dim_names() -> None: + # test for missing output_sizes kwarg sneaking through + # see GH discussion 7503 + + data = np.random.randn(4, 4, 3, 2) + da = xr.DataArray(data=data, dims=("x", "y", "i", "j")).chunk(x=1, y=1) + + with pytest.raises(ValueError, match="output_sizes"): + xr.apply_ufunc( + np.linalg.pinv, + da, + input_core_dims=[["i", "j"]], + output_core_dims=[["i", "j"]], + exclude_dims=set(("i", "j")), + dask="parallelized", + ) + + +def pandas_median(x): + return pd.Series(x).median() + + +def test_vectorize() -> None: + data_array = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=("x", "y")) + expected = xr.DataArray([1, 2], dims=["x"]) + actual = apply_ufunc( + pandas_median, data_array, input_core_dims=[["y"]], vectorize=True + ) + assert_identical(expected, actual) + + +@requires_dask +def test_vectorize_dask() -> None: + # run vectorization in dask.array.gufunc by using `dask='parallelized'` + data_array = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=("x", "y")) + expected = xr.DataArray([1, 2], dims=["x"]) + actual = apply_ufunc( + pandas_median, + data_array.chunk({"x": 1}), + input_core_dims=[["y"]], + vectorize=True, + dask="parallelized", + output_dtypes=[float], + ) + assert_identical(expected, actual) + + +@requires_dask +def test_vectorize_dask_dtype() -> None: + # ensure output_dtypes is preserved with vectorize=True + # GH4015 + + # integer + data_array = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=("x", "y")) + expected = xr.DataArray([1, 2], dims=["x"]) + actual = apply_ufunc( + pandas_median, + data_array.chunk({"x": 1}), + input_core_dims=[["y"]], + vectorize=True, + dask="parallelized", + output_dtypes=[int], + ) + assert_identical(expected, actual) + assert expected.dtype == actual.dtype + + # complex + data_array = xr.DataArray([[0 + 0j, 1 + 2j, 2 + 1j]], dims=("x", "y")) + expected = data_array.copy() + actual = apply_ufunc( + identity, + data_array.chunk({"x": 1}), + vectorize=True, + dask="parallelized", + output_dtypes=[complex], + ) + assert_identical(expected, actual) + assert expected.dtype == actual.dtype + + +@requires_dask +@pytest.mark.parametrize( + "data_array", + [ + xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=("x", "y")), + xr.DataArray([[0 + 0j, 1 + 2j, 2 + 1j]], dims=("x", "y")), + ], +) +def test_vectorize_dask_dtype_without_output_dtypes(data_array) -> None: + # ensure output_dtypes is preserved with vectorize=True + # GH4015 + + expected = data_array.copy() + actual = apply_ufunc( + identity, + data_array.chunk({"x": 1}), + vectorize=True, + dask="parallelized", + ) + + assert_identical(expected, actual) + assert expected.dtype == actual.dtype + + +@requires_dask +def test_vectorize_dask_dtype_meta() -> None: + data_array = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=("x", "y")) + expected = xr.DataArray([1, 2], dims=["x"]) + + actual = apply_ufunc( + pandas_median, + data_array.chunk({"x": 1}), + input_core_dims=[["y"]], + vectorize=True, + dask="parallelized", + dask_gufunc_kwargs=dict(meta=np.ndarray((0, 0), dtype=float)), + ) + + assert_identical(expected, actual) + assert float == actual.dtype + + +def pandas_median_add(x, y): + # function which can consume input of unequal length + return pd.Series(x).median() + pd.Series(y).median() + + +def test_vectorize_exclude_dims() -> None: + # GH 3890 + data_array_a = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=("x", "y")) + data_array_b = xr.DataArray([[0, 1, 2, 3, 4], [1, 2, 3, 4, 5]], dims=("x", "y")) + + expected = xr.DataArray([3, 5], dims=["x"]) + actual = apply_ufunc( + pandas_median_add, + data_array_a, + data_array_b, + input_core_dims=[["y"], ["y"]], + vectorize=True, + exclude_dims=set("y"), + ) + assert_identical(expected, actual) + + +@requires_dask +def test_vectorize_exclude_dims_dask() -> None: + # GH 3890 + data_array_a = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=("x", "y")) + data_array_b = xr.DataArray([[0, 1, 2, 3, 4], [1, 2, 3, 4, 5]], dims=("x", "y")) + + expected = xr.DataArray([3, 5], dims=["x"]) + actual = apply_ufunc( + pandas_median_add, + data_array_a.chunk({"x": 1}), + data_array_b.chunk({"x": 1}), + input_core_dims=[["y"], ["y"]], + exclude_dims=set("y"), + vectorize=True, + dask="parallelized", + output_dtypes=[float], + ) + assert_identical(expected, actual) + + +def test_corr_only_dataarray() -> None: + with pytest.raises(TypeError, match="Only xr.DataArray is supported"): + xr.corr(xr.Dataset(), xr.Dataset()) # type: ignore[type-var] + + +@pytest.fixture(scope="module") +def arrays(): + da = xr.DataArray( + np.random.random((3, 21, 4)), + coords={"time": pd.date_range("2000-01-01", freq="1D", periods=21)}, + dims=("a", "time", "x"), + ) + + return [ + da.isel(time=range(0, 18)), + da.isel(time=range(2, 20)).rolling(time=3, center=True).mean(), + xr.DataArray([[1, 2], [1, np.nan]], dims=["x", "time"]), + xr.DataArray([[1, 2], [np.nan, np.nan]], dims=["x", "time"]), + xr.DataArray([[1, 2], [2, 1]], dims=["x", "time"]), + ] + + +@pytest.fixture(scope="module") +def array_tuples(arrays): + return [ + (arrays[0], arrays[0]), + (arrays[0], arrays[1]), + (arrays[1], arrays[1]), + (arrays[2], arrays[2]), + (arrays[2], arrays[3]), + (arrays[2], arrays[4]), + (arrays[4], arrays[2]), + (arrays[3], arrays[3]), + (arrays[4], arrays[4]), + ] + + +@pytest.mark.parametrize("ddof", [0, 1]) +@pytest.mark.parametrize("n", [3, 4, 5, 6, 7, 8]) +@pytest.mark.parametrize("dim", [None, "x", "time"]) +@requires_dask +def test_lazy_corrcov( + n: int, dim: str | None, ddof: int, array_tuples: tuple[xr.DataArray, xr.DataArray] +) -> None: + # GH 5284 + from dask import is_dask_collection + + da_a, da_b = array_tuples[n] + + with raise_if_dask_computes(): + cov = xr.cov(da_a.chunk(), da_b.chunk(), dim=dim, ddof=ddof) + assert is_dask_collection(cov) + + corr = xr.corr(da_a.chunk(), da_b.chunk(), dim=dim) + assert is_dask_collection(corr) + + +@pytest.mark.parametrize("ddof", [0, 1]) +@pytest.mark.parametrize("n", [0, 1, 2]) +@pytest.mark.parametrize("dim", [None, "time"]) +def test_cov( + n: int, dim: str | None, ddof: int, array_tuples: tuple[xr.DataArray, xr.DataArray] +) -> None: + da_a, da_b = array_tuples[n] + + if dim is not None: + + def np_cov_ind(ts1, ts2, a, x): + # Ensure the ts are aligned and missing values ignored + ts1, ts2 = broadcast(ts1, ts2) + valid_values = ts1.notnull() & ts2.notnull() + + # While dropping isn't ideal here, numpy will return nan + # if any segment contains a NaN. + ts1 = ts1.where(valid_values) + ts2 = ts2.where(valid_values) + + return np.ma.cov( + np.ma.masked_invalid(ts1.sel(a=a, x=x).data.flatten()), + np.ma.masked_invalid(ts2.sel(a=a, x=x).data.flatten()), + ddof=ddof, + )[0, 1] + + expected = np.zeros((3, 4)) + for a in [0, 1, 2]: + for x in [0, 1, 2, 3]: + expected[a, x] = np_cov_ind(da_a, da_b, a=a, x=x) + actual = xr.cov(da_a, da_b, dim=dim, ddof=ddof) + assert_allclose(actual, expected) + + else: + + def np_cov(ts1, ts2): + # Ensure the ts are aligned and missing values ignored + ts1, ts2 = broadcast(ts1, ts2) + valid_values = ts1.notnull() & ts2.notnull() + + ts1 = ts1.where(valid_values) + ts2 = ts2.where(valid_values) + + return np.ma.cov( + np.ma.masked_invalid(ts1.data.flatten()), + np.ma.masked_invalid(ts2.data.flatten()), + ddof=ddof, + )[0, 1] + + expected = np_cov(da_a, da_b) + actual = xr.cov(da_a, da_b, dim=dim, ddof=ddof) + assert_allclose(actual, expected) + + +@pytest.mark.parametrize("n", [0, 1, 2]) +@pytest.mark.parametrize("dim", [None, "time"]) +def test_corr( + n: int, dim: str | None, array_tuples: tuple[xr.DataArray, xr.DataArray] +) -> None: + da_a, da_b = array_tuples[n] + + if dim is not None: + + def np_corr_ind(ts1, ts2, a, x): + # Ensure the ts are aligned and missing values ignored + ts1, ts2 = broadcast(ts1, ts2) + valid_values = ts1.notnull() & ts2.notnull() + + ts1 = ts1.where(valid_values) + ts2 = ts2.where(valid_values) + + return np.ma.corrcoef( + np.ma.masked_invalid(ts1.sel(a=a, x=x).data.flatten()), + np.ma.masked_invalid(ts2.sel(a=a, x=x).data.flatten()), + )[0, 1] + + expected = np.zeros((3, 4)) + for a in [0, 1, 2]: + for x in [0, 1, 2, 3]: + expected[a, x] = np_corr_ind(da_a, da_b, a=a, x=x) + actual = xr.corr(da_a, da_b, dim) + assert_allclose(actual, expected) + + else: + + def np_corr(ts1, ts2): + # Ensure the ts are aligned and missing values ignored + ts1, ts2 = broadcast(ts1, ts2) + valid_values = ts1.notnull() & ts2.notnull() + + ts1 = ts1.where(valid_values) + ts2 = ts2.where(valid_values) + + return np.ma.corrcoef( + np.ma.masked_invalid(ts1.data.flatten()), + np.ma.masked_invalid(ts2.data.flatten()), + )[0, 1] + + expected = np_corr(da_a, da_b) + actual = xr.corr(da_a, da_b, dim) + assert_allclose(actual, expected) + + +@pytest.mark.parametrize("n", range(9)) +@pytest.mark.parametrize("dim", [None, "time", "x"]) +def test_covcorr_consistency( + n: int, dim: str | None, array_tuples: tuple[xr.DataArray, xr.DataArray] +) -> None: + da_a, da_b = array_tuples[n] + # Testing that xr.corr and xr.cov are consistent with each other + # 1. Broadcast the two arrays + da_a, da_b = broadcast(da_a, da_b) + # 2. Ignore the nans + valid_values = da_a.notnull() & da_b.notnull() + da_a = da_a.where(valid_values) + da_b = da_b.where(valid_values) + + expected = xr.cov(da_a, da_b, dim=dim, ddof=0) / ( + da_a.std(dim=dim) * da_b.std(dim=dim) + ) + actual = xr.corr(da_a, da_b, dim=dim) + assert_allclose(actual, expected) + + +@requires_dask +@pytest.mark.parametrize("n", range(9)) +@pytest.mark.parametrize("dim", [None, "time", "x"]) +@pytest.mark.filterwarnings("ignore:invalid value encountered in .*divide") +def test_corr_lazycorr_consistency( + n: int, dim: str | None, array_tuples: tuple[xr.DataArray, xr.DataArray] +) -> None: + da_a, da_b = array_tuples[n] + da_al = da_a.chunk() + da_bl = da_b.chunk() + c_abl = xr.corr(da_al, da_bl, dim=dim) + c_ab = xr.corr(da_a, da_b, dim=dim) + c_ab_mixed = xr.corr(da_a, da_bl, dim=dim) + assert_allclose(c_ab, c_abl) + assert_allclose(c_ab, c_ab_mixed) + + +@requires_dask +def test_corr_dtype_error(): + da_a = xr.DataArray([[1, 2], [2, 1]], dims=["x", "time"]) + da_b = xr.DataArray([[1, 2], [1, np.nan]], dims=["x", "time"]) + + xr.testing.assert_equal(xr.corr(da_a, da_b), xr.corr(da_a.chunk(), da_b.chunk())) + xr.testing.assert_equal(xr.corr(da_a, da_b), xr.corr(da_a, da_b.chunk())) + + +@pytest.mark.parametrize("n", range(5)) +@pytest.mark.parametrize("dim", [None, "time", "x", ["time", "x"]]) +def test_autocov(n: int, dim: str | None, arrays) -> None: + da = arrays[n] + + # Testing that the autocovariance*(N-1) is ~=~ to the variance matrix + # 1. Ignore the nans + valid_values = da.notnull() + # Because we're using ddof=1, this requires > 1 value in each sample + da = da.where(valid_values.sum(dim=dim) > 1) + expected = ((da - da.mean(dim=dim)) ** 2).sum(dim=dim, skipna=True, min_count=1) + actual = xr.cov(da, da, dim=dim) * (valid_values.sum(dim) - 1) + assert_allclose(actual, expected) + + +def test_complex_cov() -> None: + da = xr.DataArray([1j, -1j]) + actual = xr.cov(da, da) + assert abs(actual.item()) == 2 + + +@pytest.mark.parametrize("weighted", [True, False]) +def test_bilinear_cov_corr(weighted: bool) -> None: + # Test the bilinear properties of covariance and correlation + da = xr.DataArray( + np.random.random((3, 21, 4)), + coords={"time": pd.date_range("2000-01-01", freq="1D", periods=21)}, + dims=("a", "time", "x"), + ) + db = xr.DataArray( + np.random.random((3, 21, 4)), + coords={"time": pd.date_range("2000-01-01", freq="1D", periods=21)}, + dims=("a", "time", "x"), + ) + dc = xr.DataArray( + np.random.random((3, 21, 4)), + coords={"time": pd.date_range("2000-01-01", freq="1D", periods=21)}, + dims=("a", "time", "x"), + ) + if weighted: + weights = xr.DataArray( + np.abs(np.random.random(4)), + dims=("x"), + ) + else: + weights = None + k = np.random.random(1)[0] + + # Test covariance properties + assert_allclose( + xr.cov(da + k, db, weights=weights), xr.cov(da, db, weights=weights) + ) + assert_allclose( + xr.cov(da, db + k, weights=weights), xr.cov(da, db, weights=weights) + ) + assert_allclose( + xr.cov(da + dc, db, weights=weights), + xr.cov(da, db, weights=weights) + xr.cov(dc, db, weights=weights), + ) + assert_allclose( + xr.cov(da, db + dc, weights=weights), + xr.cov(da, db, weights=weights) + xr.cov(da, dc, weights=weights), + ) + assert_allclose( + xr.cov(k * da, db, weights=weights), k * xr.cov(da, db, weights=weights) + ) + assert_allclose( + xr.cov(da, k * db, weights=weights), k * xr.cov(da, db, weights=weights) + ) + + # Test correlation properties + assert_allclose( + xr.corr(da + k, db, weights=weights), xr.corr(da, db, weights=weights) + ) + assert_allclose( + xr.corr(da, db + k, weights=weights), xr.corr(da, db, weights=weights) + ) + assert_allclose( + xr.corr(k * da, db, weights=weights), xr.corr(da, db, weights=weights) + ) + assert_allclose( + xr.corr(da, k * db, weights=weights), xr.corr(da, db, weights=weights) + ) + + +def test_equally_weighted_cov_corr() -> None: + # Test that equal weights for all values produces same results as weights=None + da = xr.DataArray( + np.random.random((3, 21, 4)), + coords={"time": pd.date_range("2000-01-01", freq="1D", periods=21)}, + dims=("a", "time", "x"), + ) + db = xr.DataArray( + np.random.random((3, 21, 4)), + coords={"time": pd.date_range("2000-01-01", freq="1D", periods=21)}, + dims=("a", "time", "x"), + ) + # + assert_allclose( + xr.cov(da, db, weights=None), xr.cov(da, db, weights=xr.DataArray(1)) + ) + assert_allclose( + xr.cov(da, db, weights=None), xr.cov(da, db, weights=xr.DataArray(2)) + ) + assert_allclose( + xr.corr(da, db, weights=None), xr.corr(da, db, weights=xr.DataArray(1)) + ) + assert_allclose( + xr.corr(da, db, weights=None), xr.corr(da, db, weights=xr.DataArray(2)) + ) + + +@requires_dask +def test_vectorize_dask_new_output_dims() -> None: + # regression test for GH3574 + # run vectorization in dask.array.gufunc by using `dask='parallelized'` + data_array = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=("x", "y")) + func = lambda x: x[np.newaxis, ...] + expected = data_array.expand_dims("z") + actual = apply_ufunc( + func, + data_array.chunk({"x": 1}), + output_core_dims=[["z"]], + vectorize=True, + dask="parallelized", + output_dtypes=[float], + dask_gufunc_kwargs=dict(output_sizes={"z": 1}), + ).transpose(*expected.dims) + assert_identical(expected, actual) + + with pytest.raises( + ValueError, match=r"dimension 'z1' in 'output_sizes' must correspond" + ): + apply_ufunc( + func, + data_array.chunk({"x": 1}), + output_core_dims=[["z"]], + vectorize=True, + dask="parallelized", + output_dtypes=[float], + dask_gufunc_kwargs=dict(output_sizes={"z1": 1}), + ) + + with pytest.raises( + ValueError, match=r"dimension 'z' in 'output_core_dims' needs corresponding" + ): + apply_ufunc( + func, + data_array.chunk({"x": 1}), + output_core_dims=[["z"]], + vectorize=True, + dask="parallelized", + output_dtypes=[float], + ) + + +def test_output_wrong_number() -> None: + variable = xr.Variable("x", np.arange(10)) + + def identity(x): + return x + + def tuple3x(x): + return (x, x, x) + + with pytest.raises( + ValueError, + match=r"number of outputs.* Received a with 10 elements. Expected a tuple of 2 elements:\n\narray\(\[0", + ): + apply_ufunc(identity, variable, output_core_dims=[(), ()]) + + with pytest.raises(ValueError, match=r"number of outputs"): + apply_ufunc(tuple3x, variable, output_core_dims=[(), ()]) + + +def test_output_wrong_dims() -> None: + variable = xr.Variable("x", np.arange(10)) + + def add_dim(x): + return x[..., np.newaxis] + + def remove_dim(x): + return x[..., 0] + + with pytest.raises( + ValueError, + match=r"unexpected number of dimensions.*from:\n\n.*array\(\[\[0", + ): + apply_ufunc(add_dim, variable, output_core_dims=[("y", "z")]) + + with pytest.raises(ValueError, match=r"unexpected number of dimensions"): + apply_ufunc(add_dim, variable) + + with pytest.raises(ValueError, match=r"unexpected number of dimensions"): + apply_ufunc(remove_dim, variable) + + +def test_output_wrong_dim_size() -> None: + array = np.arange(10) + variable = xr.Variable("x", array) + data_array = xr.DataArray(variable, [("x", -array)]) + dataset = xr.Dataset({"y": variable}, {"x": -array}) + + def truncate(array): + return array[:5] + + def apply_truncate_broadcast_invalid(obj): + return apply_ufunc(truncate, obj) + + with pytest.raises(ValueError, match=r"size of dimension"): + apply_truncate_broadcast_invalid(variable) + with pytest.raises(ValueError, match=r"size of dimension"): + apply_truncate_broadcast_invalid(data_array) + with pytest.raises(ValueError, match=r"size of dimension"): + apply_truncate_broadcast_invalid(dataset) + + def apply_truncate_x_x_invalid(obj): + return apply_ufunc( + truncate, obj, input_core_dims=[["x"]], output_core_dims=[["x"]] + ) + + with pytest.raises(ValueError, match=r"size of dimension"): + apply_truncate_x_x_invalid(variable) + with pytest.raises(ValueError, match=r"size of dimension"): + apply_truncate_x_x_invalid(data_array) + with pytest.raises(ValueError, match=r"size of dimension"): + apply_truncate_x_x_invalid(dataset) + + def apply_truncate_x_z(obj): + return apply_ufunc( + truncate, obj, input_core_dims=[["x"]], output_core_dims=[["z"]] + ) + + assert_identical(xr.Variable("z", array[:5]), apply_truncate_x_z(variable)) + assert_identical( + xr.DataArray(array[:5], dims=["z"]), apply_truncate_x_z(data_array) + ) + assert_identical(xr.Dataset({"y": ("z", array[:5])}), apply_truncate_x_z(dataset)) + + def apply_truncate_x_x_valid(obj): + return apply_ufunc( + truncate, + obj, + input_core_dims=[["x"]], + output_core_dims=[["x"]], + exclude_dims={"x"}, + ) + + assert_identical(xr.Variable("x", array[:5]), apply_truncate_x_x_valid(variable)) + assert_identical( + xr.DataArray(array[:5], dims=["x"]), apply_truncate_x_x_valid(data_array) + ) + assert_identical( + xr.Dataset({"y": ("x", array[:5])}), apply_truncate_x_x_valid(dataset) + ) + + +@pytest.mark.parametrize("use_dask", [True, False]) +def test_dot(use_dask: bool) -> None: + if use_dask: + if not has_dask: + pytest.skip("test for dask.") + + a = np.arange(30 * 4).reshape(30, 4) + b = np.arange(30 * 4 * 5).reshape(30, 4, 5) + c = np.arange(5 * 60).reshape(5, 60) + da_a = xr.DataArray(a, dims=["a", "b"], coords={"a": np.linspace(0, 1, 30)}) + da_b = xr.DataArray(b, dims=["a", "b", "c"], coords={"a": np.linspace(0, 1, 30)}) + da_c = xr.DataArray(c, dims=["c", "e"]) + if use_dask: + da_a = da_a.chunk({"a": 3}) + da_b = da_b.chunk({"a": 3}) + da_c = da_c.chunk({"c": 3}) + actual = xr.dot(da_a, da_b, dim=["a", "b"]) + assert actual.dims == ("c",) + assert (actual.data == np.einsum("ij,ijk->k", a, b)).all() + assert isinstance(actual.variable.data, type(da_a.variable.data)) + + actual = xr.dot(da_a, da_b) + assert actual.dims == ("c",) + assert (actual.data == np.einsum("ij,ijk->k", a, b)).all() + assert isinstance(actual.variable.data, type(da_a.variable.data)) + + # for only a single array is passed without dims argument, just return + # as is + actual = xr.dot(da_a) + assert_identical(da_a, actual) + + # test for variable + actual = xr.dot(da_a.variable, da_b.variable) + assert actual.dims == ("c",) + assert (actual.data == np.einsum("ij,ijk->k", a, b)).all() + assert isinstance(actual.data, type(da_a.variable.data)) + + if use_dask: + da_a = da_a.chunk({"a": 3}) + da_b = da_b.chunk({"a": 3}) + actual = xr.dot(da_a, da_b, dim=["b"]) + assert actual.dims == ("a", "c") + assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all() + assert isinstance(actual.variable.data, type(da_a.variable.data)) + + actual = xr.dot(da_a, da_b, dim=["b"]) + assert actual.dims == ("a", "c") + assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all() + + actual = xr.dot(da_a, da_b, dim="b") + assert actual.dims == ("a", "c") + assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all() + + actual = xr.dot(da_a, da_b, dim="a") + assert actual.dims == ("b", "c") + assert (actual.data == np.einsum("ij,ijk->jk", a, b)).all() + + actual = xr.dot(da_a, da_b, dim="c") + assert actual.dims == ("a", "b") + assert (actual.data == np.einsum("ij,ijk->ij", a, b)).all() + + actual = xr.dot(da_a, da_b, da_c, dim=["a", "b"]) + assert actual.dims == ("c", "e") + assert (actual.data == np.einsum("ij,ijk,kl->kl ", a, b, c)).all() + + # should work with tuple + actual = xr.dot(da_a, da_b, dim=("c",)) + assert actual.dims == ("a", "b") + assert (actual.data == np.einsum("ij,ijk->ij", a, b)).all() + + # default dims + actual = xr.dot(da_a, da_b, da_c) + assert actual.dims == ("e",) + assert (actual.data == np.einsum("ij,ijk,kl->l ", a, b, c)).all() + + # 1 array summation + actual = xr.dot(da_a, dim="a") + assert actual.dims == ("b",) + assert (actual.data == np.einsum("ij->j ", a)).all() + + # empty dim + actual = xr.dot(da_a.sel(a=[]), da_a.sel(a=[]), dim="a") + assert actual.dims == ("b",) + assert (actual.data == np.zeros(actual.shape)).all() + + # Ellipsis (...) sums over all dimensions + actual = xr.dot(da_a, da_b, dim=...) + assert actual.dims == () + assert (actual.data == np.einsum("ij,ijk->", a, b)).all() + + actual = xr.dot(da_a, da_b, da_c, dim=...) + assert actual.dims == () + assert (actual.data == np.einsum("ij,ijk,kl-> ", a, b, c)).all() + + actual = xr.dot(da_a, dim=...) + assert actual.dims == () + assert (actual.data == np.einsum("ij-> ", a)).all() + + actual = xr.dot(da_a.sel(a=[]), da_a.sel(a=[]), dim=...) + assert actual.dims == () + assert (actual.data == np.zeros(actual.shape)).all() + + # Invalid cases + if not use_dask: + with pytest.raises(TypeError): + xr.dot(da_a, dim="a", invalid=None) + with pytest.raises(TypeError): + xr.dot(da_a.to_dataset(name="da"), dim="a") + with pytest.raises(TypeError): + xr.dot(dim="a") + + # einsum parameters + actual = xr.dot(da_a, da_b, dim=["b"], order="C") + assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all() + assert actual.values.flags["C_CONTIGUOUS"] + assert not actual.values.flags["F_CONTIGUOUS"] + actual = xr.dot(da_a, da_b, dim=["b"], order="F") + assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all() + # dask converts Fortran arrays to C order when merging the final array + if not use_dask: + assert not actual.values.flags["C_CONTIGUOUS"] + assert actual.values.flags["F_CONTIGUOUS"] + + # einsum has a constant string as of the first parameter, which makes + # it hard to pass to xarray.apply_ufunc. + # make sure dot() uses functools.partial(einsum, subscripts), which + # can be pickled, and not a lambda, which can't. + pickle.loads(pickle.dumps(xr.dot(da_a))) + + +@pytest.mark.parametrize("use_dask", [True, False]) +def test_dot_align_coords(use_dask: bool) -> None: + # GH 3694 + + if use_dask: + if not has_dask: + pytest.skip("test for dask.") + + a = np.arange(30 * 4).reshape(30, 4) + b = np.arange(30 * 4 * 5).reshape(30, 4, 5) + + # use partially overlapping coords + coords_a = {"a": np.arange(30), "b": np.arange(4)} + coords_b = {"a": np.arange(5, 35), "b": np.arange(1, 5)} + + da_a = xr.DataArray(a, dims=["a", "b"], coords=coords_a) + da_b = xr.DataArray(b, dims=["a", "b", "c"], coords=coords_b) + + if use_dask: + da_a = da_a.chunk({"a": 3}) + da_b = da_b.chunk({"a": 3}) + + # join="inner" is the default + actual = xr.dot(da_a, da_b) + # `dot` sums over the common dimensions of the arguments + expected = (da_a * da_b).sum(["a", "b"]) + xr.testing.assert_allclose(expected, actual) + + actual = xr.dot(da_a, da_b, dim=...) + expected = (da_a * da_b).sum() + xr.testing.assert_allclose(expected, actual) + + with xr.set_options(arithmetic_join="exact"): + with pytest.raises(ValueError, match=r"cannot align.*join.*exact.*not equal.*"): + xr.dot(da_a, da_b) + + # NOTE: dot always uses `join="inner"` because `(a * b).sum()` yields the same for all + # join method (except "exact") + with xr.set_options(arithmetic_join="left"): + actual = xr.dot(da_a, da_b) + expected = (da_a * da_b).sum(["a", "b"]) + xr.testing.assert_allclose(expected, actual) + + with xr.set_options(arithmetic_join="right"): + actual = xr.dot(da_a, da_b) + expected = (da_a * da_b).sum(["a", "b"]) + xr.testing.assert_allclose(expected, actual) + + with xr.set_options(arithmetic_join="outer"): + actual = xr.dot(da_a, da_b) + expected = (da_a * da_b).sum(["a", "b"]) + xr.testing.assert_allclose(expected, actual) + + +def test_where() -> None: + cond = xr.DataArray([True, False], dims="x") + actual = xr.where(cond, 1, 0) + expected = xr.DataArray([1, 0], dims="x") + assert_identical(expected, actual) + + +def test_where_attrs() -> None: + cond = xr.DataArray([True, False], coords={"a": [0, 1]}, attrs={"attr": "cond_da"}) + cond["a"].attrs = {"attr": "cond_coord"} + x = xr.DataArray([1, 1], coords={"a": [0, 1]}, attrs={"attr": "x_da"}) + x["a"].attrs = {"attr": "x_coord"} + y = xr.DataArray([0, 0], coords={"a": [0, 1]}, attrs={"attr": "y_da"}) + y["a"].attrs = {"attr": "y_coord"} + + # 3 DataArrays, takes attrs from x + actual = xr.where(cond, x, y, keep_attrs=True) + expected = xr.DataArray([1, 0], coords={"a": [0, 1]}, attrs={"attr": "x_da"}) + expected["a"].attrs = {"attr": "x_coord"} + assert_identical(expected, actual) + + # x as a scalar, takes no attrs + actual = xr.where(cond, 0, y, keep_attrs=True) + expected = xr.DataArray([0, 0], coords={"a": [0, 1]}) + assert_identical(expected, actual) + + # y as a scalar, takes attrs from x + actual = xr.where(cond, x, 0, keep_attrs=True) + expected = xr.DataArray([1, 0], coords={"a": [0, 1]}, attrs={"attr": "x_da"}) + expected["a"].attrs = {"attr": "x_coord"} + assert_identical(expected, actual) + + # x and y as a scalar, takes no attrs + actual = xr.where(cond, 1, 0, keep_attrs=True) + expected = xr.DataArray([1, 0], coords={"a": [0, 1]}) + assert_identical(expected, actual) + + # cond and y as a scalar, takes attrs from x + actual = xr.where(True, x, y, keep_attrs=True) + expected = xr.DataArray([1, 1], coords={"a": [0, 1]}, attrs={"attr": "x_da"}) + expected["a"].attrs = {"attr": "x_coord"} + assert_identical(expected, actual) + + # no xarray objects, handle no attrs + actual_np = xr.where(True, 0, 1, keep_attrs=True) + expected_np = np.array(0) + assert_identical(expected_np, actual_np) + + # DataArray and 2 Datasets, takes attrs from x + ds_x = xr.Dataset(data_vars={"x": x}, attrs={"attr": "x_ds"}) + ds_y = xr.Dataset(data_vars={"x": y}, attrs={"attr": "y_ds"}) + ds_actual = xr.where(cond, ds_x, ds_y, keep_attrs=True) + ds_expected = xr.Dataset( + data_vars={ + "x": xr.DataArray([1, 0], coords={"a": [0, 1]}, attrs={"attr": "x_da"}) + }, + attrs={"attr": "x_ds"}, + ) + ds_expected["a"].attrs = {"attr": "x_coord"} + assert_identical(ds_expected, ds_actual) + + # 2 DataArrays and 1 Dataset, takes attrs from x + ds_actual = xr.where(cond, x.rename("x"), ds_y, keep_attrs=True) + ds_expected = xr.Dataset( + data_vars={ + "x": xr.DataArray([1, 0], coords={"a": [0, 1]}, attrs={"attr": "x_da"}) + }, + ) + ds_expected["a"].attrs = {"attr": "x_coord"} + assert_identical(ds_expected, ds_actual) + + +@pytest.mark.parametrize( + "use_dask", [pytest.param(False, id="nodask"), pytest.param(True, id="dask")] +) +@pytest.mark.parametrize( + ["x", "coeffs", "expected"], + [ + pytest.param( + xr.DataArray([1, 2, 3], dims="x"), + xr.DataArray([2, 3, 4], dims="degree", coords={"degree": [0, 1, 2]}), + xr.DataArray([9, 2 + 6 + 16, 2 + 9 + 36], dims="x"), + id="simple", + ), + pytest.param( + xr.DataArray([1, 2, 3], dims="x"), + xr.DataArray( + [[0, 1], [0, 1]], dims=("y", "degree"), coords={"degree": [0, 1]} + ), + xr.DataArray([[1, 1], [2, 2], [3, 3]], dims=("x", "y")), + id="broadcast-x", + ), + pytest.param( + xr.DataArray([1, 2, 3], dims="x"), + xr.DataArray( + [[0, 1], [1, 0], [1, 1]], + dims=("x", "degree"), + coords={"degree": [0, 1]}, + ), + xr.DataArray([1, 1, 1 + 3], dims="x"), + id="shared-dim", + ), + pytest.param( + xr.DataArray([1, 2, 3], dims="x"), + xr.DataArray([1, 0, 0], dims="degree", coords={"degree": [2, 1, 0]}), + xr.DataArray([1, 2**2, 3**2], dims="x"), + id="reordered-index", + ), + pytest.param( + xr.DataArray([1, 2, 3], dims="x"), + xr.DataArray([5], dims="degree", coords={"degree": [3]}), + xr.DataArray([5, 5 * 2**3, 5 * 3**3], dims="x"), + id="sparse-index", + ), + pytest.param( + xr.DataArray([1, 2, 3], dims="x"), + xr.Dataset( + {"a": ("degree", [0, 1]), "b": ("degree", [1, 0])}, + coords={"degree": [0, 1]}, + ), + xr.Dataset({"a": ("x", [1, 2, 3]), "b": ("x", [1, 1, 1])}), + id="array-dataset", + ), + pytest.param( + xr.Dataset({"a": ("x", [1, 2, 3]), "b": ("x", [2, 3, 4])}), + xr.DataArray([1, 1], dims="degree", coords={"degree": [0, 1]}), + xr.Dataset({"a": ("x", [2, 3, 4]), "b": ("x", [3, 4, 5])}), + id="dataset-array", + ), + pytest.param( + xr.Dataset({"a": ("x", [1, 2, 3]), "b": ("y", [2, 3, 4])}), + xr.Dataset( + {"a": ("degree", [0, 1]), "b": ("degree", [1, 1])}, + coords={"degree": [0, 1]}, + ), + xr.Dataset({"a": ("x", [1, 2, 3]), "b": ("y", [3, 4, 5])}), + id="dataset-dataset", + ), + pytest.param( + xr.DataArray(pd.date_range("1970-01-01", freq="s", periods=3), dims="x"), + xr.DataArray([0, 1], dims="degree", coords={"degree": [0, 1]}), + xr.DataArray( + [0, 1e9, 2e9], + dims="x", + coords={"x": pd.date_range("1970-01-01", freq="s", periods=3)}, + ), + id="datetime", + ), + pytest.param( + xr.DataArray( + np.array([1000, 2000, 3000], dtype="timedelta64[ns]"), dims="x" + ), + xr.DataArray([0, 1], dims="degree", coords={"degree": [0, 1]}), + xr.DataArray([1000.0, 2000.0, 3000.0], dims="x"), + id="timedelta", + ), + pytest.param( + xr.DataArray([1, 2, 3], dims="x"), + xr.DataArray( + [2, 3, 4], + dims="degree", + coords={"degree": np.array([0, 1, 2], dtype=np.int64)}, + ), + xr.DataArray([9, 2 + 6 + 16, 2 + 9 + 36], dims="x"), + id="int64-degree", + ), + pytest.param( + xr.DataArray([1, 2, 3], dims="x"), + xr.DataArray( + [2, 3, 4], + dims="degree", + coords={"degree": np.array([0, 1, 2], dtype=np.int32)}, + ), + xr.DataArray([9, 2 + 6 + 16, 2 + 9 + 36], dims="x"), + id="int32-degree", + ), + pytest.param( + xr.DataArray([1, 2, 3], dims="x"), + xr.DataArray( + [2, 3, 4], + dims="degree", + coords={"degree": np.array([0, 1, 2], dtype=np.uint8)}, + ), + xr.DataArray([9, 2 + 6 + 16, 2 + 9 + 36], dims="x"), + id="uint8-degree", + ), + ], +) +def test_polyval( + use_dask: bool, + x: xr.DataArray | xr.Dataset, + coeffs: xr.DataArray | xr.Dataset, + expected: xr.DataArray | xr.Dataset, +) -> None: + if use_dask: + if not has_dask: + pytest.skip("requires dask") + coeffs = coeffs.chunk({"degree": 2}) + x = x.chunk({"x": 2}) + + with raise_if_dask_computes(): + actual = xr.polyval(coord=x, coeffs=coeffs) + + xr.testing.assert_allclose(actual, expected) + + +@requires_cftime +@pytest.mark.parametrize( + "use_dask", [pytest.param(False, id="nodask"), pytest.param(True, id="dask")] +) +@pytest.mark.parametrize("date", ["1970-01-01", "0753-04-21"]) +def test_polyval_cftime(use_dask: bool, date: str) -> None: + import cftime + + x = xr.DataArray( + xr.date_range(date, freq="1s", periods=3, use_cftime=True), + dims="x", + ) + coeffs = xr.DataArray([0, 1], dims="degree", coords={"degree": [0, 1]}) + + if use_dask: + if not has_dask: + pytest.skip("requires dask") + coeffs = coeffs.chunk({"degree": 2}) + x = x.chunk({"x": 2}) + + with raise_if_dask_computes(max_computes=1): + actual = xr.polyval(coord=x, coeffs=coeffs) + + t0 = xr.date_range(date, periods=1)[0] + offset = (t0 - cftime.DatetimeGregorian(1970, 1, 1)).total_seconds() * 1e9 + expected = ( + xr.DataArray( + [0, 1e9, 2e9], + dims="x", + coords={"x": xr.date_range(date, freq="1s", periods=3, use_cftime=True)}, + ) + + offset + ) + xr.testing.assert_allclose(actual, expected) + + +def test_polyval_degree_dim_checks() -> None: + x = xr.DataArray([1, 2, 3], dims="x") + coeffs = xr.DataArray([2, 3, 4], dims="degree", coords={"degree": [0, 1, 2]}) + with pytest.raises(ValueError): + xr.polyval(x, coeffs.drop_vars("degree")) + with pytest.raises(ValueError): + xr.polyval(x, coeffs.assign_coords(degree=coeffs.degree.astype(float))) + + +@pytest.mark.parametrize( + "use_dask", [pytest.param(False, id="nodask"), pytest.param(True, id="dask")] +) +@pytest.mark.parametrize( + "x", + [ + pytest.param(xr.DataArray([0, 1, 2], dims="x"), id="simple"), + pytest.param( + xr.DataArray(pd.date_range("1970-01-01", freq="ns", periods=3), dims="x"), + id="datetime", + ), + pytest.param( + xr.DataArray(np.array([0, 1, 2], dtype="timedelta64[ns]"), dims="x"), + id="timedelta", + ), + ], +) +@pytest.mark.parametrize( + "y", + [ + pytest.param(xr.DataArray([1, 6, 17], dims="x"), id="1D"), + pytest.param( + xr.DataArray([[1, 6, 17], [34, 57, 86]], dims=("y", "x")), id="2D" + ), + ], +) +def test_polyfit_polyval_integration( + use_dask: bool, x: xr.DataArray, y: xr.DataArray +) -> None: + y.coords["x"] = x + if use_dask: + if not has_dask: + pytest.skip("requires dask") + y = y.chunk({"x": 2}) + + fit = y.polyfit(dim="x", deg=2) + evaluated = xr.polyval(y.x, fit.polyfit_coefficients) + expected = y.transpose(*evaluated.dims) + xr.testing.assert_allclose(evaluated.variable, expected.variable) + + +@pytest.mark.parametrize("use_dask", [False, True]) +@pytest.mark.parametrize( + "a, b, ae, be, dim, axis", + [ + [ + xr.DataArray([1, 2, 3]), + xr.DataArray([4, 5, 6]), + np.array([1, 2, 3]), + np.array([4, 5, 6]), + "dim_0", + -1, + ], + [ + xr.DataArray([1, 2]), + xr.DataArray([4, 5, 6]), + np.array([1, 2, 0]), + np.array([4, 5, 6]), + "dim_0", + -1, + ], + [ + xr.Variable(dims=["dim_0"], data=[1, 2, 3]), + xr.Variable(dims=["dim_0"], data=[4, 5, 6]), + np.array([1, 2, 3]), + np.array([4, 5, 6]), + "dim_0", + -1, + ], + [ + xr.Variable(dims=["dim_0"], data=[1, 2]), + xr.Variable(dims=["dim_0"], data=[4, 5, 6]), + np.array([1, 2, 0]), + np.array([4, 5, 6]), + "dim_0", + -1, + ], + [ # Test dim in the middle: + xr.DataArray( + np.arange(0, 5 * 3 * 4).reshape((5, 3, 4)), + dims=["time", "cartesian", "var"], + coords=dict( + time=(["time"], np.arange(0, 5)), + cartesian=(["cartesian"], ["x", "y", "z"]), + var=(["var"], [1, 1.5, 2, 2.5]), + ), + ), + xr.DataArray( + np.arange(0, 5 * 3 * 4).reshape((5, 3, 4)) + 1, + dims=["time", "cartesian", "var"], + coords=dict( + time=(["time"], np.arange(0, 5)), + cartesian=(["cartesian"], ["x", "y", "z"]), + var=(["var"], [1, 1.5, 2, 2.5]), + ), + ), + np.arange(0, 5 * 3 * 4).reshape((5, 3, 4)), + np.arange(0, 5 * 3 * 4).reshape((5, 3, 4)) + 1, + "cartesian", + 1, + ], + [ # Test 1 sized arrays with coords: + xr.DataArray( + np.array([1]), + dims=["cartesian"], + coords=dict(cartesian=(["cartesian"], ["z"])), + ), + xr.DataArray( + np.array([4, 5, 6]), + dims=["cartesian"], + coords=dict(cartesian=(["cartesian"], ["x", "y", "z"])), + ), + np.array([0, 0, 1]), + np.array([4, 5, 6]), + "cartesian", + -1, + ], + [ # Test filling in between with coords: + xr.DataArray( + [1, 2], + dims=["cartesian"], + coords=dict(cartesian=(["cartesian"], ["x", "z"])), + ), + xr.DataArray( + [4, 5, 6], + dims=["cartesian"], + coords=dict(cartesian=(["cartesian"], ["x", "y", "z"])), + ), + np.array([1, 0, 2]), + np.array([4, 5, 6]), + "cartesian", + -1, + ], + ], +) +def test_cross(a, b, ae, be, dim: str, axis: int, use_dask: bool) -> None: + expected = np.cross(ae, be, axis=axis) + + if use_dask: + if not has_dask: + pytest.skip("test for dask.") + a = a.chunk() + b = b.chunk() + + actual = xr.cross(a, b, dim=dim) + xr.testing.assert_duckarray_allclose(expected, actual) diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_concat.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_concat.py new file mode 100644 index 0000000..0c570de --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_concat.py @@ -0,0 +1,1372 @@ +from __future__ import annotations + +from copy import deepcopy +from typing import TYPE_CHECKING, Any, Callable + +import numpy as np +import pandas as pd +import pytest + +from xarray import DataArray, Dataset, Variable, concat +from xarray.core import dtypes, merge +from xarray.core.coordinates import Coordinates +from xarray.core.indexes import PandasIndex +from xarray.tests import ( + ConcatenatableArray, + InaccessibleArray, + UnexpectedDataAccess, + assert_array_equal, + assert_equal, + assert_identical, + requires_dask, +) +from xarray.tests.test_dataset import create_test_data + +if TYPE_CHECKING: + from xarray.core.types import CombineAttrsOptions, JoinOptions + + +# helper method to create multiple tests datasets to concat +def create_concat_datasets( + num_datasets: int = 2, seed: int | None = None, include_day: bool = True +) -> list[Dataset]: + rng = np.random.default_rng(seed) + lat = rng.standard_normal(size=(1, 4)) + lon = rng.standard_normal(size=(1, 4)) + result = [] + variables = ["temperature", "pressure", "humidity", "precipitation", "cloud_cover"] + for i in range(num_datasets): + if include_day: + data_tuple = ( + ["x", "y", "day"], + rng.standard_normal(size=(1, 4, 2)), + ) + data_vars = {v: data_tuple for v in variables} + result.append( + Dataset( + data_vars=data_vars, + coords={ + "lat": (["x", "y"], lat), + "lon": (["x", "y"], lon), + "day": ["day" + str(i * 2 + 1), "day" + str(i * 2 + 2)], + }, + ) + ) + else: + data_tuple = ( + ["x", "y"], + rng.standard_normal(size=(1, 4)), + ) + data_vars = {v: data_tuple for v in variables} + result.append( + Dataset( + data_vars=data_vars, + coords={"lat": (["x", "y"], lat), "lon": (["x", "y"], lon)}, + ) + ) + + return result + + +# helper method to create multiple tests datasets to concat with specific types +def create_typed_datasets( + num_datasets: int = 2, seed: int | None = None +) -> list[Dataset]: + var_strings = ["a", "b", "c", "d", "e", "f", "g", "h"] + result = [] + rng = np.random.default_rng(seed) + lat = rng.standard_normal(size=(1, 4)) + lon = rng.standard_normal(size=(1, 4)) + for i in range(num_datasets): + result.append( + Dataset( + data_vars={ + "float": (["x", "y", "day"], rng.standard_normal(size=(1, 4, 2))), + "float2": (["x", "y", "day"], rng.standard_normal(size=(1, 4, 2))), + "string": ( + ["x", "y", "day"], + rng.choice(var_strings, size=(1, 4, 2)), + ), + "int": (["x", "y", "day"], rng.integers(0, 10, size=(1, 4, 2))), + "datetime64": ( + ["x", "y", "day"], + np.arange( + np.datetime64("2017-01-01"), np.datetime64("2017-01-09") + ).reshape(1, 4, 2), + ), + "timedelta64": ( + ["x", "y", "day"], + np.reshape([pd.Timedelta(days=i) for i in range(8)], [1, 4, 2]), + ), + }, + coords={ + "lat": (["x", "y"], lat), + "lon": (["x", "y"], lon), + "day": ["day" + str(i * 2 + 1), "day" + str(i * 2 + 2)], + }, + ) + ) + return result + + +def test_concat_compat() -> None: + ds1 = Dataset( + { + "has_x_y": (("y", "x"), [[1, 2]]), + "has_x": ("x", [1, 2]), + "no_x_y": ("z", [1, 2]), + }, + coords={"x": [0, 1], "y": [0], "z": [-1, -2]}, + ) + ds2 = Dataset( + { + "has_x_y": (("y", "x"), [[3, 4]]), + "has_x": ("x", [1, 2]), + "no_x_y": (("q", "z"), [[1, 2]]), + }, + coords={"x": [0, 1], "y": [1], "z": [-1, -2], "q": [0]}, + ) + + result = concat([ds1, ds2], dim="y", data_vars="minimal", compat="broadcast_equals") + assert_equal(ds2.no_x_y, result.no_x_y.transpose()) + + for var in ["has_x", "no_x_y"]: + assert "y" not in result[var].dims and "y" not in result[var].coords + with pytest.raises(ValueError, match=r"'q' not present in all datasets"): + concat([ds1, ds2], dim="q") + with pytest.raises(ValueError, match=r"'q' not present in all datasets"): + concat([ds2, ds1], dim="q") + + +def test_concat_missing_var() -> None: + datasets = create_concat_datasets(2, seed=123) + expected = concat(datasets, dim="day") + vars_to_drop = ["humidity", "precipitation", "cloud_cover"] + + expected = expected.drop_vars(vars_to_drop) + expected["pressure"][..., 2:] = np.nan + + datasets[0] = datasets[0].drop_vars(vars_to_drop) + datasets[1] = datasets[1].drop_vars(vars_to_drop + ["pressure"]) + actual = concat(datasets, dim="day") + + assert list(actual.data_vars.keys()) == ["temperature", "pressure"] + assert_identical(actual, expected) + + +def test_concat_categorical() -> None: + data1 = create_test_data(use_extension_array=True) + data2 = create_test_data(use_extension_array=True) + concatenated = concat([data1, data2], dim="dim1") + assert ( + concatenated["var4"] + == type(data2["var4"].variable.data.array)._concat_same_type( + [ + data1["var4"].variable.data.array, + data2["var4"].variable.data.array, + ] + ) + ).all() + + +def test_concat_missing_multiple_consecutive_var() -> None: + datasets = create_concat_datasets(3, seed=123) + expected = concat(datasets, dim="day") + vars_to_drop = ["humidity", "pressure"] + + expected["pressure"][..., :4] = np.nan + expected["humidity"][..., :4] = np.nan + + datasets[0] = datasets[0].drop_vars(vars_to_drop) + datasets[1] = datasets[1].drop_vars(vars_to_drop) + actual = concat(datasets, dim="day") + + assert list(actual.data_vars.keys()) == [ + "temperature", + "precipitation", + "cloud_cover", + "pressure", + "humidity", + ] + assert_identical(actual, expected) + + +def test_concat_all_empty() -> None: + ds1 = Dataset() + ds2 = Dataset() + expected = Dataset() + actual = concat([ds1, ds2], dim="new_dim") + + assert_identical(actual, expected) + + +def test_concat_second_empty() -> None: + ds1 = Dataset(data_vars={"a": ("y", [0.1])}, coords={"x": 0.1}) + ds2 = Dataset(coords={"x": 0.1}) + + expected = Dataset(data_vars={"a": ("y", [0.1, np.nan])}, coords={"x": 0.1}) + actual = concat([ds1, ds2], dim="y") + assert_identical(actual, expected) + + expected = Dataset( + data_vars={"a": ("y", [0.1, np.nan])}, coords={"x": ("y", [0.1, 0.1])} + ) + actual = concat([ds1, ds2], dim="y", coords="all") + assert_identical(actual, expected) + + # Check concatenating scalar data_var only present in ds1 + ds1["b"] = 0.1 + expected = Dataset( + data_vars={"a": ("y", [0.1, np.nan]), "b": ("y", [0.1, np.nan])}, + coords={"x": ("y", [0.1, 0.1])}, + ) + actual = concat([ds1, ds2], dim="y", coords="all", data_vars="all") + assert_identical(actual, expected) + + expected = Dataset( + data_vars={"a": ("y", [0.1, np.nan]), "b": 0.1}, coords={"x": 0.1} + ) + actual = concat([ds1, ds2], dim="y", coords="different", data_vars="different") + assert_identical(actual, expected) + + +def test_concat_multiple_missing_variables() -> None: + datasets = create_concat_datasets(2, seed=123) + expected = concat(datasets, dim="day") + vars_to_drop = ["pressure", "cloud_cover"] + + expected["pressure"][..., 2:] = np.nan + expected["cloud_cover"][..., 2:] = np.nan + + datasets[1] = datasets[1].drop_vars(vars_to_drop) + actual = concat(datasets, dim="day") + + # check the variables orders are the same + assert list(actual.data_vars.keys()) == [ + "temperature", + "pressure", + "humidity", + "precipitation", + "cloud_cover", + ] + + assert_identical(actual, expected) + + +@pytest.mark.parametrize("include_day", [True, False]) +def test_concat_multiple_datasets_missing_vars(include_day: bool) -> None: + vars_to_drop = [ + "temperature", + "pressure", + "humidity", + "precipitation", + "cloud_cover", + ] + + datasets = create_concat_datasets( + len(vars_to_drop), seed=123, include_day=include_day + ) + expected = concat(datasets, dim="day") + + for i, name in enumerate(vars_to_drop): + if include_day: + expected[name][..., i * 2 : (i + 1) * 2] = np.nan + else: + expected[name][i : i + 1, ...] = np.nan + + # set up the test data + datasets = [ds.drop_vars(varname) for ds, varname in zip(datasets, vars_to_drop)] + + actual = concat(datasets, dim="day") + + assert list(actual.data_vars.keys()) == [ + "pressure", + "humidity", + "precipitation", + "cloud_cover", + "temperature", + ] + assert_identical(actual, expected) + + +def test_concat_multiple_datasets_with_multiple_missing_variables() -> None: + vars_to_drop_in_first = ["temperature", "pressure"] + vars_to_drop_in_second = ["humidity", "precipitation", "cloud_cover"] + datasets = create_concat_datasets(2, seed=123) + expected = concat(datasets, dim="day") + for name in vars_to_drop_in_first: + expected[name][..., :2] = np.nan + for name in vars_to_drop_in_second: + expected[name][..., 2:] = np.nan + + # set up the test data + datasets[0] = datasets[0].drop_vars(vars_to_drop_in_first) + datasets[1] = datasets[1].drop_vars(vars_to_drop_in_second) + + actual = concat(datasets, dim="day") + + assert list(actual.data_vars.keys()) == [ + "humidity", + "precipitation", + "cloud_cover", + "temperature", + "pressure", + ] + assert_identical(actual, expected) + + +@pytest.mark.filterwarnings("ignore:Converting non-nanosecond") +def test_concat_type_of_missing_fill() -> None: + datasets = create_typed_datasets(2, seed=123) + expected1 = concat(datasets, dim="day", fill_value=dtypes.NA) + expected2 = concat(datasets[::-1], dim="day", fill_value=dtypes.NA) + vars = ["float", "float2", "string", "int", "datetime64", "timedelta64"] + expected = [expected2, expected1] + for i, exp in enumerate(expected): + sl = slice(i * 2, (i + 1) * 2) + exp["float2"][..., sl] = np.nan + exp["datetime64"][..., sl] = np.nan + exp["timedelta64"][..., sl] = np.nan + var = exp["int"] * 1.0 + var[..., sl] = np.nan + exp["int"] = var + var = exp["string"].astype(object) + var[..., sl] = np.nan + exp["string"] = var + + # set up the test data + datasets[1] = datasets[1].drop_vars(vars[1:]) + + actual = concat(datasets, dim="day", fill_value=dtypes.NA) + + assert_identical(actual, expected[1]) + + # reversed + actual = concat(datasets[::-1], dim="day", fill_value=dtypes.NA) + + assert_identical(actual, expected[0]) + + +def test_concat_order_when_filling_missing() -> None: + vars_to_drop_in_first: list[str] = [] + # drop middle + vars_to_drop_in_second = ["humidity"] + datasets = create_concat_datasets(2, seed=123) + expected1 = concat(datasets, dim="day") + for name in vars_to_drop_in_second: + expected1[name][..., 2:] = np.nan + expected2 = concat(datasets[::-1], dim="day") + for name in vars_to_drop_in_second: + expected2[name][..., :2] = np.nan + + # set up the test data + datasets[0] = datasets[0].drop_vars(vars_to_drop_in_first) + datasets[1] = datasets[1].drop_vars(vars_to_drop_in_second) + + actual = concat(datasets, dim="day") + + assert list(actual.data_vars.keys()) == [ + "temperature", + "pressure", + "humidity", + "precipitation", + "cloud_cover", + ] + assert_identical(actual, expected1) + + actual = concat(datasets[::-1], dim="day") + + assert list(actual.data_vars.keys()) == [ + "temperature", + "pressure", + "precipitation", + "cloud_cover", + "humidity", + ] + assert_identical(actual, expected2) + + +@pytest.fixture +def concat_var_names() -> Callable: + # create var names list with one missing value + def get_varnames(var_cnt: int = 10, list_cnt: int = 10) -> list[list[str]]: + orig = [f"d{i:02d}" for i in range(var_cnt)] + var_names = [] + for i in range(0, list_cnt): + l1 = orig.copy() + var_names.append(l1) + return var_names + + return get_varnames + + +@pytest.fixture +def create_concat_ds() -> Callable: + def create_ds( + var_names: list[list[str]], + dim: bool = False, + coord: bool = False, + drop_idx: list[int] | None = None, + ) -> list[Dataset]: + out_ds = [] + ds = Dataset() + ds = ds.assign_coords({"x": np.arange(2)}) + ds = ds.assign_coords({"y": np.arange(3)}) + ds = ds.assign_coords({"z": np.arange(4)}) + for i, dsl in enumerate(var_names): + vlist = dsl.copy() + if drop_idx is not None: + vlist.pop(drop_idx[i]) + foo_data = np.arange(48, dtype=float).reshape(2, 2, 3, 4) + dsi = ds.copy() + if coord: + dsi = ds.assign({"time": (["time"], [i * 2, i * 2 + 1])}) + for k in vlist: + dsi = dsi.assign({k: (["time", "x", "y", "z"], foo_data.copy())}) + if not dim: + dsi = dsi.isel(time=0) + out_ds.append(dsi) + return out_ds + + return create_ds + + +@pytest.mark.parametrize("dim", [True, False]) +@pytest.mark.parametrize("coord", [True, False]) +def test_concat_fill_missing_variables( + concat_var_names, create_concat_ds, dim: bool, coord: bool +) -> None: + var_names = concat_var_names() + drop_idx = [0, 7, 6, 4, 4, 8, 0, 6, 2, 0] + + expected = concat( + create_concat_ds(var_names, dim=dim, coord=coord), dim="time", data_vars="all" + ) + for i, idx in enumerate(drop_idx): + if dim: + expected[var_names[0][idx]][i * 2 : i * 2 + 2] = np.nan + else: + expected[var_names[0][idx]][i] = np.nan + + concat_ds = create_concat_ds(var_names, dim=dim, coord=coord, drop_idx=drop_idx) + actual = concat(concat_ds, dim="time", data_vars="all") + + assert list(actual.data_vars.keys()) == [ + "d01", + "d02", + "d03", + "d04", + "d05", + "d06", + "d07", + "d08", + "d09", + "d00", + ] + assert_identical(actual, expected) + + +class TestConcatDataset: + @pytest.fixture + def data(self, request) -> Dataset: + use_extension_array = request.param if hasattr(request, "param") else False + return create_test_data(use_extension_array=use_extension_array).drop_dims( + "dim3" + ) + + def rectify_dim_order(self, data, dataset) -> Dataset: + # return a new dataset with all variable dimensions transposed into + # the order in which they are found in `data` + return Dataset( + {k: v.transpose(*data[k].dims) for k, v in dataset.data_vars.items()}, + dataset.coords, + attrs=dataset.attrs, + ) + + @pytest.mark.parametrize("coords", ["different", "minimal"]) + @pytest.mark.parametrize( + "dim,data", [["dim1", True], ["dim2", False]], indirect=["data"] + ) + def test_concat_simple(self, data, dim, coords) -> None: + datasets = [g for _, g in data.groupby(dim, squeeze=False)] + assert_identical(data, concat(datasets, dim, coords=coords)) + + def test_concat_merge_variables_present_in_some_datasets(self, data) -> None: + # coordinates present in some datasets but not others + ds1 = Dataset(data_vars={"a": ("y", [0.1])}, coords={"x": 0.1}) + ds2 = Dataset(data_vars={"a": ("y", [0.2])}, coords={"z": 0.2}) + actual = concat([ds1, ds2], dim="y", coords="minimal") + expected = Dataset({"a": ("y", [0.1, 0.2])}, coords={"x": 0.1, "z": 0.2}) + assert_identical(expected, actual) + + # data variables present in some datasets but not others + split_data = [data.isel(dim1=slice(3)), data.isel(dim1=slice(3, None))] + data0, data1 = deepcopy(split_data) + data1["foo"] = ("bar", np.random.randn(10)) + actual = concat([data0, data1], "dim1", data_vars="minimal") + expected = data.copy().assign(foo=data1.foo) + assert_identical(expected, actual) + + # expand foo + actual = concat([data0, data1], "dim1") + foo = np.ones((8, 10), dtype=data1.foo.dtype) * np.nan + foo[3:] = data1.foo.values[None, ...] + expected = data.copy().assign(foo=(["dim1", "bar"], foo)) + assert_identical(expected, actual) + + @pytest.mark.parametrize("data", [False], indirect=["data"]) + def test_concat_2(self, data) -> None: + dim = "dim2" + datasets = [g.squeeze(dim) for _, g in data.groupby(dim, squeeze=False)] + concat_over = [k for k, v in data.coords.items() if dim in v.dims and k != dim] + actual = concat(datasets, data[dim], coords=concat_over) + assert_identical(data, self.rectify_dim_order(data, actual)) + + @pytest.mark.parametrize("coords", ["different", "minimal", "all"]) + @pytest.mark.parametrize("dim", ["dim1", "dim2"]) + def test_concat_coords_kwarg(self, data, dim, coords) -> None: + data = data.copy(deep=True) + # make sure the coords argument behaves as expected + data.coords["extra"] = ("dim4", np.arange(3)) + datasets = [g.squeeze() for _, g in data.groupby(dim, squeeze=False)] + + actual = concat(datasets, data[dim], coords=coords) + if coords == "all": + expected = np.array([data["extra"].values for _ in range(data.sizes[dim])]) + assert_array_equal(actual["extra"].values, expected) + + else: + assert_equal(data["extra"], actual["extra"]) + + def test_concat(self, data) -> None: + split_data = [ + data.isel(dim1=slice(3)), + data.isel(dim1=3), + data.isel(dim1=slice(4, None)), + ] + assert_identical(data, concat(split_data, "dim1")) + + def test_concat_dim_precedence(self, data) -> None: + # verify that the dim argument takes precedence over + # concatenating dataset variables of the same name + dim = (2 * data["dim1"]).rename("dim1") + datasets = [g for _, g in data.groupby("dim1", squeeze=False)] + expected = data.copy() + expected["dim1"] = dim + assert_identical(expected, concat(datasets, dim)) + + def test_concat_data_vars_typing(self) -> None: + # Testing typing, can be removed if the next function works with annotations. + data = Dataset({"foo": ("x", np.random.randn(10))}) + objs: list[Dataset] = [data.isel(x=slice(5)), data.isel(x=slice(5, None))] + actual = concat(objs, dim="x", data_vars="minimal") + assert_identical(data, actual) + + def test_concat_data_vars(self) -> None: + data = Dataset({"foo": ("x", np.random.randn(10))}) + objs: list[Dataset] = [data.isel(x=slice(5)), data.isel(x=slice(5, None))] + for data_vars in ["minimal", "different", "all", [], ["foo"]]: + actual = concat(objs, dim="x", data_vars=data_vars) + assert_identical(data, actual) + + def test_concat_coords(self): + # TODO: annotating this func fails + data = Dataset({"foo": ("x", np.random.randn(10))}) + expected = data.assign_coords(c=("x", [0] * 5 + [1] * 5)) + objs = [ + data.isel(x=slice(5)).assign_coords(c=0), + data.isel(x=slice(5, None)).assign_coords(c=1), + ] + for coords in ["different", "all", ["c"]]: + actual = concat(objs, dim="x", coords=coords) + assert_identical(expected, actual) + for coords in ["minimal", []]: + with pytest.raises(merge.MergeError, match="conflicting values"): + concat(objs, dim="x", coords=coords) + + def test_concat_constant_index(self): + # TODO: annotating this func fails + # GH425 + ds1 = Dataset({"foo": 1.5}, {"y": 1}) + ds2 = Dataset({"foo": 2.5}, {"y": 1}) + expected = Dataset({"foo": ("y", [1.5, 2.5]), "y": [1, 1]}) + for mode in ["different", "all", ["foo"]]: + actual = concat([ds1, ds2], "y", data_vars=mode) + assert_identical(expected, actual) + with pytest.raises(merge.MergeError, match="conflicting values"): + # previously dim="y", and raised error which makes no sense. + # "foo" has dimension "y" so minimal should concatenate it? + concat([ds1, ds2], "new_dim", data_vars="minimal") + + def test_concat_size0(self) -> None: + data = create_test_data() + split_data = [data.isel(dim1=slice(0, 0)), data] + actual = concat(split_data, "dim1") + assert_identical(data, actual) + + actual = concat(split_data[::-1], "dim1") + assert_identical(data, actual) + + def test_concat_autoalign(self) -> None: + ds1 = Dataset({"foo": DataArray([1, 2], coords=[("x", [1, 2])])}) + ds2 = Dataset({"foo": DataArray([1, 2], coords=[("x", [1, 3])])}) + actual = concat([ds1, ds2], "y") + expected = Dataset( + { + "foo": DataArray( + [[1, 2, np.nan], [1, np.nan, 2]], + dims=["y", "x"], + coords={"x": [1, 2, 3]}, + ) + } + ) + assert_identical(expected, actual) + + def test_concat_errors(self): + # TODO: annotating this func fails + data = create_test_data() + split_data = [data.isel(dim1=slice(3)), data.isel(dim1=slice(3, None))] + + with pytest.raises(ValueError, match=r"must supply at least one"): + concat([], "dim1") + + with pytest.raises(ValueError, match=r"Cannot specify both .*='different'"): + concat( + [data, data], dim="concat_dim", data_vars="different", compat="override" + ) + + with pytest.raises(ValueError, match=r"must supply at least one"): + concat([], "dim1") + + with pytest.raises(ValueError, match=r"are not found in the coordinates"): + concat([data, data], "new_dim", coords=["not_found"]) + + with pytest.raises(ValueError, match=r"are not found in the data variables"): + concat([data, data], "new_dim", data_vars=["not_found"]) + + with pytest.raises(ValueError, match=r"global attributes not"): + # call deepcopy separately to get unique attrs + data0 = deepcopy(split_data[0]) + data1 = deepcopy(split_data[1]) + data1.attrs["foo"] = "bar" + concat([data0, data1], "dim1", compat="identical") + assert_identical(data, concat([data0, data1], "dim1", compat="equals")) + + with pytest.raises(ValueError, match=r"compat.* invalid"): + concat(split_data, "dim1", compat="foobar") + + with pytest.raises(ValueError, match=r"unexpected value for"): + concat([data, data], "new_dim", coords="foobar") + + with pytest.raises( + ValueError, match=r"coordinate in some datasets but not others" + ): + concat([Dataset({"x": 0}), Dataset({"x": [1]})], dim="z") + + with pytest.raises( + ValueError, match=r"coordinate in some datasets but not others" + ): + concat([Dataset({"x": 0}), Dataset({}, {"x": 1})], dim="z") + + def test_concat_join_kwarg(self) -> None: + ds1 = Dataset({"a": (("x", "y"), [[0]])}, coords={"x": [0], "y": [0]}) + ds2 = Dataset({"a": (("x", "y"), [[0]])}, coords={"x": [1], "y": [0.0001]}) + + expected: dict[JoinOptions, Any] = {} + expected["outer"] = Dataset( + {"a": (("x", "y"), [[0, np.nan], [np.nan, 0]])}, + {"x": [0, 1], "y": [0, 0.0001]}, + ) + expected["inner"] = Dataset( + {"a": (("x", "y"), [[], []])}, {"x": [0, 1], "y": []} + ) + expected["left"] = Dataset( + {"a": (("x", "y"), np.array([0, np.nan], ndmin=2).T)}, + coords={"x": [0, 1], "y": [0]}, + ) + expected["right"] = Dataset( + {"a": (("x", "y"), np.array([np.nan, 0], ndmin=2).T)}, + coords={"x": [0, 1], "y": [0.0001]}, + ) + expected["override"] = Dataset( + {"a": (("x", "y"), np.array([0, 0], ndmin=2).T)}, + coords={"x": [0, 1], "y": [0]}, + ) + + with pytest.raises(ValueError, match=r"cannot align.*exact.*dimensions.*'y'"): + actual = concat([ds1, ds2], join="exact", dim="x") + + for join in expected: + actual = concat([ds1, ds2], join=join, dim="x") + assert_equal(actual, expected[join]) + + # regression test for #3681 + actual = concat( + [ds1.drop_vars("x"), ds2.drop_vars("x")], join="override", dim="y" + ) + expected2 = Dataset( + {"a": (("x", "y"), np.array([0, 0], ndmin=2))}, coords={"y": [0, 0.0001]} + ) + assert_identical(actual, expected2) + + @pytest.mark.parametrize( + "combine_attrs, var1_attrs, var2_attrs, expected_attrs, expect_exception", + [ + ( + "no_conflicts", + {"a": 1, "b": 2}, + {"a": 1, "c": 3}, + {"a": 1, "b": 2, "c": 3}, + False, + ), + ("no_conflicts", {"a": 1, "b": 2}, {}, {"a": 1, "b": 2}, False), + ("no_conflicts", {}, {"a": 1, "c": 3}, {"a": 1, "c": 3}, False), + ( + "no_conflicts", + {"a": 1, "b": 2}, + {"a": 4, "c": 3}, + {"a": 1, "b": 2, "c": 3}, + True, + ), + ("drop", {"a": 1, "b": 2}, {"a": 1, "c": 3}, {}, False), + ("identical", {"a": 1, "b": 2}, {"a": 1, "b": 2}, {"a": 1, "b": 2}, False), + ("identical", {"a": 1, "b": 2}, {"a": 1, "c": 3}, {"a": 1, "b": 2}, True), + ( + "override", + {"a": 1, "b": 2}, + {"a": 4, "b": 5, "c": 3}, + {"a": 1, "b": 2}, + False, + ), + ( + "drop_conflicts", + {"a": 41, "b": 42, "c": 43}, + {"b": 2, "c": 43, "d": 44}, + {"a": 41, "c": 43, "d": 44}, + False, + ), + ( + lambda attrs, context: {"a": -1, "b": 0, "c": 1} if any(attrs) else {}, + {"a": 41, "b": 42, "c": 43}, + {"b": 2, "c": 43, "d": 44}, + {"a": -1, "b": 0, "c": 1}, + False, + ), + ], + ) + def test_concat_combine_attrs_kwarg( + self, combine_attrs, var1_attrs, var2_attrs, expected_attrs, expect_exception + ): + ds1 = Dataset({"a": ("x", [0])}, coords={"x": [0]}, attrs=var1_attrs) + ds2 = Dataset({"a": ("x", [0])}, coords={"x": [1]}, attrs=var2_attrs) + + if expect_exception: + with pytest.raises(ValueError, match=f"combine_attrs='{combine_attrs}'"): + concat([ds1, ds2], dim="x", combine_attrs=combine_attrs) + else: + actual = concat([ds1, ds2], dim="x", combine_attrs=combine_attrs) + expected = Dataset( + {"a": ("x", [0, 0])}, {"x": [0, 1]}, attrs=expected_attrs + ) + + assert_identical(actual, expected) + + @pytest.mark.parametrize( + "combine_attrs, attrs1, attrs2, expected_attrs, expect_exception", + [ + ( + "no_conflicts", + {"a": 1, "b": 2}, + {"a": 1, "c": 3}, + {"a": 1, "b": 2, "c": 3}, + False, + ), + ("no_conflicts", {"a": 1, "b": 2}, {}, {"a": 1, "b": 2}, False), + ("no_conflicts", {}, {"a": 1, "c": 3}, {"a": 1, "c": 3}, False), + ( + "no_conflicts", + {"a": 1, "b": 2}, + {"a": 4, "c": 3}, + {"a": 1, "b": 2, "c": 3}, + True, + ), + ("drop", {"a": 1, "b": 2}, {"a": 1, "c": 3}, {}, False), + ("identical", {"a": 1, "b": 2}, {"a": 1, "b": 2}, {"a": 1, "b": 2}, False), + ("identical", {"a": 1, "b": 2}, {"a": 1, "c": 3}, {"a": 1, "b": 2}, True), + ( + "override", + {"a": 1, "b": 2}, + {"a": 4, "b": 5, "c": 3}, + {"a": 1, "b": 2}, + False, + ), + ( + "drop_conflicts", + {"a": 41, "b": 42, "c": 43}, + {"b": 2, "c": 43, "d": 44}, + {"a": 41, "c": 43, "d": 44}, + False, + ), + ( + lambda attrs, context: {"a": -1, "b": 0, "c": 1} if any(attrs) else {}, + {"a": 41, "b": 42, "c": 43}, + {"b": 2, "c": 43, "d": 44}, + {"a": -1, "b": 0, "c": 1}, + False, + ), + ], + ) + def test_concat_combine_attrs_kwarg_variables( + self, combine_attrs, attrs1, attrs2, expected_attrs, expect_exception + ): + """check that combine_attrs is used on data variables and coords""" + ds1 = Dataset({"a": ("x", [0], attrs1)}, coords={"x": ("x", [0], attrs1)}) + ds2 = Dataset({"a": ("x", [0], attrs2)}, coords={"x": ("x", [1], attrs2)}) + + if expect_exception: + with pytest.raises(ValueError, match=f"combine_attrs='{combine_attrs}'"): + concat([ds1, ds2], dim="x", combine_attrs=combine_attrs) + else: + actual = concat([ds1, ds2], dim="x", combine_attrs=combine_attrs) + expected = Dataset( + {"a": ("x", [0, 0], expected_attrs)}, + {"x": ("x", [0, 1], expected_attrs)}, + ) + + assert_identical(actual, expected) + + def test_concat_promote_shape(self) -> None: + # mixed dims within variables + objs = [Dataset({}, {"x": 0}), Dataset({"x": [1]})] + actual = concat(objs, "x") + expected = Dataset({"x": [0, 1]}) + assert_identical(actual, expected) + + objs = [Dataset({"x": [0]}), Dataset({}, {"x": 1})] + actual = concat(objs, "x") + assert_identical(actual, expected) + + # mixed dims between variables + objs = [Dataset({"x": [2], "y": 3}), Dataset({"x": [4], "y": 5})] + actual = concat(objs, "x") + expected = Dataset({"x": [2, 4], "y": ("x", [3, 5])}) + assert_identical(actual, expected) + + # mixed dims in coord variable + objs = [Dataset({"x": [0]}, {"y": -1}), Dataset({"x": [1]}, {"y": ("x", [-2])})] + actual = concat(objs, "x") + expected = Dataset({"x": [0, 1]}, {"y": ("x", [-1, -2])}) + assert_identical(actual, expected) + + # scalars with mixed lengths along concat dim -- values should repeat + objs = [Dataset({"x": [0]}, {"y": -1}), Dataset({"x": [1, 2]}, {"y": -2})] + actual = concat(objs, "x") + expected = Dataset({"x": [0, 1, 2]}, {"y": ("x", [-1, -2, -2])}) + assert_identical(actual, expected) + + # broadcast 1d x 1d -> 2d + objs = [ + Dataset({"z": ("x", [-1])}, {"x": [0], "y": [0]}), + Dataset({"z": ("y", [1])}, {"x": [1], "y": [0]}), + ] + actual = concat(objs, "x") + expected = Dataset({"z": (("x", "y"), [[-1], [1]])}, {"x": [0, 1], "y": [0]}) + assert_identical(actual, expected) + + # regression GH6384 + objs = [ + Dataset({}, {"x": pd.Interval(-1, 0, closed="right")}), + Dataset({"x": [pd.Interval(0, 1, closed="right")]}), + ] + actual = concat(objs, "x") + expected = Dataset( + { + "x": [ + pd.Interval(-1, 0, closed="right"), + pd.Interval(0, 1, closed="right"), + ] + } + ) + assert_identical(actual, expected) + + # regression GH6416 (coord dtype) and GH6434 + time_data1 = np.array(["2022-01-01", "2022-02-01"], dtype="datetime64[ns]") + time_data2 = np.array("2022-03-01", dtype="datetime64[ns]") + time_expected = np.array( + ["2022-01-01", "2022-02-01", "2022-03-01"], dtype="datetime64[ns]" + ) + objs = [Dataset({}, {"time": time_data1}), Dataset({}, {"time": time_data2})] + actual = concat(objs, "time") + expected = Dataset({}, {"time": time_expected}) + assert_identical(actual, expected) + assert isinstance(actual.indexes["time"], pd.DatetimeIndex) + + def test_concat_do_not_promote(self) -> None: + # GH438 + objs = [ + Dataset({"y": ("t", [1])}, {"x": 1, "t": [0]}), + Dataset({"y": ("t", [2])}, {"x": 1, "t": [0]}), + ] + expected = Dataset({"y": ("t", [1, 2])}, {"x": 1, "t": [0, 0]}) + actual = concat(objs, "t") + assert_identical(expected, actual) + + objs = [ + Dataset({"y": ("t", [1])}, {"x": 1, "t": [0]}), + Dataset({"y": ("t", [2])}, {"x": 2, "t": [0]}), + ] + with pytest.raises(ValueError): + concat(objs, "t", coords="minimal") + + def test_concat_dim_is_variable(self) -> None: + objs = [Dataset({"x": 0}), Dataset({"x": 1})] + coord = Variable("y", [3, 4], attrs={"foo": "bar"}) + expected = Dataset({"x": ("y", [0, 1]), "y": coord}) + actual = concat(objs, coord) + assert_identical(actual, expected) + + def test_concat_dim_is_dataarray(self) -> None: + objs = [Dataset({"x": 0}), Dataset({"x": 1})] + coord = DataArray([3, 4], dims="y", attrs={"foo": "bar"}) + expected = Dataset({"x": ("y", [0, 1]), "y": coord}) + actual = concat(objs, coord) + assert_identical(actual, expected) + + def test_concat_multiindex(self) -> None: + midx = pd.MultiIndex.from_product([[1, 2, 3], ["a", "b"]]) + midx_coords = Coordinates.from_pandas_multiindex(midx, "x") + expected = Dataset(coords=midx_coords) + actual = concat( + [expected.isel(x=slice(2)), expected.isel(x=slice(2, None))], "x" + ) + assert expected.equals(actual) + assert isinstance(actual.x.to_index(), pd.MultiIndex) + + def test_concat_along_new_dim_multiindex(self) -> None: + # see https://github.com/pydata/xarray/issues/6881 + level_names = ["x_level_0", "x_level_1"] + midx = pd.MultiIndex.from_product([[1, 2, 3], ["a", "b"]], names=level_names) + midx_coords = Coordinates.from_pandas_multiindex(midx, "x") + ds = Dataset(coords=midx_coords) + concatenated = concat([ds], "new") + actual = list(concatenated.xindexes.get_all_coords("x")) + expected = ["x"] + level_names + assert actual == expected + + @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0, {"a": 2, "b": 1}]) + def test_concat_fill_value(self, fill_value) -> None: + datasets = [ + Dataset({"a": ("x", [2, 3]), "b": ("x", [-2, 1]), "x": [1, 2]}), + Dataset({"a": ("x", [1, 2]), "b": ("x", [3, -1]), "x": [0, 1]}), + ] + if fill_value == dtypes.NA: + # if we supply the default, we expect the missing value for a + # float array + fill_value_a = fill_value_b = np.nan + elif isinstance(fill_value, dict): + fill_value_a = fill_value["a"] + fill_value_b = fill_value["b"] + else: + fill_value_a = fill_value_b = fill_value + expected = Dataset( + { + "a": (("t", "x"), [[fill_value_a, 2, 3], [1, 2, fill_value_a]]), + "b": (("t", "x"), [[fill_value_b, -2, 1], [3, -1, fill_value_b]]), + }, + {"x": [0, 1, 2]}, + ) + actual = concat(datasets, dim="t", fill_value=fill_value) + assert_identical(actual, expected) + + @pytest.mark.parametrize("dtype", [str, bytes]) + @pytest.mark.parametrize("dim", ["x1", "x2"]) + def test_concat_str_dtype(self, dtype, dim) -> None: + data = np.arange(4).reshape([2, 2]) + + da1 = Dataset( + { + "data": (["x1", "x2"], data), + "x1": [0, 1], + "x2": np.array(["a", "b"], dtype=dtype), + } + ) + da2 = Dataset( + { + "data": (["x1", "x2"], data), + "x1": np.array([1, 2]), + "x2": np.array(["c", "d"], dtype=dtype), + } + ) + actual = concat([da1, da2], dim=dim) + + assert np.issubdtype(actual.x2.dtype, dtype) + + def test_concat_avoids_index_auto_creation(self) -> None: + # TODO once passing indexes={} directly to Dataset constructor is allowed then no need to create coords first + coords = Coordinates( + {"x": ConcatenatableArray(np.array([1, 2, 3]))}, indexes={} + ) + datasets = [ + Dataset( + {"a": (["x", "y"], ConcatenatableArray(np.zeros((3, 3))))}, + coords=coords, + ) + for _ in range(2) + ] + # should not raise on concat + combined = concat(datasets, dim="x") + assert combined["a"].shape == (6, 3) + assert combined["a"].dims == ("x", "y") + + # nor have auto-created any indexes + assert combined.indexes == {} + + # should not raise on stack + combined = concat(datasets, dim="z") + assert combined["a"].shape == (2, 3, 3) + assert combined["a"].dims == ("z", "x", "y") + + # nor have auto-created any indexes + assert combined.indexes == {} + + def test_concat_avoids_index_auto_creation_new_1d_coord(self) -> None: + # create 0D coordinates (without indexes) + datasets = [ + Dataset( + coords={"x": ConcatenatableArray(np.array(10))}, + ) + for _ in range(2) + ] + + with pytest.raises(UnexpectedDataAccess): + concat(datasets, dim="x", create_index_for_new_dim=True) + + # should not raise on concat iff create_index_for_new_dim=False + combined = concat(datasets, dim="x", create_index_for_new_dim=False) + assert combined["x"].shape == (2,) + assert combined["x"].dims == ("x",) + + # nor have auto-created any indexes + assert combined.indexes == {} + + def test_concat_promote_shape_without_creating_new_index(self) -> None: + # different shapes but neither have indexes + ds1 = Dataset(coords={"x": 0}) + ds2 = Dataset(data_vars={"x": [1]}).drop_indexes("x") + actual = concat([ds1, ds2], dim="x", create_index_for_new_dim=False) + expected = Dataset(data_vars={"x": [0, 1]}).drop_indexes("x") + assert_identical(actual, expected, check_default_indexes=False) + assert actual.indexes == {} + + +class TestConcatDataArray: + def test_concat(self) -> None: + ds = Dataset( + { + "foo": (["x", "y"], np.random.random((2, 3))), + "bar": (["x", "y"], np.random.random((2, 3))), + }, + {"x": [0, 1]}, + ) + foo = ds["foo"] + bar = ds["bar"] + + # from dataset array: + expected = DataArray( + np.array([foo.values, bar.values]), + dims=["w", "x", "y"], + coords={"x": [0, 1]}, + ) + actual = concat([foo, bar], "w") + assert_equal(expected, actual) + # from iteration: + grouped = [g.squeeze() for _, g in foo.groupby("x", squeeze=False)] + stacked = concat(grouped, ds["x"]) + assert_identical(foo, stacked) + # with an index as the 'dim' argument + stacked = concat(grouped, pd.Index(ds["x"], name="x")) + assert_identical(foo, stacked) + + actual2 = concat([foo[0], foo[1]], pd.Index([0, 1])).reset_coords(drop=True) + expected = foo[:2].rename({"x": "concat_dim"}) + assert_identical(expected, actual2) + + actual3 = concat([foo[0], foo[1]], [0, 1]).reset_coords(drop=True) + expected = foo[:2].rename({"x": "concat_dim"}) + assert_identical(expected, actual3) + + with pytest.raises(ValueError, match=r"not identical"): + concat([foo, bar], dim="w", compat="identical") + + with pytest.raises(ValueError, match=r"not a valid argument"): + concat([foo, bar], dim="w", data_vars="minimal") + + def test_concat_encoding(self) -> None: + # Regression test for GH1297 + ds = Dataset( + { + "foo": (["x", "y"], np.random.random((2, 3))), + "bar": (["x", "y"], np.random.random((2, 3))), + }, + {"x": [0, 1]}, + ) + foo = ds["foo"] + foo.encoding = {"complevel": 5} + ds.encoding = {"unlimited_dims": "x"} + assert concat([foo, foo], dim="x").encoding == foo.encoding + assert concat([ds, ds], dim="x").encoding == ds.encoding + + @requires_dask + def test_concat_lazy(self) -> None: + import dask.array as da + + arrays = [ + DataArray( + da.from_array(InaccessibleArray(np.zeros((3, 3))), 3), dims=["x", "y"] + ) + for _ in range(2) + ] + # should not raise + combined = concat(arrays, dim="z") + assert combined.shape == (2, 3, 3) + assert combined.dims == ("z", "x", "y") + + def test_concat_avoids_index_auto_creation(self) -> None: + # TODO once passing indexes={} directly to DataArray constructor is allowed then no need to create coords first + coords = Coordinates( + {"x": ConcatenatableArray(np.array([1, 2, 3]))}, indexes={} + ) + arrays = [ + DataArray( + ConcatenatableArray(np.zeros((3, 3))), + dims=["x", "y"], + coords=coords, + ) + for _ in range(2) + ] + # should not raise on concat + combined = concat(arrays, dim="x") + assert combined.shape == (6, 3) + assert combined.dims == ("x", "y") + + # nor have auto-created any indexes + assert combined.indexes == {} + + # should not raise on stack + combined = concat(arrays, dim="z") + assert combined.shape == (2, 3, 3) + assert combined.dims == ("z", "x", "y") + + # nor have auto-created any indexes + assert combined.indexes == {} + + @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0]) + def test_concat_fill_value(self, fill_value) -> None: + foo = DataArray([1, 2], coords=[("x", [1, 2])]) + bar = DataArray([1, 2], coords=[("x", [1, 3])]) + if fill_value == dtypes.NA: + # if we supply the default, we expect the missing value for a + # float array + fill_value = np.nan + expected = DataArray( + [[1, 2, fill_value], [1, fill_value, 2]], + dims=["y", "x"], + coords={"x": [1, 2, 3]}, + ) + actual = concat((foo, bar), dim="y", fill_value=fill_value) + assert_identical(actual, expected) + + def test_concat_join_kwarg(self) -> None: + ds1 = Dataset( + {"a": (("x", "y"), [[0]])}, coords={"x": [0], "y": [0]} + ).to_dataarray() + ds2 = Dataset( + {"a": (("x", "y"), [[0]])}, coords={"x": [1], "y": [0.0001]} + ).to_dataarray() + + expected: dict[JoinOptions, Any] = {} + expected["outer"] = Dataset( + {"a": (("x", "y"), [[0, np.nan], [np.nan, 0]])}, + {"x": [0, 1], "y": [0, 0.0001]}, + ) + expected["inner"] = Dataset( + {"a": (("x", "y"), [[], []])}, {"x": [0, 1], "y": []} + ) + expected["left"] = Dataset( + {"a": (("x", "y"), np.array([0, np.nan], ndmin=2).T)}, + coords={"x": [0, 1], "y": [0]}, + ) + expected["right"] = Dataset( + {"a": (("x", "y"), np.array([np.nan, 0], ndmin=2).T)}, + coords={"x": [0, 1], "y": [0.0001]}, + ) + expected["override"] = Dataset( + {"a": (("x", "y"), np.array([0, 0], ndmin=2).T)}, + coords={"x": [0, 1], "y": [0]}, + ) + + with pytest.raises(ValueError, match=r"cannot align.*exact.*dimensions.*'y'"): + actual = concat([ds1, ds2], join="exact", dim="x") + + for join in expected: + actual = concat([ds1, ds2], join=join, dim="x") + assert_equal(actual, expected[join].to_dataarray()) + + def test_concat_combine_attrs_kwarg(self) -> None: + da1 = DataArray([0], coords=[("x", [0])], attrs={"b": 42}) + da2 = DataArray([0], coords=[("x", [1])], attrs={"b": 42, "c": 43}) + + expected: dict[CombineAttrsOptions, Any] = {} + expected["drop"] = DataArray([0, 0], coords=[("x", [0, 1])]) + expected["no_conflicts"] = DataArray( + [0, 0], coords=[("x", [0, 1])], attrs={"b": 42, "c": 43} + ) + expected["override"] = DataArray( + [0, 0], coords=[("x", [0, 1])], attrs={"b": 42} + ) + + with pytest.raises(ValueError, match=r"combine_attrs='identical'"): + actual = concat([da1, da2], dim="x", combine_attrs="identical") + with pytest.raises(ValueError, match=r"combine_attrs='no_conflicts'"): + da3 = da2.copy(deep=True) + da3.attrs["b"] = 44 + actual = concat([da1, da3], dim="x", combine_attrs="no_conflicts") + + for combine_attrs in expected: + actual = concat([da1, da2], dim="x", combine_attrs=combine_attrs) + assert_identical(actual, expected[combine_attrs]) + + @pytest.mark.parametrize("dtype", [str, bytes]) + @pytest.mark.parametrize("dim", ["x1", "x2"]) + def test_concat_str_dtype(self, dtype, dim) -> None: + data = np.arange(4).reshape([2, 2]) + + da1 = DataArray( + data=data, + dims=["x1", "x2"], + coords={"x1": [0, 1], "x2": np.array(["a", "b"], dtype=dtype)}, + ) + da2 = DataArray( + data=data, + dims=["x1", "x2"], + coords={"x1": np.array([1, 2]), "x2": np.array(["c", "d"], dtype=dtype)}, + ) + actual = concat([da1, da2], dim=dim) + + assert np.issubdtype(actual.x2.dtype, dtype) + + def test_concat_coord_name(self) -> None: + da = DataArray([0], dims="a") + da_concat = concat([da, da], dim=DataArray([0, 1], dims="b")) + assert list(da_concat.coords) == ["b"] + + da_concat_std = concat([da, da], dim=DataArray([0, 1])) + assert list(da_concat_std.coords) == ["dim_0"] + + +@pytest.mark.parametrize("attr1", ({"a": {"meta": [10, 20, 30]}}, {"a": [1, 2, 3]}, {})) +@pytest.mark.parametrize("attr2", ({"a": [1, 2, 3]}, {})) +def test_concat_attrs_first_variable(attr1, attr2) -> None: + arrs = [ + DataArray([[1], [2]], dims=["x", "y"], attrs=attr1), + DataArray([[3], [4]], dims=["x", "y"], attrs=attr2), + ] + + concat_attrs = concat(arrs, "y").attrs + assert concat_attrs == attr1 + + +def test_concat_merge_single_non_dim_coord(): + # TODO: annotating this func fails + da1 = DataArray([1, 2, 3], dims="x", coords={"x": [1, 2, 3], "y": 1}) + da2 = DataArray([4, 5, 6], dims="x", coords={"x": [4, 5, 6]}) + + expected = DataArray(range(1, 7), dims="x", coords={"x": range(1, 7), "y": 1}) + + for coords in ["different", "minimal"]: + actual = concat([da1, da2], "x", coords=coords) + assert_identical(actual, expected) + + with pytest.raises(ValueError, match=r"'y' not present in all datasets."): + concat([da1, da2], dim="x", coords="all") + + da1 = DataArray([1, 2, 3], dims="x", coords={"x": [1, 2, 3], "y": 1}) + da2 = DataArray([4, 5, 6], dims="x", coords={"x": [4, 5, 6]}) + da3 = DataArray([7, 8, 9], dims="x", coords={"x": [7, 8, 9], "y": 1}) + for coords in ["different", "all"]: + with pytest.raises(ValueError, match=r"'y' not present in all datasets"): + concat([da1, da2, da3], dim="x", coords=coords) + + +def test_concat_preserve_coordinate_order() -> None: + x = np.arange(0, 5) + y = np.arange(0, 10) + time = np.arange(0, 4) + data = np.zeros((4, 10, 5), dtype=bool) + + ds1 = Dataset( + {"data": (["time", "y", "x"], data[0:2])}, + coords={"time": time[0:2], "y": y, "x": x}, + ) + ds2 = Dataset( + {"data": (["time", "y", "x"], data[2:4])}, + coords={"time": time[2:4], "y": y, "x": x}, + ) + + expected = Dataset( + {"data": (["time", "y", "x"], data)}, + coords={"time": time, "y": y, "x": x}, + ) + + actual = concat([ds1, ds2], dim="time") + + # check dimension order + for act, exp in zip(actual.dims, expected.dims): + assert act == exp + assert actual.sizes[act] == expected.sizes[exp] + + # check coordinate order + for act, exp in zip(actual.coords, expected.coords): + assert act == exp + assert_identical(actual.coords[act], expected.coords[exp]) + + +def test_concat_typing_check() -> None: + ds = Dataset({"foo": 1}, {"bar": 2}) + da = Dataset({"foo": 3}, {"bar": 4}).to_dataarray(dim="foo") + + # concatenate a list of non-homogeneous types must raise TypeError + with pytest.raises( + TypeError, + match="The elements in the input list need to be either all 'Dataset's or all 'DataArray's", + ): + concat([ds, da], dim="foo") # type: ignore + with pytest.raises( + TypeError, + match="The elements in the input list need to be either all 'Dataset's or all 'DataArray's", + ): + concat([da, ds], dim="foo") # type: ignore + + +def test_concat_not_all_indexes() -> None: + ds1 = Dataset(coords={"x": ("x", [1, 2])}) + # ds2.x has no default index + ds2 = Dataset(coords={"x": ("y", [3, 4])}) + + with pytest.raises( + ValueError, match=r"'x' must have either an index or no index in all datasets.*" + ): + concat([ds1, ds2], dim="x") + + +def test_concat_index_not_same_dim() -> None: + ds1 = Dataset(coords={"x": ("x", [1, 2])}) + ds2 = Dataset(coords={"x": ("y", [3, 4])}) + # TODO: use public API for setting a non-default index, when available + ds2._indexes["x"] = PandasIndex([3, 4], "y") + + with pytest.raises( + ValueError, + match=r"Cannot concatenate along dimension 'x' indexes with dimensions.*", + ): + concat([ds1, ds2], dim="x") diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_conventions.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_conventions.py new file mode 100644 index 0000000..fdfea3c --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_conventions.py @@ -0,0 +1,573 @@ +from __future__ import annotations + +import contextlib +import warnings + +import numpy as np +import pandas as pd +import pytest + +from xarray import ( + Dataset, + SerializationWarning, + Variable, + cftime_range, + coding, + conventions, + open_dataset, +) +from xarray.backends.common import WritableCFDataStore +from xarray.backends.memory import InMemoryDataStore +from xarray.conventions import decode_cf +from xarray.testing import assert_identical +from xarray.tests import ( + assert_array_equal, + requires_cftime, + requires_dask, + requires_netCDF4, +) +from xarray.tests.test_backends import CFEncodedBase + + +class TestBoolTypeArray: + def test_booltype_array(self) -> None: + x = np.array([1, 0, 1, 1, 0], dtype="i1") + bx = coding.variables.BoolTypeArray(x) + assert bx.dtype == bool + assert_array_equal(bx, np.array([True, False, True, True, False], dtype=bool)) + + +class TestNativeEndiannessArray: + def test(self) -> None: + x = np.arange(5, dtype=">i8") + expected = np.arange(5, dtype="int64") + a = coding.variables.NativeEndiannessArray(x) + assert a.dtype == expected.dtype + assert a.dtype == expected[:].dtype + assert_array_equal(a, expected) + + +def test_decode_cf_with_conflicting_fill_missing_value() -> None: + expected = Variable(["t"], [np.nan, np.nan, 2], {"units": "foobar"}) + var = Variable( + ["t"], np.arange(3), {"units": "foobar", "missing_value": 0, "_FillValue": 1} + ) + with pytest.warns(SerializationWarning, match="has multiple fill"): + actual = conventions.decode_cf_variable("t", var) + assert_identical(actual, expected) + + expected = Variable(["t"], np.arange(10), {"units": "foobar"}) + + var = Variable( + ["t"], + np.arange(10), + {"units": "foobar", "missing_value": np.nan, "_FillValue": np.nan}, + ) + + # the following code issues two warnings, so we need to check for both + with pytest.warns(SerializationWarning) as winfo: + actual = conventions.decode_cf_variable("t", var) + for aw in winfo: + assert "non-conforming" in str(aw.message) + + assert_identical(actual, expected) + + var = Variable( + ["t"], + np.arange(10), + { + "units": "foobar", + "missing_value": np.float32(np.nan), + "_FillValue": np.float32(np.nan), + }, + ) + + # the following code issues two warnings, so we need to check for both + with pytest.warns(SerializationWarning) as winfo: + actual = conventions.decode_cf_variable("t", var) + for aw in winfo: + assert "non-conforming" in str(aw.message) + assert_identical(actual, expected) + + +def test_decode_cf_variable_with_mismatched_coordinates() -> None: + # tests for decoding mismatched coordinates attributes + # see GH #1809 + zeros1 = np.zeros((1, 5, 3)) + orig = Dataset( + { + "XLONG": (["x", "y"], zeros1.squeeze(0), {}), + "XLAT": (["x", "y"], zeros1.squeeze(0), {}), + "foo": (["time", "x", "y"], zeros1, {"coordinates": "XTIME XLONG XLAT"}), + "time": ("time", [0.0], {"units": "hours since 2017-01-01"}), + } + ) + decoded = conventions.decode_cf(orig, decode_coords=True) + assert decoded["foo"].encoding["coordinates"] == "XTIME XLONG XLAT" + assert list(decoded.coords.keys()) == ["XLONG", "XLAT", "time"] + + decoded = conventions.decode_cf(orig, decode_coords=False) + assert "coordinates" not in decoded["foo"].encoding + assert decoded["foo"].attrs.get("coordinates") == "XTIME XLONG XLAT" + assert list(decoded.coords.keys()) == ["time"] + + +@requires_cftime +class TestEncodeCFVariable: + def test_incompatible_attributes(self) -> None: + invalid_vars = [ + Variable( + ["t"], pd.date_range("2000-01-01", periods=3), {"units": "foobar"} + ), + Variable(["t"], pd.to_timedelta(["1 day"]), {"units": "foobar"}), + Variable(["t"], [0, 1, 2], {"add_offset": 0}, {"add_offset": 2}), + Variable(["t"], [0, 1, 2], {"_FillValue": 0}, {"_FillValue": 2}), + ] + for var in invalid_vars: + with pytest.raises(ValueError): + conventions.encode_cf_variable(var) + + def test_missing_fillvalue(self) -> None: + v = Variable(["x"], np.array([np.nan, 1, 2, 3])) + v.encoding = {"dtype": "int16"} + with pytest.warns(Warning, match="floating point data as an integer"): + conventions.encode_cf_variable(v) + + def test_multidimensional_coordinates(self) -> None: + # regression test for GH1763 + # Set up test case with coordinates that have overlapping (but not + # identical) dimensions. + zeros1 = np.zeros((1, 5, 3)) + zeros2 = np.zeros((1, 6, 3)) + zeros3 = np.zeros((1, 5, 4)) + orig = Dataset( + { + "lon1": (["x1", "y1"], zeros1.squeeze(0), {}), + "lon2": (["x2", "y1"], zeros2.squeeze(0), {}), + "lon3": (["x1", "y2"], zeros3.squeeze(0), {}), + "lat1": (["x1", "y1"], zeros1.squeeze(0), {}), + "lat2": (["x2", "y1"], zeros2.squeeze(0), {}), + "lat3": (["x1", "y2"], zeros3.squeeze(0), {}), + "foo1": (["time", "x1", "y1"], zeros1, {"coordinates": "lon1 lat1"}), + "foo2": (["time", "x2", "y1"], zeros2, {"coordinates": "lon2 lat2"}), + "foo3": (["time", "x1", "y2"], zeros3, {"coordinates": "lon3 lat3"}), + "time": ("time", [0.0], {"units": "hours since 2017-01-01"}), + } + ) + orig = conventions.decode_cf(orig) + # Encode the coordinates, as they would be in a netCDF output file. + enc, attrs = conventions.encode_dataset_coordinates(orig) + # Make sure we have the right coordinates for each variable. + foo1_coords = enc["foo1"].attrs.get("coordinates", "") + foo2_coords = enc["foo2"].attrs.get("coordinates", "") + foo3_coords = enc["foo3"].attrs.get("coordinates", "") + assert foo1_coords == "lon1 lat1" + assert foo2_coords == "lon2 lat2" + assert foo3_coords == "lon3 lat3" + # Should not have any global coordinates. + assert "coordinates" not in attrs + + def test_var_with_coord_attr(self) -> None: + # regression test for GH6310 + # don't overwrite user-defined "coordinates" attributes + orig = Dataset( + {"values": ("time", np.zeros(2), {"coordinates": "time lon lat"})}, + coords={ + "time": ("time", np.zeros(2)), + "lat": ("time", np.zeros(2)), + "lon": ("time", np.zeros(2)), + }, + ) + # Encode the coordinates, as they would be in a netCDF output file. + enc, attrs = conventions.encode_dataset_coordinates(orig) + # Make sure we have the right coordinates for each variable. + values_coords = enc["values"].attrs.get("coordinates", "") + assert values_coords == "time lon lat" + # Should not have any global coordinates. + assert "coordinates" not in attrs + + def test_do_not_overwrite_user_coordinates(self) -> None: + # don't overwrite user-defined "coordinates" encoding + orig = Dataset( + coords={"x": [0, 1, 2], "y": ("x", [5, 6, 7]), "z": ("x", [8, 9, 10])}, + data_vars={"a": ("x", [1, 2, 3]), "b": ("x", [3, 5, 6])}, + ) + orig["a"].encoding["coordinates"] = "y" + orig["b"].encoding["coordinates"] = "z" + enc, _ = conventions.encode_dataset_coordinates(orig) + assert enc["a"].attrs["coordinates"] == "y" + assert enc["b"].attrs["coordinates"] == "z" + orig["a"].attrs["coordinates"] = "foo" + with pytest.raises(ValueError, match=r"'coordinates' found in both attrs"): + conventions.encode_dataset_coordinates(orig) + + def test_deterministic_coords_encoding(self) -> None: + # the coordinates attribute is sorted when set by xarray.conventions ... + # ... on a variable's coordinates attribute + ds = Dataset({"foo": 0}, coords={"baz": 0, "bar": 0}) + vars, attrs = conventions.encode_dataset_coordinates(ds) + assert vars["foo"].attrs["coordinates"] == "bar baz" + assert attrs.get("coordinates") is None + # ... on the global coordinates attribute + ds = ds.drop_vars("foo") + vars, attrs = conventions.encode_dataset_coordinates(ds) + assert attrs["coordinates"] == "bar baz" + + @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") + def test_emit_coordinates_attribute_in_attrs(self) -> None: + orig = Dataset( + {"a": 1, "b": 1}, + coords={"t": np.array("2004-11-01T00:00:00", dtype=np.datetime64)}, + ) + + orig["a"].attrs["coordinates"] = None + enc, _ = conventions.encode_dataset_coordinates(orig) + + # check coordinate attribute emitted for 'a' + assert "coordinates" not in enc["a"].attrs + assert "coordinates" not in enc["a"].encoding + + # check coordinate attribute not emitted for 'b' + assert enc["b"].attrs.get("coordinates") == "t" + assert "coordinates" not in enc["b"].encoding + + @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") + def test_emit_coordinates_attribute_in_encoding(self) -> None: + orig = Dataset( + {"a": 1, "b": 1}, + coords={"t": np.array("2004-11-01T00:00:00", dtype=np.datetime64)}, + ) + + orig["a"].encoding["coordinates"] = None + enc, _ = conventions.encode_dataset_coordinates(orig) + + # check coordinate attribute emitted for 'a' + assert "coordinates" not in enc["a"].attrs + assert "coordinates" not in enc["a"].encoding + + # check coordinate attribute not emitted for 'b' + assert enc["b"].attrs.get("coordinates") == "t" + assert "coordinates" not in enc["b"].encoding + + @requires_dask + def test_string_object_warning(self) -> None: + original = Variable(("x",), np.array(["foo", "bar"], dtype=object)).chunk() + with pytest.warns(SerializationWarning, match="dask array with dtype=object"): + encoded = conventions.encode_cf_variable(original) + assert_identical(original, encoded) + + +@requires_cftime +class TestDecodeCF: + def test_dataset(self) -> None: + original = Dataset( + { + "t": ("t", [0, 1, 2], {"units": "days since 2000-01-01"}), + "foo": ("t", [0, 0, 0], {"coordinates": "y", "units": "bar"}), + "y": ("t", [5, 10, -999], {"_FillValue": -999}), + } + ) + expected = Dataset( + {"foo": ("t", [0, 0, 0], {"units": "bar"})}, + { + "t": pd.date_range("2000-01-01", periods=3), + "y": ("t", [5.0, 10.0, np.nan]), + }, + ) + actual = conventions.decode_cf(original) + assert_identical(expected, actual) + + def test_invalid_coordinates(self) -> None: + # regression test for GH308, GH1809 + original = Dataset({"foo": ("t", [1, 2], {"coordinates": "invalid"})}) + decoded = Dataset({"foo": ("t", [1, 2], {}, {"coordinates": "invalid"})}) + actual = conventions.decode_cf(original) + assert_identical(decoded, actual) + actual = conventions.decode_cf(original, decode_coords=False) + assert_identical(original, actual) + + def test_decode_coordinates(self) -> None: + # regression test for GH610 + original = Dataset( + {"foo": ("t", [1, 2], {"coordinates": "x"}), "x": ("t", [4, 5])} + ) + actual = conventions.decode_cf(original) + assert actual.foo.encoding["coordinates"] == "x" + + def test_0d_int32_encoding(self) -> None: + original = Variable((), np.int32(0), encoding={"dtype": "int64"}) + expected = Variable((), np.int64(0)) + actual = coding.variables.NonStringCoder().encode(original) + assert_identical(expected, actual) + + def test_decode_cf_with_multiple_missing_values(self) -> None: + original = Variable(["t"], [0, 1, 2], {"missing_value": np.array([0, 1])}) + expected = Variable(["t"], [np.nan, np.nan, 2], {}) + with pytest.warns(SerializationWarning, match="has multiple fill"): + actual = conventions.decode_cf_variable("t", original) + assert_identical(expected, actual) + + def test_decode_cf_with_drop_variables(self) -> None: + original = Dataset( + { + "t": ("t", [0, 1, 2], {"units": "days since 2000-01-01"}), + "x": ("x", [9, 8, 7], {"units": "km"}), + "foo": ( + ("t", "x"), + [[0, 0, 0], [1, 1, 1], [2, 2, 2]], + {"units": "bar"}, + ), + "y": ("t", [5, 10, -999], {"_FillValue": -999}), + } + ) + expected = Dataset( + { + "t": pd.date_range("2000-01-01", periods=3), + "foo": ( + ("t", "x"), + [[0, 0, 0], [1, 1, 1], [2, 2, 2]], + {"units": "bar"}, + ), + "y": ("t", [5, 10, np.nan]), + } + ) + actual = conventions.decode_cf(original, drop_variables=("x",)) + actual2 = conventions.decode_cf(original, drop_variables="x") + assert_identical(expected, actual) + assert_identical(expected, actual2) + + @pytest.mark.filterwarnings("ignore:Ambiguous reference date string") + def test_invalid_time_units_raises_eagerly(self) -> None: + ds = Dataset({"time": ("time", [0, 1], {"units": "foobar since 123"})}) + with pytest.raises(ValueError, match=r"unable to decode time"): + decode_cf(ds) + + @pytest.mark.parametrize("decode_times", [True, False]) + def test_invalid_timedelta_units_do_not_decode(self, decode_times) -> None: + # regression test for #8269 + ds = Dataset( + {"time": ("time", [0, 1, 20], {"units": "days invalid", "_FillValue": 20})} + ) + expected = Dataset( + {"time": ("time", [0.0, 1.0, np.nan], {"units": "days invalid"})} + ) + assert_identical(expected, decode_cf(ds, decode_times=decode_times)) + + @requires_cftime + def test_dataset_repr_with_netcdf4_datetimes(self) -> None: + # regression test for #347 + attrs = {"units": "days since 0001-01-01", "calendar": "noleap"} + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "unable to decode time") + ds = decode_cf(Dataset({"time": ("time", [0, 1], attrs)})) + assert "(time) object" in repr(ds) + + attrs = {"units": "days since 1900-01-01"} + ds = decode_cf(Dataset({"time": ("time", [0, 1], attrs)})) + assert "(time) datetime64[ns]" in repr(ds) + + @requires_cftime + def test_decode_cf_datetime_transition_to_invalid(self) -> None: + # manually create dataset with not-decoded date + from datetime import datetime + + ds = Dataset(coords={"time": [0, 266 * 365]}) + units = "days since 2000-01-01 00:00:00" + ds.time.attrs = dict(units=units) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "unable to decode time") + ds_decoded = conventions.decode_cf(ds) + + expected = np.array([datetime(2000, 1, 1, 0, 0), datetime(2265, 10, 28, 0, 0)]) + + assert_array_equal(ds_decoded.time.values, expected) + + @requires_dask + def test_decode_cf_with_dask(self) -> None: + import dask.array as da + + original = Dataset( + { + "t": ("t", [0, 1, 2], {"units": "days since 2000-01-01"}), + "foo": ("t", [0, 0, 0], {"coordinates": "y", "units": "bar"}), + "bar": ("string2", [b"a", b"b"]), + "baz": (("x"), [b"abc"], {"_Encoding": "utf-8"}), + "y": ("t", [5, 10, -999], {"_FillValue": -999}), + } + ).chunk() + decoded = conventions.decode_cf(original) + assert all( + isinstance(var.data, da.Array) + for name, var in decoded.variables.items() + if name not in decoded.xindexes + ) + assert_identical(decoded, conventions.decode_cf(original).compute()) + + @requires_dask + def test_decode_dask_times(self) -> None: + original = Dataset.from_dict( + { + "coords": {}, + "dims": {"time": 5}, + "data_vars": { + "average_T1": { + "dims": ("time",), + "attrs": {"units": "days since 1958-01-01 00:00:00"}, + "data": [87659.0, 88024.0, 88389.0, 88754.0, 89119.0], + } + }, + } + ) + assert_identical( + conventions.decode_cf(original.chunk()), + conventions.decode_cf(original).chunk(), + ) + + def test_decode_cf_time_kwargs(self) -> None: + ds = Dataset.from_dict( + { + "coords": { + "timedelta": { + "data": np.array([1, 2, 3], dtype="int64"), + "dims": "timedelta", + "attrs": {"units": "days"}, + }, + "time": { + "data": np.array([1, 2, 3], dtype="int64"), + "dims": "time", + "attrs": {"units": "days since 2000-01-01"}, + }, + }, + "dims": {"time": 3, "timedelta": 3}, + "data_vars": { + "a": {"dims": ("time", "timedelta"), "data": np.ones((3, 3))}, + }, + } + ) + + dsc = conventions.decode_cf(ds) + assert dsc.timedelta.dtype == np.dtype("m8[ns]") + assert dsc.time.dtype == np.dtype("M8[ns]") + dsc = conventions.decode_cf(ds, decode_times=False) + assert dsc.timedelta.dtype == np.dtype("int64") + assert dsc.time.dtype == np.dtype("int64") + dsc = conventions.decode_cf(ds, decode_times=True, decode_timedelta=False) + assert dsc.timedelta.dtype == np.dtype("int64") + assert dsc.time.dtype == np.dtype("M8[ns]") + dsc = conventions.decode_cf(ds, decode_times=False, decode_timedelta=True) + assert dsc.timedelta.dtype == np.dtype("m8[ns]") + assert dsc.time.dtype == np.dtype("int64") + + +class CFEncodedInMemoryStore(WritableCFDataStore, InMemoryDataStore): + def encode_variable(self, var): + """encode one variable""" + coder = coding.strings.EncodedStringCoder(allows_unicode=True) + var = coder.encode(var) + return var + + +@requires_netCDF4 +class TestCFEncodedDataStore(CFEncodedBase): + @contextlib.contextmanager + def create_store(self): + yield CFEncodedInMemoryStore() + + @contextlib.contextmanager + def roundtrip( + self, data, save_kwargs=None, open_kwargs=None, allow_cleanup_failure=False + ): + if save_kwargs is None: + save_kwargs = {} + if open_kwargs is None: + open_kwargs = {} + store = CFEncodedInMemoryStore() + data.dump_to_store(store, **save_kwargs) + yield open_dataset(store, **open_kwargs) + + @pytest.mark.skip("cannot roundtrip coordinates yet for CFEncodedInMemoryStore") + def test_roundtrip_coordinates(self) -> None: + pass + + def test_invalid_dataarray_names_raise(self) -> None: + # only relevant for on-disk file formats + pass + + def test_encoding_kwarg(self) -> None: + # we haven't bothered to raise errors yet for unexpected encodings in + # this test dummy + pass + + def test_encoding_kwarg_fixed_width_string(self) -> None: + # CFEncodedInMemoryStore doesn't support explicit string encodings. + pass + + +@pytest.mark.parametrize( + "data", + [ + np.array([["ab", "cdef", b"X"], [1, 2, "c"]], dtype=object), + np.array([["x", 1], ["y", 2]], dtype="object"), + ], +) +def test_infer_dtype_error_on_mixed_types(data): + with pytest.raises(ValueError, match="unable to infer dtype on variable"): + conventions._infer_dtype(data, "test") + + +class TestDecodeCFVariableWithArrayUnits: + def test_decode_cf_variable_with_array_units(self) -> None: + v = Variable(["t"], [1, 2, 3], {"units": np.array(["foobar"], dtype=object)}) + v_decoded = conventions.decode_cf_variable("test2", v) + assert_identical(v, v_decoded) + + +def test_decode_cf_variable_timedelta64(): + variable = Variable(["time"], pd.timedelta_range("1D", periods=2)) + decoded = conventions.decode_cf_variable("time", variable) + assert decoded.encoding == {} + assert_identical(decoded, variable) + + +def test_decode_cf_variable_datetime64(): + variable = Variable(["time"], pd.date_range("2000", periods=2)) + decoded = conventions.decode_cf_variable("time", variable) + assert decoded.encoding == {} + assert_identical(decoded, variable) + + +@requires_cftime +def test_decode_cf_variable_cftime(): + variable = Variable(["time"], cftime_range("2000", periods=2)) + decoded = conventions.decode_cf_variable("time", variable) + assert decoded.encoding == {} + assert_identical(decoded, variable) + + +def test_scalar_units() -> None: + # test that scalar units does not raise an exception + var = Variable(["t"], [np.nan, np.nan, 2], {"units": np.nan}) + + actual = conventions.decode_cf_variable("t", var) + assert_identical(actual, var) + + +def test_decode_cf_error_includes_variable_name(): + ds = Dataset({"invalid": ([], 1e36, {"units": "days since 2000-01-01"})}) + with pytest.raises(ValueError, match="Failed to decode variable 'invalid'"): + decode_cf(ds) + + +def test_encode_cf_variable_with_vlen_dtype() -> None: + v = Variable( + ["x"], np.array(["a", "b"], dtype=coding.strings.create_vlen_dtype(str)) + ) + encoded_v = conventions.encode_cf_variable(v) + assert encoded_v.data.dtype.kind == "O" + assert coding.strings.check_vlen_dtype(encoded_v.data.dtype) == str + + # empty array + v = Variable(["x"], np.array([], dtype=coding.strings.create_vlen_dtype(str))) + encoded_v = conventions.encode_cf_variable(v) + assert encoded_v.data.dtype.kind == "O" + assert coding.strings.check_vlen_dtype(encoded_v.data.dtype) == str diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_coordinates.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_coordinates.py new file mode 100644 index 0000000..f88e554 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_coordinates.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +import numpy as np +import pandas as pd +import pytest + +from xarray.core.alignment import align +from xarray.core.coordinates import Coordinates +from xarray.core.dataarray import DataArray +from xarray.core.dataset import Dataset +from xarray.core.indexes import PandasIndex, PandasMultiIndex +from xarray.core.variable import IndexVariable, Variable +from xarray.tests import assert_identical, source_ndarray + + +class TestCoordinates: + def test_init_noindex(self) -> None: + coords = Coordinates(coords={"foo": ("x", [0, 1, 2])}) + expected = Dataset(coords={"foo": ("x", [0, 1, 2])}) + assert_identical(coords.to_dataset(), expected) + + def test_init_default_index(self) -> None: + coords = Coordinates(coords={"x": [1, 2]}) + expected = Dataset(coords={"x": [1, 2]}) + assert_identical(coords.to_dataset(), expected) + assert "x" in coords.xindexes + + @pytest.mark.filterwarnings("error:IndexVariable") + def test_init_no_default_index(self) -> None: + # dimension coordinate with no default index (explicit) + coords = Coordinates(coords={"x": [1, 2]}, indexes={}) + assert "x" not in coords.xindexes + assert not isinstance(coords["x"], IndexVariable) + + def test_init_from_coords(self) -> None: + expected = Dataset(coords={"foo": ("x", [0, 1, 2])}) + coords = Coordinates(coords=expected.coords) + assert_identical(coords.to_dataset(), expected) + + # test variables copied + assert coords.variables["foo"] is not expected.variables["foo"] + + # test indexes are extracted + expected = Dataset(coords={"x": [0, 1, 2]}) + coords = Coordinates(coords=expected.coords) + assert_identical(coords.to_dataset(), expected) + assert expected.xindexes == coords.xindexes + + # coords + indexes not supported + with pytest.raises( + ValueError, match="passing both.*Coordinates.*indexes.*not allowed" + ): + coords = Coordinates( + coords=expected.coords, indexes={"x": PandasIndex([0, 1, 2], "x")} + ) + + def test_init_empty(self) -> None: + coords = Coordinates() + assert len(coords) == 0 + + def test_init_index_error(self) -> None: + idx = PandasIndex([1, 2, 3], "x") + with pytest.raises(ValueError, match="no coordinate variables found"): + Coordinates(indexes={"x": idx}) + + with pytest.raises(TypeError, match=".* is not an `xarray.indexes.Index`"): + Coordinates(coords={"x": ("x", [1, 2, 3])}, indexes={"x": "not_an_xarray_index"}) # type: ignore + + def test_init_dim_sizes_conflict(self) -> None: + with pytest.raises(ValueError): + Coordinates(coords={"foo": ("x", [1, 2]), "bar": ("x", [1, 2, 3, 4])}) + + def test_from_pandas_multiindex(self) -> None: + midx = pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")) + coords = Coordinates.from_pandas_multiindex(midx, "x") + + assert isinstance(coords.xindexes["x"], PandasMultiIndex) + assert coords.xindexes["x"].index.equals(midx) + assert coords.xindexes["x"].dim == "x" + + expected = PandasMultiIndex(midx, "x").create_variables() + assert list(coords.variables) == list(expected) + for name in ("x", "one", "two"): + assert_identical(expected[name], coords.variables[name]) + + @pytest.mark.filterwarnings("ignore:return type") + def test_dims(self) -> None: + coords = Coordinates(coords={"x": [0, 1, 2]}) + assert set(coords.dims) == {"x"} + + def test_sizes(self) -> None: + coords = Coordinates(coords={"x": [0, 1, 2]}) + assert coords.sizes == {"x": 3} + + def test_dtypes(self) -> None: + coords = Coordinates(coords={"x": [0, 1, 2]}) + assert coords.dtypes == {"x": int} + + def test_getitem(self) -> None: + coords = Coordinates(coords={"x": [0, 1, 2]}) + assert_identical( + coords["x"], + DataArray([0, 1, 2], coords={"x": [0, 1, 2]}, name="x"), + ) + + def test_delitem(self) -> None: + coords = Coordinates(coords={"x": [0, 1, 2]}) + del coords["x"] + assert "x" not in coords + + with pytest.raises( + KeyError, match="'nonexistent' is not in coordinate variables" + ): + del coords["nonexistent"] + + def test_update(self) -> None: + coords = Coordinates(coords={"x": [0, 1, 2]}) + + coords.update({"y": ("y", [4, 5, 6])}) + assert "y" in coords + assert "y" in coords.xindexes + expected = DataArray([4, 5, 6], coords={"y": [4, 5, 6]}, name="y") + assert_identical(coords["y"], expected) + + def test_equals(self): + coords = Coordinates(coords={"x": [0, 1, 2]}) + + assert coords.equals(coords) + assert not coords.equals("not_a_coords") + + def test_identical(self): + coords = Coordinates(coords={"x": [0, 1, 2]}) + + assert coords.identical(coords) + assert not coords.identical("not_a_coords") + + def test_assign(self) -> None: + coords = Coordinates(coords={"x": [0, 1, 2]}) + expected = Coordinates(coords={"x": [0, 1, 2], "y": [3, 4]}) + + actual = coords.assign(y=[3, 4]) + assert_identical(actual, expected) + + actual = coords.assign({"y": [3, 4]}) + assert_identical(actual, expected) + + def test_copy(self) -> None: + no_index_coords = Coordinates({"foo": ("x", [1, 2, 3])}) + copied = no_index_coords.copy() + assert_identical(no_index_coords, copied) + v0 = no_index_coords.variables["foo"] + v1 = copied.variables["foo"] + assert v0 is not v1 + assert source_ndarray(v0.data) is source_ndarray(v1.data) + + deep_copied = no_index_coords.copy(deep=True) + assert_identical(no_index_coords.to_dataset(), deep_copied.to_dataset()) + v0 = no_index_coords.variables["foo"] + v1 = deep_copied.variables["foo"] + assert v0 is not v1 + assert source_ndarray(v0.data) is not source_ndarray(v1.data) + + def test_align(self) -> None: + coords = Coordinates(coords={"x": [0, 1, 2]}) + + left = coords + + # test Coordinates._reindex_callback + right = coords.to_dataset().isel(x=[0, 1]).coords + left2, right2 = align(left, right, join="inner") + assert_identical(left2, right2) + + # test Coordinates._overwrite_indexes + right.update({"x": ("x", [4, 5, 6])}) + left2, right2 = align(left, right, join="override") + assert_identical(left2, left) + assert_identical(left2, right2) + + def test_dataset_from_coords_with_multidim_var_same_name(self): + # regression test for GH #8883 + var = Variable(data=np.arange(6).reshape(2, 3), dims=["x", "y"]) + coords = Coordinates(coords={"x": var}, indexes={}) + ds = Dataset(coords=coords) + assert ds.coords["x"].dims == ("x", "y") diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_cupy.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_cupy.py new file mode 100644 index 0000000..9477690 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_cupy.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import numpy as np +import pandas as pd +import pytest + +import xarray as xr + +cp = pytest.importorskip("cupy") + + +@pytest.fixture +def toy_weather_data(): + """Construct the example DataSet from the Toy weather data example. + + https://docs.xarray.dev/en/stable/examples/weather-data.html + + Here we construct the DataSet exactly as shown in the example and then + convert the numpy arrays to cupy. + + """ + np.random.seed(123) + times = pd.date_range("2000-01-01", "2001-12-31", name="time") + annual_cycle = np.sin(2 * np.pi * (times.dayofyear.values / 365.25 - 0.28)) + + base = 10 + 15 * annual_cycle.reshape(-1, 1) + tmin_values = base + 3 * np.random.randn(annual_cycle.size, 3) + tmax_values = base + 10 + 3 * np.random.randn(annual_cycle.size, 3) + + ds = xr.Dataset( + { + "tmin": (("time", "location"), tmin_values), + "tmax": (("time", "location"), tmax_values), + }, + {"time": times, "location": ["IA", "IN", "IL"]}, + ) + + ds.tmax.data = cp.asarray(ds.tmax.data) + ds.tmin.data = cp.asarray(ds.tmin.data) + + return ds + + +def test_cupy_import() -> None: + """Check the import worked.""" + assert cp + + +def test_check_data_stays_on_gpu(toy_weather_data) -> None: + """Perform some operations and check the data stays on the GPU.""" + freeze = (toy_weather_data["tmin"] <= 0).groupby("time.month").mean("time") + assert isinstance(freeze.data, cp.ndarray) + + +def test_where() -> None: + from xarray.core.duck_array_ops import where + + data = cp.zeros(10) + + output = where(data < 1, 1, data).all() + assert output + assert isinstance(output, cp.ndarray) diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_dask.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_dask.py new file mode 100644 index 0000000..517fc0c --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_dask.py @@ -0,0 +1,1788 @@ +from __future__ import annotations + +import operator +import pickle +import sys +from contextlib import suppress +from textwrap import dedent + +import numpy as np +import pandas as pd +import pytest + +import xarray as xr +from xarray import DataArray, Dataset, Variable +from xarray.core import duck_array_ops +from xarray.core.duck_array_ops import lazy_array_equiv +from xarray.testing import assert_chunks_equal +from xarray.tests import ( + assert_allclose, + assert_array_equal, + assert_equal, + assert_frame_equal, + assert_identical, + mock, + raise_if_dask_computes, + requires_pint, + requires_scipy_or_netCDF4, +) +from xarray.tests.test_backends import create_tmp_file + +dask = pytest.importorskip("dask") +da = pytest.importorskip("dask.array") +dd = pytest.importorskip("dask.dataframe") + +ON_WINDOWS = sys.platform == "win32" + + +def test_raise_if_dask_computes(): + data = da.from_array(np.random.RandomState(0).randn(4, 6), chunks=(2, 2)) + with pytest.raises(RuntimeError, match=r"Too many computes"): + with raise_if_dask_computes(): + data.compute() + + +class DaskTestCase: + def assertLazyAnd(self, expected, actual, test): + with dask.config.set(scheduler="synchronous"): + test(actual, expected) + + if isinstance(actual, Dataset): + for k, v in actual.variables.items(): + if k in actual.xindexes: + assert isinstance(v.data, np.ndarray) + else: + assert isinstance(v.data, da.Array) + elif isinstance(actual, DataArray): + assert isinstance(actual.data, da.Array) + for k, v in actual.coords.items(): + if k in actual.xindexes: + assert isinstance(v.data, np.ndarray) + else: + assert isinstance(v.data, da.Array) + elif isinstance(actual, Variable): + assert isinstance(actual.data, da.Array) + else: + assert False + + +class TestVariable(DaskTestCase): + def assertLazyAndIdentical(self, expected, actual): + self.assertLazyAnd(expected, actual, assert_identical) + + def assertLazyAndAllClose(self, expected, actual): + self.assertLazyAnd(expected, actual, assert_allclose) + + @pytest.fixture(autouse=True) + def setUp(self): + self.values = np.random.RandomState(0).randn(4, 6) + self.data = da.from_array(self.values, chunks=(2, 2)) + + self.eager_var = Variable(("x", "y"), self.values) + self.lazy_var = Variable(("x", "y"), self.data) + + def test_basics(self): + v = self.lazy_var + assert self.data is v.data + assert self.data.chunks == v.chunks + assert_array_equal(self.values, v) + + def test_copy(self): + self.assertLazyAndIdentical(self.eager_var, self.lazy_var.copy()) + self.assertLazyAndIdentical(self.eager_var, self.lazy_var.copy(deep=True)) + + def test_chunk(self): + for chunks, expected in [ + ({}, ((2, 2), (2, 2, 2))), + (3, ((3, 1), (3, 3))), + ({"x": 3, "y": 3}, ((3, 1), (3, 3))), + ({"x": 3}, ((3, 1), (2, 2, 2))), + ({"x": (3, 1)}, ((3, 1), (2, 2, 2))), + ]: + rechunked = self.lazy_var.chunk(chunks) + assert rechunked.chunks == expected + self.assertLazyAndIdentical(self.eager_var, rechunked) + + expected_chunksizes = { + dim: chunks for dim, chunks in zip(self.lazy_var.dims, expected) + } + assert rechunked.chunksizes == expected_chunksizes + + def test_indexing(self): + u = self.eager_var + v = self.lazy_var + self.assertLazyAndIdentical(u[0], v[0]) + self.assertLazyAndIdentical(u[:1], v[:1]) + self.assertLazyAndIdentical(u[[0, 1], [0, 1, 2]], v[[0, 1], [0, 1, 2]]) + + @pytest.mark.parametrize( + "expected_data, index", + [ + (da.array([99, 2, 3, 4]), 0), + (da.array([99, 99, 99, 4]), slice(2, None, -1)), + (da.array([99, 99, 3, 99]), [0, -1, 1]), + (da.array([99, 99, 99, 4]), np.arange(3)), + (da.array([1, 99, 99, 99]), [False, True, True, True]), + (da.array([1, 99, 99, 99]), np.array([False, True, True, True])), + (da.array([99, 99, 99, 99]), Variable(("x"), np.array([True] * 4))), + ], + ) + def test_setitem_dask_array(self, expected_data, index): + arr = Variable(("x"), da.array([1, 2, 3, 4])) + expected = Variable(("x"), expected_data) + with raise_if_dask_computes(): + arr[index] = 99 + assert_identical(arr, expected) + + def test_squeeze(self): + u = self.eager_var + v = self.lazy_var + self.assertLazyAndIdentical(u[0].squeeze(), v[0].squeeze()) + + def test_equals(self): + v = self.lazy_var + assert v.equals(v) + assert isinstance(v.data, da.Array) + assert v.identical(v) + assert isinstance(v.data, da.Array) + + def test_transpose(self): + u = self.eager_var + v = self.lazy_var + self.assertLazyAndIdentical(u.T, v.T) + + def test_shift(self): + u = self.eager_var + v = self.lazy_var + self.assertLazyAndIdentical(u.shift(x=2), v.shift(x=2)) + self.assertLazyAndIdentical(u.shift(x=-2), v.shift(x=-2)) + assert v.data.chunks == v.shift(x=1).data.chunks + + def test_roll(self): + u = self.eager_var + v = self.lazy_var + self.assertLazyAndIdentical(u.roll(x=2), v.roll(x=2)) + assert v.data.chunks == v.roll(x=1).data.chunks + + def test_unary_op(self): + u = self.eager_var + v = self.lazy_var + self.assertLazyAndIdentical(-u, -v) + self.assertLazyAndIdentical(abs(u), abs(v)) + self.assertLazyAndIdentical(u.round(), v.round()) + + def test_binary_op(self): + u = self.eager_var + v = self.lazy_var + self.assertLazyAndIdentical(2 * u, 2 * v) + self.assertLazyAndIdentical(u + u, v + v) + self.assertLazyAndIdentical(u[0] + u, v[0] + v) + + def test_binary_op_bitshift(self) -> None: + # bit shifts only work on ints so we need to generate + # new eager and lazy vars + rng = np.random.default_rng(0) + values = rng.integers(low=-10000, high=10000, size=(4, 6)) + data = da.from_array(values, chunks=(2, 2)) + u = Variable(("x", "y"), values) + v = Variable(("x", "y"), data) + self.assertLazyAndIdentical(u << 2, v << 2) + self.assertLazyAndIdentical(u << 5, v << 5) + self.assertLazyAndIdentical(u >> 2, v >> 2) + self.assertLazyAndIdentical(u >> 5, v >> 5) + + def test_repr(self): + expected = dedent( + f"""\ + Size: 192B + {self.lazy_var.data!r}""" + ) + assert expected == repr(self.lazy_var) + + def test_pickle(self): + # Test that pickling/unpickling does not convert the dask + # backend to numpy + a1 = Variable(["x"], build_dask_array("x")) + a1.compute() + assert not a1._in_memory + assert kernel_call_count == 1 + a2 = pickle.loads(pickle.dumps(a1)) + assert kernel_call_count == 1 + assert_identical(a1, a2) + assert not a1._in_memory + assert not a2._in_memory + + def test_reduce(self): + u = self.eager_var + v = self.lazy_var + self.assertLazyAndAllClose(u.mean(), v.mean()) + self.assertLazyAndAllClose(u.std(), v.std()) + with raise_if_dask_computes(): + actual = v.argmax(dim="x") + self.assertLazyAndAllClose(u.argmax(dim="x"), actual) + with raise_if_dask_computes(): + actual = v.argmin(dim="x") + self.assertLazyAndAllClose(u.argmin(dim="x"), actual) + self.assertLazyAndAllClose((u > 1).any(), (v > 1).any()) + self.assertLazyAndAllClose((u < 1).all("x"), (v < 1).all("x")) + with pytest.raises(NotImplementedError, match=r"only works along an axis"): + v.median() + with pytest.raises(NotImplementedError, match=r"only works along an axis"): + v.median(v.dims) + with raise_if_dask_computes(): + v.reduce(duck_array_ops.mean) + + def test_missing_values(self): + values = np.array([0, 1, np.nan, 3]) + data = da.from_array(values, chunks=(2,)) + + eager_var = Variable("x", values) + lazy_var = Variable("x", data) + self.assertLazyAndIdentical(eager_var, lazy_var.fillna(lazy_var)) + self.assertLazyAndIdentical(Variable("x", range(4)), lazy_var.fillna(2)) + self.assertLazyAndIdentical(eager_var.count(), lazy_var.count()) + + def test_concat(self): + u = self.eager_var + v = self.lazy_var + self.assertLazyAndIdentical(u, Variable.concat([v[:2], v[2:]], "x")) + self.assertLazyAndIdentical(u[:2], Variable.concat([v[0], v[1]], "x")) + self.assertLazyAndIdentical(u[:2], Variable.concat([u[0], v[1]], "x")) + self.assertLazyAndIdentical(u[:2], Variable.concat([v[0], u[1]], "x")) + self.assertLazyAndIdentical( + u[:3], Variable.concat([v[[0, 2]], v[[1]]], "x", positions=[[0, 2], [1]]) + ) + + def test_missing_methods(self): + v = self.lazy_var + try: + v.argsort() + except NotImplementedError as err: + assert "dask" in str(err) + try: + v[0].item() + except NotImplementedError as err: + assert "dask" in str(err) + + def test_univariate_ufunc(self): + u = self.eager_var + v = self.lazy_var + self.assertLazyAndAllClose(np.sin(u), np.sin(v)) + + def test_bivariate_ufunc(self): + u = self.eager_var + v = self.lazy_var + self.assertLazyAndAllClose(np.maximum(u, 0), np.maximum(v, 0)) + self.assertLazyAndAllClose(np.maximum(u, 0), np.maximum(0, v)) + + def test_compute(self): + u = self.eager_var + v = self.lazy_var + + assert dask.is_dask_collection(v) + (v2,) = dask.compute(v + 1) + assert not dask.is_dask_collection(v2) + + assert ((u + 1).data == v2.data).all() + + def test_persist(self): + u = self.eager_var + v = self.lazy_var + 1 + + (v2,) = dask.persist(v) + assert v is not v2 + assert len(v2.__dask_graph__()) < len(v.__dask_graph__()) + assert v2.__dask_keys__() == v.__dask_keys__() + assert dask.is_dask_collection(v) + assert dask.is_dask_collection(v2) + + self.assertLazyAndAllClose(u + 1, v) + self.assertLazyAndAllClose(u + 1, v2) + + @requires_pint + def test_tokenize_duck_dask_array(self): + import pint + + unit_registry = pint.UnitRegistry() + + q = unit_registry.Quantity(self.data, "meter") + variable = xr.Variable(("x", "y"), q) + + token = dask.base.tokenize(variable) + post_op = variable + 5 * unit_registry.meter + + assert dask.base.tokenize(variable) != dask.base.tokenize(post_op) + # Immutability check + assert dask.base.tokenize(variable) == token + + +class TestDataArrayAndDataset(DaskTestCase): + def assertLazyAndIdentical(self, expected, actual): + self.assertLazyAnd(expected, actual, assert_identical) + + def assertLazyAndAllClose(self, expected, actual): + self.assertLazyAnd(expected, actual, assert_allclose) + + def assertLazyAndEqual(self, expected, actual): + self.assertLazyAnd(expected, actual, assert_equal) + + @pytest.fixture(autouse=True) + def setUp(self): + self.values = np.random.randn(4, 6) + self.data = da.from_array(self.values, chunks=(2, 2)) + self.eager_array = DataArray( + self.values, coords={"x": range(4)}, dims=("x", "y"), name="foo" + ) + self.lazy_array = DataArray( + self.data, coords={"x": range(4)}, dims=("x", "y"), name="foo" + ) + + def test_chunk(self): + for chunks, expected in [ + ({}, ((2, 2), (2, 2, 2))), + (3, ((3, 1), (3, 3))), + ({"x": 3, "y": 3}, ((3, 1), (3, 3))), + ({"x": 3}, ((3, 1), (2, 2, 2))), + ({"x": (3, 1)}, ((3, 1), (2, 2, 2))), + ]: + # Test DataArray + rechunked = self.lazy_array.chunk(chunks) + assert rechunked.chunks == expected + self.assertLazyAndIdentical(self.eager_array, rechunked) + + expected_chunksizes = { + dim: chunks for dim, chunks in zip(self.lazy_array.dims, expected) + } + assert rechunked.chunksizes == expected_chunksizes + + # Test Dataset + lazy_dataset = self.lazy_array.to_dataset() + eager_dataset = self.eager_array.to_dataset() + expected_chunksizes = { + dim: chunks for dim, chunks in zip(lazy_dataset.dims, expected) + } + rechunked = lazy_dataset.chunk(chunks) + + # Dataset.chunks has a different return type to DataArray.chunks - see issue #5843 + assert rechunked.chunks == expected_chunksizes + self.assertLazyAndIdentical(eager_dataset, rechunked) + + assert rechunked.chunksizes == expected_chunksizes + + def test_rechunk(self): + chunked = self.eager_array.chunk({"x": 2}).chunk({"y": 2}) + assert chunked.chunks == ((2,) * 2, (2,) * 3) + self.assertLazyAndIdentical(self.lazy_array, chunked) + + def test_new_chunk(self): + chunked = self.eager_array.chunk() + assert chunked.data.name.startswith("xarray-") + + def test_lazy_dataset(self): + lazy_ds = Dataset({"foo": (("x", "y"), self.data)}) + assert isinstance(lazy_ds.foo.variable.data, da.Array) + + def test_lazy_array(self): + u = self.eager_array + v = self.lazy_array + + self.assertLazyAndAllClose(u, v) + self.assertLazyAndAllClose(-u, -v) + self.assertLazyAndAllClose(u.T, v.T) + self.assertLazyAndAllClose(u.mean(), v.mean()) + self.assertLazyAndAllClose(1 + u, 1 + v) + + actual = xr.concat([v[:2], v[2:]], "x") + self.assertLazyAndAllClose(u, actual) + + def test_compute(self): + u = self.eager_array + v = self.lazy_array + + assert dask.is_dask_collection(v) + (v2,) = dask.compute(v + 1) + assert not dask.is_dask_collection(v2) + + assert ((u + 1).data == v2.data).all() + + def test_persist(self): + u = self.eager_array + v = self.lazy_array + 1 + + (v2,) = dask.persist(v) + assert v is not v2 + assert len(v2.__dask_graph__()) < len(v.__dask_graph__()) + assert v2.__dask_keys__() == v.__dask_keys__() + assert dask.is_dask_collection(v) + assert dask.is_dask_collection(v2) + + self.assertLazyAndAllClose(u + 1, v) + self.assertLazyAndAllClose(u + 1, v2) + + def test_concat_loads_variables(self): + # Test that concat() computes not-in-memory variables at most once + # and loads them in the output, while leaving the input unaltered. + d1 = build_dask_array("d1") + c1 = build_dask_array("c1") + d2 = build_dask_array("d2") + c2 = build_dask_array("c2") + d3 = build_dask_array("d3") + c3 = build_dask_array("c3") + # Note: c is a non-index coord. + # Index coords are loaded by IndexVariable.__init__. + ds1 = Dataset(data_vars={"d": ("x", d1)}, coords={"c": ("x", c1)}) + ds2 = Dataset(data_vars={"d": ("x", d2)}, coords={"c": ("x", c2)}) + ds3 = Dataset(data_vars={"d": ("x", d3)}, coords={"c": ("x", c3)}) + + assert kernel_call_count == 0 + out = xr.concat( + [ds1, ds2, ds3], dim="n", data_vars="different", coords="different" + ) + # each kernel is computed exactly once + assert kernel_call_count == 6 + # variables are loaded in the output + assert isinstance(out["d"].data, np.ndarray) + assert isinstance(out["c"].data, np.ndarray) + + out = xr.concat([ds1, ds2, ds3], dim="n", data_vars="all", coords="all") + # no extra kernel calls + assert kernel_call_count == 6 + assert isinstance(out["d"].data, dask.array.Array) + assert isinstance(out["c"].data, dask.array.Array) + + out = xr.concat([ds1, ds2, ds3], dim="n", data_vars=["d"], coords=["c"]) + # no extra kernel calls + assert kernel_call_count == 6 + assert isinstance(out["d"].data, dask.array.Array) + assert isinstance(out["c"].data, dask.array.Array) + + out = xr.concat([ds1, ds2, ds3], dim="n", data_vars=[], coords=[]) + # variables are loaded once as we are validating that they're identical + assert kernel_call_count == 12 + assert isinstance(out["d"].data, np.ndarray) + assert isinstance(out["c"].data, np.ndarray) + + out = xr.concat( + [ds1, ds2, ds3], + dim="n", + data_vars="different", + coords="different", + compat="identical", + ) + # compat=identical doesn't do any more kernel calls than compat=equals + assert kernel_call_count == 18 + assert isinstance(out["d"].data, np.ndarray) + assert isinstance(out["c"].data, np.ndarray) + + # When the test for different turns true halfway through, + # stop computing variables as it would not have any benefit + ds4 = Dataset(data_vars={"d": ("x", [2.0])}, coords={"c": ("x", [2.0])}) + out = xr.concat( + [ds1, ds2, ds4, ds3], dim="n", data_vars="different", coords="different" + ) + # the variables of ds1 and ds2 were computed, but those of ds3 didn't + assert kernel_call_count == 22 + assert isinstance(out["d"].data, dask.array.Array) + assert isinstance(out["c"].data, dask.array.Array) + # the data of ds1 and ds2 was loaded into numpy and then + # concatenated to the data of ds3. Thus, only ds3 is computed now. + out.compute() + assert kernel_call_count == 24 + + # Finally, test that originals are unaltered + assert ds1["d"].data is d1 + assert ds1["c"].data is c1 + assert ds2["d"].data is d2 + assert ds2["c"].data is c2 + assert ds3["d"].data is d3 + assert ds3["c"].data is c3 + + # now check that concat() is correctly using dask name equality to skip loads + out = xr.concat( + [ds1, ds1, ds1], dim="n", data_vars="different", coords="different" + ) + assert kernel_call_count == 24 + # variables are not loaded in the output + assert isinstance(out["d"].data, dask.array.Array) + assert isinstance(out["c"].data, dask.array.Array) + + out = xr.concat( + [ds1, ds1, ds1], dim="n", data_vars=[], coords=[], compat="identical" + ) + assert kernel_call_count == 24 + # variables are not loaded in the output + assert isinstance(out["d"].data, dask.array.Array) + assert isinstance(out["c"].data, dask.array.Array) + + out = xr.concat( + [ds1, ds2.compute(), ds3], + dim="n", + data_vars="all", + coords="different", + compat="identical", + ) + # c1,c3 must be computed for comparison since c2 is numpy; + # d2 is computed too + assert kernel_call_count == 28 + + out = xr.concat( + [ds1, ds2.compute(), ds3], + dim="n", + data_vars="all", + coords="all", + compat="identical", + ) + # no extra computes + assert kernel_call_count == 30 + + # Finally, test that originals are unaltered + assert ds1["d"].data is d1 + assert ds1["c"].data is c1 + assert ds2["d"].data is d2 + assert ds2["c"].data is c2 + assert ds3["d"].data is d3 + assert ds3["c"].data is c3 + + def test_groupby(self): + u = self.eager_array + v = self.lazy_array + + expected = u.groupby("x").mean(...) + with raise_if_dask_computes(): + actual = v.groupby("x").mean(...) + self.assertLazyAndAllClose(expected, actual) + + def test_rolling(self): + u = self.eager_array + v = self.lazy_array + + expected = u.rolling(x=2).mean() + with raise_if_dask_computes(): + actual = v.rolling(x=2).mean() + self.assertLazyAndAllClose(expected, actual) + + @pytest.mark.parametrize("func", ["first", "last"]) + def test_groupby_first_last(self, func): + method = operator.methodcaller(func) + u = self.eager_array + v = self.lazy_array + + for coords in [u.coords, v.coords]: + coords["ab"] = ("x", ["a", "a", "b", "b"]) + expected = method(u.groupby("ab")) + + with raise_if_dask_computes(): + actual = method(v.groupby("ab")) + self.assertLazyAndAllClose(expected, actual) + + with raise_if_dask_computes(): + actual = method(v.groupby("ab")) + self.assertLazyAndAllClose(expected, actual) + + def test_reindex(self): + u = self.eager_array.assign_coords(y=range(6)) + v = self.lazy_array.assign_coords(y=range(6)) + + for kwargs in [ + {"x": [2, 3, 4]}, + {"x": [1, 100, 2, 101, 3]}, + {"x": [2.5, 3, 3.5], "y": [2, 2.5, 3]}, + ]: + expected = u.reindex(**kwargs) + actual = v.reindex(**kwargs) + self.assertLazyAndAllClose(expected, actual) + + def test_to_dataset_roundtrip(self): + u = self.eager_array + v = self.lazy_array + + expected = u.assign_coords(x=u["x"]) + self.assertLazyAndEqual(expected, v.to_dataset("x").to_dataarray("x")) + + def test_merge(self): + def duplicate_and_merge(array): + return xr.merge([array, array.rename("bar")]).to_dataarray() + + expected = duplicate_and_merge(self.eager_array) + actual = duplicate_and_merge(self.lazy_array) + self.assertLazyAndEqual(expected, actual) + + def test_ufuncs(self): + u = self.eager_array + v = self.lazy_array + self.assertLazyAndAllClose(np.sin(u), np.sin(v)) + + def test_where_dispatching(self): + a = np.arange(10) + b = a > 3 + x = da.from_array(a, 5) + y = da.from_array(b, 5) + expected = DataArray(a).where(b) + self.assertLazyAndEqual(expected, DataArray(a).where(y)) + self.assertLazyAndEqual(expected, DataArray(x).where(b)) + self.assertLazyAndEqual(expected, DataArray(x).where(y)) + + def test_simultaneous_compute(self): + ds = Dataset({"foo": ("x", range(5)), "bar": ("x", range(5))}).chunk() + + count = [0] + + def counting_get(*args, **kwargs): + count[0] += 1 + return dask.get(*args, **kwargs) + + ds.load(scheduler=counting_get) + + assert count[0] == 1 + + def test_stack(self): + data = da.random.normal(size=(2, 3, 4), chunks=(1, 3, 4)) + arr = DataArray(data, dims=("w", "x", "y")) + stacked = arr.stack(z=("x", "y")) + z = pd.MultiIndex.from_product([np.arange(3), np.arange(4)], names=["x", "y"]) + expected = DataArray(data.reshape(2, -1), {"z": z}, dims=["w", "z"]) + assert stacked.data.chunks == expected.data.chunks + self.assertLazyAndEqual(expected, stacked) + + def test_dot(self): + eager = self.eager_array.dot(self.eager_array[0]) + lazy = self.lazy_array.dot(self.lazy_array[0]) + self.assertLazyAndAllClose(eager, lazy) + + def test_dataarray_repr(self): + data = build_dask_array("data") + nonindex_coord = build_dask_array("coord") + a = DataArray(data, dims=["x"], coords={"y": ("x", nonindex_coord)}) + expected = dedent( + f"""\ + Size: 8B + {data!r} + Coordinates: + y (x) int64 8B dask.array + Dimensions without coordinates: x""" + ) + assert expected == repr(a) + assert kernel_call_count == 0 # should not evaluate dask array + + def test_dataset_repr(self): + data = build_dask_array("data") + nonindex_coord = build_dask_array("coord") + ds = Dataset(data_vars={"a": ("x", data)}, coords={"y": ("x", nonindex_coord)}) + expected = dedent( + """\ + Size: 16B + Dimensions: (x: 1) + Coordinates: + y (x) int64 8B dask.array + Dimensions without coordinates: x + Data variables: + a (x) int64 8B dask.array""" + ) + assert expected == repr(ds) + assert kernel_call_count == 0 # should not evaluate dask array + + def test_dataarray_pickle(self): + # Test that pickling/unpickling converts the dask backend + # to numpy in neither the data variable nor the non-index coords + data = build_dask_array("data") + nonindex_coord = build_dask_array("coord") + a1 = DataArray(data, dims=["x"], coords={"y": ("x", nonindex_coord)}) + a1.compute() + assert not a1._in_memory + assert not a1.coords["y"]._in_memory + assert kernel_call_count == 2 + a2 = pickle.loads(pickle.dumps(a1)) + assert kernel_call_count == 2 + assert_identical(a1, a2) + assert not a1._in_memory + assert not a2._in_memory + assert not a1.coords["y"]._in_memory + assert not a2.coords["y"]._in_memory + + def test_dataset_pickle(self): + # Test that pickling/unpickling converts the dask backend + # to numpy in neither the data variables nor the non-index coords + data = build_dask_array("data") + nonindex_coord = build_dask_array("coord") + ds1 = Dataset(data_vars={"a": ("x", data)}, coords={"y": ("x", nonindex_coord)}) + ds1.compute() + assert not ds1["a"]._in_memory + assert not ds1["y"]._in_memory + assert kernel_call_count == 2 + ds2 = pickle.loads(pickle.dumps(ds1)) + assert kernel_call_count == 2 + assert_identical(ds1, ds2) + assert not ds1["a"]._in_memory + assert not ds2["a"]._in_memory + assert not ds1["y"]._in_memory + assert not ds2["y"]._in_memory + + def test_dataarray_getattr(self): + # ipython/jupyter does a long list of getattr() calls to when trying to + # represent an object. + # Make sure we're not accidentally computing dask variables. + data = build_dask_array("data") + nonindex_coord = build_dask_array("coord") + a = DataArray(data, dims=["x"], coords={"y": ("x", nonindex_coord)}) + with suppress(AttributeError): + getattr(a, "NOTEXIST") + assert kernel_call_count == 0 + + def test_dataset_getattr(self): + # Test that pickling/unpickling converts the dask backend + # to numpy in neither the data variables nor the non-index coords + data = build_dask_array("data") + nonindex_coord = build_dask_array("coord") + ds = Dataset(data_vars={"a": ("x", data)}, coords={"y": ("x", nonindex_coord)}) + with suppress(AttributeError): + getattr(ds, "NOTEXIST") + assert kernel_call_count == 0 + + def test_values(self): + # Test that invoking the values property does not convert the dask + # backend to numpy + a = DataArray([1, 2]).chunk() + assert not a._in_memory + assert a.values.tolist() == [1, 2] + assert not a._in_memory + + def test_from_dask_variable(self): + # Test array creation from Variable with dask backend. + # This is used e.g. in broadcast() + a = DataArray(self.lazy_array.variable, coords={"x": range(4)}, name="foo") + self.assertLazyAndIdentical(self.lazy_array, a) + + @requires_pint + def test_tokenize_duck_dask_array(self): + import pint + + unit_registry = pint.UnitRegistry() + + q = unit_registry.Quantity(self.data, unit_registry.meter) + data_array = xr.DataArray( + data=q, coords={"x": range(4)}, dims=("x", "y"), name="foo" + ) + + token = dask.base.tokenize(data_array) + post_op = data_array + 5 * unit_registry.meter + + assert dask.base.tokenize(data_array) != dask.base.tokenize(post_op) + # Immutability check + assert dask.base.tokenize(data_array) == token + + +class TestToDaskDataFrame: + def test_to_dask_dataframe(self): + # Test conversion of Datasets to dask DataFrames + x = np.random.randn(10) + y = np.arange(10, dtype="uint8") + t = list("abcdefghij") + + ds = Dataset( + {"a": ("t", da.from_array(x, chunks=4)), "b": ("t", y), "t": ("t", t)} + ) + + expected_pd = pd.DataFrame({"a": x, "b": y}, index=pd.Index(t, name="t")) + + # test if 1-D index is correctly set up + expected = dd.from_pandas(expected_pd, chunksize=4) + actual = ds.to_dask_dataframe(set_index=True) + # test if we have dask dataframes + assert isinstance(actual, dd.DataFrame) + + # use the .equals from pandas to check dataframes are equivalent + assert_frame_equal(actual.compute(), expected.compute()) + + # test if no index is given + expected = dd.from_pandas(expected_pd.reset_index(drop=False), chunksize=4) + + actual = ds.to_dask_dataframe(set_index=False) + + assert isinstance(actual, dd.DataFrame) + assert_frame_equal(actual.compute(), expected.compute()) + + @pytest.mark.xfail( + reason="Currently pandas with pyarrow installed will return a `string[pyarrow]` type, " + "which causes the `y` column to have a different type depending on whether pyarrow is installed" + ) + def test_to_dask_dataframe_2D(self): + # Test if 2-D dataset is supplied + w = np.random.randn(2, 3) + ds = Dataset({"w": (("x", "y"), da.from_array(w, chunks=(1, 2)))}) + ds["x"] = ("x", np.array([0, 1], np.int64)) + ds["y"] = ("y", list("abc")) + + # dask dataframes do not (yet) support multiindex, + # but when it does, this would be the expected index: + exp_index = pd.MultiIndex.from_arrays( + [[0, 0, 0, 1, 1, 1], ["a", "b", "c", "a", "b", "c"]], names=["x", "y"] + ) + expected = pd.DataFrame({"w": w.reshape(-1)}, index=exp_index) + # so for now, reset the index + expected = expected.reset_index(drop=False) + actual = ds.to_dask_dataframe(set_index=False) + + assert isinstance(actual, dd.DataFrame) + assert_frame_equal(actual.compute(), expected) + + @pytest.mark.xfail(raises=NotImplementedError) + def test_to_dask_dataframe_2D_set_index(self): + # This will fail until dask implements MultiIndex support + w = da.from_array(np.random.randn(2, 3), chunks=(1, 2)) + ds = Dataset({"w": (("x", "y"), w)}) + ds["x"] = ("x", np.array([0, 1], np.int64)) + ds["y"] = ("y", list("abc")) + + expected = ds.compute().to_dataframe() + actual = ds.to_dask_dataframe(set_index=True) + assert isinstance(actual, dd.DataFrame) + assert_frame_equal(expected, actual.compute()) + + def test_to_dask_dataframe_coordinates(self): + # Test if coordinate is also a dask array + x = np.random.randn(10) + t = np.arange(10) * 2 + + ds = Dataset( + { + "a": ("t", da.from_array(x, chunks=4)), + "t": ("t", da.from_array(t, chunks=4)), + } + ) + + expected_pd = pd.DataFrame({"a": x}, index=pd.Index(t, name="t")) + expected = dd.from_pandas(expected_pd, chunksize=4) + actual = ds.to_dask_dataframe(set_index=True) + assert isinstance(actual, dd.DataFrame) + assert_frame_equal(expected.compute(), actual.compute()) + + @pytest.mark.xfail( + reason="Currently pandas with pyarrow installed will return a `string[pyarrow]` type, " + "which causes the index to have a different type depending on whether pyarrow is installed" + ) + def test_to_dask_dataframe_not_daskarray(self): + # Test if DataArray is not a dask array + x = np.random.randn(10) + y = np.arange(10, dtype="uint8") + t = list("abcdefghij") + + ds = Dataset({"a": ("t", x), "b": ("t", y), "t": ("t", t)}) + + expected = pd.DataFrame({"a": x, "b": y}, index=pd.Index(t, name="t")) + + actual = ds.to_dask_dataframe(set_index=True) + assert isinstance(actual, dd.DataFrame) + assert_frame_equal(expected, actual.compute()) + + def test_to_dask_dataframe_no_coordinate(self): + x = da.from_array(np.random.randn(10), chunks=4) + ds = Dataset({"x": ("dim_0", x)}) + + expected = ds.compute().to_dataframe().reset_index() + actual = ds.to_dask_dataframe() + assert isinstance(actual, dd.DataFrame) + assert_frame_equal(expected, actual.compute()) + + expected = ds.compute().to_dataframe() + actual = ds.to_dask_dataframe(set_index=True) + assert isinstance(actual, dd.DataFrame) + assert_frame_equal(expected, actual.compute()) + + def test_to_dask_dataframe_dim_order(self): + values = np.array([[1, 2], [3, 4]], dtype=np.int64) + ds = Dataset({"w": (("x", "y"), values)}).chunk(1) + + expected = ds["w"].to_series().reset_index() + actual = ds.to_dask_dataframe(dim_order=["x", "y"]) + assert isinstance(actual, dd.DataFrame) + assert_frame_equal(expected, actual.compute()) + + expected = ds["w"].T.to_series().reset_index() + actual = ds.to_dask_dataframe(dim_order=["y", "x"]) + assert isinstance(actual, dd.DataFrame) + assert_frame_equal(expected, actual.compute()) + + with pytest.raises(ValueError, match=r"does not match the set of dimensions"): + ds.to_dask_dataframe(dim_order=["x"]) + + +@pytest.mark.parametrize("method", ["load", "compute"]) +def test_dask_kwargs_variable(method): + chunked_array = da.from_array(np.arange(3), chunks=(2,)) + x = Variable("y", chunked_array) + # args should be passed on to dask.compute() (via DaskManager.compute()) + with mock.patch.object(da, "compute", return_value=(np.arange(3),)) as mock_compute: + getattr(x, method)(foo="bar") + mock_compute.assert_called_with(chunked_array, foo="bar") + + +@pytest.mark.parametrize("method", ["load", "compute", "persist"]) +def test_dask_kwargs_dataarray(method): + data = da.from_array(np.arange(3), chunks=(2,)) + x = DataArray(data) + if method in ["load", "compute"]: + dask_func = "dask.array.compute" + else: + dask_func = "dask.persist" + # args should be passed on to "dask_func" + with mock.patch(dask_func) as mock_func: + getattr(x, method)(foo="bar") + mock_func.assert_called_with(data, foo="bar") + + +@pytest.mark.parametrize("method", ["load", "compute", "persist"]) +def test_dask_kwargs_dataset(method): + data = da.from_array(np.arange(3), chunks=(2,)) + x = Dataset({"x": (("y"), data)}) + if method in ["load", "compute"]: + dask_func = "dask.array.compute" + else: + dask_func = "dask.persist" + # args should be passed on to "dask_func" + with mock.patch(dask_func) as mock_func: + getattr(x, method)(foo="bar") + mock_func.assert_called_with(data, foo="bar") + + +kernel_call_count = 0 + + +def kernel(name): + """Dask kernel to test pickling/unpickling and __repr__. + Must be global to make it pickleable. + """ + global kernel_call_count + kernel_call_count += 1 + return np.ones(1, dtype=np.int64) + + +def build_dask_array(name): + global kernel_call_count + kernel_call_count = 0 + return dask.array.Array( + dask={(name, 0): (kernel, name)}, name=name, chunks=((1,),), dtype=np.int64 + ) + + +@pytest.mark.parametrize( + "persist", [lambda x: x.persist(), lambda x: dask.persist(x)[0]] +) +def test_persist_Dataset(persist): + ds = Dataset({"foo": ("x", range(5)), "bar": ("x", range(5))}).chunk() + ds = ds + 1 + n = len(ds.foo.data.dask) + + ds2 = persist(ds) + + assert len(ds2.foo.data.dask) == 1 + assert len(ds.foo.data.dask) == n # doesn't mutate in place + + +@pytest.mark.parametrize( + "persist", [lambda x: x.persist(), lambda x: dask.persist(x)[0]] +) +def test_persist_DataArray(persist): + x = da.arange(10, chunks=(5,)) + y = DataArray(x) + z = y + 1 + n = len(z.data.dask) + + zz = persist(z) + + assert len(z.data.dask) == n + assert len(zz.data.dask) == zz.data.npartitions + + +def test_dataarray_with_dask_coords(): + import toolz + + x = xr.Variable("x", da.arange(8, chunks=(4,))) + y = xr.Variable("y", da.arange(8, chunks=(4,)) * 2) + data = da.random.random((8, 8), chunks=(4, 4)) + 1 + array = xr.DataArray(data, dims=["x", "y"]) + array.coords["xx"] = x + array.coords["yy"] = y + + assert dict(array.__dask_graph__()) == toolz.merge( + data.__dask_graph__(), x.__dask_graph__(), y.__dask_graph__() + ) + + (array2,) = dask.compute(array) + assert not dask.is_dask_collection(array2) + + assert all(isinstance(v._variable.data, np.ndarray) for v in array2.coords.values()) + + +def test_basic_compute(): + ds = Dataset({"foo": ("x", range(5)), "bar": ("x", range(5))}).chunk({"x": 2}) + for get in [dask.threaded.get, dask.multiprocessing.get, dask.local.get_sync, None]: + with dask.config.set(scheduler=get): + ds.compute() + ds.foo.compute() + ds.foo.variable.compute() + + +def test_dask_layers_and_dependencies(): + ds = Dataset({"foo": ("x", range(5)), "bar": ("x", range(5))}).chunk() + + x = dask.delayed(ds) + assert set(x.__dask_graph__().dependencies).issuperset( + ds.__dask_graph__().dependencies + ) + assert set(x.foo.__dask_graph__().dependencies).issuperset( + ds.__dask_graph__().dependencies + ) + + +def make_da(): + da = xr.DataArray( + np.ones((10, 20)), + dims=["x", "y"], + coords={"x": np.arange(10), "y": np.arange(100, 120)}, + name="a", + ).chunk({"x": 4, "y": 5}) + da.x.attrs["long_name"] = "x" + da.attrs["test"] = "test" + da.coords["c2"] = 0.5 + da.coords["ndcoord"] = da.x * 2 + da.coords["cxy"] = (da.x * da.y).chunk({"x": 4, "y": 5}) + + return da + + +def make_ds(): + map_ds = xr.Dataset() + map_ds["a"] = make_da() + map_ds["b"] = map_ds.a + 50 + map_ds["c"] = map_ds.x + 20 + map_ds = map_ds.chunk({"x": 4, "y": 5}) + map_ds["d"] = ("z", [1, 1, 1, 1]) + map_ds["z"] = [0, 1, 2, 3] + map_ds["e"] = map_ds.x + map_ds.y + map_ds.coords["c1"] = 0.5 + map_ds.coords["cx"] = ("x", np.arange(len(map_ds.x))) + map_ds.coords["cx"].attrs["test2"] = "test2" + map_ds.attrs["test"] = "test" + map_ds.coords["xx"] = map_ds["a"] * map_ds.y + + map_ds.x.attrs["long_name"] = "x" + map_ds.y.attrs["long_name"] = "y" + + return map_ds + + +# fixtures cannot be used in parametrize statements +# instead use this workaround +# https://docs.pytest.org/en/latest/deprecations.html#calling-fixtures-directly +@pytest.fixture +def map_da(): + return make_da() + + +@pytest.fixture +def map_ds(): + return make_ds() + + +def test_unify_chunks(map_ds): + ds_copy = map_ds.copy() + ds_copy["cxy"] = ds_copy.cxy.chunk({"y": 10}) + + with pytest.raises(ValueError, match=r"inconsistent chunks"): + ds_copy.chunks + + expected_chunks = {"x": (4, 4, 2), "y": (5, 5, 5, 5)} + with raise_if_dask_computes(): + actual_chunks = ds_copy.unify_chunks().chunks + assert actual_chunks == expected_chunks + assert_identical(map_ds, ds_copy.unify_chunks()) + + out_a, out_b = xr.unify_chunks(ds_copy.cxy, ds_copy.drop_vars("cxy")) + assert out_a.chunks == ((4, 4, 2), (5, 5, 5, 5)) + assert out_b.chunks == expected_chunks + + # Test unordered dims + da = ds_copy["cxy"] + out_a, out_b = xr.unify_chunks(da.chunk({"x": -1}), da.T.chunk({"y": -1})) + assert out_a.chunks == ((4, 4, 2), (5, 5, 5, 5)) + assert out_b.chunks == ((5, 5, 5, 5), (4, 4, 2)) + + # Test mismatch + with pytest.raises(ValueError, match=r"Dimension 'x' size mismatch: 10 != 2"): + xr.unify_chunks(da, da.isel(x=slice(2))) + + +@pytest.mark.parametrize("obj", [make_ds(), make_da()]) +@pytest.mark.parametrize( + "transform", [lambda x: x.compute(), lambda x: x.unify_chunks()] +) +def test_unify_chunks_shallow_copy(obj, transform): + obj = transform(obj) + unified = obj.unify_chunks() + assert_identical(obj, unified) and obj is not obj.unify_chunks() + + +@pytest.mark.parametrize("obj", [make_da()]) +def test_auto_chunk_da(obj): + actual = obj.chunk("auto").data + expected = obj.data.rechunk("auto") + np.testing.assert_array_equal(actual, expected) + assert actual.chunks == expected.chunks + + +def test_map_blocks_error(map_da, map_ds): + def bad_func(darray): + return (darray * darray.x + 5 * darray.y)[:1, :1] + + with pytest.raises(ValueError, match=r"Received dimension 'x' of length 1"): + xr.map_blocks(bad_func, map_da).compute() + + def returns_numpy(darray): + return (darray * darray.x + 5 * darray.y).values + + with pytest.raises(TypeError, match=r"Function must return an xarray DataArray"): + xr.map_blocks(returns_numpy, map_da) + + with pytest.raises(TypeError, match=r"args must be"): + xr.map_blocks(operator.add, map_da, args=10) + + with pytest.raises(TypeError, match=r"kwargs must be"): + xr.map_blocks(operator.add, map_da, args=[10], kwargs=[20]) + + def really_bad_func(darray): + raise ValueError("couldn't do anything.") + + with pytest.raises(Exception, match=r"Cannot infer"): + xr.map_blocks(really_bad_func, map_da) + + ds_copy = map_ds.copy() + ds_copy["cxy"] = ds_copy.cxy.chunk({"y": 10}) + + with pytest.raises(ValueError, match=r"inconsistent chunks"): + xr.map_blocks(bad_func, ds_copy) + + with pytest.raises(TypeError, match=r"Cannot pass dask collections"): + xr.map_blocks(bad_func, map_da, kwargs=dict(a=map_da.chunk())) + + +@pytest.mark.parametrize("obj", [make_da(), make_ds()]) +def test_map_blocks(obj): + def func(obj): + result = obj + obj.x + 5 * obj.y + return result + + with raise_if_dask_computes(): + actual = xr.map_blocks(func, obj) + expected = func(obj) + assert_chunks_equal(expected.chunk(), actual) + assert_identical(actual, expected) + + +@pytest.mark.parametrize("obj", [make_da(), make_ds()]) +def test_map_blocks_mixed_type_inputs(obj): + def func(obj1, non_xarray_input, obj2): + result = obj1 + obj1.x + 5 * obj1.y + return result + + with raise_if_dask_computes(): + actual = xr.map_blocks(func, obj, args=["non_xarray_input", obj]) + expected = func(obj, "non_xarray_input", obj) + assert_chunks_equal(expected.chunk(), actual) + assert_identical(actual, expected) + + +@pytest.mark.parametrize("obj", [make_da(), make_ds()]) +def test_map_blocks_convert_args_to_list(obj): + expected = obj + 10 + with raise_if_dask_computes(): + actual = xr.map_blocks(operator.add, obj, [10]) + assert_chunks_equal(expected.chunk(), actual) + assert_identical(actual, expected) + + +def test_map_blocks_dask_args(): + da1 = xr.DataArray( + np.ones((10, 20)), + dims=["x", "y"], + coords={"x": np.arange(10), "y": np.arange(20)}, + ).chunk({"x": 5, "y": 4}) + + # check that block shapes are the same + def sumda(da1, da2): + assert da1.shape == da2.shape + return da1 + da2 + + da2 = da1 + 1 + with raise_if_dask_computes(): + mapped = xr.map_blocks(sumda, da1, args=[da2]) + xr.testing.assert_equal(da1 + da2, mapped) + + # one dimension in common + da2 = (da1 + 1).isel(x=1, drop=True) + with raise_if_dask_computes(): + mapped = xr.map_blocks(operator.add, da1, args=[da2]) + xr.testing.assert_equal(da1 + da2, mapped) + + # test that everything works when dimension names are different + da2 = (da1 + 1).isel(x=1, drop=True).rename({"y": "k"}) + with raise_if_dask_computes(): + mapped = xr.map_blocks(operator.add, da1, args=[da2]) + xr.testing.assert_equal(da1 + da2, mapped) + + with pytest.raises(ValueError, match=r"Chunk sizes along dimension 'x'"): + xr.map_blocks(operator.add, da1, args=[da1.chunk({"x": 1})]) + + with pytest.raises(ValueError, match=r"cannot align.*index.*are not equal"): + xr.map_blocks(operator.add, da1, args=[da1.reindex(x=np.arange(20))]) + + # reduction + da1 = da1.chunk({"x": -1}) + da2 = da1 + 1 + with raise_if_dask_computes(): + mapped = xr.map_blocks(lambda a, b: (a + b).sum("x"), da1, args=[da2]) + xr.testing.assert_equal((da1 + da2).sum("x"), mapped) + + # reduction with template + da1 = da1.chunk({"x": -1}) + da2 = da1 + 1 + with raise_if_dask_computes(): + mapped = xr.map_blocks( + lambda a, b: (a + b).sum("x"), da1, args=[da2], template=da1.sum("x") + ) + xr.testing.assert_equal((da1 + da2).sum("x"), mapped) + + # bad template: not chunked + with pytest.raises(ValueError, match="Provided template has no dask arrays"): + xr.map_blocks( + lambda a, b: (a + b).sum("x"), + da1, + args=[da2], + template=da1.sum("x").compute(), + ) + + +@pytest.mark.parametrize("obj", [make_da(), make_ds()]) +def test_map_blocks_add_attrs(obj): + def add_attrs(obj): + obj = obj.copy(deep=True) + obj.attrs["new"] = "new" + obj.cxy.attrs["new2"] = "new2" + return obj + + expected = add_attrs(obj) + with raise_if_dask_computes(): + actual = xr.map_blocks(add_attrs, obj) + + assert_identical(actual, expected) + + # when template is specified, attrs are copied from template, not set by function + with raise_if_dask_computes(): + actual = xr.map_blocks(add_attrs, obj, template=obj) + assert_identical(actual, obj) + + +def test_map_blocks_change_name(map_da): + def change_name(obj): + obj = obj.copy(deep=True) + obj.name = "new" + return obj + + expected = change_name(map_da) + with raise_if_dask_computes(): + actual = xr.map_blocks(change_name, map_da) + + assert_identical(actual, expected) + + +@pytest.mark.parametrize("obj", [make_da(), make_ds()]) +def test_map_blocks_kwargs(obj): + expected = xr.full_like(obj, fill_value=np.nan) + with raise_if_dask_computes(): + actual = xr.map_blocks(xr.full_like, obj, kwargs=dict(fill_value=np.nan)) + assert_chunks_equal(expected.chunk(), actual) + assert_identical(actual, expected) + + +def test_map_blocks_to_dataarray(map_ds): + with raise_if_dask_computes(): + actual = xr.map_blocks(lambda x: x.to_dataarray(), map_ds) + + # to_dataarray does not preserve name, so cannot use assert_identical + assert_equal(actual, map_ds.to_dataarray()) + + +@pytest.mark.parametrize( + "func", + [ + lambda x: x, + lambda x: x.to_dataset(), + lambda x: x.drop_vars("x"), + lambda x: x.expand_dims(k=[1, 2, 3]), + lambda x: x.expand_dims(k=3), + lambda x: x.assign_coords(new_coord=("y", x.y.data * 2)), + lambda x: x.astype(np.int32), + lambda x: x.x, + ], +) +def test_map_blocks_da_transformations(func, map_da): + with raise_if_dask_computes(): + actual = xr.map_blocks(func, map_da) + + assert_identical(actual, func(map_da)) + + +@pytest.mark.parametrize( + "func", + [ + lambda x: x, + lambda x: x.drop_vars("cxy"), + lambda x: x.drop_vars("a"), + lambda x: x.drop_vars("x"), + lambda x: x.expand_dims(k=[1, 2, 3]), + lambda x: x.expand_dims(k=3), + lambda x: x.rename({"a": "new1", "b": "new2"}), + lambda x: x.x, + ], +) +def test_map_blocks_ds_transformations(func, map_ds): + with raise_if_dask_computes(): + actual = xr.map_blocks(func, map_ds) + + assert_identical(actual, func(map_ds)) + + +@pytest.mark.parametrize("obj", [make_da(), make_ds()]) +def test_map_blocks_da_ds_with_template(obj): + func = lambda x: x.isel(x=[1]) + template = obj.isel(x=[1, 5, 9]) + with raise_if_dask_computes(): + actual = xr.map_blocks(func, obj, template=template) + assert_identical(actual, template) + + with raise_if_dask_computes(): + actual = obj.map_blocks(func, template=template) + assert_identical(actual, template) + + +def test_map_blocks_roundtrip_string_index(): + ds = xr.Dataset( + {"data": (["label"], [1, 2, 3])}, coords={"label": ["foo", "bar", "baz"]} + ).chunk(label=1) + assert ds.label.dtype == np.dtype(" None: + v = Variable(["time", "x"], [[1, 2, 3], [4, 5, 6]], {"foo": "bar"}) + v = v.astype(np.uint64) + coords = {"x": np.arange(3, dtype=np.uint64), "other": np.uint64(0)} + data_array = DataArray(v, coords, name="my_variable") + expected = dedent( + """\ + Size: 48B + array([[1, 2, 3], + [4, 5, 6]], dtype=uint64) + Coordinates: + * x (x) uint64 24B 0 1 2 + other uint64 8B 0 + Dimensions without coordinates: time + Attributes: + foo: bar""" + ) + assert expected == repr(data_array) + + def test_repr_multiindex(self) -> None: + expected = dedent( + """\ + Size: 32B + array([0, 1, 2, 3], dtype=uint64) + Coordinates: + * x (x) object 32B MultiIndex + * level_1 (x) object 32B 'a' 'a' 'b' 'b' + * level_2 (x) int64 32B 1 2 1 2""" + ) + assert expected == repr(self.mda) + + def test_repr_multiindex_long(self) -> None: + mindex_long = pd.MultiIndex.from_product( + [["a", "b", "c", "d"], [1, 2, 3, 4, 5, 6, 7, 8]], + names=("level_1", "level_2"), + ) + mda_long = DataArray( + list(range(32)), coords={"x": mindex_long}, dims="x" + ).astype(np.uint64) + expected = dedent( + """\ + Size: 256B + array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], + dtype=uint64) + Coordinates: + * x (x) object 256B MultiIndex + * level_1 (x) object 256B 'a' 'a' 'a' 'a' 'a' 'a' ... 'd' 'd' 'd' 'd' 'd' 'd' + * level_2 (x) int64 256B 1 2 3 4 5 6 7 8 1 2 3 4 ... 5 6 7 8 1 2 3 4 5 6 7 8""" + ) + assert expected == repr(mda_long) + + def test_properties(self) -> None: + assert_equal(self.dv.variable, self.v) + assert_array_equal(self.dv.values, self.v.values) + for attr in ["dims", "dtype", "shape", "size", "nbytes", "ndim", "attrs"]: + assert getattr(self.dv, attr) == getattr(self.v, attr) + assert len(self.dv) == len(self.v) + assert_equal(self.dv.variable, self.v) + assert set(self.dv.coords) == set(self.ds.coords) + for k, v in self.dv.coords.items(): + assert_array_equal(v, self.ds.coords[k]) + with pytest.raises(AttributeError): + self.dv.dataset + assert isinstance(self.ds["x"].to_index(), pd.Index) + with pytest.raises(ValueError, match=r"must be 1-dimensional"): + self.ds["foo"].to_index() + with pytest.raises(AttributeError): + self.dv.variable = self.v + + def test_data_property(self) -> None: + array = DataArray(np.zeros((3, 4))) + actual = array.copy() + actual.values = np.ones((3, 4)) + assert_array_equal(np.ones((3, 4)), actual.values) + actual.data = 2 * np.ones((3, 4)) + assert_array_equal(2 * np.ones((3, 4)), actual.data) + assert_array_equal(actual.data, actual.values) + + def test_indexes(self) -> None: + array = DataArray(np.zeros((2, 3)), [("x", [0, 1]), ("y", ["a", "b", "c"])]) + expected_indexes = {"x": pd.Index([0, 1]), "y": pd.Index(["a", "b", "c"])} + expected_xindexes = { + k: PandasIndex(idx, k) for k, idx in expected_indexes.items() + } + assert array.xindexes.keys() == expected_xindexes.keys() + assert array.indexes.keys() == expected_indexes.keys() + assert all([isinstance(idx, pd.Index) for idx in array.indexes.values()]) + assert all([isinstance(idx, Index) for idx in array.xindexes.values()]) + for k in expected_indexes: + assert array.xindexes[k].equals(expected_xindexes[k]) + assert array.indexes[k].equals(expected_indexes[k]) + + def test_get_index(self) -> None: + array = DataArray(np.zeros((2, 3)), coords={"x": ["a", "b"]}, dims=["x", "y"]) + assert array.get_index("x").equals(pd.Index(["a", "b"])) + assert array.get_index("y").equals(pd.Index([0, 1, 2])) + with pytest.raises(KeyError): + array.get_index("z") + + def test_get_index_size_zero(self) -> None: + array = DataArray(np.zeros((0,)), dims=["x"]) + actual = array.get_index("x") + expected = pd.Index([], dtype=np.int64) + assert actual.equals(expected) + assert actual.dtype == expected.dtype + + def test_struct_array_dims(self) -> None: + """ + This test checks subtraction of two DataArrays for the case + when dimension is a structured array. + """ + # GH837, GH861 + # checking array subtraction when dims are the same + p_data = np.array( + [("Abe", 180), ("Stacy", 150), ("Dick", 200)], + dtype=[("name", "|S256"), ("height", object)], + ) + weights_0 = DataArray( + [80, 56, 120], dims=["participant"], coords={"participant": p_data} + ) + weights_1 = DataArray( + [81, 52, 115], dims=["participant"], coords={"participant": p_data} + ) + actual = weights_1 - weights_0 + + expected = DataArray( + [1, -4, -5], dims=["participant"], coords={"participant": p_data} + ) + + assert_identical(actual, expected) + + # checking array subtraction when dims are not the same + p_data_alt = np.array( + [("Abe", 180), ("Stacy", 151), ("Dick", 200)], + dtype=[("name", "|S256"), ("height", object)], + ) + weights_1 = DataArray( + [81, 52, 115], dims=["participant"], coords={"participant": p_data_alt} + ) + actual = weights_1 - weights_0 + + expected = DataArray( + [1, -5], dims=["participant"], coords={"participant": p_data[[0, 2]]} + ) + + assert_identical(actual, expected) + + # checking array subtraction when dims are not the same and one + # is np.nan + p_data_nan = np.array( + [("Abe", 180), ("Stacy", np.nan), ("Dick", 200)], + dtype=[("name", "|S256"), ("height", object)], + ) + weights_1 = DataArray( + [81, 52, 115], dims=["participant"], coords={"participant": p_data_nan} + ) + actual = weights_1 - weights_0 + + expected = DataArray( + [1, -5], dims=["participant"], coords={"participant": p_data[[0, 2]]} + ) + + assert_identical(actual, expected) + + def test_name(self) -> None: + arr = self.dv + assert arr.name == "foo" + + copied = arr.copy() + arr.name = "bar" + assert arr.name == "bar" + assert_equal(copied, arr) + + actual = DataArray(IndexVariable("x", [3])) + actual.name = "y" + expected = DataArray([3], [("x", [3])], name="y") + assert_identical(actual, expected) + + def test_dims(self) -> None: + arr = self.dv + assert arr.dims == ("x", "y") + + with pytest.raises(AttributeError, match=r"you cannot assign"): + arr.dims = ("w", "z") + + def test_sizes(self) -> None: + array = DataArray(np.zeros((3, 4)), dims=["x", "y"]) + assert array.sizes == {"x": 3, "y": 4} + assert tuple(array.sizes) == array.dims + with pytest.raises(TypeError): + array.sizes["foo"] = 5 # type: ignore + + def test_encoding(self) -> None: + expected = {"foo": "bar"} + self.dv.encoding["foo"] = "bar" + assert expected == self.dv.encoding + + expected2 = {"baz": 0} + self.dv.encoding = expected2 + assert expected2 is not self.dv.encoding + + def test_drop_encoding(self) -> None: + array = self.mda + encoding = {"scale_factor": 10} + array.encoding = encoding + array["x"].encoding = encoding + + assert array.encoding == encoding + assert array["x"].encoding == encoding + + actual = array.drop_encoding() + + # did not modify in place + assert array.encoding == encoding + assert array["x"].encoding == encoding + + # variable and coord encoding is empty + assert actual.encoding == {} + assert actual["x"].encoding == {} + + def test_constructor(self) -> None: + data = np.random.random((2, 3)) + + # w/o coords, w/o dims + actual = DataArray(data) + expected = Dataset({None: (["dim_0", "dim_1"], data)})[None] + assert_identical(expected, actual) + + actual = DataArray(data, [["a", "b"], [-1, -2, -3]]) + expected = Dataset( + { + None: (["dim_0", "dim_1"], data), + "dim_0": ("dim_0", ["a", "b"]), + "dim_1": ("dim_1", [-1, -2, -3]), + } + )[None] + assert_identical(expected, actual) + + # pd.Index coords, w/o dims + actual = DataArray( + data, [pd.Index(["a", "b"], name="x"), pd.Index([-1, -2, -3], name="y")] + ) + expected = Dataset( + {None: (["x", "y"], data), "x": ("x", ["a", "b"]), "y": ("y", [-1, -2, -3])} + )[None] + assert_identical(expected, actual) + + # list coords, w dims + coords1 = [["a", "b"], [-1, -2, -3]] + actual = DataArray(data, coords1, ["x", "y"]) + assert_identical(expected, actual) + + # pd.Index coords, w dims + coords2 = [pd.Index(["a", "b"], name="A"), pd.Index([-1, -2, -3], name="B")] + actual = DataArray(data, coords2, ["x", "y"]) + assert_identical(expected, actual) + + # dict coords, w dims + coords3 = {"x": ["a", "b"], "y": [-1, -2, -3]} + actual = DataArray(data, coords3, ["x", "y"]) + assert_identical(expected, actual) + + # dict coords, w/o dims + actual = DataArray(data, coords3) + assert_identical(expected, actual) + + # tuple[dim, list] coords, w/o dims + coords4 = [("x", ["a", "b"]), ("y", [-1, -2, -3])] + actual = DataArray(data, coords4) + assert_identical(expected, actual) + + # partial dict coords, w dims + expected = Dataset({None: (["x", "y"], data), "x": ("x", ["a", "b"])})[None] + actual = DataArray(data, {"x": ["a", "b"]}, ["x", "y"]) + assert_identical(expected, actual) + + # w/o coords, w dims + actual = DataArray(data, dims=["x", "y"]) + expected = Dataset({None: (["x", "y"], data)})[None] + assert_identical(expected, actual) + + # w/o coords, w dims, w name + actual = DataArray(data, dims=["x", "y"], name="foo") + expected = Dataset({"foo": (["x", "y"], data)})["foo"] + assert_identical(expected, actual) + + # w/o coords, w/o dims, w name + actual = DataArray(data, name="foo") + expected = Dataset({"foo": (["dim_0", "dim_1"], data)})["foo"] + assert_identical(expected, actual) + + # w/o coords, w dims, w attrs + actual = DataArray(data, dims=["x", "y"], attrs={"bar": 2}) + expected = Dataset({None: (["x", "y"], data, {"bar": 2})})[None] + assert_identical(expected, actual) + + # w/o coords, w dims (ds has attrs) + actual = DataArray(data, dims=["x", "y"]) + expected = Dataset({None: (["x", "y"], data, {}, {"bar": 2})})[None] + assert_identical(expected, actual) + + # data is list, w coords + actual = DataArray([1, 2, 3], coords={"x": [0, 1, 2]}) + expected = DataArray([1, 2, 3], coords=[("x", [0, 1, 2])]) + assert_identical(expected, actual) + + def test_constructor_invalid(self) -> None: + data = np.random.randn(3, 2) + + with pytest.raises(ValueError, match=r"coords is not dict-like"): + DataArray(data, [[0, 1, 2]], ["x", "y"]) + + with pytest.raises(ValueError, match=r"not a subset of the .* dim"): + DataArray(data, {"x": [0, 1, 2]}, ["a", "b"]) + with pytest.raises(ValueError, match=r"not a subset of the .* dim"): + DataArray(data, {"x": [0, 1, 2]}) + + with pytest.raises(TypeError, match=r"is not hashable"): + DataArray(data, dims=["x", []]) # type: ignore[list-item] + + with pytest.raises(ValueError, match=r"conflicting sizes for dim"): + DataArray([1, 2, 3], coords=[("x", [0, 1])]) + with pytest.raises(ValueError, match=r"conflicting sizes for dim"): + DataArray([1, 2], coords={"x": [0, 1], "y": ("x", [1])}, dims="x") + + with pytest.raises(ValueError, match=r"conflicting MultiIndex"): + DataArray(np.random.rand(4, 4), [("x", self.mindex), ("y", self.mindex)]) + with pytest.raises(ValueError, match=r"conflicting MultiIndex"): + DataArray(np.random.rand(4, 4), [("x", self.mindex), ("level_1", range(4))]) + + def test_constructor_from_self_described(self) -> None: + data = [[-0.1, 21], [0, 2]] + expected = DataArray( + data, + coords={"x": ["a", "b"], "y": [-1, -2]}, + dims=["x", "y"], + name="foobar", + attrs={"bar": 2}, + ) + actual = DataArray(expected) + assert_identical(expected, actual) + + actual = DataArray(expected.values, actual.coords) + assert_equal(expected, actual) + + frame = pd.DataFrame( + data, + index=pd.Index(["a", "b"], name="x"), + columns=pd.Index([-1, -2], name="y"), + ) + actual = DataArray(frame) + assert_equal(expected, actual) + + series = pd.Series(data[0], index=pd.Index([-1, -2], name="y")) + actual = DataArray(series) + assert_equal(expected[0].reset_coords("x", drop=True), actual) + + expected = DataArray( + data, + coords={"x": ["a", "b"], "y": [-1, -2], "a": 0, "z": ("x", [-0.5, 0.5])}, + dims=["x", "y"], + ) + actual = DataArray(expected) + assert_identical(expected, actual) + + actual = DataArray(expected.values, expected.coords) + assert_identical(expected, actual) + + expected = Dataset({"foo": ("foo", ["a", "b"])})["foo"] + actual = DataArray(pd.Index(["a", "b"], name="foo")) + assert_identical(expected, actual) + + actual = DataArray(IndexVariable("foo", ["a", "b"])) + assert_identical(expected, actual) + + @requires_dask + def test_constructor_from_self_described_chunked(self) -> None: + expected = DataArray( + [[-0.1, 21], [0, 2]], + coords={"x": ["a", "b"], "y": [-1, -2]}, + dims=["x", "y"], + name="foobar", + attrs={"bar": 2}, + ).chunk() + actual = DataArray(expected) + assert_identical(expected, actual) + assert_chunks_equal(expected, actual) + + def test_constructor_from_0d(self) -> None: + expected = Dataset({None: ([], 0)})[None] + actual = DataArray(0) + assert_identical(expected, actual) + + @requires_dask + def test_constructor_dask_coords(self) -> None: + # regression test for GH1684 + import dask.array as da + + coord = da.arange(8, chunks=(4,)) + data = da.random.random((8, 8), chunks=(4, 4)) + 1 + actual = DataArray(data, coords={"x": coord, "y": coord}, dims=["x", "y"]) + + ecoord = np.arange(8) + expected = DataArray(data, coords={"x": ecoord, "y": ecoord}, dims=["x", "y"]) + assert_equal(actual, expected) + + def test_constructor_no_default_index(self) -> None: + # explicitly passing a Coordinates object skips the creation of default index + da = DataArray(range(3), coords=Coordinates({"x": [1, 2, 3]}, indexes={})) + assert "x" in da.coords + assert "x" not in da.xindexes + + def test_constructor_multiindex(self) -> None: + midx = pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")) + coords = Coordinates.from_pandas_multiindex(midx, "x") + + da = DataArray(range(4), coords=coords, dims="x") + assert_identical(da.coords, coords) + + def test_constructor_custom_index(self) -> None: + class CustomIndex(Index): ... + + coords = Coordinates( + coords={"x": ("x", [1, 2, 3])}, indexes={"x": CustomIndex()} + ) + da = DataArray(range(3), coords=coords) + assert isinstance(da.xindexes["x"], CustomIndex) + + # test coordinate variables copied + assert da.coords["x"] is not coords.variables["x"] + + def test_equals_and_identical(self) -> None: + orig = DataArray(np.arange(5.0), {"a": 42}, dims="x") + + expected = orig + actual = orig.copy() + assert expected.equals(actual) + assert expected.identical(actual) + + actual = expected.rename("baz") + assert expected.equals(actual) + assert not expected.identical(actual) + + actual = expected.rename({"x": "xxx"}) + assert not expected.equals(actual) + assert not expected.identical(actual) + + actual = expected.copy() + actual.attrs["foo"] = "bar" + assert expected.equals(actual) + assert not expected.identical(actual) + + actual = expected.copy() + actual["x"] = ("x", -np.arange(5)) + assert not expected.equals(actual) + assert not expected.identical(actual) + + actual = expected.reset_coords(drop=True) + assert not expected.equals(actual) + assert not expected.identical(actual) + + actual = orig.copy() + actual[0] = np.nan + expected = actual.copy() + assert expected.equals(actual) + assert expected.identical(actual) + + actual[:] = np.nan + assert not expected.equals(actual) + assert not expected.identical(actual) + + actual = expected.copy() + actual["a"] = 100000 + assert not expected.equals(actual) + assert not expected.identical(actual) + + def test_equals_failures(self) -> None: + orig = DataArray(np.arange(5.0), {"a": 42}, dims="x") + assert not orig.equals(np.arange(5)) # type: ignore[arg-type] + assert not orig.identical(123) # type: ignore + assert not orig.broadcast_equals({1: 2}) # type: ignore + + def test_broadcast_equals(self) -> None: + a = DataArray([0, 0], {"y": 0}, dims="x") + b = DataArray([0, 0], {"y": ("x", [0, 0])}, dims="x") + assert a.broadcast_equals(b) + assert b.broadcast_equals(a) + assert not a.equals(b) + assert not a.identical(b) + + c = DataArray([0], coords={"x": 0}, dims="y") + assert not a.broadcast_equals(c) + assert not c.broadcast_equals(a) + + def test_getitem(self) -> None: + # strings pull out dataarrays + assert_identical(self.dv, self.ds["foo"]) + x = self.dv["x"] + y = self.dv["y"] + assert_identical(self.ds["x"], x) + assert_identical(self.ds["y"], y) + + arr = ReturnItem() + for i in [ + arr[:], + arr[...], + arr[x.values], + arr[x.variable], + arr[x], + arr[x, y], + arr[x.values > -1], + arr[x.variable > -1], + arr[x > -1], + arr[x > -1, y > -1], + ]: + assert_equal(self.dv, self.dv[i]) + for i in [ + arr[0], + arr[:, 0], + arr[:3, :2], + arr[x.values[:3]], + arr[x.variable[:3]], + arr[x[:3]], + arr[x[:3], y[:4]], + arr[x.values > 3], + arr[x.variable > 3], + arr[x > 3], + arr[x > 3, y > 3], + ]: + assert_array_equal(self.v[i], self.dv[i]) + + def test_getitem_dict(self) -> None: + actual = self.dv[{"x": slice(3), "y": 0}] + expected = self.dv.isel(x=slice(3), y=0) + assert_identical(expected, actual) + + def test_getitem_coords(self) -> None: + orig = DataArray( + [[10], [20]], + { + "x": [1, 2], + "y": [3], + "z": 4, + "x2": ("x", ["a", "b"]), + "y2": ("y", ["c"]), + "xy": (["y", "x"], [["d", "e"]]), + }, + dims=["x", "y"], + ) + + assert_identical(orig, orig[:]) + assert_identical(orig, orig[:, :]) + assert_identical(orig, orig[...]) + assert_identical(orig, orig[:2, :1]) + assert_identical(orig, orig[[0, 1], [0]]) + + actual = orig[0, 0] + expected = DataArray( + 10, {"x": 1, "y": 3, "z": 4, "x2": "a", "y2": "c", "xy": "d"} + ) + assert_identical(expected, actual) + + actual = orig[0, :] + expected = DataArray( + [10], + { + "x": 1, + "y": [3], + "z": 4, + "x2": "a", + "y2": ("y", ["c"]), + "xy": ("y", ["d"]), + }, + dims="y", + ) + assert_identical(expected, actual) + + actual = orig[:, 0] + expected = DataArray( + [10, 20], + { + "x": [1, 2], + "y": 3, + "z": 4, + "x2": ("x", ["a", "b"]), + "y2": "c", + "xy": ("x", ["d", "e"]), + }, + dims="x", + ) + assert_identical(expected, actual) + + def test_getitem_dataarray(self) -> None: + # It should not conflict + da = DataArray(np.arange(12).reshape((3, 4)), dims=["x", "y"]) + ind = DataArray([[0, 1], [0, 1]], dims=["x", "z"]) + actual = da[ind] + assert_array_equal(actual, da.values[[[0, 1], [0, 1]], :]) + + da = DataArray( + np.arange(12).reshape((3, 4)), + dims=["x", "y"], + coords={"x": [0, 1, 2], "y": ["a", "b", "c", "d"]}, + ) + ind = xr.DataArray([[0, 1], [0, 1]], dims=["X", "Y"]) + actual = da[ind] + expected = da.values[[[0, 1], [0, 1]], :] + assert_array_equal(actual, expected) + assert actual.dims == ("X", "Y", "y") + + # boolean indexing + ind = xr.DataArray([True, True, False], dims=["x"]) + assert_equal(da[ind], da[[0, 1], :]) + assert_equal(da[ind], da[[0, 1]]) + assert_equal(da[ind], da[ind.values]) + + def test_getitem_empty_index(self) -> None: + da = DataArray(np.arange(12).reshape((3, 4)), dims=["x", "y"]) + assert_identical(da[{"x": []}], DataArray(np.zeros((0, 4)), dims=["x", "y"])) + assert_identical( + da.loc[{"y": []}], DataArray(np.zeros((3, 0)), dims=["x", "y"]) + ) + assert_identical(da[[]], DataArray(np.zeros((0, 4)), dims=["x", "y"])) + + def test_setitem(self) -> None: + # basic indexing should work as numpy's indexing + tuples = [ + (0, 0), + (0, slice(None, None)), + (slice(None, None), slice(None, None)), + (slice(None, None), 0), + ([1, 0], slice(None, None)), + (slice(None, None), [1, 0]), + ] + for t in tuples: + expected = np.arange(6).reshape(3, 2) + orig = DataArray( + np.arange(6).reshape(3, 2), + { + "x": [1, 2, 3], + "y": ["a", "b"], + "z": 4, + "x2": ("x", ["a", "b", "c"]), + "y2": ("y", ["d", "e"]), + }, + dims=["x", "y"], + ) + orig[t] = 1 + expected[t] = 1 + assert_array_equal(orig.values, expected) + + def test_setitem_fancy(self) -> None: + # vectorized indexing + da = DataArray(np.ones((3, 2)), dims=["x", "y"]) + ind = Variable(["a"], [0, 1]) + da[dict(x=ind, y=ind)] = 0 + expected = DataArray([[0, 1], [1, 0], [1, 1]], dims=["x", "y"]) + assert_identical(expected, da) + # assign another 0d-variable + da[dict(x=ind, y=ind)] = Variable((), 0) + expected = DataArray([[0, 1], [1, 0], [1, 1]], dims=["x", "y"]) + assert_identical(expected, da) + # assign another 1d-variable + da[dict(x=ind, y=ind)] = Variable(["a"], [2, 3]) + expected = DataArray([[2, 1], [1, 3], [1, 1]], dims=["x", "y"]) + assert_identical(expected, da) + + # 2d-vectorized indexing + da = DataArray(np.ones((3, 2)), dims=["x", "y"]) + ind_x = DataArray([[0, 1]], dims=["a", "b"]) + ind_y = DataArray([[1, 0]], dims=["a", "b"]) + da[dict(x=ind_x, y=ind_y)] = 0 + expected = DataArray([[1, 0], [0, 1], [1, 1]], dims=["x", "y"]) + assert_identical(expected, da) + + da = DataArray(np.ones((3, 2)), dims=["x", "y"]) + ind = Variable(["a"], [0, 1]) + da[ind] = 0 + expected = DataArray([[0, 0], [0, 0], [1, 1]], dims=["x", "y"]) + assert_identical(expected, da) + + def test_setitem_dataarray(self) -> None: + def get_data(): + return DataArray( + np.ones((4, 3, 2)), + dims=["x", "y", "z"], + coords={ + "x": np.arange(4), + "y": ["a", "b", "c"], + "non-dim": ("x", [1, 3, 4, 2]), + }, + ) + + da = get_data() + # indexer with inconsistent coordinates. + ind = DataArray(np.arange(1, 4), dims=["x"], coords={"x": np.random.randn(3)}) + with pytest.raises(IndexError, match=r"dimension coordinate 'x'"): + da[dict(x=ind)] = 0 + + # indexer with consistent coordinates. + ind = DataArray(np.arange(1, 4), dims=["x"], coords={"x": np.arange(1, 4)}) + da[dict(x=ind)] = 0 # should not raise + assert np.allclose(da[dict(x=ind)].values, 0) + assert_identical(da["x"], get_data()["x"]) + assert_identical(da["non-dim"], get_data()["non-dim"]) + + da = get_data() + # conflict in the assigning values + value = xr.DataArray( + np.zeros((3, 3, 2)), + dims=["x", "y", "z"], + coords={"x": [0, 1, 2], "non-dim": ("x", [0, 2, 4])}, + ) + with pytest.raises(IndexError, match=r"dimension coordinate 'x'"): + da[dict(x=ind)] = value + + # consistent coordinate in the assigning values + value = xr.DataArray( + np.zeros((3, 3, 2)), + dims=["x", "y", "z"], + coords={"x": [1, 2, 3], "non-dim": ("x", [0, 2, 4])}, + ) + da[dict(x=ind)] = value + assert np.allclose(da[dict(x=ind)].values, 0) + assert_identical(da["x"], get_data()["x"]) + assert_identical(da["non-dim"], get_data()["non-dim"]) + + # Conflict in the non-dimension coordinate + value = xr.DataArray( + np.zeros((3, 3, 2)), + dims=["x", "y", "z"], + coords={"x": [1, 2, 3], "non-dim": ("x", [0, 2, 4])}, + ) + da[dict(x=ind)] = value # should not raise + + # conflict in the assigning values + value = xr.DataArray( + np.zeros((3, 3, 2)), + dims=["x", "y", "z"], + coords={"x": [0, 1, 2], "non-dim": ("x", [0, 2, 4])}, + ) + with pytest.raises(IndexError, match=r"dimension coordinate 'x'"): + da[dict(x=ind)] = value + + # consistent coordinate in the assigning values + value = xr.DataArray( + np.zeros((3, 3, 2)), + dims=["x", "y", "z"], + coords={"x": [1, 2, 3], "non-dim": ("x", [0, 2, 4])}, + ) + da[dict(x=ind)] = value # should not raise + + def test_setitem_vectorized(self) -> None: + # Regression test for GH:7030 + # Positional indexing + v = xr.DataArray(np.r_[:120].reshape(2, 3, 4, 5), dims=["a", "b", "c", "d"]) + b = xr.DataArray([[0, 0], [1, 0]], dims=["u", "v"]) + c = xr.DataArray([[0, 1], [2, 3]], dims=["u", "v"]) + w = xr.DataArray([-1, -2], dims=["u"]) + index = dict(b=b, c=c) + v[index] = w + assert (v[index] == w).all() + + # Indexing with coordinates + v = xr.DataArray(np.r_[:120].reshape(2, 3, 4, 5), dims=["a", "b", "c", "d"]) + v.coords["b"] = [2, 4, 6] + b = xr.DataArray([[2, 2], [4, 2]], dims=["u", "v"]) + c = xr.DataArray([[0, 1], [2, 3]], dims=["u", "v"]) + w = xr.DataArray([-1, -2], dims=["u"]) + index = dict(b=b, c=c) + v.loc[index] = w + assert (v.loc[index] == w).all() + + def test_contains(self) -> None: + data_array = DataArray([1, 2]) + assert 1 in data_array + assert 3 not in data_array + + def test_pickle(self) -> None: + data = DataArray(np.random.random((3, 3)), dims=("id", "time")) + roundtripped = pickle.loads(pickle.dumps(data)) + assert_identical(data, roundtripped) + + @requires_dask + def test_chunk(self) -> None: + unblocked = DataArray(np.ones((3, 4))) + assert unblocked.chunks is None + + blocked = unblocked.chunk() + assert blocked.chunks == ((3,), (4,)) + first_dask_name = blocked.data.name + + with pytest.warns(DeprecationWarning): + blocked = unblocked.chunk(chunks=((2, 1), (2, 2))) # type: ignore + assert blocked.chunks == ((2, 1), (2, 2)) + assert blocked.data.name != first_dask_name + + blocked = unblocked.chunk(chunks=(3, 3)) + assert blocked.chunks == ((3,), (3, 1)) + assert blocked.data.name != first_dask_name + + # name doesn't change when rechunking by same amount + # this fails if ReprObject doesn't have __dask_tokenize__ defined + assert unblocked.chunk(2).data.name == unblocked.chunk(2).data.name + + assert blocked.load().chunks is None + + # Check that kwargs are passed + import dask.array as da + + blocked = unblocked.chunk(name_prefix="testname_") + assert isinstance(blocked.data, da.Array) + assert "testname_" in blocked.data.name + + # test kwargs form of chunks + blocked = unblocked.chunk(dim_0=3, dim_1=3) + assert blocked.chunks == ((3,), (3, 1)) + assert blocked.data.name != first_dask_name + + def test_isel(self) -> None: + assert_identical(self.dv[0], self.dv.isel(x=0)) + assert_identical(self.dv, self.dv.isel(x=slice(None))) + assert_identical(self.dv[:3], self.dv.isel(x=slice(3))) + assert_identical(self.dv[:3, :5], self.dv.isel(x=slice(3), y=slice(5))) + with pytest.raises( + ValueError, + match=r"Dimensions {'not_a_dim'} do not exist. Expected " + r"one or more of \('x', 'y'\)", + ): + self.dv.isel(not_a_dim=0) + with pytest.warns( + UserWarning, + match=r"Dimensions {'not_a_dim'} do not exist. " + r"Expected one or more of \('x', 'y'\)", + ): + self.dv.isel(not_a_dim=0, missing_dims="warn") + assert_identical(self.dv, self.dv.isel(not_a_dim=0, missing_dims="ignore")) + + def test_isel_types(self) -> None: + # regression test for #1405 + da = DataArray([1, 2, 3], dims="x") + # uint64 + assert_identical( + da.isel(x=np.array([0], dtype="uint64")), da.isel(x=np.array([0])) + ) + # uint32 + assert_identical( + da.isel(x=np.array([0], dtype="uint32")), da.isel(x=np.array([0])) + ) + # int64 + assert_identical( + da.isel(x=np.array([0], dtype="int64")), da.isel(x=np.array([0])) + ) + + @pytest.mark.filterwarnings("ignore::DeprecationWarning") + def test_isel_fancy(self) -> None: + shape = (10, 7, 6) + np_array = np.random.random(shape) + da = DataArray( + np_array, dims=["time", "y", "x"], coords={"time": np.arange(0, 100, 10)} + ) + y = [1, 3] + x = [3, 0] + + expected = da.values[:, y, x] + + actual = da.isel(y=(("test_coord",), y), x=(("test_coord",), x)) + assert actual.coords["test_coord"].shape == (len(y),) + assert list(actual.coords) == ["time"] + assert actual.dims == ("time", "test_coord") + + np.testing.assert_equal(actual, expected) + + # a few corner cases + da.isel( + time=(("points",), [1, 2]), x=(("points",), [2, 2]), y=(("points",), [3, 4]) + ) + np.testing.assert_allclose( + da.isel( + time=(("p",), [1]), x=(("p",), [2]), y=(("p",), [4]) + ).values.squeeze(), + np_array[1, 4, 2].squeeze(), + ) + da.isel(time=(("points",), [1, 2])) + y = [-1, 0] + x = [-2, 2] + expected2 = da.values[:, y, x] + actual2 = da.isel(x=(("points",), x), y=(("points",), y)).values + np.testing.assert_equal(actual2, expected2) + + # test that the order of the indexers doesn't matter + assert_identical( + da.isel(y=(("points",), y), x=(("points",), x)), + da.isel(x=(("points",), x), y=(("points",), y)), + ) + + # make sure we're raising errors in the right places + with pytest.raises(IndexError, match=r"Dimensions of indexers mismatch"): + da.isel(y=(("points",), [1, 2]), x=(("points",), [1, 2, 3])) + + # tests using index or DataArray as indexers + stations = Dataset() + stations["station"] = (("station",), ["A", "B", "C"]) + stations["dim1s"] = (("station",), [1, 2, 3]) + stations["dim2s"] = (("station",), [4, 5, 1]) + + actual3 = da.isel(x=stations["dim1s"], y=stations["dim2s"]) + assert "station" in actual3.coords + assert "station" in actual3.dims + assert_identical(actual3["station"], stations["station"]) + + with pytest.raises(ValueError, match=r"conflicting values/indexes on "): + da.isel( + x=DataArray([0, 1, 2], dims="station", coords={"station": [0, 1, 2]}), + y=DataArray([0, 1, 2], dims="station", coords={"station": [0, 1, 3]}), + ) + + # multi-dimensional selection + stations = Dataset() + stations["a"] = (("a",), ["A", "B", "C"]) + stations["b"] = (("b",), [0, 1]) + stations["dim1s"] = (("a", "b"), [[1, 2], [2, 3], [3, 4]]) + stations["dim2s"] = (("a",), [4, 5, 1]) + + actual4 = da.isel(x=stations["dim1s"], y=stations["dim2s"]) + assert "a" in actual4.coords + assert "a" in actual4.dims + assert "b" in actual4.coords + assert "b" in actual4.dims + assert_identical(actual4["a"], stations["a"]) + assert_identical(actual4["b"], stations["b"]) + expected4 = da.variable[ + :, stations["dim2s"].variable, stations["dim1s"].variable + ] + assert_array_equal(actual4, expected4) + + def test_sel(self) -> None: + self.ds["x"] = ("x", np.array(list("abcdefghij"))) + da = self.ds["foo"] + assert_identical(da, da.sel(x=slice(None))) + assert_identical(da[1], da.sel(x="b")) + assert_identical(da[:3], da.sel(x=slice("c"))) + assert_identical(da[:3], da.sel(x=["a", "b", "c"])) + assert_identical(da[:, :4], da.sel(y=(self.ds["y"] < 4))) + # verify that indexing with a dataarray works + b = DataArray("b") + assert_identical(da[1], da.sel(x=b)) + assert_identical(da[[1]], da.sel(x=slice(b, b))) + + def test_sel_dataarray(self) -> None: + # indexing with DataArray + self.ds["x"] = ("x", np.array(list("abcdefghij"))) + da = self.ds["foo"] + + ind = DataArray(["a", "b", "c"], dims=["x"]) + actual = da.sel(x=ind) + assert_identical(actual, da.isel(x=[0, 1, 2])) + + # along new dimension + ind = DataArray(["a", "b", "c"], dims=["new_dim"]) + actual = da.sel(x=ind) + assert_array_equal(actual, da.isel(x=[0, 1, 2])) + assert "new_dim" in actual.dims + + # with coordinate + ind = DataArray( + ["a", "b", "c"], dims=["new_dim"], coords={"new_dim": [0, 1, 2]} + ) + actual = da.sel(x=ind) + assert_array_equal(actual, da.isel(x=[0, 1, 2])) + assert "new_dim" in actual.dims + assert "new_dim" in actual.coords + assert_equal(actual["new_dim"].drop_vars("x"), ind["new_dim"]) + + def test_sel_invalid_slice(self) -> None: + array = DataArray(np.arange(10), [("x", np.arange(10))]) + with pytest.raises(ValueError, match=r"cannot use non-scalar arrays"): + array.sel(x=slice(array.x)) + + def test_sel_dataarray_datetime_slice(self) -> None: + # regression test for GH1240 + times = pd.date_range("2000-01-01", freq="D", periods=365) + array = DataArray(np.arange(365), [("time", times)]) + result = array.sel(time=slice(array.time[0], array.time[-1])) + assert_equal(result, array) + + array = DataArray(np.arange(365), [("delta", times - times[0])]) + result = array.sel(delta=slice(array.delta[0], array.delta[-1])) + assert_equal(result, array) + + @pytest.mark.parametrize( + ["coord_values", "indices"], + ( + pytest.param( + np.array([0.0, 0.111, 0.222, 0.333], dtype="float64"), + slice(1, 3), + id="float64", + ), + pytest.param( + np.array([0.0, 0.111, 0.222, 0.333], dtype="float32"), + slice(1, 3), + id="float32", + ), + pytest.param( + np.array([0.0, 0.111, 0.222, 0.333], dtype="float32"), [2], id="scalar" + ), + ), + ) + def test_sel_float(self, coord_values, indices) -> None: + data_values = np.arange(4) + + arr = DataArray(data_values, coords={"x": coord_values}, dims="x") + + actual = arr.sel(x=coord_values[indices]) + expected = DataArray( + data_values[indices], coords={"x": coord_values[indices]}, dims="x" + ) + + assert_equal(actual, expected) + + def test_sel_float16(self) -> None: + data_values = np.arange(4) + coord_values = np.array([0.0, 0.111, 0.222, 0.333], dtype="float16") + indices = slice(1, 3) + + message = "`pandas.Index` does not support the `float16` dtype.*" + + with pytest.warns(DeprecationWarning, match=message): + arr = DataArray(data_values, coords={"x": coord_values}, dims="x") + with pytest.warns(DeprecationWarning, match=message): + expected = DataArray( + data_values[indices], coords={"x": coord_values[indices]}, dims="x" + ) + + actual = arr.sel(x=coord_values[indices]) + + assert_equal(actual, expected) + + def test_sel_float_multiindex(self) -> None: + # regression test https://github.com/pydata/xarray/issues/5691 + # test multi-index created from coordinates, one with dtype=float32 + lvl1 = ["a", "a", "b", "b"] + lvl2 = np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32) + da = xr.DataArray( + [1, 2, 3, 4], dims="x", coords={"lvl1": ("x", lvl1), "lvl2": ("x", lvl2)} + ) + da = da.set_index(x=["lvl1", "lvl2"]) + + actual = da.sel(lvl1="a", lvl2=0.1) + expected = da.isel(x=0) + + assert_equal(actual, expected) + + def test_sel_no_index(self) -> None: + array = DataArray(np.arange(10), dims="x") + assert_identical(array[0], array.sel(x=0)) + assert_identical(array[:5], array.sel(x=slice(5))) + assert_identical(array[[0, -1]], array.sel(x=[0, -1])) + assert_identical(array[array < 5], array.sel(x=(array < 5))) + + def test_sel_method(self) -> None: + data = DataArray(np.random.randn(3, 4), [("x", [0, 1, 2]), ("y", list("abcd"))]) + + with pytest.raises(KeyError, match="Try setting the `method`"): + data.sel(y="ab") + + expected = data.sel(y=["a", "b"]) + actual = data.sel(y=["ab", "ba"], method="pad") + assert_identical(expected, actual) + + expected = data.sel(x=[1, 2]) + actual = data.sel(x=[0.9, 1.9], method="backfill", tolerance=1) + assert_identical(expected, actual) + + def test_sel_drop(self) -> None: + data = DataArray([1, 2, 3], [("x", [0, 1, 2])]) + expected = DataArray(1) + selected = data.sel(x=0, drop=True) + assert_identical(expected, selected) + + expected = DataArray(1, {"x": 0}) + selected = data.sel(x=0, drop=False) + assert_identical(expected, selected) + + data = DataArray([1, 2, 3], dims=["x"]) + expected = DataArray(1) + selected = data.sel(x=0, drop=True) + assert_identical(expected, selected) + + def test_isel_drop(self) -> None: + data = DataArray([1, 2, 3], [("x", [0, 1, 2])]) + expected = DataArray(1) + selected = data.isel(x=0, drop=True) + assert_identical(expected, selected) + + expected = DataArray(1, {"x": 0}) + selected = data.isel(x=0, drop=False) + assert_identical(expected, selected) + + def test_head(self) -> None: + assert_equal(self.dv.isel(x=slice(5)), self.dv.head(x=5)) + assert_equal(self.dv.isel(x=slice(0)), self.dv.head(x=0)) + assert_equal( + self.dv.isel({dim: slice(6) for dim in self.dv.dims}), self.dv.head(6) + ) + assert_equal( + self.dv.isel({dim: slice(5) for dim in self.dv.dims}), self.dv.head() + ) + with pytest.raises(TypeError, match=r"either dict-like or a single int"): + self.dv.head([3]) + with pytest.raises(TypeError, match=r"expected integer type"): + self.dv.head(x=3.1) + with pytest.raises(ValueError, match=r"expected positive int"): + self.dv.head(-3) + + def test_tail(self) -> None: + assert_equal(self.dv.isel(x=slice(-5, None)), self.dv.tail(x=5)) + assert_equal(self.dv.isel(x=slice(0)), self.dv.tail(x=0)) + assert_equal( + self.dv.isel({dim: slice(-6, None) for dim in self.dv.dims}), + self.dv.tail(6), + ) + assert_equal( + self.dv.isel({dim: slice(-5, None) for dim in self.dv.dims}), self.dv.tail() + ) + with pytest.raises(TypeError, match=r"either dict-like or a single int"): + self.dv.tail([3]) + with pytest.raises(TypeError, match=r"expected integer type"): + self.dv.tail(x=3.1) + with pytest.raises(ValueError, match=r"expected positive int"): + self.dv.tail(-3) + + def test_thin(self) -> None: + assert_equal(self.dv.isel(x=slice(None, None, 5)), self.dv.thin(x=5)) + assert_equal( + self.dv.isel({dim: slice(None, None, 6) for dim in self.dv.dims}), + self.dv.thin(6), + ) + with pytest.raises(TypeError, match=r"either dict-like or a single int"): + self.dv.thin([3]) + with pytest.raises(TypeError, match=r"expected integer type"): + self.dv.thin(x=3.1) + with pytest.raises(ValueError, match=r"expected positive int"): + self.dv.thin(-3) + with pytest.raises(ValueError, match=r"cannot be zero"): + self.dv.thin(time=0) + + def test_loc(self) -> None: + self.ds["x"] = ("x", np.array(list("abcdefghij"))) + da = self.ds["foo"] + # typing issue: see https://github.com/python/mypy/issues/2410 + assert_identical(da[:3], da.loc[:"c"]) # type: ignore[misc] + assert_identical(da[1], da.loc["b"]) + assert_identical(da[1], da.loc[{"x": "b"}]) + assert_identical(da[1], da.loc["b", ...]) + assert_identical(da[:3], da.loc[["a", "b", "c"]]) + assert_identical(da[:3, :4], da.loc[["a", "b", "c"], np.arange(4)]) + assert_identical(da[:, :4], da.loc[:, self.ds["y"] < 4]) + + def test_loc_datetime64_value(self) -> None: + # regression test for https://github.com/pydata/xarray/issues/4283 + t = np.array(["2017-09-05T12", "2017-09-05T15"], dtype="datetime64[ns]") + array = DataArray(np.ones(t.shape), dims=("time",), coords=(t,)) + assert_identical(array.loc[{"time": t[0]}], array[0]) + + def test_loc_assign(self) -> None: + self.ds["x"] = ("x", np.array(list("abcdefghij"))) + da = self.ds["foo"] + # assignment + # typing issue: see https://github.com/python/mypy/issues/2410 + da.loc["a":"j"] = 0 # type: ignore[misc] + assert np.all(da.values == 0) + da.loc[{"x": slice("a", "j")}] = 2 + assert np.all(da.values == 2) + + da.loc[{"x": slice("a", "j")}] = 2 + assert np.all(da.values == 2) + + # Multi dimensional case + da = DataArray(np.arange(12).reshape(3, 4), dims=["x", "y"]) + da.loc[0, 0] = 0 + assert da.values[0, 0] == 0 + assert da.values[0, 1] != 0 + + da = DataArray(np.arange(12).reshape(3, 4), dims=["x", "y"]) + da.loc[0] = 0 + assert np.all(da.values[0] == np.zeros(4)) + assert da.values[1, 0] != 0 + + def test_loc_assign_dataarray(self) -> None: + def get_data(): + return DataArray( + np.ones((4, 3, 2)), + dims=["x", "y", "z"], + coords={ + "x": np.arange(4), + "y": ["a", "b", "c"], + "non-dim": ("x", [1, 3, 4, 2]), + }, + ) + + da = get_data() + # indexer with inconsistent coordinates. + ind = DataArray(np.arange(1, 4), dims=["y"], coords={"y": np.random.randn(3)}) + with pytest.raises(IndexError, match=r"dimension coordinate 'y'"): + da.loc[dict(x=ind)] = 0 + + # indexer with consistent coordinates. + ind = DataArray(np.arange(1, 4), dims=["x"], coords={"x": np.arange(1, 4)}) + da.loc[dict(x=ind)] = 0 # should not raise + assert np.allclose(da[dict(x=ind)].values, 0) + assert_identical(da["x"], get_data()["x"]) + assert_identical(da["non-dim"], get_data()["non-dim"]) + + da = get_data() + # conflict in the assigning values + value = xr.DataArray( + np.zeros((3, 3, 2)), + dims=["x", "y", "z"], + coords={"x": [0, 1, 2], "non-dim": ("x", [0, 2, 4])}, + ) + with pytest.raises(IndexError, match=r"dimension coordinate 'x'"): + da.loc[dict(x=ind)] = value + + # consistent coordinate in the assigning values + value = xr.DataArray( + np.zeros((3, 3, 2)), + dims=["x", "y", "z"], + coords={"x": [1, 2, 3], "non-dim": ("x", [0, 2, 4])}, + ) + da.loc[dict(x=ind)] = value + assert np.allclose(da[dict(x=ind)].values, 0) + assert_identical(da["x"], get_data()["x"]) + assert_identical(da["non-dim"], get_data()["non-dim"]) + + def test_loc_single_boolean(self) -> None: + data = DataArray([0, 1], coords=[[True, False]]) + assert data.loc[True] == 0 + assert data.loc[False] == 1 + + def test_loc_dim_name_collision_with_sel_params(self) -> None: + da = xr.DataArray( + [[0, 0], [1, 1]], + dims=["dim1", "method"], + coords={"dim1": ["x", "y"], "method": ["a", "b"]}, + ) + np.testing.assert_array_equal( + da.loc[dict(dim1=["x", "y"], method=["a"])], [[0], [1]] + ) + + def test_selection_multiindex(self) -> None: + mindex = pd.MultiIndex.from_product( + [["a", "b"], [1, 2], [-1, -2]], names=("one", "two", "three") + ) + mdata = DataArray(range(8), [("x", mindex)]) + + def test_sel( + lab_indexer, pos_indexer, replaced_idx=False, renamed_dim=None + ) -> None: + da = mdata.sel(x=lab_indexer) + expected_da = mdata.isel(x=pos_indexer) + if not replaced_idx: + assert_identical(da, expected_da) + else: + if renamed_dim: + assert da.dims[0] == renamed_dim + da = da.rename({renamed_dim: "x"}) + assert_identical(da.variable, expected_da.variable) + assert not da["x"].equals(expected_da["x"]) + + test_sel(("a", 1, -1), 0) + test_sel(("b", 2, -2), -1) + test_sel(("a", 1), [0, 1], replaced_idx=True, renamed_dim="three") + test_sel(("a",), range(4), replaced_idx=True) + test_sel("a", range(4), replaced_idx=True) + test_sel([("a", 1, -1), ("b", 2, -2)], [0, 7]) + test_sel(slice("a", "b"), range(8)) + test_sel(slice(("a", 1), ("b", 1)), range(6)) + test_sel({"one": "a", "two": 1, "three": -1}, 0) + test_sel({"one": "a", "two": 1}, [0, 1], replaced_idx=True, renamed_dim="three") + test_sel({"one": "a"}, range(4), replaced_idx=True) + + assert_identical(mdata.loc["a"], mdata.sel(x="a")) + assert_identical(mdata.loc[("a", 1), ...], mdata.sel(x=("a", 1))) + assert_identical(mdata.loc[{"one": "a"}, ...], mdata.sel(x={"one": "a"})) + with pytest.raises(IndexError): + mdata.loc[("a", 1)] + + assert_identical(mdata.sel(x={"one": "a", "two": 1}), mdata.sel(one="a", two=1)) + + def test_selection_multiindex_remove_unused(self) -> None: + # GH2619. For MultiIndex, we need to call remove_unused. + ds = xr.DataArray( + np.arange(40).reshape(8, 5), + dims=["x", "y"], + coords={"x": np.arange(8), "y": np.arange(5)}, + ) + ds = ds.stack(xy=["x", "y"]) + ds_isel = ds.isel(xy=ds["x"] < 4) + with pytest.raises(KeyError): + ds_isel.sel(x=5) + + actual = ds_isel.unstack() + expected = ds.reset_index("xy").isel(xy=ds["x"] < 4) + expected = expected.set_index(xy=["x", "y"]).unstack() + assert_identical(expected, actual) + + def test_selection_multiindex_from_level(self) -> None: + # GH: 3512 + da = DataArray([0, 1], dims=["x"], coords={"x": [0, 1], "y": "a"}) + db = DataArray([2, 3], dims=["x"], coords={"x": [0, 1], "y": "b"}) + data = xr.concat([da, db], dim="x").set_index(xy=["x", "y"]) + assert data.dims == ("xy",) + actual = data.sel(y="a") + expected = data.isel(xy=[0, 1]).unstack("xy").squeeze("y") + assert_equal(actual, expected) + + def test_virtual_default_coords(self) -> None: + array = DataArray(np.zeros((5,)), dims="x") + expected = DataArray(range(5), dims="x", name="x") + assert_identical(expected, array["x"]) + assert_identical(expected, array.coords["x"]) + + def test_virtual_time_components(self) -> None: + dates = pd.date_range("2000-01-01", periods=10) + da = DataArray(np.arange(1, 11), [("time", dates)]) + + assert_array_equal(da["time.dayofyear"], da.values) + assert_array_equal(da.coords["time.dayofyear"], da.values) + + def test_coords(self) -> None: + # use int64 to ensure repr() consistency on windows + coords = [ + IndexVariable("x", np.array([-1, -2], "int64")), + IndexVariable("y", np.array([0, 1, 2], "int64")), + ] + da = DataArray(np.random.randn(2, 3), coords, name="foo") + + # len + assert len(da.coords) == 2 + + # iter + assert list(da.coords) == ["x", "y"] + + assert coords[0].identical(da.coords["x"]) + assert coords[1].identical(da.coords["y"]) + + assert "x" in da.coords + assert 0 not in da.coords + assert "foo" not in da.coords + + with pytest.raises(KeyError): + da.coords[0] + with pytest.raises(KeyError): + da.coords["foo"] + + # repr + expected_repr = dedent( + """\ + Coordinates: + * x (x) int64 16B -1 -2 + * y (y) int64 24B 0 1 2""" + ) + actual = repr(da.coords) + assert expected_repr == actual + + # dtypes + assert da.coords.dtypes == {"x": np.dtype("int64"), "y": np.dtype("int64")} + + del da.coords["x"] + da._indexes = filter_indexes_from_coords(da.xindexes, set(da.coords)) + expected = DataArray(da.values, {"y": [0, 1, 2]}, dims=["x", "y"], name="foo") + assert_identical(da, expected) + + with pytest.raises( + ValueError, match=r"cannot drop or update coordinate.*corrupt.*index " + ): + self.mda["level_1"] = ("x", np.arange(4)) + self.mda.coords["level_1"] = ("x", np.arange(4)) + + def test_coords_to_index(self) -> None: + da = DataArray(np.zeros((2, 3)), [("x", [1, 2]), ("y", list("abc"))]) + + with pytest.raises(ValueError, match=r"no valid index"): + da[0, 0].coords.to_index() + + expected = pd.Index(["a", "b", "c"], name="y") + actual = da[0].coords.to_index() + assert expected.equals(actual) + + expected = pd.MultiIndex.from_product( + [[1, 2], ["a", "b", "c"]], names=["x", "y"] + ) + actual = da.coords.to_index() + assert expected.equals(actual) + + expected = pd.MultiIndex.from_product( + [["a", "b", "c"], [1, 2]], names=["y", "x"] + ) + actual = da.coords.to_index(["y", "x"]) + assert expected.equals(actual) + + with pytest.raises(ValueError, match=r"ordered_dims must match"): + da.coords.to_index(["x"]) + + def test_coord_coords(self) -> None: + orig = DataArray( + [10, 20], {"x": [1, 2], "x2": ("x", ["a", "b"]), "z": 4}, dims="x" + ) + + actual = orig.coords["x"] + expected = DataArray( + [1, 2], {"z": 4, "x2": ("x", ["a", "b"]), "x": [1, 2]}, dims="x", name="x" + ) + assert_identical(expected, actual) + + del actual.coords["x2"] + assert_identical(expected.reset_coords("x2", drop=True), actual) + + actual.coords["x3"] = ("x", ["a", "b"]) + expected = DataArray( + [1, 2], {"z": 4, "x3": ("x", ["a", "b"]), "x": [1, 2]}, dims="x", name="x" + ) + assert_identical(expected, actual) + + def test_reset_coords(self) -> None: + data = DataArray( + np.zeros((3, 4)), + {"bar": ("x", ["a", "b", "c"]), "baz": ("y", range(4)), "y": range(4)}, + dims=["x", "y"], + name="foo", + ) + + actual1 = data.reset_coords() + expected1 = Dataset( + { + "foo": (["x", "y"], np.zeros((3, 4))), + "bar": ("x", ["a", "b", "c"]), + "baz": ("y", range(4)), + "y": range(4), + } + ) + assert_identical(actual1, expected1) + + actual2 = data.reset_coords(["bar", "baz"]) + assert_identical(actual2, expected1) + + actual3 = data.reset_coords("bar") + expected3 = Dataset( + {"foo": (["x", "y"], np.zeros((3, 4))), "bar": ("x", ["a", "b", "c"])}, + {"baz": ("y", range(4)), "y": range(4)}, + ) + assert_identical(actual3, expected3) + + actual4 = data.reset_coords(["bar"]) + assert_identical(actual4, expected3) + + actual5 = data.reset_coords(drop=True) + expected5 = DataArray( + np.zeros((3, 4)), coords={"y": range(4)}, dims=["x", "y"], name="foo" + ) + assert_identical(actual5, expected5) + + actual6 = data.copy().reset_coords(drop=True) + assert_identical(actual6, expected5) + + actual7 = data.reset_coords("bar", drop=True) + expected7 = DataArray( + np.zeros((3, 4)), + {"baz": ("y", range(4)), "y": range(4)}, + dims=["x", "y"], + name="foo", + ) + assert_identical(actual7, expected7) + + with pytest.raises(ValueError, match=r"cannot be found"): + data.reset_coords("foo", drop=True) + with pytest.raises(ValueError, match=r"cannot be found"): + data.reset_coords("not_found") + with pytest.raises(ValueError, match=r"cannot remove index"): + data.reset_coords("y") + + # non-dimension index coordinate + midx = pd.MultiIndex.from_product([["a", "b"], [0, 1]], names=("lvl1", "lvl2")) + data = DataArray([1, 2, 3, 4], coords={"x": midx}, dims="x", name="foo") + with pytest.raises(ValueError, match=r"cannot remove index"): + data.reset_coords("lvl1") + + def test_assign_coords(self) -> None: + array = DataArray(10) + actual = array.assign_coords(c=42) + expected = DataArray(10, {"c": 42}) + assert_identical(actual, expected) + + with pytest.raises( + ValueError, match=r"cannot drop or update coordinate.*corrupt.*index " + ): + self.mda.assign_coords(level_1=("x", range(4))) + + # GH: 2112 + da = xr.DataArray([0, 1, 2], dims="x") + with pytest.raises(ValueError): + da["x"] = [0, 1, 2, 3] # size conflict + with pytest.raises(ValueError): + da.coords["x"] = [0, 1, 2, 3] # size conflict + with pytest.raises(ValueError): + da.coords["x"] = ("y", [1, 2, 3]) # no new dimension to a DataArray + + def test_assign_coords_existing_multiindex(self) -> None: + data = self.mda + with pytest.warns( + FutureWarning, match=r"updating coordinate.*MultiIndex.*inconsistent" + ): + data.assign_coords(x=range(4)) + + def test_assign_coords_custom_index(self) -> None: + class CustomIndex(Index): + pass + + coords = Coordinates( + coords={"x": ("x", [1, 2, 3])}, indexes={"x": CustomIndex()} + ) + da = xr.DataArray([0, 1, 2], dims="x") + actual = da.assign_coords(coords) + assert isinstance(actual.xindexes["x"], CustomIndex) + + def test_assign_coords_no_default_index(self) -> None: + coords = Coordinates({"y": [1, 2, 3]}, indexes={}) + da = DataArray([1, 2, 3], dims="y") + actual = da.assign_coords(coords) + assert_identical(actual.coords, coords, check_default_indexes=False) + assert "y" not in actual.xindexes + + def test_coords_alignment(self) -> None: + lhs = DataArray([1, 2, 3], [("x", [0, 1, 2])]) + rhs = DataArray([2, 3, 4], [("x", [1, 2, 3])]) + lhs.coords["rhs"] = rhs + + expected = DataArray( + [1, 2, 3], coords={"rhs": ("x", [np.nan, 2, 3]), "x": [0, 1, 2]}, dims="x" + ) + assert_identical(lhs, expected) + + def test_set_coords_update_index(self) -> None: + actual = DataArray([1, 2, 3], [("x", [1, 2, 3])]) + actual.coords["x"] = ["a", "b", "c"] + assert actual.xindexes["x"].to_pandas_index().equals(pd.Index(["a", "b", "c"])) + + def test_set_coords_multiindex_level(self) -> None: + with pytest.raises( + ValueError, match=r"cannot drop or update coordinate.*corrupt.*index " + ): + self.mda["level_1"] = range(4) + + def test_coords_replacement_alignment(self) -> None: + # regression test for GH725 + arr = DataArray([0, 1, 2], dims=["abc"]) + new_coord = DataArray([1, 2, 3], dims=["abc"], coords=[[1, 2, 3]]) + arr["abc"] = new_coord + expected = DataArray([0, 1, 2], coords=[("abc", [1, 2, 3])]) + assert_identical(arr, expected) + + def test_coords_non_string(self) -> None: + arr = DataArray(0, coords={1: 2}) + actual = arr.coords[1] + expected = DataArray(2, coords={1: 2}, name=1) + assert_identical(actual, expected) + + def test_coords_delitem_delete_indexes(self) -> None: + # regression test for GH3746 + arr = DataArray(np.ones((2,)), dims="x", coords={"x": [0, 1]}) + del arr.coords["x"] + assert "x" not in arr.xindexes + + def test_coords_delitem_multiindex_level(self) -> None: + with pytest.raises( + ValueError, match=r"cannot remove coordinate.*corrupt.*index " + ): + del self.mda.coords["level_1"] + + def test_broadcast_like(self) -> None: + arr1 = DataArray( + np.ones((2, 3)), + dims=["x", "y"], + coords={"x": ["a", "b"], "y": ["a", "b", "c"]}, + ) + arr2 = DataArray( + np.ones((3, 2)), + dims=["x", "y"], + coords={"x": ["a", "b", "c"], "y": ["a", "b"]}, + ) + orig1, orig2 = broadcast(arr1, arr2) + new1 = arr1.broadcast_like(arr2) + new2 = arr2.broadcast_like(arr1) + + assert_identical(orig1, new1) + assert_identical(orig2, new2) + + orig3 = DataArray(np.random.randn(5), [("x", range(5))]) + orig4 = DataArray(np.random.randn(6), [("y", range(6))]) + new3, new4 = broadcast(orig3, orig4) + + assert_identical(orig3.broadcast_like(orig4), new3.transpose("y", "x")) + assert_identical(orig4.broadcast_like(orig3), new4) + + def test_reindex_like(self) -> None: + foo = DataArray(np.random.randn(5, 6), [("x", range(5)), ("y", range(6))]) + bar = foo[:2, :2] + assert_identical(foo.reindex_like(bar), bar) + + expected = foo.copy() + expected[:] = np.nan + expected[:2, :2] = bar + assert_identical(bar.reindex_like(foo), expected) + + def test_reindex_like_no_index(self) -> None: + foo = DataArray(np.random.randn(5, 6), dims=["x", "y"]) + assert_identical(foo, foo.reindex_like(foo)) + + bar = foo[:4] + with pytest.raises(ValueError, match=r"different size for unlabeled"): + foo.reindex_like(bar) + + def test_reindex_regressions(self) -> None: + da = DataArray(np.random.randn(5), coords=[("time", range(5))]) + time2 = DataArray(np.arange(5), dims="time2") + with pytest.raises(ValueError): + da.reindex(time=time2) + + # regression test for #736, reindex can not change complex nums dtype + xnp = np.array([1, 2, 3], dtype=complex) + x = DataArray(xnp, coords=[[0.1, 0.2, 0.3]]) + y = DataArray([2, 5, 6, 7, 8], coords=[[-1.1, 0.21, 0.31, 0.41, 0.51]]) + re_dtype = x.reindex_like(y, method="pad").dtype + assert x.dtype == re_dtype + + def test_reindex_method(self) -> None: + x = DataArray([10, 20], dims="y", coords={"y": [0, 1]}) + y = [-0.1, 0.5, 1.1] + actual = x.reindex(y=y, method="backfill", tolerance=0.2) + expected = DataArray([10, np.nan, np.nan], coords=[("y", y)]) + assert_identical(expected, actual) + + actual = x.reindex(y=y, method="backfill", tolerance=[0.1, 0.1, 0.01]) + expected = DataArray([10, np.nan, np.nan], coords=[("y", y)]) + assert_identical(expected, actual) + + alt = Dataset({"y": y}) + actual = x.reindex_like(alt, method="backfill") + expected = DataArray([10, 20, np.nan], coords=[("y", y)]) + assert_identical(expected, actual) + + @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0, {None: 2, "u": 1}]) + def test_reindex_fill_value(self, fill_value) -> None: + x = DataArray([10, 20], dims="y", coords={"y": [0, 1], "u": ("y", [1, 2])}) + y = [0, 1, 2] + if fill_value == dtypes.NA: + # if we supply the default, we expect the missing value for a + # float array + fill_value_var = fill_value_u = np.nan + elif isinstance(fill_value, dict): + fill_value_var = fill_value[None] + fill_value_u = fill_value["u"] + else: + fill_value_var = fill_value_u = fill_value + actual = x.reindex(y=y, fill_value=fill_value) + expected = DataArray( + [10, 20, fill_value_var], + dims="y", + coords={"y": y, "u": ("y", [1, 2, fill_value_u])}, + ) + assert_identical(expected, actual) + + @pytest.mark.parametrize("dtype", [str, bytes]) + def test_reindex_str_dtype(self, dtype) -> None: + data = DataArray( + [1, 2], dims="x", coords={"x": np.array(["a", "b"], dtype=dtype)} + ) + + actual = data.reindex(x=data.x) + expected = data + + assert_identical(expected, actual) + assert actual.dtype == expected.dtype + + def test_reindex_empty_array_dtype(self) -> None: + # Dtype of reindex result should match dtype of the original DataArray. + # See GH issue #7299 + x = xr.DataArray([], dims=("x",), coords={"x": []}).astype("float32") + y = x.reindex(x=[1.0, 2.0]) + + assert ( + x.dtype == y.dtype + ), "Dtype of reindexed DataArray should match dtype of the original DataArray" + assert ( + y.dtype == np.float32 + ), "Dtype of reindexed DataArray should remain float32" + + def test_rename(self) -> None: + da = xr.DataArray( + [1, 2, 3], dims="dim", name="name", coords={"coord": ("dim", [5, 6, 7])} + ) + + # change name + renamed_name = da.rename("name_new") + assert renamed_name.name == "name_new" + expected_name = da.copy() + expected_name.name = "name_new" + assert_identical(renamed_name, expected_name) + + # change name to None? + renamed_noname = da.rename(None) + assert renamed_noname.name is None + expected_noname = da.copy() + expected_noname.name = None + assert_identical(renamed_noname, expected_noname) + renamed_noname = da.rename() + assert renamed_noname.name is None + assert_identical(renamed_noname, expected_noname) + + # change dim + renamed_dim = da.rename({"dim": "dim_new"}) + assert renamed_dim.dims == ("dim_new",) + expected_dim = xr.DataArray( + [1, 2, 3], + dims="dim_new", + name="name", + coords={"coord": ("dim_new", [5, 6, 7])}, + ) + assert_identical(renamed_dim, expected_dim) + + # change dim with kwargs + renamed_dimkw = da.rename(dim="dim_new") + assert renamed_dimkw.dims == ("dim_new",) + assert_identical(renamed_dimkw, expected_dim) + + # change coords + renamed_coord = da.rename({"coord": "coord_new"}) + assert "coord_new" in renamed_coord.coords + expected_coord = xr.DataArray( + [1, 2, 3], dims="dim", name="name", coords={"coord_new": ("dim", [5, 6, 7])} + ) + assert_identical(renamed_coord, expected_coord) + + # change coords with kwargs + renamed_coordkw = da.rename(coord="coord_new") + assert "coord_new" in renamed_coordkw.coords + assert_identical(renamed_coordkw, expected_coord) + + # change coord and dim + renamed_both = da.rename({"dim": "dim_new", "coord": "coord_new"}) + assert renamed_both.dims == ("dim_new",) + assert "coord_new" in renamed_both.coords + expected_both = xr.DataArray( + [1, 2, 3], + dims="dim_new", + name="name", + coords={"coord_new": ("dim_new", [5, 6, 7])}, + ) + assert_identical(renamed_both, expected_both) + + # change coord and dim with kwargs + renamed_bothkw = da.rename(dim="dim_new", coord="coord_new") + assert renamed_bothkw.dims == ("dim_new",) + assert "coord_new" in renamed_bothkw.coords + assert_identical(renamed_bothkw, expected_both) + + # change all + renamed_all = da.rename("name_new", dim="dim_new", coord="coord_new") + assert renamed_all.name == "name_new" + assert renamed_all.dims == ("dim_new",) + assert "coord_new" in renamed_all.coords + expected_all = xr.DataArray( + [1, 2, 3], + dims="dim_new", + name="name_new", + coords={"coord_new": ("dim_new", [5, 6, 7])}, + ) + assert_identical(renamed_all, expected_all) + + def test_rename_dimension_coord_warnings(self) -> None: + # create a dimension coordinate by renaming a dimension or coordinate + # should raise a warning (no index created) + da = DataArray([0, 0], coords={"x": ("y", [0, 1])}, dims="y") + + with pytest.warns( + UserWarning, match="rename 'x' to 'y' does not create an index.*" + ): + da.rename(x="y") + + da = xr.DataArray([0, 0], coords={"y": ("x", [0, 1])}, dims="x") + + with pytest.warns( + UserWarning, match="rename 'x' to 'y' does not create an index.*" + ): + da.rename(x="y") + + # No operation should not raise a warning + da = xr.DataArray( + data=np.ones((2, 3)), + dims=["x", "y"], + coords={"x": range(2), "y": range(3), "a": ("x", [3, 4])}, + ) + with warnings.catch_warnings(): + warnings.simplefilter("error") + da.rename(x="x") + + def test_init_value(self) -> None: + expected = DataArray( + np.full((3, 4), 3), dims=["x", "y"], coords=[range(3), range(4)] + ) + actual = DataArray(3, dims=["x", "y"], coords=[range(3), range(4)]) + assert_identical(expected, actual) + + expected = DataArray( + np.full((1, 10, 2), 0), + dims=["w", "x", "y"], + coords={"x": np.arange(10), "y": ["north", "south"]}, + ) + actual = DataArray(0, dims=expected.dims, coords=expected.coords) + assert_identical(expected, actual) + + expected = DataArray( + np.full((10, 2), np.nan), coords=[("x", np.arange(10)), ("y", ["a", "b"])] + ) + actual = DataArray(coords=[("x", np.arange(10)), ("y", ["a", "b"])]) + assert_identical(expected, actual) + + with pytest.raises(ValueError, match=r"different number of dim"): + DataArray(np.array(1), coords={"x": np.arange(10)}, dims=["x"]) + with pytest.raises(ValueError, match=r"does not match the 0 dim"): + DataArray(np.array(1), coords=[("x", np.arange(10))]) + + def test_swap_dims(self) -> None: + array = DataArray(np.random.randn(3), {"x": list("abc")}, "x") + expected = DataArray(array.values, {"x": ("y", list("abc"))}, dims="y") + actual = array.swap_dims({"x": "y"}) + assert_identical(expected, actual) + for dim_name in set().union(expected.xindexes.keys(), actual.xindexes.keys()): + assert actual.xindexes[dim_name].equals(expected.xindexes[dim_name]) + + # as kwargs + array = DataArray(np.random.randn(3), {"x": list("abc")}, "x") + expected = DataArray(array.values, {"x": ("y", list("abc"))}, dims="y") + actual = array.swap_dims(x="y") + assert_identical(expected, actual) + for dim_name in set().union(expected.xindexes.keys(), actual.xindexes.keys()): + assert actual.xindexes[dim_name].equals(expected.xindexes[dim_name]) + + # multiindex case + idx = pd.MultiIndex.from_arrays([list("aab"), list("yzz")], names=["y1", "y2"]) + array = DataArray(np.random.randn(3), {"y": ("x", idx)}, "x") + expected = DataArray(array.values, {"y": idx}, "y") + actual = array.swap_dims({"x": "y"}) + assert_identical(expected, actual) + for dim_name in set().union(expected.xindexes.keys(), actual.xindexes.keys()): + assert actual.xindexes[dim_name].equals(expected.xindexes[dim_name]) + + def test_expand_dims_error(self) -> None: + array = DataArray( + np.random.randn(3, 4), + dims=["x", "dim_0"], + coords={"x": np.linspace(0.0, 1.0, 3)}, + attrs={"key": "entry"}, + ) + + with pytest.raises(TypeError, match=r"dim should be Hashable or"): + array.expand_dims(0) + with pytest.raises(ValueError, match=r"lengths of dim and axis"): + # dims and axis argument should be the same length + array.expand_dims(dim=["a", "b"], axis=[1, 2, 3]) + with pytest.raises(ValueError, match=r"Dimension x already"): + # Should not pass the already existing dimension. + array.expand_dims(dim=["x"]) + # raise if duplicate + with pytest.raises(ValueError, match=r"duplicate values"): + array.expand_dims(dim=["y", "y"]) + with pytest.raises(ValueError, match=r"duplicate values"): + array.expand_dims(dim=["y", "z"], axis=[1, 1]) + with pytest.raises(ValueError, match=r"duplicate values"): + array.expand_dims(dim=["y", "z"], axis=[2, -2]) + + # out of bounds error, axis must be in [-4, 3] + with pytest.raises(IndexError): + array.expand_dims(dim=["y", "z"], axis=[2, 4]) + with pytest.raises(IndexError): + array.expand_dims(dim=["y", "z"], axis=[2, -5]) + # Does not raise an IndexError + array.expand_dims(dim=["y", "z"], axis=[2, -4]) + array.expand_dims(dim=["y", "z"], axis=[2, 3]) + + array = DataArray( + np.random.randn(3, 4), + dims=["x", "dim_0"], + coords={"x": np.linspace(0.0, 1.0, 3)}, + attrs={"key": "entry"}, + ) + with pytest.raises(TypeError): + array.expand_dims({"new_dim": 3.2}) + + # Attempt to use both dim and kwargs + with pytest.raises(ValueError): + array.expand_dims({"d": 4}, e=4) + + def test_expand_dims(self) -> None: + array = DataArray( + np.random.randn(3, 4), + dims=["x", "dim_0"], + coords={"x": np.linspace(0.0, 1.0, 3)}, + attrs={"key": "entry"}, + ) + # pass only dim label + actual = array.expand_dims(dim="y") + expected = DataArray( + np.expand_dims(array.values, 0), + dims=["y", "x", "dim_0"], + coords={"x": np.linspace(0.0, 1.0, 3)}, + attrs={"key": "entry"}, + ) + assert_identical(expected, actual) + roundtripped = actual.squeeze("y", drop=True) + assert_identical(array, roundtripped) + + # pass multiple dims + actual = array.expand_dims(dim=["y", "z"]) + expected = DataArray( + np.expand_dims(np.expand_dims(array.values, 0), 0), + dims=["y", "z", "x", "dim_0"], + coords={"x": np.linspace(0.0, 1.0, 3)}, + attrs={"key": "entry"}, + ) + assert_identical(expected, actual) + roundtripped = actual.squeeze(["y", "z"], drop=True) + assert_identical(array, roundtripped) + + # pass multiple dims and axis. Axis is out of order + actual = array.expand_dims(dim=["z", "y"], axis=[2, 1]) + expected = DataArray( + np.expand_dims(np.expand_dims(array.values, 1), 2), + dims=["x", "y", "z", "dim_0"], + coords={"x": np.linspace(0.0, 1.0, 3)}, + attrs={"key": "entry"}, + ) + assert_identical(expected, actual) + # make sure the attrs are tracked + assert actual.attrs["key"] == "entry" + roundtripped = actual.squeeze(["z", "y"], drop=True) + assert_identical(array, roundtripped) + + # Negative axis and they are out of order + actual = array.expand_dims(dim=["y", "z"], axis=[-1, -2]) + expected = DataArray( + np.expand_dims(np.expand_dims(array.values, -1), -1), + dims=["x", "dim_0", "z", "y"], + coords={"x": np.linspace(0.0, 1.0, 3)}, + attrs={"key": "entry"}, + ) + assert_identical(expected, actual) + assert actual.attrs["key"] == "entry" + roundtripped = actual.squeeze(["y", "z"], drop=True) + assert_identical(array, roundtripped) + + def test_expand_dims_with_scalar_coordinate(self) -> None: + array = DataArray( + np.random.randn(3, 4), + dims=["x", "dim_0"], + coords={"x": np.linspace(0.0, 1.0, 3), "z": 1.0}, + attrs={"key": "entry"}, + ) + actual = array.expand_dims(dim="z") + expected = DataArray( + np.expand_dims(array.values, 0), + dims=["z", "x", "dim_0"], + coords={"x": np.linspace(0.0, 1.0, 3), "z": np.ones(1)}, + attrs={"key": "entry"}, + ) + assert_identical(expected, actual) + roundtripped = actual.squeeze(["z"], drop=False) + assert_identical(array, roundtripped) + + def test_expand_dims_with_greater_dim_size(self) -> None: + array = DataArray( + np.random.randn(3, 4), + dims=["x", "dim_0"], + coords={"x": np.linspace(0.0, 1.0, 3), "z": 1.0}, + attrs={"key": "entry"}, + ) + actual = array.expand_dims({"y": 2, "z": 1, "dim_1": ["a", "b", "c"]}) + + expected_coords = { + "y": [0, 1], + "z": [1.0], + "dim_1": ["a", "b", "c"], + "x": np.linspace(0, 1, 3), + "dim_0": range(4), + } + expected = DataArray( + array.values * np.ones([2, 1, 3, 3, 4]), + coords=expected_coords, + dims=list(expected_coords.keys()), + attrs={"key": "entry"}, + ).drop_vars(["y", "dim_0"]) + assert_identical(expected, actual) + + # Test with kwargs instead of passing dict to dim arg. + + other_way = array.expand_dims(dim_1=["a", "b", "c"]) + + other_way_expected = DataArray( + array.values * np.ones([3, 3, 4]), + coords={ + "dim_1": ["a", "b", "c"], + "x": np.linspace(0, 1, 3), + "dim_0": range(4), + "z": 1.0, + }, + dims=["dim_1", "x", "dim_0"], + attrs={"key": "entry"}, + ).drop_vars("dim_0") + assert_identical(other_way_expected, other_way) + + def test_set_index(self) -> None: + indexes = [self.mindex.get_level_values(n) for n in self.mindex.names] + coords = {idx.name: ("x", idx) for idx in indexes} + array = DataArray(self.mda.values, coords=coords, dims="x") + expected = self.mda.copy() + level_3 = ("x", [1, 2, 3, 4]) + array["level_3"] = level_3 + expected["level_3"] = level_3 + + obj = array.set_index(x=self.mindex.names) + assert_identical(obj, expected) + + obj = obj.set_index(x="level_3", append=True) + expected = array.set_index(x=["level_1", "level_2", "level_3"]) + assert_identical(obj, expected) + + array = array.set_index(x=["level_1", "level_2", "level_3"]) + assert_identical(array, expected) + + array2d = DataArray( + np.random.rand(2, 2), + coords={"x": ("x", [0, 1]), "level": ("y", [1, 2])}, + dims=("x", "y"), + ) + with pytest.raises(ValueError, match=r"dimension mismatch"): + array2d.set_index(x="level") + + # Issue 3176: Ensure clear error message on key error. + with pytest.raises(ValueError, match=r".*variable\(s\) do not exist"): + obj.set_index(x="level_4") + + def test_reset_index(self) -> None: + indexes = [self.mindex.get_level_values(n) for n in self.mindex.names] + coords = {idx.name: ("x", idx) for idx in indexes} + expected = DataArray(self.mda.values, coords=coords, dims="x") + + obj = self.mda.reset_index("x") + assert_identical(obj, expected, check_default_indexes=False) + assert len(obj.xindexes) == 0 + obj = self.mda.reset_index(self.mindex.names) + assert_identical(obj, expected, check_default_indexes=False) + assert len(obj.xindexes) == 0 + obj = self.mda.reset_index(["x", "level_1"]) + assert_identical(obj, expected, check_default_indexes=False) + assert len(obj.xindexes) == 0 + + coords = { + "x": ("x", self.mindex.droplevel("level_1")), + "level_1": ("x", self.mindex.get_level_values("level_1")), + } + expected = DataArray(self.mda.values, coords=coords, dims="x") + obj = self.mda.reset_index(["level_1"]) + assert_identical(obj, expected, check_default_indexes=False) + assert list(obj.xindexes) == ["x"] + assert type(obj.xindexes["x"]) is PandasIndex + + expected = DataArray(self.mda.values, dims="x") + obj = self.mda.reset_index("x", drop=True) + assert_identical(obj, expected, check_default_indexes=False) + + array = self.mda.copy() + array = array.reset_index(["x"], drop=True) + assert_identical(array, expected, check_default_indexes=False) + + # single index + array = DataArray([1, 2], coords={"x": ["a", "b"]}, dims="x") + obj = array.reset_index("x") + print(obj.x.variable) + print(array.x.variable) + assert_equal(obj.x.variable, array.x.variable.to_base_variable()) + assert len(obj.xindexes) == 0 + + def test_reset_index_keep_attrs(self) -> None: + coord_1 = DataArray([1, 2], dims=["coord_1"], attrs={"attrs": True}) + da = DataArray([1, 0], [coord_1]) + obj = da.reset_index("coord_1") + assert obj.coord_1.attrs == da.coord_1.attrs + assert len(obj.xindexes) == 0 + + def test_reorder_levels(self) -> None: + midx = self.mindex.reorder_levels(["level_2", "level_1"]) + expected = DataArray(self.mda.values, coords={"x": midx}, dims="x") + + obj = self.mda.reorder_levels(x=["level_2", "level_1"]) + assert_identical(obj, expected) + + array = DataArray([1, 2], dims="x") + with pytest.raises(KeyError): + array.reorder_levels(x=["level_1", "level_2"]) + + array["x"] = [0, 1] + with pytest.raises(ValueError, match=r"has no MultiIndex"): + array.reorder_levels(x=["level_1", "level_2"]) + + def test_set_xindex(self) -> None: + da = DataArray( + [1, 2, 3, 4], coords={"foo": ("x", ["a", "a", "b", "b"])}, dims="x" + ) + + class IndexWithOptions(Index): + def __init__(self, opt): + self.opt = opt + + @classmethod + def from_variables(cls, variables, options): + return cls(options["opt"]) + + indexed = da.set_xindex("foo", IndexWithOptions, opt=1) + assert "foo" in indexed.xindexes + assert getattr(indexed.xindexes["foo"], "opt") == 1 + + def test_dataset_getitem(self) -> None: + dv = self.ds["foo"] + assert_identical(dv, self.dv) + + def test_array_interface(self) -> None: + assert_array_equal(np.asarray(self.dv), self.x) + # test patched in methods + assert_array_equal(self.dv.astype(float), self.v.astype(float)) + assert_array_equal(self.dv.argsort(), self.v.argsort()) + assert_array_equal(self.dv.clip(2, 3), self.v.clip(2, 3)) + # test ufuncs + expected = deepcopy(self.ds) + expected["foo"][:] = np.sin(self.x) + assert_equal(expected["foo"], np.sin(self.dv)) + assert_array_equal(self.dv, np.maximum(self.v, self.dv)) + bar = Variable(["x", "y"], np.zeros((10, 20))) + assert_equal(self.dv, np.maximum(self.dv, bar)) + + def test_astype_attrs(self) -> None: + for v in [self.va.copy(), self.mda.copy(), self.ds.copy()]: + v.attrs["foo"] = "bar" + assert v.attrs == v.astype(float).attrs + assert not v.astype(float, keep_attrs=False).attrs + + def test_astype_dtype(self) -> None: + original = DataArray([-1, 1, 2, 3, 1000]) + converted = original.astype(float) + assert_array_equal(original, converted) + assert np.issubdtype(original.dtype, np.integer) + assert np.issubdtype(converted.dtype, np.floating) + + def test_astype_order(self) -> None: + original = DataArray([[1, 2], [3, 4]]) + converted = original.astype("d", order="F") + assert_equal(original, converted) + assert original.values.flags["C_CONTIGUOUS"] + assert converted.values.flags["F_CONTIGUOUS"] + + def test_astype_subok(self) -> None: + class NdArraySubclass(np.ndarray): + pass + + original = DataArray(NdArraySubclass(np.arange(3))) + converted_not_subok = original.astype("d", subok=False) + converted_subok = original.astype("d", subok=True) + if not isinstance(original.data, NdArraySubclass): + pytest.xfail("DataArray cannot be backed yet by a subclasses of np.ndarray") + assert isinstance(converted_not_subok.data, np.ndarray) + assert not isinstance(converted_not_subok.data, NdArraySubclass) + assert isinstance(converted_subok.data, NdArraySubclass) + + def test_is_null(self) -> None: + x = np.random.RandomState(42).randn(5, 6) + x[x < 0] = np.nan + original = DataArray(x, [-np.arange(5), np.arange(6)], ["x", "y"]) + expected = DataArray(pd.isnull(x), [-np.arange(5), np.arange(6)], ["x", "y"]) + assert_identical(expected, original.isnull()) + assert_identical(~expected, original.notnull()) + + def test_math(self) -> None: + x = self.x + v = self.v + a = self.dv + # variable math was already tested extensively, so let's just make sure + # that all types are properly converted here + assert_equal(a, +a) + assert_equal(a, a + 0) + assert_equal(a, 0 + a) + assert_equal(a, a + 0 * v) + assert_equal(a, 0 * v + a) + assert_equal(a, a + 0 * x) + assert_equal(a, 0 * x + a) + assert_equal(a, a + 0 * a) + assert_equal(a, 0 * a + a) + + def test_math_automatic_alignment(self) -> None: + a = DataArray(range(5), [("x", range(5))]) + b = DataArray(range(5), [("x", range(1, 6))]) + expected = DataArray(np.ones(4), [("x", [1, 2, 3, 4])]) + assert_identical(a - b, expected) + + def test_non_overlapping_dataarrays_return_empty_result(self) -> None: + a = DataArray(range(5), [("x", range(5))]) + result = a.isel(x=slice(2)) + a.isel(x=slice(2, None)) + assert len(result["x"]) == 0 + + def test_empty_dataarrays_return_empty_result(self) -> None: + a = DataArray(data=[]) + result = a * a + assert len(result["dim_0"]) == 0 + + def test_inplace_math_basics(self) -> None: + x = self.x + a = self.dv + v = a.variable + b = a + b += 1 + assert b is a + assert b.variable is v + assert_array_equal(b.values, x) + assert source_ndarray(b.values) is x + + def test_inplace_math_error(self) -> None: + data = np.random.rand(4) + times = np.arange(4) + foo = DataArray(data, coords=[times], dims=["time"]) + b = times.copy() + with pytest.raises( + TypeError, match=r"Values of an IndexVariable are immutable" + ): + foo.coords["time"] += 1 + # Check error throwing prevented inplace operation + assert_array_equal(foo.coords["time"], b) + + def test_inplace_math_automatic_alignment(self) -> None: + a = DataArray(range(5), [("x", range(5))]) + b = DataArray(range(1, 6), [("x", range(1, 6))]) + with pytest.raises(xr.MergeError, match="Automatic alignment is not supported"): + a += b + with pytest.raises(xr.MergeError, match="Automatic alignment is not supported"): + b += a + + def test_math_name(self) -> None: + # Verify that name is preserved only when it can be done unambiguously. + # The rule (copied from pandas.Series) is keep the current name only if + # the other object has the same name or no name attribute and this + # object isn't a coordinate; otherwise reset to None. + a = self.dv + assert (+a).name == "foo" + assert (a + 0).name == "foo" + assert (a + a.rename(None)).name is None + assert (a + a.rename("bar")).name is None + assert (a + a).name == "foo" + assert (+a["x"]).name == "x" + assert (a["x"] + 0).name == "x" + assert (a + a["x"]).name is None + + def test_math_with_coords(self) -> None: + coords = { + "x": [-1, -2], + "y": ["ab", "cd", "ef"], + "lat": (["x", "y"], [[1, 2, 3], [-1, -2, -3]]), + "c": -999, + } + orig = DataArray(np.random.randn(2, 3), coords, dims=["x", "y"]) + + actual = orig + 1 + expected = DataArray(orig.values + 1, orig.coords) + assert_identical(expected, actual) + + actual = 1 + orig + assert_identical(expected, actual) + + actual = orig + orig[0, 0] + exp_coords = {k: v for k, v in coords.items() if k != "lat"} + expected = DataArray( + orig.values + orig.values[0, 0], exp_coords, dims=["x", "y"] + ) + assert_identical(expected, actual) + + actual = orig[0, 0] + orig + assert_identical(expected, actual) + + actual = orig[0, 0] + orig[-1, -1] + expected = DataArray(orig.values[0, 0] + orig.values[-1, -1], {"c": -999}) + assert_identical(expected, actual) + + actual = orig[:, 0] + orig[0, :] + exp_values = orig[:, 0].values[:, None] + orig[0, :].values[None, :] + expected = DataArray(exp_values, exp_coords, dims=["x", "y"]) + assert_identical(expected, actual) + + actual = orig[0, :] + orig[:, 0] + assert_identical(expected.transpose(transpose_coords=True), actual) + + actual = orig - orig.transpose(transpose_coords=True) + expected = DataArray(np.zeros((2, 3)), orig.coords) + assert_identical(expected, actual) + + actual = orig.transpose(transpose_coords=True) - orig + assert_identical(expected.transpose(transpose_coords=True), actual) + + alt = DataArray([1, 1], {"x": [-1, -2], "c": "foo", "d": 555}, "x") + actual = orig + alt + expected = orig + 1 + expected.coords["d"] = 555 + del expected.coords["c"] + assert_identical(expected, actual) + + actual = alt + orig + assert_identical(expected, actual) + + def test_index_math(self) -> None: + orig = DataArray(range(3), dims="x", name="x") + actual = orig + 1 + expected = DataArray(1 + np.arange(3), dims="x", name="x") + assert_identical(expected, actual) + + # regression tests for #254 + actual = orig[0] < orig + expected = DataArray([False, True, True], dims="x", name="x") + assert_identical(expected, actual) + + actual = orig > orig[0] + assert_identical(expected, actual) + + def test_dataset_math(self) -> None: + # more comprehensive tests with multiple dataset variables + obs = Dataset( + {"tmin": ("x", np.arange(5)), "tmax": ("x", 10 + np.arange(5))}, + {"x": ("x", 0.5 * np.arange(5)), "loc": ("x", range(-2, 3))}, + ) + + actual1 = 2 * obs["tmax"] + expected1 = DataArray(2 * (10 + np.arange(5)), obs.coords, name="tmax") + assert_identical(actual1, expected1) + + actual2 = obs["tmax"] - obs["tmin"] + expected2 = DataArray(10 * np.ones(5), obs.coords) + assert_identical(actual2, expected2) + + sim = Dataset( + { + "tmin": ("x", 1 + np.arange(5)), + "tmax": ("x", 11 + np.arange(5)), + # does *not* include 'loc' as a coordinate + "x": ("x", 0.5 * np.arange(5)), + } + ) + + actual3 = sim["tmin"] - obs["tmin"] + expected3 = DataArray(np.ones(5), obs.coords, name="tmin") + assert_identical(actual3, expected3) + + actual4 = -obs["tmin"] + sim["tmin"] + assert_identical(actual4, expected3) + + actual5 = sim["tmin"].copy() + actual5 -= obs["tmin"] + assert_identical(actual5, expected3) + + actual6 = sim.copy() + actual6["tmin"] = sim["tmin"] - obs["tmin"] + expected6 = Dataset( + {"tmin": ("x", np.ones(5)), "tmax": ("x", sim["tmax"].values)}, obs.coords + ) + assert_identical(actual6, expected6) + + actual7 = sim.copy() + actual7["tmin"] -= obs["tmin"] + assert_identical(actual7, expected6) + + def test_stack_unstack(self) -> None: + orig = DataArray( + [[0, 1], [2, 3]], + dims=["x", "y"], + attrs={"foo": 2}, + ) + assert_identical(orig, orig.unstack()) + + # test GH3000 + a = orig[:0, :1].stack(new_dim=("x", "y")).indexes["new_dim"] + b = pd.MultiIndex( + levels=[pd.Index([], np.int64), pd.Index([0], np.int64)], + codes=[[], []], + names=["x", "y"], + ) + pd.testing.assert_index_equal(a, b) + + actual = orig.stack(z=["x", "y"]).unstack("z").drop_vars(["x", "y"]) + assert_identical(orig, actual) + + actual = orig.stack(z=[...]).unstack("z").drop_vars(["x", "y"]) + assert_identical(orig, actual) + + dims = ["a", "b", "c", "d", "e"] + coords = { + "a": [0], + "b": [1, 2], + "c": [3, 4, 5], + "d": [6, 7], + "e": [8], + } + orig = xr.DataArray(np.random.rand(1, 2, 3, 2, 1), coords=coords, dims=dims) + stacked = orig.stack(ab=["a", "b"], cd=["c", "d"]) + + unstacked = stacked.unstack(["ab", "cd"]) + assert_identical(orig, unstacked.transpose(*dims)) + + unstacked = stacked.unstack() + assert_identical(orig, unstacked.transpose(*dims)) + + def test_stack_unstack_decreasing_coordinate(self) -> None: + # regression test for GH980 + orig = DataArray( + np.random.rand(3, 4), + dims=("y", "x"), + coords={"x": np.arange(4), "y": np.arange(3, 0, -1)}, + ) + stacked = orig.stack(allpoints=["y", "x"]) + actual = stacked.unstack("allpoints") + assert_identical(orig, actual) + + def test_unstack_pandas_consistency(self) -> None: + df = pd.DataFrame({"foo": range(3), "x": ["a", "b", "b"], "y": [0, 0, 1]}) + s = df.set_index(["x", "y"])["foo"] + expected = DataArray(s.unstack(), name="foo") + actual = DataArray(s, dims="z").unstack("z") + assert_identical(expected, actual) + + def test_unstack_requires_unique(self) -> None: + df = pd.DataFrame({"foo": range(2), "x": ["a", "a"], "y": [0, 0]}) + s = df.set_index(["x", "y"])["foo"] + + with pytest.raises( + ValueError, match="Cannot unstack MultiIndex containing duplicates" + ): + DataArray(s, dims="z").unstack("z") + + @pytest.mark.filterwarnings("error") + def test_unstack_roundtrip_integer_array(self) -> None: + arr = xr.DataArray( + np.arange(6).reshape(2, 3), + coords={"x": ["a", "b"], "y": [0, 1, 2]}, + dims=["x", "y"], + ) + + stacked = arr.stack(z=["x", "y"]) + roundtripped = stacked.unstack() + + assert_identical(arr, roundtripped) + + def test_stack_nonunique_consistency(self, da) -> None: + da = da.isel(time=0, drop=True) # 2D + actual = da.stack(z=["a", "x"]) + expected = DataArray(da.to_pandas().stack(), dims="z") + assert_identical(expected, actual) + + def test_to_unstacked_dataset_raises_value_error(self) -> None: + data = DataArray([0, 1], dims="x", coords={"x": [0, 1]}) + with pytest.raises(ValueError, match="'x' is not a stacked coordinate"): + data.to_unstacked_dataset("x", 0) + + def test_transpose(self) -> None: + da = DataArray( + np.random.randn(3, 4, 5), + dims=("x", "y", "z"), + coords={ + "x": range(3), + "y": range(4), + "z": range(5), + "xy": (("x", "y"), np.random.randn(3, 4)), + }, + ) + + actual = da.transpose(transpose_coords=False) + expected = DataArray(da.values.T, dims=("z", "y", "x"), coords=da.coords) + assert_equal(expected, actual) + + actual = da.transpose("z", "y", "x", transpose_coords=True) + expected = DataArray( + da.values.T, + dims=("z", "y", "x"), + coords={ + "x": da.x.values, + "y": da.y.values, + "z": da.z.values, + "xy": (("y", "x"), da.xy.values.T), + }, + ) + assert_equal(expected, actual) + + # same as previous but with ellipsis + actual = da.transpose("z", ..., "x", transpose_coords=True) + assert_equal(expected, actual) + + # same as previous but with a missing dimension + actual = da.transpose( + "z", "y", "x", "not_a_dim", transpose_coords=True, missing_dims="ignore" + ) + assert_equal(expected, actual) + + with pytest.raises(ValueError): + da.transpose("x", "y") + + with pytest.raises(ValueError): + da.transpose("not_a_dim", "z", "x", ...) + + with pytest.warns(UserWarning): + da.transpose("not_a_dim", "y", "x", ..., missing_dims="warn") + + def test_squeeze(self) -> None: + assert_equal(self.dv.variable.squeeze(), self.dv.squeeze().variable) + + def test_squeeze_drop(self) -> None: + array = DataArray([1], [("x", [0])]) + expected = DataArray(1) + actual = array.squeeze(drop=True) + assert_identical(expected, actual) + + expected = DataArray(1, {"x": 0}) + actual = array.squeeze(drop=False) + assert_identical(expected, actual) + + array = DataArray([[[0.0, 1.0]]], dims=["dim_0", "dim_1", "dim_2"]) + expected = DataArray([[0.0, 1.0]], dims=["dim_1", "dim_2"]) + actual = array.squeeze(axis=0) + assert_identical(expected, actual) + + array = DataArray([[[[0.0, 1.0]]]], dims=["dim_0", "dim_1", "dim_2", "dim_3"]) + expected = DataArray([[0.0, 1.0]], dims=["dim_1", "dim_3"]) + actual = array.squeeze(axis=(0, 2)) + assert_identical(expected, actual) + + array = DataArray([[[0.0, 1.0]]], dims=["dim_0", "dim_1", "dim_2"]) + with pytest.raises(ValueError): + array.squeeze(axis=0, dim="dim_1") + + def test_drop_coordinates(self) -> None: + expected = DataArray(np.random.randn(2, 3), dims=["x", "y"]) + arr = expected.copy() + arr.coords["z"] = 2 + actual = arr.drop_vars("z") + assert_identical(expected, actual) + + with pytest.raises(ValueError): + arr.drop_vars("not found") + + actual = expected.drop_vars("not found", errors="ignore") + assert_identical(actual, expected) + + with pytest.raises(ValueError, match=r"cannot be found"): + arr.drop_vars("w") + + actual = expected.drop_vars("w", errors="ignore") + assert_identical(actual, expected) + + renamed = arr.rename("foo") + with pytest.raises(ValueError, match=r"cannot be found"): + renamed.drop_vars("foo") + + actual = renamed.drop_vars("foo", errors="ignore") + assert_identical(actual, renamed) + + def test_drop_vars_callable(self) -> None: + A = DataArray( + np.random.randn(2, 3), dims=["x", "y"], coords={"x": [1, 2], "y": [3, 4, 5]} + ) + expected = A.drop_vars(["x", "y"]) + actual = A.drop_vars(lambda x: x.indexes) + assert_identical(expected, actual) + + def test_drop_multiindex_level(self) -> None: + # GH6505 + expected = self.mda.drop_vars(["x", "level_1", "level_2"]) + with pytest.warns(DeprecationWarning): + actual = self.mda.drop_vars("level_1") + assert_identical(expected, actual) + + def test_drop_all_multiindex_levels(self) -> None: + dim_levels = ["x", "level_1", "level_2"] + actual = self.mda.drop_vars(dim_levels) + # no error, multi-index dropped + for key in dim_levels: + assert key not in actual.xindexes + + def test_drop_index_labels(self) -> None: + arr = DataArray(np.random.randn(2, 3), coords={"y": [0, 1, 2]}, dims=["x", "y"]) + actual = arr.drop_sel(y=[0, 1]) + expected = arr[:, 2:] + assert_identical(actual, expected) + + with pytest.raises((KeyError, ValueError), match=r"not .* in axis"): + actual = arr.drop_sel(y=[0, 1, 3]) + + actual = arr.drop_sel(y=[0, 1, 3], errors="ignore") + assert_identical(actual, expected) + + with pytest.warns(DeprecationWarning): + arr.drop([0, 1, 3], dim="y", errors="ignore") # type: ignore + + def test_drop_index_positions(self) -> None: + arr = DataArray(np.random.randn(2, 3), dims=["x", "y"]) + actual = arr.drop_isel(y=[0, 1]) + expected = arr[:, 2:] + assert_identical(actual, expected) + + def test_drop_indexes(self) -> None: + arr = DataArray([1, 2, 3], coords={"x": ("x", [1, 2, 3])}, dims="x") + actual = arr.drop_indexes("x") + assert "x" not in actual.xindexes + + actual = arr.drop_indexes("not_a_coord", errors="ignore") + assert_identical(actual, arr) + + def test_dropna(self) -> None: + x = np.random.randn(4, 4) + x[::2, 0] = np.nan + arr = DataArray(x, dims=["a", "b"]) + + actual = arr.dropna("a") + expected = arr[1::2] + assert_identical(actual, expected) + + actual = arr.dropna("b", how="all") + assert_identical(actual, arr) + + actual = arr.dropna("a", thresh=1) + assert_identical(actual, arr) + + actual = arr.dropna("b", thresh=3) + expected = arr[:, 1:] + assert_identical(actual, expected) + + def test_where(self) -> None: + arr = DataArray(np.arange(4), dims="x") + expected = arr.sel(x=slice(2)) + actual = arr.where(arr.x < 2, drop=True) + assert_identical(actual, expected) + + def test_where_lambda(self) -> None: + arr = DataArray(np.arange(4), dims="y") + expected = arr.sel(y=slice(2)) + actual = arr.where(lambda x: x.y < 2, drop=True) + assert_identical(actual, expected) + + def test_where_other_lambda(self) -> None: + arr = DataArray(np.arange(4), dims="y") + expected = xr.concat( + [arr.sel(y=slice(2)), arr.sel(y=slice(2, None)) + 1], dim="y" + ) + actual = arr.where(lambda x: x.y < 2, lambda x: x + 1) + assert_identical(actual, expected) + + def test_where_string(self) -> None: + array = DataArray(["a", "b"]) + expected = DataArray(np.array(["a", np.nan], dtype=object)) + actual = array.where([True, False]) + assert_identical(actual, expected) + + def test_cumops(self) -> None: + coords = { + "x": [-1, -2], + "y": ["ab", "cd", "ef"], + "lat": (["x", "y"], [[1, 2, 3], [-1, -2, -3]]), + "c": -999, + } + orig = DataArray([[-1, 0, 1], [-3, 0, 3]], coords, dims=["x", "y"]) + + actual = orig.cumsum() + expected = DataArray([[-1, -1, 0], [-4, -4, 0]], coords, dims=["x", "y"]) + assert_identical(expected, actual) + + actual = orig.cumsum("x") + expected = DataArray([[-1, 0, 1], [-4, 0, 4]], coords, dims=["x", "y"]) + assert_identical(expected, actual) + + actual = orig.cumsum("y") + expected = DataArray([[-1, -1, 0], [-3, -3, 0]], coords, dims=["x", "y"]) + assert_identical(expected, actual) + + actual = orig.cumprod("x") + expected = DataArray([[-1, 0, 1], [3, 0, 3]], coords, dims=["x", "y"]) + assert_identical(expected, actual) + + actual = orig.cumprod("y") + expected = DataArray([[-1, 0, 0], [-3, 0, 0]], coords, dims=["x", "y"]) + assert_identical(expected, actual) + + def test_reduce(self) -> None: + coords = { + "x": [-1, -2], + "y": ["ab", "cd", "ef"], + "lat": (["x", "y"], [[1, 2, 3], [-1, -2, -3]]), + "c": -999, + } + orig = DataArray([[-1, 0, 1], [-3, 0, 3]], coords, dims=["x", "y"]) + + actual = orig.mean() + expected = DataArray(0, {"c": -999}) + assert_identical(expected, actual) + + actual = orig.mean(["x", "y"]) + assert_identical(expected, actual) + + actual = orig.mean("x") + expected = DataArray([-2, 0, 2], {"y": coords["y"], "c": -999}, "y") + assert_identical(expected, actual) + + actual = orig.mean(["x"]) + assert_identical(expected, actual) + + actual = orig.mean("y") + expected = DataArray([0, 0], {"x": coords["x"], "c": -999}, "x") + assert_identical(expected, actual) + + assert_equal(self.dv.reduce(np.mean, "x").variable, self.v.reduce(np.mean, "x")) + + orig = DataArray([[1, 0, np.nan], [3, 0, 3]], coords, dims=["x", "y"]) + actual = orig.count() + expected = DataArray(5, {"c": -999}) + assert_identical(expected, actual) + + # uint support + orig = DataArray(np.arange(6).reshape(3, 2).astype("uint"), dims=["x", "y"]) + assert orig.dtype.kind == "u" + actual = orig.mean(dim="x", skipna=True) + expected = DataArray(orig.values.astype(int), dims=["x", "y"]).mean("x") + assert_equal(actual, expected) + + def test_reduce_keepdims(self) -> None: + coords = { + "x": [-1, -2], + "y": ["ab", "cd", "ef"], + "lat": (["x", "y"], [[1, 2, 3], [-1, -2, -3]]), + "c": -999, + } + orig = DataArray([[-1, 0, 1], [-3, 0, 3]], coords, dims=["x", "y"]) + + # Mean on all axes loses non-constant coordinates + actual = orig.mean(keepdims=True) + expected = DataArray( + orig.data.mean(keepdims=True), + dims=orig.dims, + coords={k: v for k, v in coords.items() if k in ["c"]}, + ) + assert_equal(actual, expected) + + assert actual.sizes["x"] == 1 + assert actual.sizes["y"] == 1 + + # Mean on specific axes loses coordinates not involving that axis + actual = orig.mean("y", keepdims=True) + expected = DataArray( + orig.data.mean(axis=1, keepdims=True), + dims=orig.dims, + coords={k: v for k, v in coords.items() if k not in ["y", "lat"]}, + ) + assert_equal(actual, expected) + + @requires_bottleneck + def test_reduce_keepdims_bottleneck(self) -> None: + import bottleneck + + coords = { + "x": [-1, -2], + "y": ["ab", "cd", "ef"], + "lat": (["x", "y"], [[1, 2, 3], [-1, -2, -3]]), + "c": -999, + } + orig = DataArray([[-1, 0, 1], [-3, 0, 3]], coords, dims=["x", "y"]) + + # Bottleneck does not have its own keepdims implementation + actual = orig.reduce(bottleneck.nanmean, keepdims=True) + expected = orig.mean(keepdims=True) + assert_equal(actual, expected) + + def test_reduce_dtype(self) -> None: + coords = { + "x": [-1, -2], + "y": ["ab", "cd", "ef"], + "lat": (["x", "y"], [[1, 2, 3], [-1, -2, -3]]), + "c": -999, + } + orig = DataArray([[-1, 0, 1], [-3, 0, 3]], coords, dims=["x", "y"]) + + for dtype in [np.float16, np.float32, np.float64]: + assert orig.astype(float).mean(dtype=dtype).dtype == dtype + + def test_reduce_out(self) -> None: + coords = { + "x": [-1, -2], + "y": ["ab", "cd", "ef"], + "lat": (["x", "y"], [[1, 2, 3], [-1, -2, -3]]), + "c": -999, + } + orig = DataArray([[-1, 0, 1], [-3, 0, 3]], coords, dims=["x", "y"]) + + with pytest.raises(TypeError): + orig.mean(out=np.ones(orig.shape)) + + @pytest.mark.parametrize("compute_backend", ["numbagg", None], indirect=True) + @pytest.mark.parametrize("skipna", [True, False, None]) + @pytest.mark.parametrize("q", [0.25, [0.50], [0.25, 0.75]]) + @pytest.mark.parametrize( + "axis, dim", zip([None, 0, [0], [0, 1]], [None, "x", ["x"], ["x", "y"]]) + ) + def test_quantile(self, q, axis, dim, skipna, compute_backend) -> None: + va = self.va.copy(deep=True) + va[0, 0] = np.nan + + actual = DataArray(va).quantile(q, dim=dim, keep_attrs=True, skipna=skipna) + _percentile_func = np.nanpercentile if skipna in (True, None) else np.percentile + expected = _percentile_func(va.values, np.array(q) * 100, axis=axis) + np.testing.assert_allclose(actual.values, expected) + if is_scalar(q): + assert "quantile" not in actual.dims + else: + assert "quantile" in actual.dims + + assert actual.attrs == self.attrs + + @pytest.mark.parametrize("method", ["midpoint", "lower"]) + def test_quantile_method(self, method) -> None: + q = [0.25, 0.5, 0.75] + actual = DataArray(self.va).quantile(q, method=method) + + expected = np.nanquantile(self.dv.values, np.array(q), method=method) + + np.testing.assert_allclose(actual.values, expected) + + @pytest.mark.parametrize("method", ["midpoint", "lower"]) + def test_quantile_interpolation_deprecated(self, method) -> None: + da = DataArray(self.va) + q = [0.25, 0.5, 0.75] + + with pytest.warns( + FutureWarning, + match="`interpolation` argument to quantile was renamed to `method`", + ): + actual = da.quantile(q, interpolation=method) + + expected = da.quantile(q, method=method) + + np.testing.assert_allclose(actual.values, expected.values) + + with warnings.catch_warnings(record=True): + with pytest.raises(TypeError, match="interpolation and method keywords"): + da.quantile(q, method=method, interpolation=method) + + def test_reduce_keep_attrs(self) -> None: + # Test dropped attrs + vm = self.va.mean() + assert len(vm.attrs) == 0 + assert vm.attrs == {} + + # Test kept attrs + vm = self.va.mean(keep_attrs=True) + assert len(vm.attrs) == len(self.attrs) + assert vm.attrs == self.attrs + + def test_assign_attrs(self) -> None: + expected = DataArray([], attrs=dict(a=1, b=2)) + expected.attrs["a"] = 1 + expected.attrs["b"] = 2 + new = DataArray([]) + actual = DataArray([]).assign_attrs(a=1, b=2) + assert_identical(actual, expected) + assert new.attrs == {} + + expected.attrs["c"] = 3 + new_actual = actual.assign_attrs({"c": 3}) + assert_identical(new_actual, expected) + assert actual.attrs == {"a": 1, "b": 2} + + @pytest.mark.parametrize( + "func", [lambda x: x.clip(0, 1), lambda x: np.float64(1.0) * x, np.abs, abs] + ) + def test_propagate_attrs(self, func) -> None: + da = DataArray(self.va) + + # test defaults + assert func(da).attrs == da.attrs + + with set_options(keep_attrs=False): + assert func(da).attrs == {} + + with set_options(keep_attrs=True): + assert func(da).attrs == da.attrs + + def test_fillna(self) -> None: + a = DataArray([np.nan, 1, np.nan, 3], coords={"x": range(4)}, dims="x") + actual = a.fillna(-1) + expected = DataArray([-1, 1, -1, 3], coords={"x": range(4)}, dims="x") + assert_identical(expected, actual) + + b = DataArray(range(4), coords={"x": range(4)}, dims="x") + actual = a.fillna(b) + expected = b.copy() + assert_identical(expected, actual) + + actual = a.fillna(np.arange(4)) + assert_identical(expected, actual) + + actual = a.fillna(b[:3]) + assert_identical(expected, actual) + + actual = a.fillna(b[:0]) + assert_identical(a, actual) + + with pytest.raises(TypeError, match=r"fillna on a DataArray"): + a.fillna({0: 0}) + + with pytest.raises(ValueError, match=r"broadcast"): + a.fillna(np.array([1, 2])) + + def test_align(self) -> None: + array = DataArray( + np.random.random((6, 8)), coords={"x": list("abcdef")}, dims=["x", "y"] + ) + array1, array2 = align(array, array[:5], join="inner") + assert_identical(array1, array[:5]) + assert_identical(array2, array[:5]) + + def test_align_dtype(self) -> None: + # regression test for #264 + x1 = np.arange(30) + x2 = np.arange(5, 35) + a = DataArray(np.random.random((30,)).astype(np.float32), [("x", x1)]) + b = DataArray(np.random.random((30,)).astype(np.float32), [("x", x2)]) + c, d = align(a, b, join="outer") + assert c.dtype == np.float32 + + def test_align_copy(self) -> None: + x = DataArray([1, 2, 3], coords=[("a", [1, 2, 3])]) + y = DataArray([1, 2], coords=[("a", [3, 1])]) + + expected_x2 = x + expected_y2 = DataArray([2, np.nan, 1], coords=[("a", [1, 2, 3])]) + + x2, y2 = align(x, y, join="outer", copy=False) + assert_identical(expected_x2, x2) + assert_identical(expected_y2, y2) + assert source_ndarray(x2.data) is source_ndarray(x.data) + + x2, y2 = align(x, y, join="outer", copy=True) + assert_identical(expected_x2, x2) + assert_identical(expected_y2, y2) + assert source_ndarray(x2.data) is not source_ndarray(x.data) + + # Trivial align - 1 element + x = DataArray([1, 2, 3], coords=[("a", [1, 2, 3])]) + (x2,) = align(x, copy=False) + assert_identical(x, x2) + assert source_ndarray(x2.data) is source_ndarray(x.data) + + (x2,) = align(x, copy=True) + assert_identical(x, x2) + assert source_ndarray(x2.data) is not source_ndarray(x.data) + + def test_align_override(self) -> None: + left = DataArray([1, 2, 3], dims="x", coords={"x": [0, 1, 2]}) + right = DataArray( + np.arange(9).reshape((3, 3)), + dims=["x", "y"], + coords={"x": [0.1, 1.1, 2.1], "y": [1, 2, 3]}, + ) + + expected_right = DataArray( + np.arange(9).reshape(3, 3), + dims=["x", "y"], + coords={"x": [0, 1, 2], "y": [1, 2, 3]}, + ) + + new_left, new_right = align(left, right, join="override") + assert_identical(left, new_left) + assert_identical(new_right, expected_right) + + new_left, new_right = align(left, right, exclude="x", join="override") + assert_identical(left, new_left) + assert_identical(right, new_right) + + new_left, new_right = xr.align( + left.isel(x=0, drop=True), right, exclude="x", join="override" + ) + assert_identical(left.isel(x=0, drop=True), new_left) + assert_identical(right, new_right) + + with pytest.raises( + ValueError, match=r"cannot align.*join.*override.*same size" + ): + align(left.isel(x=0).expand_dims("x"), right, join="override") + + @pytest.mark.parametrize( + "darrays", + [ + [ + DataArray(0), + DataArray([1], [("x", [1])]), + DataArray([2, 3], [("x", [2, 3])]), + ], + [ + DataArray([2, 3], [("x", [2, 3])]), + DataArray([1], [("x", [1])]), + DataArray(0), + ], + ], + ) + def test_align_override_error(self, darrays) -> None: + with pytest.raises( + ValueError, match=r"cannot align.*join.*override.*same size" + ): + xr.align(*darrays, join="override") + + def test_align_exclude(self) -> None: + x = DataArray([[1, 2], [3, 4]], coords=[("a", [-1, -2]), ("b", [3, 4])]) + y = DataArray([[1, 2], [3, 4]], coords=[("a", [-1, 20]), ("b", [5, 6])]) + z = DataArray([1], dims=["a"], coords={"a": [20], "b": 7}) + + x2, y2, z2 = align(x, y, z, join="outer", exclude=["b"]) + expected_x2 = DataArray( + [[3, 4], [1, 2], [np.nan, np.nan]], + coords=[("a", [-2, -1, 20]), ("b", [3, 4])], + ) + expected_y2 = DataArray( + [[np.nan, np.nan], [1, 2], [3, 4]], + coords=[("a", [-2, -1, 20]), ("b", [5, 6])], + ) + expected_z2 = DataArray( + [np.nan, np.nan, 1], dims=["a"], coords={"a": [-2, -1, 20], "b": 7} + ) + assert_identical(expected_x2, x2) + assert_identical(expected_y2, y2) + assert_identical(expected_z2, z2) + + def test_align_indexes(self) -> None: + x = DataArray([1, 2, 3], coords=[("a", [-1, 10, -2])]) + y = DataArray([1, 2], coords=[("a", [-2, -1])]) + + x2, y2 = align(x, y, join="outer", indexes={"a": [10, -1, -2]}) + expected_x2 = DataArray([2, 1, 3], coords=[("a", [10, -1, -2])]) + expected_y2 = DataArray([np.nan, 2, 1], coords=[("a", [10, -1, -2])]) + assert_identical(expected_x2, x2) + assert_identical(expected_y2, y2) + + (x2,) = align(x, join="outer", indexes={"a": [-2, 7, 10, -1]}) + expected_x2 = DataArray([3, np.nan, 2, 1], coords=[("a", [-2, 7, 10, -1])]) + assert_identical(expected_x2, x2) + + def test_align_without_indexes_exclude(self) -> None: + arrays = [DataArray([1, 2, 3], dims=["x"]), DataArray([1, 2], dims=["x"])] + result0, result1 = align(*arrays, exclude=["x"]) + assert_identical(result0, arrays[0]) + assert_identical(result1, arrays[1]) + + def test_align_mixed_indexes(self) -> None: + array_no_coord = DataArray([1, 2], dims=["x"]) + array_with_coord = DataArray([1, 2], coords=[("x", ["a", "b"])]) + result0, result1 = align(array_no_coord, array_with_coord) + assert_identical(result0, array_with_coord) + assert_identical(result1, array_with_coord) + + result0, result1 = align(array_no_coord, array_with_coord, exclude=["x"]) + assert_identical(result0, array_no_coord) + assert_identical(result1, array_with_coord) + + def test_align_without_indexes_errors(self) -> None: + with pytest.raises( + ValueError, + match=r"cannot.*align.*dimension.*conflicting.*sizes.*", + ): + align(DataArray([1, 2, 3], dims=["x"]), DataArray([1, 2], dims=["x"])) + + with pytest.raises( + ValueError, + match=r"cannot.*align.*dimension.*conflicting.*sizes.*", + ): + align( + DataArray([1, 2, 3], dims=["x"]), + DataArray([1, 2], coords=[("x", [0, 1])]), + ) + + def test_align_str_dtype(self) -> None: + a = DataArray([0, 1], dims=["x"], coords={"x": ["a", "b"]}) + b = DataArray([1, 2], dims=["x"], coords={"x": ["b", "c"]}) + + expected_a = DataArray( + [0, 1, np.nan], dims=["x"], coords={"x": ["a", "b", "c"]} + ) + expected_b = DataArray( + [np.nan, 1, 2], dims=["x"], coords={"x": ["a", "b", "c"]} + ) + + actual_a, actual_b = xr.align(a, b, join="outer") + + assert_identical(expected_a, actual_a) + assert expected_a.x.dtype == actual_a.x.dtype + + assert_identical(expected_b, actual_b) + assert expected_b.x.dtype == actual_b.x.dtype + + def test_broadcast_on_vs_off_global_option_different_dims(self) -> None: + xda_1 = xr.DataArray([1], dims="x1") + xda_2 = xr.DataArray([1], dims="x2") + + with xr.set_options(arithmetic_broadcast=True): + expected_xda = xr.DataArray([[1.0]], dims=("x1", "x2")) + actual_xda = xda_1 / xda_2 + assert_identical(actual_xda, expected_xda) + + with xr.set_options(arithmetic_broadcast=False): + with pytest.raises( + ValueError, + match=re.escape( + "Broadcasting is necessary but automatic broadcasting is disabled via " + "global option `'arithmetic_broadcast'`. " + "Use `xr.set_options(arithmetic_broadcast=True)` to enable automatic broadcasting." + ), + ): + xda_1 / xda_2 + + @pytest.mark.parametrize("arithmetic_broadcast", [True, False]) + def test_broadcast_on_vs_off_global_option_same_dims( + self, arithmetic_broadcast: bool + ) -> None: + # Ensure that no error is raised when arithmetic broadcasting is disabled, + # when broadcasting is not needed. The two DataArrays have the same + # dimensions of the same size. + xda_1 = xr.DataArray([1], dims="x") + xda_2 = xr.DataArray([1], dims="x") + expected_xda = xr.DataArray([2.0], dims=("x",)) + + with xr.set_options(arithmetic_broadcast=arithmetic_broadcast): + assert_identical(xda_1 + xda_2, expected_xda) + assert_identical(xda_1 + np.array([1.0]), expected_xda) + assert_identical(np.array([1.0]) + xda_1, expected_xda) + + def test_broadcast_arrays(self) -> None: + x = DataArray([1, 2], coords=[("a", [-1, -2])], name="x") + y = DataArray([1, 2], coords=[("b", [3, 4])], name="y") + x2, y2 = broadcast(x, y) + expected_coords = [("a", [-1, -2]), ("b", [3, 4])] + expected_x2 = DataArray([[1, 1], [2, 2]], expected_coords, name="x") + expected_y2 = DataArray([[1, 2], [1, 2]], expected_coords, name="y") + assert_identical(expected_x2, x2) + assert_identical(expected_y2, y2) + + x = DataArray(np.random.randn(2, 3), dims=["a", "b"]) + y = DataArray(np.random.randn(3, 2), dims=["b", "a"]) + x2, y2 = broadcast(x, y) + expected_x2 = x + expected_y2 = y.T + assert_identical(expected_x2, x2) + assert_identical(expected_y2, y2) + + def test_broadcast_arrays_misaligned(self) -> None: + # broadcast on misaligned coords must auto-align + x = DataArray([[1, 2], [3, 4]], coords=[("a", [-1, -2]), ("b", [3, 4])]) + y = DataArray([1, 2], coords=[("a", [-1, 20])]) + expected_x2 = DataArray( + [[3, 4], [1, 2], [np.nan, np.nan]], + coords=[("a", [-2, -1, 20]), ("b", [3, 4])], + ) + expected_y2 = DataArray( + [[np.nan, np.nan], [1, 1], [2, 2]], + coords=[("a", [-2, -1, 20]), ("b", [3, 4])], + ) + x2, y2 = broadcast(x, y) + assert_identical(expected_x2, x2) + assert_identical(expected_y2, y2) + + def test_broadcast_arrays_nocopy(self) -> None: + # Test that input data is not copied over in case + # no alteration is needed + x = DataArray([1, 2], coords=[("a", [-1, -2])], name="x") + y = DataArray(3, name="y") + expected_x2 = DataArray([1, 2], coords=[("a", [-1, -2])], name="x") + expected_y2 = DataArray([3, 3], coords=[("a", [-1, -2])], name="y") + + x2, y2 = broadcast(x, y) + assert_identical(expected_x2, x2) + assert_identical(expected_y2, y2) + assert source_ndarray(x2.data) is source_ndarray(x.data) + + # single-element broadcast (trivial case) + (x2,) = broadcast(x) + assert_identical(x, x2) + assert source_ndarray(x2.data) is source_ndarray(x.data) + + def test_broadcast_arrays_exclude(self) -> None: + x = DataArray([[1, 2], [3, 4]], coords=[("a", [-1, -2]), ("b", [3, 4])]) + y = DataArray([1, 2], coords=[("a", [-1, 20])]) + z = DataArray(5, coords={"b": 5}) + + x2, y2, z2 = broadcast(x, y, z, exclude=["b"]) + expected_x2 = DataArray( + [[3, 4], [1, 2], [np.nan, np.nan]], + coords=[("a", [-2, -1, 20]), ("b", [3, 4])], + ) + expected_y2 = DataArray([np.nan, 1, 2], coords=[("a", [-2, -1, 20])]) + expected_z2 = DataArray( + [5, 5, 5], dims=["a"], coords={"a": [-2, -1, 20], "b": 5} + ) + assert_identical(expected_x2, x2) + assert_identical(expected_y2, y2) + assert_identical(expected_z2, z2) + + def test_broadcast_coordinates(self) -> None: + # regression test for GH649 + ds = Dataset({"a": (["x", "y"], np.ones((5, 6)))}) + x_bc, y_bc, a_bc = broadcast(ds.x, ds.y, ds.a) + assert_identical(ds.a, a_bc) + + X, Y = np.meshgrid(np.arange(5), np.arange(6), indexing="ij") + exp_x = DataArray(X, dims=["x", "y"], name="x") + exp_y = DataArray(Y, dims=["x", "y"], name="y") + assert_identical(exp_x, x_bc) + assert_identical(exp_y, y_bc) + + def test_to_pandas(self) -> None: + # 0d + actual = DataArray(42).to_pandas() + expected = np.array(42) + assert_array_equal(actual, expected) + + # 1d + values = np.random.randn(3) + index = pd.Index(["a", "b", "c"], name="x") + da = DataArray(values, coords=[index]) + actual = da.to_pandas() + assert_array_equal(actual.values, values) + assert_array_equal(actual.index, index) + assert_array_equal(actual.index.name, "x") + + # 2d + values = np.random.randn(3, 2) + da = DataArray( + values, coords=[("x", ["a", "b", "c"]), ("y", [0, 1])], name="foo" + ) + actual = da.to_pandas() + assert_array_equal(actual.values, values) + assert_array_equal(actual.index, ["a", "b", "c"]) + assert_array_equal(actual.columns, [0, 1]) + + # roundtrips + for shape in [(3,), (3, 4)]: + dims = list("abc")[: len(shape)] + da = DataArray(np.random.randn(*shape), dims=dims) + roundtripped = DataArray(da.to_pandas()).drop_vars(dims) + assert_identical(da, roundtripped) + + with pytest.raises(ValueError, match=r"Cannot convert"): + DataArray(np.random.randn(1, 2, 3, 4, 5)).to_pandas() + + def test_to_dataframe(self) -> None: + # regression test for #260 + arr_np = np.random.randn(3, 4) + + arr = DataArray(arr_np, [("B", [1, 2, 3]), ("A", list("cdef"))], name="foo") + expected = arr.to_series() + actual = arr.to_dataframe()["foo"] + assert_array_equal(expected.values, actual.values) + assert_array_equal(expected.name, actual.name) + assert_array_equal(expected.index.values, actual.index.values) + + actual = arr.to_dataframe(dim_order=["A", "B"])["foo"] + assert_array_equal(arr_np.transpose().reshape(-1), actual.values) + + # regression test for coords with different dimensions + arr.coords["C"] = ("B", [-1, -2, -3]) + expected = arr.to_series().to_frame() + expected["C"] = [-1] * 4 + [-2] * 4 + [-3] * 4 + expected = expected[["C", "foo"]] + actual = arr.to_dataframe() + assert_array_equal(expected.values, actual.values) + assert_array_equal(expected.columns.values, actual.columns.values) + assert_array_equal(expected.index.values, actual.index.values) + + with pytest.raises(ValueError, match="does not match the set of dimensions"): + arr.to_dataframe(dim_order=["B", "A", "C"]) + + with pytest.raises(ValueError, match=r"cannot convert a scalar"): + arr.sel(A="c", B=2).to_dataframe() + + arr.name = None # unnamed + with pytest.raises(ValueError, match=r"unnamed"): + arr.to_dataframe() + + def test_to_dataframe_multiindex(self) -> None: + # regression test for #3008 + arr_np = np.random.randn(4, 3) + + mindex = pd.MultiIndex.from_product([[1, 2], list("ab")], names=["A", "B"]) + + arr = DataArray(arr_np, [("MI", mindex), ("C", [5, 6, 7])], name="foo") + + actual = arr.to_dataframe() + assert_array_equal(actual["foo"].values, arr_np.flatten()) + assert_array_equal(actual.index.names, list("ABC")) + assert_array_equal(actual.index.levels[0], [1, 2]) + assert_array_equal(actual.index.levels[1], ["a", "b"]) + assert_array_equal(actual.index.levels[2], [5, 6, 7]) + + def test_to_dataframe_0length(self) -> None: + # regression test for #3008 + arr_np = np.random.randn(4, 0) + + mindex = pd.MultiIndex.from_product([[1, 2], list("ab")], names=["A", "B"]) + + arr = DataArray(arr_np, [("MI", mindex), ("C", [])], name="foo") + + actual = arr.to_dataframe() + assert len(actual) == 0 + assert_array_equal(actual.index.names, list("ABC")) + + @requires_dask_expr + @requires_dask + @pytest.mark.xfail(reason="dask-expr is broken") + def test_to_dask_dataframe(self) -> None: + arr_np = np.arange(3 * 4).reshape(3, 4) + arr = DataArray(arr_np, [("B", [1, 2, 3]), ("A", list("cdef"))], name="foo") + expected = arr.to_series() + actual = arr.to_dask_dataframe()["foo"] + + assert_array_equal(actual.values, expected.values) + + actual = arr.to_dask_dataframe(dim_order=["A", "B"])["foo"] + assert_array_equal(arr_np.transpose().reshape(-1), actual.values) + + # regression test for coords with different dimensions + + arr.coords["C"] = ("B", [-1, -2, -3]) + expected = arr.to_series().to_frame() + expected["C"] = [-1] * 4 + [-2] * 4 + [-3] * 4 + expected = expected[["C", "foo"]] + actual = arr.to_dask_dataframe()[["C", "foo"]] + + assert_array_equal(expected.values, actual.values) + assert_array_equal(expected.columns.values, actual.columns.values) + + with pytest.raises(ValueError, match="does not match the set of dimensions"): + arr.to_dask_dataframe(dim_order=["B", "A", "C"]) + + arr.name = None + with pytest.raises( + ValueError, + match="Cannot convert an unnamed DataArray", + ): + arr.to_dask_dataframe() + + def test_to_pandas_name_matches_coordinate(self) -> None: + # coordinate with same name as array + arr = DataArray([1, 2, 3], dims="x", name="x") + series = arr.to_series() + assert_array_equal([1, 2, 3], series.values) + assert_array_equal([0, 1, 2], series.index.values) + assert "x" == series.name + assert "x" == series.index.name + + frame = arr.to_dataframe() + expected = series.to_frame() + assert expected.equals(frame) + + def test_to_and_from_series(self) -> None: + expected = self.dv.to_dataframe()["foo"] + actual = self.dv.to_series() + assert_array_equal(expected.values, actual.values) + assert_array_equal(expected.index.values, actual.index.values) + assert "foo" == actual.name + # test roundtrip + assert_identical(self.dv, DataArray.from_series(actual).drop_vars(["x", "y"])) + # test name is None + actual.name = None + expected_da = self.dv.rename(None) + assert_identical( + expected_da, DataArray.from_series(actual).drop_vars(["x", "y"]) + ) + + def test_from_series_multiindex(self) -> None: + # GH:3951 + df = pd.DataFrame({"B": [1, 2, 3], "A": [4, 5, 6]}) + df = df.rename_axis("num").rename_axis("alpha", axis=1) + actual = df.stack("alpha").to_xarray() + assert (actual.sel(alpha="B") == [1, 2, 3]).all() + assert (actual.sel(alpha="A") == [4, 5, 6]).all() + + @requires_sparse + def test_from_series_sparse(self) -> None: + import sparse + + series = pd.Series([1, 2], index=[("a", 1), ("b", 2)]) + + actual_sparse = DataArray.from_series(series, sparse=True) + actual_dense = DataArray.from_series(series, sparse=False) + + assert isinstance(actual_sparse.data, sparse.COO) + actual_sparse.data = actual_sparse.data.todense() + assert_identical(actual_sparse, actual_dense) + + @requires_sparse + def test_from_multiindex_series_sparse(self) -> None: + # regression test for GH4019 + import sparse + + idx = pd.MultiIndex.from_product([np.arange(3), np.arange(5)], names=["a", "b"]) + series = pd.Series(np.random.RandomState(0).random(len(idx)), index=idx).sample( + n=5, random_state=3 + ) + + dense = DataArray.from_series(series, sparse=False) + expected_coords = sparse.COO.from_numpy(dense.data, np.nan).coords + + actual_sparse = xr.DataArray.from_series(series, sparse=True) + actual_coords = actual_sparse.data.coords + + np.testing.assert_equal(actual_coords, expected_coords) + + def test_nbytes_does_not_load_data(self) -> None: + array = InaccessibleArray(np.zeros((3, 3), dtype="uint8")) + da = xr.DataArray(array, dims=["x", "y"]) + + # If xarray tries to instantiate the InaccessibleArray to compute + # nbytes, the following will raise an error. + # However, it should still be able to accurately give us information + # about the number of bytes from the metadata + assert da.nbytes == 9 + # Here we confirm that this does not depend on array having the + # nbytes property, since it isn't really required by the array + # interface. nbytes is more a property of arrays that have been + # cast to numpy arrays. + assert not hasattr(array, "nbytes") + + def test_to_and_from_empty_series(self) -> None: + # GH697 + expected = pd.Series([], dtype=np.float64) + da = DataArray.from_series(expected) + assert len(da) == 0 + actual = da.to_series() + assert len(actual) == 0 + assert expected.equals(actual) + + def test_series_categorical_index(self) -> None: + # regression test for GH700 + if not hasattr(pd, "CategoricalIndex"): + pytest.skip("requires pandas with CategoricalIndex") + + s = pd.Series(np.arange(5), index=pd.CategoricalIndex(list("aabbc"))) + arr = DataArray(s) + assert "'a'" in repr(arr) # should not error + + @pytest.mark.parametrize("use_dask", [True, False]) + @pytest.mark.parametrize("data", ["list", "array", True]) + @pytest.mark.parametrize("encoding", [True, False]) + def test_to_and_from_dict( + self, encoding: bool, data: bool | Literal["list", "array"], use_dask: bool + ) -> None: + if use_dask and not has_dask: + pytest.skip("requires dask") + encoding_data = {"bar": "spam"} + array = DataArray( + np.random.randn(2, 3), {"x": ["a", "b"]}, ["x", "y"], name="foo" + ) + array.encoding = encoding_data + + return_data = array.to_numpy() + coords_data = np.array(["a", "b"]) + if data == "list" or data is True: + return_data = return_data.tolist() + coords_data = coords_data.tolist() + + expected: dict[str, Any] = { + "name": "foo", + "dims": ("x", "y"), + "data": return_data, + "attrs": {}, + "coords": {"x": {"dims": ("x",), "data": coords_data, "attrs": {}}}, + } + if encoding: + expected["encoding"] = encoding_data + + if has_dask: + da = array.chunk() + else: + da = array + + if data == "array" or data is False: + with raise_if_dask_computes(): + actual = da.to_dict(encoding=encoding, data=data) + else: + actual = da.to_dict(encoding=encoding, data=data) + + # check that they are identical + np.testing.assert_equal(expected, actual) + + # check roundtrip + assert_identical(da, DataArray.from_dict(actual)) + + # a more bare bones representation still roundtrips + d = { + "name": "foo", + "dims": ("x", "y"), + "data": da.values.tolist(), + "coords": {"x": {"dims": "x", "data": ["a", "b"]}}, + } + assert_identical(da, DataArray.from_dict(d)) + + # and the most bare bones representation still roundtrips + d = {"name": "foo", "dims": ("x", "y"), "data": da.values} + assert_identical(da.drop_vars("x"), DataArray.from_dict(d)) + + # missing a dims in the coords + d = { + "dims": ("x", "y"), + "data": da.values, + "coords": {"x": {"data": ["a", "b"]}}, + } + with pytest.raises( + ValueError, + match=r"cannot convert dict when coords are missing the key 'dims'", + ): + DataArray.from_dict(d) + + # this one is missing some necessary information + d = {"dims": "t"} + with pytest.raises( + ValueError, match=r"cannot convert dict without the key 'data'" + ): + DataArray.from_dict(d) + + # check the data=False option + expected_no_data = expected.copy() + del expected_no_data["data"] + del expected_no_data["coords"]["x"]["data"] + endiantype = "U1" + expected_no_data["coords"]["x"].update({"dtype": endiantype, "shape": (2,)}) + expected_no_data.update({"dtype": "float64", "shape": (2, 3)}) + actual_no_data = da.to_dict(data=False, encoding=encoding) + assert expected_no_data == actual_no_data + + @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") + def test_to_and_from_dict_with_time_dim(self) -> None: + x = np.random.randn(10, 3) + t = pd.date_range("20130101", periods=10) + lat = [77.7, 83.2, 76] + da = DataArray(x, {"t": t, "lat": lat}, dims=["t", "lat"]) + roundtripped = DataArray.from_dict(da.to_dict()) + assert_identical(da, roundtripped) + + @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") + def test_to_and_from_dict_with_nan_nat(self) -> None: + y = np.random.randn(10, 3) + y[2] = np.nan + t = pd.Series(pd.date_range("20130101", periods=10)) + t[2] = np.nan + lat = [77.7, 83.2, 76] + da = DataArray(y, {"t": t, "lat": lat}, dims=["t", "lat"]) + roundtripped = DataArray.from_dict(da.to_dict()) + assert_identical(da, roundtripped) + + def test_to_dict_with_numpy_attrs(self) -> None: + # this doesn't need to roundtrip + x = np.random.randn(10, 3) + t = list("abcdefghij") + lat = [77.7, 83.2, 76] + attrs = { + "created": np.float64(1998), + "coords": np.array([37, -110.1, 100]), + "maintainer": "bar", + } + da = DataArray(x, {"t": t, "lat": lat}, dims=["t", "lat"], attrs=attrs) + expected_attrs = { + "created": attrs["created"].item(), # type: ignore[attr-defined] + "coords": attrs["coords"].tolist(), # type: ignore[attr-defined] + "maintainer": "bar", + } + actual = da.to_dict() + + # check that they are identical + assert expected_attrs == actual["attrs"] + + def test_to_masked_array(self) -> None: + rs = np.random.RandomState(44) + x = rs.random_sample(size=(10, 20)) + x_masked = np.ma.masked_where(x < 0.5, x) + da = DataArray(x_masked) + + # Test round trip + x_masked_2 = da.to_masked_array() + da_2 = DataArray(x_masked_2) + assert_array_equal(x_masked, x_masked_2) + assert_equal(da, da_2) + + da_masked_array = da.to_masked_array(copy=True) + assert isinstance(da_masked_array, np.ma.MaskedArray) + # Test masks + assert_array_equal(da_masked_array.mask, x_masked.mask) + # Test that mask is unpacked correctly + assert_array_equal(da.values, x_masked.filled(np.nan)) + # Test that the underlying data (including nans) hasn't changed + assert_array_equal(da_masked_array, x_masked.filled(np.nan)) + + # Test that copy=False gives access to values + masked_array = da.to_masked_array(copy=False) + masked_array[0, 0] = 10.0 + assert masked_array[0, 0] == 10.0 + assert da[0, 0].values == 10.0 + assert masked_array.base is da.values + assert isinstance(masked_array, np.ma.MaskedArray) + + # Test with some odd arrays + for v in [4, np.nan, True, "4", "four"]: + da = DataArray(v) + ma = da.to_masked_array() + assert isinstance(ma, np.ma.MaskedArray) + + # Fix GH issue 684 - masked arrays mask should be an array not a scalar + N = 4 + v = range(N) + da = DataArray(v) + ma = da.to_masked_array() + assert len(ma.mask) == N + + def test_to_dataset_whole(self) -> None: + unnamed = DataArray([1, 2], dims="x") + with pytest.raises(ValueError, match=r"unable to convert unnamed"): + unnamed.to_dataset() + + actual = unnamed.to_dataset(name="foo") + expected = Dataset({"foo": ("x", [1, 2])}) + assert_identical(expected, actual) + + named = DataArray([1, 2], dims="x", name="foo", attrs={"y": "testattr"}) + actual = named.to_dataset() + expected = Dataset({"foo": ("x", [1, 2], {"y": "testattr"})}) + assert_identical(expected, actual) + + # Test promoting attrs + actual = named.to_dataset(promote_attrs=True) + expected = Dataset( + {"foo": ("x", [1, 2], {"y": "testattr"})}, attrs={"y": "testattr"} + ) + assert_identical(expected, actual) + + with pytest.raises(TypeError): + actual = named.to_dataset("bar") + + def test_to_dataset_split(self) -> None: + array = DataArray( + [[1, 2], [3, 4], [5, 6]], + coords=[("x", list("abc")), ("y", [0.0, 0.1])], + attrs={"a": 1}, + ) + expected = Dataset( + {"a": ("y", [1, 2]), "b": ("y", [3, 4]), "c": ("y", [5, 6])}, + coords={"y": [0.0, 0.1]}, + attrs={"a": 1}, + ) + actual = array.to_dataset("x") + assert_identical(expected, actual) + + with pytest.raises(TypeError): + array.to_dataset("x", name="foo") + + roundtripped = actual.to_dataarray(dim="x") + assert_identical(array, roundtripped) + + array = DataArray([1, 2, 3], dims="x") + expected = Dataset({0: 1, 1: 2, 2: 3}) + actual = array.to_dataset("x") + assert_identical(expected, actual) + + def test_to_dataset_retains_keys(self) -> None: + # use dates as convenient non-str objects. Not a specific date test + import datetime + + dates = [datetime.date(2000, 1, d) for d in range(1, 4)] + + array = DataArray([1, 2, 3], coords=[("x", dates)], attrs={"a": 1}) + + # convert to dateset and back again + result = array.to_dataset("x").to_dataarray(dim="x") + + assert_equal(array, result) + + def test_to_dataset_coord_value_is_dim(self) -> None: + # github issue #7823 + + array = DataArray( + np.zeros((3, 3)), + coords={ + # 'a' is both a coordinate value and the name of a coordinate + "x": ["a", "b", "c"], + "a": [1, 2, 3], + }, + ) + + with pytest.raises( + ValueError, + match=( + re.escape("dimension 'x' would produce the variables ('a',)") + + ".*" + + re.escape("DataArray.rename(a=...) or DataArray.assign_coords(x=...)") + ), + ): + array.to_dataset("x") + + # test error message formatting when there are multiple ambiguous + # values/coordinates + array2 = DataArray( + np.zeros((3, 3, 2)), + coords={ + "x": ["a", "b", "c"], + "a": [1, 2, 3], + "b": [0.0, 0.1], + }, + ) + + with pytest.raises( + ValueError, + match=( + re.escape("dimension 'x' would produce the variables ('a', 'b')") + + ".*" + + re.escape( + "DataArray.rename(a=..., b=...) or DataArray.assign_coords(x=...)" + ) + ), + ): + array2.to_dataset("x") + + def test__title_for_slice(self) -> None: + array = DataArray( + np.ones((4, 3, 2)), + dims=["a", "b", "c"], + coords={"a": range(4), "b": range(3), "c": range(2)}, + ) + assert "" == array._title_for_slice() + assert "c = 0" == array.isel(c=0)._title_for_slice() + title = array.isel(b=1, c=0)._title_for_slice() + assert "b = 1, c = 0" == title or "c = 0, b = 1" == title + + a2 = DataArray(np.ones((4, 1)), dims=["a", "b"]) + assert "" == a2._title_for_slice() + + def test__title_for_slice_truncate(self) -> None: + array = DataArray(np.ones(4)) + array.coords["a"] = "a" * 100 + array.coords["b"] = "b" * 100 + + nchar = 80 + title = array._title_for_slice(truncate=nchar) + + assert nchar == len(title) + assert title.endswith("...") + + def test_dataarray_diff_n1(self) -> None: + da = DataArray(np.random.randn(3, 4), dims=["x", "y"]) + actual = da.diff("y") + expected = DataArray(np.diff(da.values, axis=1), dims=["x", "y"]) + assert_equal(expected, actual) + + def test_coordinate_diff(self) -> None: + # regression test for GH634 + arr = DataArray(range(0, 20, 2), dims=["lon"], coords=[range(10)]) + lon = arr.coords["lon"] + expected = DataArray([1] * 9, dims=["lon"], coords=[range(1, 10)], name="lon") + actual = lon.diff("lon") + assert_equal(expected, actual) + + @pytest.mark.parametrize("offset", [-5, 0, 1, 2]) + @pytest.mark.parametrize("fill_value, dtype", [(2, int), (dtypes.NA, float)]) + def test_shift(self, offset, fill_value, dtype) -> None: + arr = DataArray([1, 2, 3], dims="x") + actual = arr.shift(x=1, fill_value=fill_value) + if fill_value == dtypes.NA: + # if we supply the default, we expect the missing value for a + # float array + fill_value = np.nan + expected = DataArray([fill_value, 1, 2], dims="x") + assert_identical(expected, actual) + assert actual.dtype == dtype + + arr = DataArray([1, 2, 3], [("x", ["a", "b", "c"])]) + expected = DataArray(arr.to_pandas().shift(offset)) + actual = arr.shift(x=offset) + assert_identical(expected, actual) + + def test_roll_coords(self) -> None: + arr = DataArray([1, 2, 3], coords={"x": range(3)}, dims="x") + actual = arr.roll(x=1, roll_coords=True) + expected = DataArray([3, 1, 2], coords=[("x", [2, 0, 1])]) + assert_identical(expected, actual) + + def test_roll_no_coords(self) -> None: + arr = DataArray([1, 2, 3], coords={"x": range(3)}, dims="x") + actual = arr.roll(x=1) + expected = DataArray([3, 1, 2], coords=[("x", [0, 1, 2])]) + assert_identical(expected, actual) + + def test_copy_with_data(self) -> None: + orig = DataArray( + np.random.random(size=(2, 2)), + dims=("x", "y"), + attrs={"attr1": "value1"}, + coords={"x": [4, 3]}, + name="helloworld", + ) + new_data = np.arange(4).reshape(2, 2) + actual = orig.copy(data=new_data) + expected = orig.copy() + expected.data = new_data + assert_identical(expected, actual) + + @pytest.mark.xfail(raises=AssertionError) + @pytest.mark.parametrize( + "deep, expected_orig", + [ + [ + True, + xr.DataArray( + xr.IndexVariable("a", np.array([1, 2])), + coords={"a": [1, 2]}, + dims=["a"], + ), + ], + [ + False, + xr.DataArray( + xr.IndexVariable("a", np.array([999, 2])), + coords={"a": [999, 2]}, + dims=["a"], + ), + ], + ], + ) + def test_copy_coords(self, deep, expected_orig) -> None: + """The test fails for the shallow copy, and apparently only on Windows + for some reason. In windows coords seem to be immutable unless it's one + dataarray deep copied from another.""" + da = xr.DataArray( + np.ones([2, 2, 2]), + coords={"a": [1, 2], "b": ["x", "y"], "c": [0, 1]}, + dims=["a", "b", "c"], + ) + da_cp = da.copy(deep) + new_a = np.array([999, 2]) + da_cp.coords["a"] = da_cp["a"].copy(data=new_a) + + expected_cp = xr.DataArray( + xr.IndexVariable("a", np.array([999, 2])), + coords={"a": [999, 2]}, + dims=["a"], + ) + assert_identical(da_cp["a"], expected_cp) + + assert_identical(da["a"], expected_orig) + + def test_real_and_imag(self) -> None: + array = DataArray(1 + 2j) + assert_identical(array.real, DataArray(1)) + assert_identical(array.imag, DataArray(2)) + + def test_setattr_raises(self) -> None: + array = DataArray(0, coords={"scalar": 1}, attrs={"foo": "bar"}) + with pytest.raises(AttributeError, match=r"cannot set attr"): + array.scalar = 2 + with pytest.raises(AttributeError, match=r"cannot set attr"): + array.foo = 2 + with pytest.raises(AttributeError, match=r"cannot set attr"): + array.other = 2 + + def test_full_like(self) -> None: + # For more thorough tests, see test_variable.py + da = DataArray( + np.random.random(size=(2, 2)), + dims=("x", "y"), + attrs={"attr1": "value1"}, + coords={"x": [4, 3]}, + name="helloworld", + ) + + actual = full_like(da, 2) + expect = da.copy(deep=True) + expect.values = np.array([[2.0, 2.0], [2.0, 2.0]]) + assert_identical(expect, actual) + + # override dtype + actual = full_like(da, fill_value=True, dtype=bool) + expect.values = np.array([[True, True], [True, True]]) + assert expect.dtype == bool + assert_identical(expect, actual) + + with pytest.raises(ValueError, match="'dtype' cannot be dict-like"): + full_like(da, fill_value=True, dtype={"x": bool}) + + def test_dot(self) -> None: + x = np.linspace(-3, 3, 6) + y = np.linspace(-3, 3, 5) + z = range(4) + da_vals = np.arange(6 * 5 * 4).reshape((6, 5, 4)) + da = DataArray(da_vals, coords=[x, y, z], dims=["x", "y", "z"]) + + dm_vals1 = range(4) + dm1 = DataArray(dm_vals1, coords=[z], dims=["z"]) + + # nd dot 1d + actual1 = da.dot(dm1) + expected_vals1 = np.tensordot(da_vals, dm_vals1, (2, 0)) + expected1 = DataArray(expected_vals1, coords=[x, y], dims=["x", "y"]) + assert_equal(expected1, actual1) + + # all shared dims + actual2 = da.dot(da) + expected_vals2 = np.tensordot(da_vals, da_vals, axes=([0, 1, 2], [0, 1, 2])) + expected2 = DataArray(expected_vals2) + assert_equal(expected2, actual2) + + # multiple shared dims + dm_vals3 = np.arange(20 * 5 * 4).reshape((20, 5, 4)) + j = np.linspace(-3, 3, 20) + dm3 = DataArray(dm_vals3, coords=[j, y, z], dims=["j", "y", "z"]) + actual3 = da.dot(dm3) + expected_vals3 = np.tensordot(da_vals, dm_vals3, axes=([1, 2], [1, 2])) + expected3 = DataArray(expected_vals3, coords=[x, j], dims=["x", "j"]) + assert_equal(expected3, actual3) + + # Ellipsis: all dims are shared + actual4 = da.dot(da, dim=...) + expected4 = da.dot(da) + assert_equal(expected4, actual4) + + # Ellipsis: not all dims are shared + actual5 = da.dot(dm3, dim=...) + expected5 = da.dot(dm3, dim=("j", "x", "y", "z")) + assert_equal(expected5, actual5) + + with pytest.raises(NotImplementedError): + da.dot(dm3.to_dataset(name="dm")) + with pytest.raises(TypeError): + da.dot(dm3.values) # type: ignore + + def test_dot_align_coords(self) -> None: + # GH 3694 + + x = np.linspace(-3, 3, 6) + y = np.linspace(-3, 3, 5) + z_a = range(4) + da_vals = np.arange(6 * 5 * 4).reshape((6, 5, 4)) + da = DataArray(da_vals, coords=[x, y, z_a], dims=["x", "y", "z"]) + + z_m = range(2, 6) + dm_vals1 = range(4) + dm1 = DataArray(dm_vals1, coords=[z_m], dims=["z"]) + + with xr.set_options(arithmetic_join="exact"): + with pytest.raises( + ValueError, match=r"cannot align.*join.*exact.*not equal.*" + ): + da.dot(dm1) + + da_aligned, dm_aligned = xr.align(da, dm1, join="inner") + + # nd dot 1d + actual1 = da.dot(dm1) + expected_vals1 = np.tensordot(da_aligned.values, dm_aligned.values, (2, 0)) + expected1 = DataArray(expected_vals1, coords=[x, da_aligned.y], dims=["x", "y"]) + assert_equal(expected1, actual1) + + # multiple shared dims + dm_vals2 = np.arange(20 * 5 * 4).reshape((20, 5, 4)) + j = np.linspace(-3, 3, 20) + dm2 = DataArray(dm_vals2, coords=[j, y, z_m], dims=["j", "y", "z"]) + da_aligned, dm_aligned = xr.align(da, dm2, join="inner") + actual2 = da.dot(dm2) + expected_vals2 = np.tensordot( + da_aligned.values, dm_aligned.values, axes=([1, 2], [1, 2]) + ) + expected2 = DataArray(expected_vals2, coords=[x, j], dims=["x", "j"]) + assert_equal(expected2, actual2) + + def test_matmul(self) -> None: + # copied from above (could make a fixture) + x = np.linspace(-3, 3, 6) + y = np.linspace(-3, 3, 5) + z = range(4) + da_vals = np.arange(6 * 5 * 4).reshape((6, 5, 4)) + da = DataArray(da_vals, coords=[x, y, z], dims=["x", "y", "z"]) + + result = da @ da + expected = da.dot(da) + assert_identical(result, expected) + + def test_matmul_align_coords(self) -> None: + # GH 3694 + + x_a = np.arange(6) + x_b = np.arange(2, 8) + da_vals = np.arange(6) + da_a = DataArray(da_vals, coords=[x_a], dims=["x"]) + da_b = DataArray(da_vals, coords=[x_b], dims=["x"]) + + # only test arithmetic_join="inner" (=default) + result = da_a @ da_b + expected = da_a.dot(da_b) + assert_identical(result, expected) + + with xr.set_options(arithmetic_join="exact"): + with pytest.raises( + ValueError, match=r"cannot align.*join.*exact.*not equal.*" + ): + da_a @ da_b + + def test_binary_op_propagate_indexes(self) -> None: + # regression test for GH2227 + self.dv["x"] = np.arange(self.dv.sizes["x"]) + expected = self.dv.xindexes["x"] + + actual = (self.dv * 10).xindexes["x"] + assert expected is actual + + actual = (self.dv > 10).xindexes["x"] + assert expected is actual + + # use mda for bitshift test as it's type int + actual = (self.mda << 2).xindexes["x"] + expected = self.mda.xindexes["x"] + assert expected is actual + + def test_binary_op_join_setting(self) -> None: + dim = "x" + align_type: Final = "outer" + coords_l, coords_r = [0, 1, 2], [1, 2, 3] + missing_3 = xr.DataArray(coords_l, [(dim, coords_l)]) + missing_0 = xr.DataArray(coords_r, [(dim, coords_r)]) + with xr.set_options(arithmetic_join=align_type): + actual = missing_0 + missing_3 + missing_0_aligned, missing_3_aligned = xr.align( + missing_0, missing_3, join=align_type + ) + expected = xr.DataArray([np.nan, 2, 4, np.nan], [(dim, [0, 1, 2, 3])]) + assert_equal(actual, expected) + + def test_combine_first(self) -> None: + ar0 = DataArray([[0, 0], [0, 0]], [("x", ["a", "b"]), ("y", [-1, 0])]) + ar1 = DataArray([[1, 1], [1, 1]], [("x", ["b", "c"]), ("y", [0, 1])]) + ar2 = DataArray([2], [("x", ["d"])]) + + actual = ar0.combine_first(ar1) + expected = DataArray( + [[0, 0, np.nan], [0, 0, 1], [np.nan, 1, 1]], + [("x", ["a", "b", "c"]), ("y", [-1, 0, 1])], + ) + assert_equal(actual, expected) + + actual = ar1.combine_first(ar0) + expected = DataArray( + [[0, 0, np.nan], [0, 1, 1], [np.nan, 1, 1]], + [("x", ["a", "b", "c"]), ("y", [-1, 0, 1])], + ) + assert_equal(actual, expected) + + actual = ar0.combine_first(ar2) + expected = DataArray( + [[0, 0], [0, 0], [2, 2]], [("x", ["a", "b", "d"]), ("y", [-1, 0])] + ) + assert_equal(actual, expected) + + def test_sortby(self) -> None: + da = DataArray( + [[1, 2], [3, 4], [5, 6]], [("x", ["c", "b", "a"]), ("y", [1, 0])] + ) + + sorted1d = DataArray( + [[5, 6], [3, 4], [1, 2]], [("x", ["a", "b", "c"]), ("y", [1, 0])] + ) + + sorted2d = DataArray( + [[6, 5], [4, 3], [2, 1]], [("x", ["a", "b", "c"]), ("y", [0, 1])] + ) + + expected = sorted1d + dax = DataArray([100, 99, 98], [("x", ["c", "b", "a"])]) + actual = da.sortby(dax) + assert_equal(actual, expected) + + # test descending order sort + actual = da.sortby(dax, ascending=False) + assert_equal(actual, da) + + # test alignment (fills in nan for 'c') + dax_short = DataArray([98, 97], [("x", ["b", "a"])]) + actual = da.sortby(dax_short) + assert_equal(actual, expected) + + # test multi-dim sort by 1D dataarray values + expected = sorted2d + dax = DataArray([100, 99, 98], [("x", ["c", "b", "a"])]) + day = DataArray([90, 80], [("y", [1, 0])]) + actual = da.sortby([day, dax]) + assert_equal(actual, expected) + + expected = sorted1d + actual = da.sortby("x") + assert_equal(actual, expected) + + expected = sorted2d + actual = da.sortby(["x", "y"]) + assert_equal(actual, expected) + + @requires_bottleneck + def test_rank(self) -> None: + # floats + ar = DataArray([[3, 4, np.nan, 1]]) + expect_0 = DataArray([[1, 1, np.nan, 1]]) + expect_1 = DataArray([[2, 3, np.nan, 1]]) + assert_equal(ar.rank("dim_0"), expect_0) + assert_equal(ar.rank("dim_1"), expect_1) + # int + x = DataArray([3, 2, 1]) + assert_equal(x.rank("dim_0"), x) + # str + y = DataArray(["c", "b", "a"]) + assert_equal(y.rank("dim_0"), x) + + x = DataArray([3.0, 1.0, np.nan, 2.0, 4.0], dims=("z",)) + y = DataArray([0.75, 0.25, np.nan, 0.5, 1.0], dims=("z",)) + assert_equal(y.rank("z", pct=True), y) + + @pytest.mark.parametrize("use_dask", [True, False]) + @pytest.mark.parametrize("use_datetime", [True, False]) + @pytest.mark.filterwarnings("ignore:overflow encountered in multiply") + def test_polyfit(self, use_dask, use_datetime) -> None: + if use_dask and not has_dask: + pytest.skip("requires dask") + xcoord = xr.DataArray( + pd.date_range("1970-01-01", freq="D", periods=10), dims=("x",), name="x" + ) + x = xr.core.missing.get_clean_interp_index(xcoord, "x") + if not use_datetime: + xcoord = x + + da_raw = DataArray( + np.stack((10 + 1e-15 * x + 2e-28 * x**2, 30 + 2e-14 * x + 1e-29 * x**2)), + dims=("d", "x"), + coords={"x": xcoord, "d": [0, 1]}, + ) + + if use_dask: + da = da_raw.chunk({"d": 1}) + else: + da = da_raw + + out = da.polyfit("x", 2) + expected = DataArray( + [[2e-28, 1e-15, 10], [1e-29, 2e-14, 30]], + dims=("d", "degree"), + coords={"degree": [2, 1, 0], "d": [0, 1]}, + ).T + assert_allclose(out.polyfit_coefficients, expected, rtol=1e-3) + + # Full output and deficient rank + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RankWarning) + out = da.polyfit("x", 12, full=True) + assert out.polyfit_residuals.isnull().all() + + # With NaN + da_raw[0, 1:3] = np.nan + if use_dask: + da = da_raw.chunk({"d": 1}) + else: + da = da_raw + out = da.polyfit("x", 2, skipna=True, cov=True) + assert_allclose(out.polyfit_coefficients, expected, rtol=1e-3) + assert "polyfit_covariance" in out + + # Skipna + Full output + out = da.polyfit("x", 2, skipna=True, full=True) + assert_allclose(out.polyfit_coefficients, expected, rtol=1e-3) + assert out.x_matrix_rank == 3 + np.testing.assert_almost_equal(out.polyfit_residuals, [0, 0]) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RankWarning) + out = da.polyfit("x", 8, full=True) + np.testing.assert_array_equal(out.polyfit_residuals.isnull(), [True, False]) + + def test_pad_constant(self) -> None: + ar = DataArray(np.arange(3 * 4 * 5).reshape(3, 4, 5)) + actual = ar.pad(dim_0=(1, 3)) + expected = DataArray( + np.pad( + np.arange(3 * 4 * 5).reshape(3, 4, 5).astype(np.float32), + mode="constant", + pad_width=((1, 3), (0, 0), (0, 0)), + constant_values=np.nan, + ) + ) + assert actual.shape == (7, 4, 5) + assert_identical(actual, expected) + + ar = xr.DataArray([9], dims="x") + + actual = ar.pad(x=1) + expected = xr.DataArray([np.nan, 9, np.nan], dims="x") + assert_identical(actual, expected) + + actual = ar.pad(x=1, constant_values=1.23456) + expected = xr.DataArray([1, 9, 1], dims="x") + assert_identical(actual, expected) + + with pytest.raises(ValueError, match="cannot convert float NaN to integer"): + ar.pad(x=1, constant_values=np.nan) + + def test_pad_coords(self) -> None: + ar = DataArray( + np.arange(3 * 4 * 5).reshape(3, 4, 5), + [("x", np.arange(3)), ("y", np.arange(4)), ("z", np.arange(5))], + ) + actual = ar.pad(x=(1, 3), constant_values=1) + expected = DataArray( + np.pad( + np.arange(3 * 4 * 5).reshape(3, 4, 5), + mode="constant", + pad_width=((1, 3), (0, 0), (0, 0)), + constant_values=1, + ), + [ + ( + "x", + np.pad( + np.arange(3).astype(np.float32), + mode="constant", + pad_width=(1, 3), + constant_values=np.nan, + ), + ), + ("y", np.arange(4)), + ("z", np.arange(5)), + ], + ) + assert_identical(actual, expected) + + @pytest.mark.parametrize("mode", ("minimum", "maximum", "mean", "median")) + @pytest.mark.parametrize( + "stat_length", (None, 3, (1, 3), {"dim_0": (2, 1), "dim_2": (4, 2)}) + ) + def test_pad_stat_length(self, mode, stat_length) -> None: + ar = DataArray(np.arange(3 * 4 * 5).reshape(3, 4, 5)) + actual = ar.pad(dim_0=(1, 3), dim_2=(2, 2), mode=mode, stat_length=stat_length) + if isinstance(stat_length, dict): + stat_length = (stat_length["dim_0"], (4, 4), stat_length["dim_2"]) + expected = DataArray( + np.pad( + np.arange(3 * 4 * 5).reshape(3, 4, 5), + pad_width=((1, 3), (0, 0), (2, 2)), + mode=mode, + stat_length=stat_length, + ) + ) + assert actual.shape == (7, 4, 9) + assert_identical(actual, expected) + + @pytest.mark.parametrize( + "end_values", (None, 3, (3, 5), {"dim_0": (2, 1), "dim_2": (4, 2)}) + ) + def test_pad_linear_ramp(self, end_values) -> None: + ar = DataArray(np.arange(3 * 4 * 5).reshape(3, 4, 5)) + actual = ar.pad( + dim_0=(1, 3), dim_2=(2, 2), mode="linear_ramp", end_values=end_values + ) + if end_values is None: + end_values = 0 + elif isinstance(end_values, dict): + end_values = (end_values["dim_0"], (4, 4), end_values["dim_2"]) + expected = DataArray( + np.pad( + np.arange(3 * 4 * 5).reshape(3, 4, 5), + pad_width=((1, 3), (0, 0), (2, 2)), + mode="linear_ramp", + end_values=end_values, + ) + ) + assert actual.shape == (7, 4, 9) + assert_identical(actual, expected) + + @pytest.mark.parametrize("mode", ("reflect", "symmetric")) + @pytest.mark.parametrize("reflect_type", (None, "even", "odd")) + def test_pad_reflect(self, mode, reflect_type) -> None: + ar = DataArray(np.arange(3 * 4 * 5).reshape(3, 4, 5)) + actual = ar.pad( + dim_0=(1, 3), dim_2=(2, 2), mode=mode, reflect_type=reflect_type + ) + np_kwargs = { + "array": np.arange(3 * 4 * 5).reshape(3, 4, 5), + "pad_width": ((1, 3), (0, 0), (2, 2)), + "mode": mode, + } + # numpy does not support reflect_type=None + if reflect_type is not None: + np_kwargs["reflect_type"] = reflect_type + expected = DataArray(np.pad(**np_kwargs)) + + assert actual.shape == (7, 4, 9) + assert_identical(actual, expected) + + @pytest.mark.parametrize( + ["keep_attrs", "attrs", "expected"], + [ + pytest.param(None, {"a": 1, "b": 2}, {"a": 1, "b": 2}, id="default"), + pytest.param(False, {"a": 1, "b": 2}, {}, id="False"), + pytest.param(True, {"a": 1, "b": 2}, {"a": 1, "b": 2}, id="True"), + ], + ) + def test_pad_keep_attrs(self, keep_attrs, attrs, expected) -> None: + arr = xr.DataArray( + [1, 2], dims="x", coords={"c": ("x", [-1, 1], attrs)}, attrs=attrs + ) + expected = xr.DataArray( + [0, 1, 2, 0], + dims="x", + coords={"c": ("x", [np.nan, -1, 1, np.nan], expected)}, + attrs=expected, + ) + + keep_attrs_ = "default" if keep_attrs is None else keep_attrs + + with set_options(keep_attrs=keep_attrs_): + actual = arr.pad({"x": (1, 1)}, mode="constant", constant_values=0) + xr.testing.assert_identical(actual, expected) + + actual = arr.pad( + {"x": (1, 1)}, mode="constant", constant_values=0, keep_attrs=keep_attrs + ) + xr.testing.assert_identical(actual, expected) + + @pytest.mark.parametrize("parser", ["pandas", "python"]) + @pytest.mark.parametrize( + "engine", ["python", None, pytest.param("numexpr", marks=[requires_numexpr])] + ) + @pytest.mark.parametrize( + "backend", ["numpy", pytest.param("dask", marks=[requires_dask])] + ) + def test_query( + self, backend, engine: QueryEngineOptions, parser: QueryParserOptions + ) -> None: + """Test querying a dataset.""" + + # setup test data + np.random.seed(42) + a = np.arange(0, 10, 1) + b = np.random.randint(0, 100, size=10) + c = np.linspace(0, 1, 20) + d = np.random.choice(["foo", "bar", "baz"], size=30, replace=True).astype( + object + ) + aa = DataArray(data=a, dims=["x"], name="a", coords={"a2": ("x", a)}) + bb = DataArray(data=b, dims=["x"], name="b", coords={"b2": ("x", b)}) + cc = DataArray(data=c, dims=["y"], name="c", coords={"c2": ("y", c)}) + dd = DataArray(data=d, dims=["z"], name="d", coords={"d2": ("z", d)}) + + if backend == "dask": + import dask.array as da + + aa = aa.copy(data=da.from_array(a, chunks=3)) + bb = bb.copy(data=da.from_array(b, chunks=3)) + cc = cc.copy(data=da.from_array(c, chunks=7)) + dd = dd.copy(data=da.from_array(d, chunks=12)) + + # query single dim, single variable + with raise_if_dask_computes(): + actual = aa.query(x="a2 > 5", engine=engine, parser=parser) + expect = aa.isel(x=(a > 5)) + assert_identical(expect, actual) + + # query single dim, single variable, via dict + with raise_if_dask_computes(): + actual = aa.query(dict(x="a2 > 5"), engine=engine, parser=parser) + expect = aa.isel(dict(x=(a > 5))) + assert_identical(expect, actual) + + # query single dim, single variable + with raise_if_dask_computes(): + actual = bb.query(x="b2 > 50", engine=engine, parser=parser) + expect = bb.isel(x=(b > 50)) + assert_identical(expect, actual) + + # query single dim, single variable + with raise_if_dask_computes(): + actual = cc.query(y="c2 < .5", engine=engine, parser=parser) + expect = cc.isel(y=(c < 0.5)) + assert_identical(expect, actual) + + # query single dim, single string variable + if parser == "pandas": + # N.B., this query currently only works with the pandas parser + # xref https://github.com/pandas-dev/pandas/issues/40436 + with raise_if_dask_computes(): + actual = dd.query(z='d2 == "bar"', engine=engine, parser=parser) + expect = dd.isel(z=(d == "bar")) + assert_identical(expect, actual) + + # test error handling + with pytest.raises(ValueError): + aa.query("a > 5") # type: ignore # must be dict or kwargs + with pytest.raises(ValueError): + aa.query(x=(a > 5)) # must be query string + with pytest.raises(UndefinedVariableError): + aa.query(x="spam > 50") # name not present + + @requires_scipy + @pytest.mark.parametrize("use_dask", [True, False]) + def test_curvefit(self, use_dask) -> None: + if use_dask and not has_dask: + pytest.skip("requires dask") + + def exp_decay(t, n0, tau=1): + return n0 * np.exp(-t / tau) + + t = np.arange(0, 5, 0.5) + da = DataArray( + np.stack([exp_decay(t, 3, 3), exp_decay(t, 5, 4), np.nan * t], axis=-1), + dims=("t", "x"), + coords={"t": t, "x": [0, 1, 2]}, + ) + da[0, 0] = np.nan + + expected = DataArray( + [[3, 3], [5, 4], [np.nan, np.nan]], + dims=("x", "param"), + coords={"x": [0, 1, 2], "param": ["n0", "tau"]}, + ) + + if use_dask: + da = da.chunk({"x": 1}) + + fit = da.curvefit( + coords=[da.t], func=exp_decay, p0={"n0": 4}, bounds={"tau": (2, 6)} + ) + assert_allclose(fit.curvefit_coefficients, expected, rtol=1e-3) + + da = da.compute() + fit = da.curvefit(coords="t", func=np.power, reduce_dims="x", param_names=["a"]) + assert "a" in fit.param + assert "x" not in fit.dims + + def test_curvefit_helpers(self) -> None: + def exp_decay(t, n0, tau=1): + return n0 * np.exp(-t / tau) + + params, func_args = xr.core.dataset._get_func_args(exp_decay, []) + assert params == ["n0", "tau"] + param_defaults, bounds_defaults = xr.core.dataset._initialize_curvefit_params( + params, {"n0": 4}, {"tau": [5, np.inf]}, func_args + ) + assert param_defaults == {"n0": 4, "tau": 6} + assert bounds_defaults == {"n0": (-np.inf, np.inf), "tau": (5, np.inf)} + + # DataArray as bound + param_defaults, bounds_defaults = xr.core.dataset._initialize_curvefit_params( + params=params, + p0={"n0": 4}, + bounds={"tau": [DataArray([3, 4], coords=[("x", [1, 2])]), np.inf]}, + func_args=func_args, + ) + assert param_defaults["n0"] == 4 + assert ( + param_defaults["tau"] == xr.DataArray([4, 5], coords=[("x", [1, 2])]) + ).all() + assert bounds_defaults["n0"] == (-np.inf, np.inf) + assert ( + bounds_defaults["tau"][0] == DataArray([3, 4], coords=[("x", [1, 2])]) + ).all() + assert bounds_defaults["tau"][1] == np.inf + + param_names = ["a"] + params, func_args = xr.core.dataset._get_func_args(np.power, param_names) + assert params == param_names + with pytest.raises(ValueError): + xr.core.dataset._get_func_args(np.power, []) + + @requires_scipy + @pytest.mark.parametrize("use_dask", [True, False]) + def test_curvefit_multidimensional_guess(self, use_dask: bool) -> None: + if use_dask and not has_dask: + pytest.skip("requires dask") + + def sine(t, a, f, p): + return a * np.sin(2 * np.pi * (f * t + p)) + + t = np.arange(0, 2, 0.02) + da = DataArray( + np.stack([sine(t, 1.0, 2, 0), sine(t, 1.0, 2, 0)]), + coords={"x": [0, 1], "t": t}, + ) + + # Fitting to a sine curve produces a different result depending on the + # initial guess: either the phase is zero and the amplitude is positive + # or the phase is 0.5 * 2pi and the amplitude is negative. + + expected = DataArray( + [[1, 2, 0], [-1, 2, 0.5]], + coords={"x": [0, 1], "param": ["a", "f", "p"]}, + ) + + # Different initial guesses for different values of x + a_guess = DataArray([1, -1], coords=[da.x]) + p_guess = DataArray([0, 0.5], coords=[da.x]) + + if use_dask: + da = da.chunk({"x": 1}) + + fit = da.curvefit( + coords=[da.t], + func=sine, + p0={"a": a_guess, "p": p_guess, "f": 2}, + ) + assert_allclose(fit.curvefit_coefficients, expected) + + with pytest.raises( + ValueError, + match=r"Initial guess for 'a' has unexpected dimensions .* should only have " + "dimensions that are in data dimensions", + ): + # initial guess with additional dimensions should be an error + da.curvefit( + coords=[da.t], + func=sine, + p0={"a": DataArray([1, 2], coords={"foo": [1, 2]})}, + ) + + @requires_scipy + @pytest.mark.parametrize("use_dask", [True, False]) + def test_curvefit_multidimensional_bounds(self, use_dask: bool) -> None: + if use_dask and not has_dask: + pytest.skip("requires dask") + + def sine(t, a, f, p): + return a * np.sin(2 * np.pi * (f * t + p)) + + t = np.arange(0, 2, 0.02) + da = xr.DataArray( + np.stack([sine(t, 1.0, 2, 0), sine(t, 1.0, 2, 0)]), + coords={"x": [0, 1], "t": t}, + ) + + # Fit a sine with different bounds: positive amplitude should result in a fit with + # phase 0 and negative amplitude should result in phase 0.5 * 2pi. + + expected = DataArray( + [[1, 2, 0], [-1, 2, 0.5]], + coords={"x": [0, 1], "param": ["a", "f", "p"]}, + ) + + if use_dask: + da = da.chunk({"x": 1}) + + fit = da.curvefit( + coords=[da.t], + func=sine, + p0={"f": 2, "p": 0.25}, # this guess is needed to get the expected result + bounds={ + "a": ( + DataArray([0, -2], coords=[da.x]), + DataArray([2, 0], coords=[da.x]), + ), + }, + ) + assert_allclose(fit.curvefit_coefficients, expected) + + # Scalar lower bound with array upper bound + fit2 = da.curvefit( + coords=[da.t], + func=sine, + p0={"f": 2, "p": 0.25}, # this guess is needed to get the expected result + bounds={ + "a": (-2, DataArray([2, 0], coords=[da.x])), + }, + ) + assert_allclose(fit2.curvefit_coefficients, expected) + + with pytest.raises( + ValueError, + match=r"Upper bound for 'a' has unexpected dimensions .* should only have " + "dimensions that are in data dimensions", + ): + # bounds with additional dimensions should be an error + da.curvefit( + coords=[da.t], + func=sine, + bounds={"a": (0, DataArray([1], coords={"foo": [1]}))}, + ) + + @requires_scipy + @pytest.mark.parametrize("use_dask", [True, False]) + def test_curvefit_ignore_errors(self, use_dask: bool) -> None: + if use_dask and not has_dask: + pytest.skip("requires dask") + + # nonsense function to make the optimization fail + def line(x, a, b): + if a > 10: + return 0 + return a * x + b + + da = DataArray( + [[1, 3, 5], [0, 20, 40]], + coords={"i": [1, 2], "x": [0.0, 1.0, 2.0]}, + ) + + if use_dask: + da = da.chunk({"i": 1}) + + expected = DataArray( + [[2, 1], [np.nan, np.nan]], coords={"i": [1, 2], "param": ["a", "b"]} + ) + + with pytest.raises(RuntimeError, match="calls to function has reached maxfev"): + da.curvefit( + coords="x", + func=line, + # limit maximum number of calls so the optimization fails + kwargs=dict(maxfev=5), + ).compute() # have to compute to raise the error + + fit = da.curvefit( + coords="x", + func=line, + errors="ignore", + # limit maximum number of calls so the optimization fails + kwargs=dict(maxfev=5), + ).compute() + + assert_allclose(fit.curvefit_coefficients, expected) + + +class TestReduce: + @pytest.fixture(autouse=True) + def setup(self): + self.attrs = {"attr1": "value1", "attr2": 2929} + + +@pytest.mark.parametrize( + ["x", "minindex", "maxindex", "nanindex"], + [ + pytest.param(np.array([0, 1, 2, 0, -2, -4, 2]), 5, 2, None, id="int"), + pytest.param( + np.array([0.0, 1.0, 2.0, 0.0, -2.0, -4.0, 2.0]), 5, 2, None, id="float" + ), + pytest.param( + np.array([1.0, np.nan, 2.0, np.nan, -2.0, -4.0, 2.0]), 5, 2, 1, id="nan" + ), + pytest.param( + np.array([1.0, np.nan, 2.0, np.nan, -2.0, -4.0, 2.0]).astype("object"), + 5, + 2, + 1, + marks=pytest.mark.filterwarnings( + "ignore:invalid value encountered in reduce:RuntimeWarning" + ), + id="obj", + ), + pytest.param(np.array([np.nan, np.nan]), np.nan, np.nan, 0, id="allnan"), + pytest.param( + np.array( + ["2015-12-31", "2020-01-02", "2020-01-01", "2016-01-01"], + dtype="datetime64[ns]", + ), + 0, + 1, + None, + id="datetime", + ), + ], +) +class TestReduce1D(TestReduce): + def test_min( + self, + x: np.ndarray, + minindex: int | float, + maxindex: int | float, + nanindex: int | None, + ) -> None: + ar = xr.DataArray( + x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs + ) + + if np.isnan(minindex): + minindex = 0 + + expected0 = ar.isel(x=minindex, drop=True) + result0 = ar.min(keep_attrs=True) + assert_identical(result0, expected0) + + result1 = ar.min() + expected1 = expected0.copy() + expected1.attrs = {} + assert_identical(result1, expected1) + + result2 = ar.min(skipna=False) + if nanindex is not None and ar.dtype.kind != "O": + expected2 = ar.isel(x=nanindex, drop=True) + expected2.attrs = {} + else: + expected2 = expected1 + + assert_identical(result2, expected2) + + def test_max( + self, + x: np.ndarray, + minindex: int | float, + maxindex: int | float, + nanindex: int | None, + ) -> None: + ar = xr.DataArray( + x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs + ) + + if np.isnan(minindex): + maxindex = 0 + + expected0 = ar.isel(x=maxindex, drop=True) + result0 = ar.max(keep_attrs=True) + assert_identical(result0, expected0) + + result1 = ar.max() + expected1 = expected0.copy() + expected1.attrs = {} + assert_identical(result1, expected1) + + result2 = ar.max(skipna=False) + if nanindex is not None and ar.dtype.kind != "O": + expected2 = ar.isel(x=nanindex, drop=True) + expected2.attrs = {} + else: + expected2 = expected1 + + assert_identical(result2, expected2) + + @pytest.mark.filterwarnings( + "ignore:Behaviour of argmin/argmax with neither dim nor :DeprecationWarning" + ) + def test_argmin( + self, + x: np.ndarray, + minindex: int | float, + maxindex: int | float, + nanindex: int | None, + ) -> None: + ar = xr.DataArray( + x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs + ) + indarr = xr.DataArray(np.arange(x.size, dtype=np.intp), dims=["x"]) + + if np.isnan(minindex): + with pytest.raises(ValueError): + ar.argmin() + return + + expected0 = indarr[minindex] + result0 = ar.argmin() + assert_identical(result0, expected0) + + result1 = ar.argmin(keep_attrs=True) + expected1 = expected0.copy() + expected1.attrs = self.attrs + assert_identical(result1, expected1) + + result2 = ar.argmin(skipna=False) + if nanindex is not None and ar.dtype.kind != "O": + expected2 = indarr.isel(x=nanindex, drop=True) + expected2.attrs = {} + else: + expected2 = expected0 + + assert_identical(result2, expected2) + + @pytest.mark.filterwarnings( + "ignore:Behaviour of argmin/argmax with neither dim nor :DeprecationWarning" + ) + def test_argmax( + self, + x: np.ndarray, + minindex: int | float, + maxindex: int | float, + nanindex: int | None, + ) -> None: + ar = xr.DataArray( + x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs + ) + indarr = xr.DataArray(np.arange(x.size, dtype=np.intp), dims=["x"]) + + if np.isnan(maxindex): + with pytest.raises(ValueError): + ar.argmax() + return + + expected0 = indarr[maxindex] + result0 = ar.argmax() + assert_identical(result0, expected0) + + result1 = ar.argmax(keep_attrs=True) + expected1 = expected0.copy() + expected1.attrs = self.attrs + assert_identical(result1, expected1) + + result2 = ar.argmax(skipna=False) + if nanindex is not None and ar.dtype.kind != "O": + expected2 = indarr.isel(x=nanindex, drop=True) + expected2.attrs = {} + else: + expected2 = expected0 + + assert_identical(result2, expected2) + + @pytest.mark.parametrize("use_dask", [True, False]) + def test_idxmin( + self, + x: np.ndarray, + minindex: int | float, + maxindex: int | float, + nanindex: int | None, + use_dask: bool, + ) -> None: + if use_dask and not has_dask: + pytest.skip("requires dask") + if use_dask and x.dtype.kind == "M": + pytest.xfail("dask operation 'argmin' breaks when dtype is datetime64 (M)") + ar0_raw = xr.DataArray( + x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs + ) + + if use_dask: + ar0 = ar0_raw.chunk({}) + else: + ar0 = ar0_raw + + with pytest.raises( + KeyError, + match=r"'spam' not found in array dimensions", + ): + ar0.idxmin(dim="spam") + + # Scalar Dataarray + with pytest.raises(ValueError): + xr.DataArray(5).idxmin() + + coordarr0 = xr.DataArray(ar0.coords["x"].data, dims=["x"]) + coordarr1 = coordarr0.copy() + + hasna = np.isnan(minindex) + if np.isnan(minindex): + minindex = 0 + + if hasna: + coordarr1[...] = 1 + fill_value_0 = np.nan + else: + fill_value_0 = 1 + + expected0 = ( + (coordarr1 * fill_value_0).isel(x=minindex, drop=True).astype("float") + ) + expected0.name = "x" + + # Default fill value (NaN) + result0 = ar0.idxmin() + assert_identical(result0, expected0) + + # Manually specify NaN fill_value + result1 = ar0.idxmin(fill_value=np.nan) + assert_identical(result1, expected0) + + # keep_attrs + result2 = ar0.idxmin(keep_attrs=True) + expected2 = expected0.copy() + expected2.attrs = self.attrs + assert_identical(result2, expected2) + + # skipna=False + if nanindex is not None and ar0.dtype.kind != "O": + expected3 = coordarr0.isel(x=nanindex, drop=True).astype("float") + expected3.name = "x" + expected3.attrs = {} + else: + expected3 = expected0.copy() + + result3 = ar0.idxmin(skipna=False) + assert_identical(result3, expected3) + + # fill_value should be ignored with skipna=False + result4 = ar0.idxmin(skipna=False, fill_value=-100j) + assert_identical(result4, expected3) + + # Float fill_value + if hasna: + fill_value_5 = -1.1 + else: + fill_value_5 = 1 + + expected5 = (coordarr1 * fill_value_5).isel(x=minindex, drop=True) + expected5.name = "x" + + result5 = ar0.idxmin(fill_value=-1.1) + assert_identical(result5, expected5) + + # Integer fill_value + if hasna: + fill_value_6 = -1 + else: + fill_value_6 = 1 + + expected6 = (coordarr1 * fill_value_6).isel(x=minindex, drop=True) + expected6.name = "x" + + result6 = ar0.idxmin(fill_value=-1) + assert_identical(result6, expected6) + + # Complex fill_value + if hasna: + fill_value_7 = -1j + else: + fill_value_7 = 1 + + expected7 = (coordarr1 * fill_value_7).isel(x=minindex, drop=True) + expected7.name = "x" + + result7 = ar0.idxmin(fill_value=-1j) + assert_identical(result7, expected7) + + @pytest.mark.parametrize("use_dask", [True, False]) + def test_idxmax( + self, + x: np.ndarray, + minindex: int | float, + maxindex: int | float, + nanindex: int | None, + use_dask: bool, + ) -> None: + if use_dask and not has_dask: + pytest.skip("requires dask") + if use_dask and x.dtype.kind == "M": + pytest.xfail("dask operation 'argmax' breaks when dtype is datetime64 (M)") + ar0_raw = xr.DataArray( + x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs + ) + + if use_dask: + ar0 = ar0_raw.chunk({}) + else: + ar0 = ar0_raw + + with pytest.raises( + KeyError, + match=r"'spam' not found in array dimensions", + ): + ar0.idxmax(dim="spam") + + # Scalar Dataarray + with pytest.raises(ValueError): + xr.DataArray(5).idxmax() + + coordarr0 = xr.DataArray(ar0.coords["x"].data, dims=["x"]) + coordarr1 = coordarr0.copy() + + hasna = np.isnan(maxindex) + if np.isnan(maxindex): + maxindex = 0 + + if hasna: + coordarr1[...] = 1 + fill_value_0 = np.nan + else: + fill_value_0 = 1 + + expected0 = ( + (coordarr1 * fill_value_0).isel(x=maxindex, drop=True).astype("float") + ) + expected0.name = "x" + + # Default fill value (NaN) + result0 = ar0.idxmax() + assert_identical(result0, expected0) + + # Manually specify NaN fill_value + result1 = ar0.idxmax(fill_value=np.nan) + assert_identical(result1, expected0) + + # keep_attrs + result2 = ar0.idxmax(keep_attrs=True) + expected2 = expected0.copy() + expected2.attrs = self.attrs + assert_identical(result2, expected2) + + # skipna=False + if nanindex is not None and ar0.dtype.kind != "O": + expected3 = coordarr0.isel(x=nanindex, drop=True).astype("float") + expected3.name = "x" + expected3.attrs = {} + else: + expected3 = expected0.copy() + + result3 = ar0.idxmax(skipna=False) + assert_identical(result3, expected3) + + # fill_value should be ignored with skipna=False + result4 = ar0.idxmax(skipna=False, fill_value=-100j) + assert_identical(result4, expected3) + + # Float fill_value + if hasna: + fill_value_5 = -1.1 + else: + fill_value_5 = 1 + + expected5 = (coordarr1 * fill_value_5).isel(x=maxindex, drop=True) + expected5.name = "x" + + result5 = ar0.idxmax(fill_value=-1.1) + assert_identical(result5, expected5) + + # Integer fill_value + if hasna: + fill_value_6 = -1 + else: + fill_value_6 = 1 + + expected6 = (coordarr1 * fill_value_6).isel(x=maxindex, drop=True) + expected6.name = "x" + + result6 = ar0.idxmax(fill_value=-1) + assert_identical(result6, expected6) + + # Complex fill_value + if hasna: + fill_value_7 = -1j + else: + fill_value_7 = 1 + + expected7 = (coordarr1 * fill_value_7).isel(x=maxindex, drop=True) + expected7.name = "x" + + result7 = ar0.idxmax(fill_value=-1j) + assert_identical(result7, expected7) + + @pytest.mark.filterwarnings( + "ignore:Behaviour of argmin/argmax with neither dim nor :DeprecationWarning" + ) + def test_argmin_dim( + self, + x: np.ndarray, + minindex: int | float, + maxindex: int | float, + nanindex: int | None, + ) -> None: + ar = xr.DataArray( + x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs + ) + indarr = xr.DataArray(np.arange(x.size, dtype=np.intp), dims=["x"]) + + if np.isnan(minindex): + with pytest.raises(ValueError): + ar.argmin() + return + + expected0 = {"x": indarr[minindex]} + result0 = ar.argmin(...) + for key in expected0: + assert_identical(result0[key], expected0[key]) + + result1 = ar.argmin(..., keep_attrs=True) + expected1 = deepcopy(expected0) + for da in expected1.values(): + da.attrs = self.attrs + for key in expected1: + assert_identical(result1[key], expected1[key]) + + result2 = ar.argmin(..., skipna=False) + if nanindex is not None and ar.dtype.kind != "O": + expected2 = {"x": indarr.isel(x=nanindex, drop=True)} + expected2["x"].attrs = {} + else: + expected2 = expected0 + + for key in expected2: + assert_identical(result2[key], expected2[key]) + + @pytest.mark.filterwarnings( + "ignore:Behaviour of argmin/argmax with neither dim nor :DeprecationWarning" + ) + def test_argmax_dim( + self, + x: np.ndarray, + minindex: int | float, + maxindex: int | float, + nanindex: int | None, + ) -> None: + ar = xr.DataArray( + x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs + ) + indarr = xr.DataArray(np.arange(x.size, dtype=np.intp), dims=["x"]) + + if np.isnan(maxindex): + with pytest.raises(ValueError): + ar.argmax() + return + + expected0 = {"x": indarr[maxindex]} + result0 = ar.argmax(...) + for key in expected0: + assert_identical(result0[key], expected0[key]) + + result1 = ar.argmax(..., keep_attrs=True) + expected1 = deepcopy(expected0) + for da in expected1.values(): + da.attrs = self.attrs + for key in expected1: + assert_identical(result1[key], expected1[key]) + + result2 = ar.argmax(..., skipna=False) + if nanindex is not None and ar.dtype.kind != "O": + expected2 = {"x": indarr.isel(x=nanindex, drop=True)} + expected2["x"].attrs = {} + else: + expected2 = expected0 + + for key in expected2: + assert_identical(result2[key], expected2[key]) + + +@pytest.mark.parametrize( + ["x", "minindex", "maxindex", "nanindex"], + [ + pytest.param( + np.array( + [ + [0, 1, 2, 0, -2, -4, 2], + [1, 1, 1, 1, 1, 1, 1], + [0, 0, -10, 5, 20, 0, 0], + ] + ), + [5, 0, 2], + [2, 0, 4], + [None, None, None], + id="int", + ), + pytest.param( + np.array( + [ + [2.0, 1.0, 2.0, 0.0, -2.0, -4.0, 2.0], + [-4.0, np.nan, 2.0, np.nan, -2.0, -4.0, 2.0], + [np.nan] * 7, + ] + ), + [5, 0, np.nan], + [0, 2, np.nan], + [None, 1, 0], + id="nan", + ), + pytest.param( + np.array( + [ + [2.0, 1.0, 2.0, 0.0, -2.0, -4.0, 2.0], + [-4.0, np.nan, 2.0, np.nan, -2.0, -4.0, 2.0], + [np.nan] * 7, + ] + ).astype("object"), + [5, 0, np.nan], + [0, 2, np.nan], + [None, 1, 0], + marks=pytest.mark.filterwarnings( + "ignore:invalid value encountered in reduce:RuntimeWarning:" + ), + id="obj", + ), + pytest.param( + np.array( + [ + ["2015-12-31", "2020-01-02", "2020-01-01", "2016-01-01"], + ["2020-01-02", "2020-01-02", "2020-01-02", "2020-01-02"], + ["1900-01-01", "1-02-03", "1900-01-02", "1-02-03"], + ], + dtype="datetime64[ns]", + ), + [0, 0, 1], + [1, 0, 2], + [None, None, None], + id="datetime", + ), + ], +) +class TestReduce2D(TestReduce): + def test_min( + self, + x: np.ndarray, + minindex: list[int | float], + maxindex: list[int | float], + nanindex: list[int | None], + ) -> None: + ar = xr.DataArray( + x, + dims=["y", "x"], + coords={"x": np.arange(x.shape[1]) * 4, "y": 1 - np.arange(x.shape[0])}, + attrs=self.attrs, + ) + + minindex = [x if not np.isnan(x) else 0 for x in minindex] + expected0list = [ + ar.isel(y=yi).isel(x=indi, drop=True) for yi, indi in enumerate(minindex) + ] + expected0 = xr.concat(expected0list, dim="y") + + result0 = ar.min(dim="x", keep_attrs=True) + assert_identical(result0, expected0) + + result1 = ar.min(dim="x") + expected1 = expected0 + expected1.attrs = {} + assert_identical(result1, expected1) + + result2 = ar.min(axis=1) + assert_identical(result2, expected1) + + minindex = [ + x if y is None or ar.dtype.kind == "O" else y + for x, y in zip(minindex, nanindex) + ] + expected2list = [ + ar.isel(y=yi).isel(x=indi, drop=True) for yi, indi in enumerate(minindex) + ] + expected2 = xr.concat(expected2list, dim="y") + expected2.attrs = {} + + result3 = ar.min(dim="x", skipna=False) + + assert_identical(result3, expected2) + + def test_max( + self, + x: np.ndarray, + minindex: list[int | float], + maxindex: list[int | float], + nanindex: list[int | None], + ) -> None: + ar = xr.DataArray( + x, + dims=["y", "x"], + coords={"x": np.arange(x.shape[1]) * 4, "y": 1 - np.arange(x.shape[0])}, + attrs=self.attrs, + ) + + maxindex = [x if not np.isnan(x) else 0 for x in maxindex] + expected0list = [ + ar.isel(y=yi).isel(x=indi, drop=True) for yi, indi in enumerate(maxindex) + ] + expected0 = xr.concat(expected0list, dim="y") + + result0 = ar.max(dim="x", keep_attrs=True) + assert_identical(result0, expected0) + + result1 = ar.max(dim="x") + expected1 = expected0.copy() + expected1.attrs = {} + assert_identical(result1, expected1) + + result2 = ar.max(axis=1) + assert_identical(result2, expected1) + + maxindex = [ + x if y is None or ar.dtype.kind == "O" else y + for x, y in zip(maxindex, nanindex) + ] + expected2list = [ + ar.isel(y=yi).isel(x=indi, drop=True) for yi, indi in enumerate(maxindex) + ] + expected2 = xr.concat(expected2list, dim="y") + expected2.attrs = {} + + result3 = ar.max(dim="x", skipna=False) + + assert_identical(result3, expected2) + + def test_argmin( + self, + x: np.ndarray, + minindex: list[int | float], + maxindex: list[int | float], + nanindex: list[int | None], + ) -> None: + ar = xr.DataArray( + x, + dims=["y", "x"], + coords={"x": np.arange(x.shape[1]) * 4, "y": 1 - np.arange(x.shape[0])}, + attrs=self.attrs, + ) + indarrnp = np.tile(np.arange(x.shape[1], dtype=np.intp), [x.shape[0], 1]) + indarr = xr.DataArray(indarrnp, dims=ar.dims, coords=ar.coords) + + if np.isnan(minindex).any(): + with pytest.raises(ValueError): + ar.argmin(dim="x") + return + + expected0list = [ + indarr.isel(y=yi).isel(x=indi, drop=True) + for yi, indi in enumerate(minindex) + ] + expected0 = xr.concat(expected0list, dim="y") + + result0 = ar.argmin(dim="x") + assert_identical(result0, expected0) + + result1 = ar.argmin(axis=1) + assert_identical(result1, expected0) + + result2 = ar.argmin(dim="x", keep_attrs=True) + expected1 = expected0.copy() + expected1.attrs = self.attrs + assert_identical(result2, expected1) + + minindex = [ + x if y is None or ar.dtype.kind == "O" else y + for x, y in zip(minindex, nanindex) + ] + expected2list = [ + indarr.isel(y=yi).isel(x=indi, drop=True) + for yi, indi in enumerate(minindex) + ] + expected2 = xr.concat(expected2list, dim="y") + expected2.attrs = {} + + result3 = ar.argmin(dim="x", skipna=False) + + assert_identical(result3, expected2) + + def test_argmax( + self, + x: np.ndarray, + minindex: list[int | float], + maxindex: list[int | float], + nanindex: list[int | None], + ) -> None: + ar = xr.DataArray( + x, + dims=["y", "x"], + coords={"x": np.arange(x.shape[1]) * 4, "y": 1 - np.arange(x.shape[0])}, + attrs=self.attrs, + ) + indarr_np = np.tile(np.arange(x.shape[1], dtype=np.intp), [x.shape[0], 1]) + indarr = xr.DataArray(indarr_np, dims=ar.dims, coords=ar.coords) + + if np.isnan(maxindex).any(): + with pytest.raises(ValueError): + ar.argmax(dim="x") + return + + expected0list = [ + indarr.isel(y=yi).isel(x=indi, drop=True) + for yi, indi in enumerate(maxindex) + ] + expected0 = xr.concat(expected0list, dim="y") + + result0 = ar.argmax(dim="x") + assert_identical(result0, expected0) + + result1 = ar.argmax(axis=1) + assert_identical(result1, expected0) + + result2 = ar.argmax(dim="x", keep_attrs=True) + expected1 = expected0.copy() + expected1.attrs = self.attrs + assert_identical(result2, expected1) + + maxindex = [ + x if y is None or ar.dtype.kind == "O" else y + for x, y in zip(maxindex, nanindex) + ] + expected2list = [ + indarr.isel(y=yi).isel(x=indi, drop=True) + for yi, indi in enumerate(maxindex) + ] + expected2 = xr.concat(expected2list, dim="y") + expected2.attrs = {} + + result3 = ar.argmax(dim="x", skipna=False) + + assert_identical(result3, expected2) + + @pytest.mark.parametrize( + "use_dask", [pytest.param(True, id="dask"), pytest.param(False, id="nodask")] + ) + def test_idxmin( + self, + x: np.ndarray, + minindex: list[int | float], + maxindex: list[int | float], + nanindex: list[int | None], + use_dask: bool, + ) -> None: + if use_dask and not has_dask: + pytest.skip("requires dask") + if use_dask and x.dtype.kind == "M": + pytest.xfail("dask operation 'argmin' breaks when dtype is datetime64 (M)") + + if x.dtype.kind == "O": + # TODO: nanops._nan_argminmax_object computes once to check for all-NaN slices. + max_computes = 1 + else: + max_computes = 0 + + ar0_raw = xr.DataArray( + x, + dims=["y", "x"], + coords={"x": np.arange(x.shape[1]) * 4, "y": 1 - np.arange(x.shape[0])}, + attrs=self.attrs, + ) + + if use_dask: + ar0 = ar0_raw.chunk({}) + else: + ar0 = ar0_raw + + assert_identical(ar0, ar0) + + # No dimension specified + with pytest.raises(ValueError): + ar0.idxmin() + + # dim doesn't exist + with pytest.raises(KeyError): + ar0.idxmin(dim="Y") + + assert_identical(ar0, ar0) + + coordarr0 = xr.DataArray( + np.tile(ar0.coords["x"], [x.shape[0], 1]), dims=ar0.dims, coords=ar0.coords + ) + + hasna = [np.isnan(x) for x in minindex] + coordarr1 = coordarr0.copy() + coordarr1[hasna, :] = 1 + minindex0 = [x if not np.isnan(x) else 0 for x in minindex] + + nan_mult_0 = np.array([np.nan if x else 1 for x in hasna])[:, None] + expected0list = [ + (coordarr1 * nan_mult_0).isel(y=yi).isel(x=indi, drop=True) + for yi, indi in enumerate(minindex0) + ] + expected0 = xr.concat(expected0list, dim="y") + expected0.name = "x" + + # Default fill value (NaN) + with raise_if_dask_computes(max_computes=max_computes): + result0 = ar0.idxmin(dim="x") + assert_identical(result0, expected0) + + # Manually specify NaN fill_value + with raise_if_dask_computes(max_computes=max_computes): + result1 = ar0.idxmin(dim="x", fill_value=np.nan) + assert_identical(result1, expected0) + + # keep_attrs + with raise_if_dask_computes(max_computes=max_computes): + result2 = ar0.idxmin(dim="x", keep_attrs=True) + expected2 = expected0.copy() + expected2.attrs = self.attrs + assert_identical(result2, expected2) + + # skipna=False + minindex3 = [ + x if y is None or ar0.dtype.kind == "O" else y + for x, y in zip(minindex0, nanindex) + ] + expected3list = [ + coordarr0.isel(y=yi).isel(x=indi, drop=True) + for yi, indi in enumerate(minindex3) + ] + expected3 = xr.concat(expected3list, dim="y") + expected3.name = "x" + expected3.attrs = {} + + with raise_if_dask_computes(max_computes=max_computes): + result3 = ar0.idxmin(dim="x", skipna=False) + assert_identical(result3, expected3) + + # fill_value should be ignored with skipna=False + with raise_if_dask_computes(max_computes=max_computes): + result4 = ar0.idxmin(dim="x", skipna=False, fill_value=-100j) + assert_identical(result4, expected3) + + # Float fill_value + nan_mult_5 = np.array([-1.1 if x else 1 for x in hasna])[:, None] + expected5list = [ + (coordarr1 * nan_mult_5).isel(y=yi).isel(x=indi, drop=True) + for yi, indi in enumerate(minindex0) + ] + expected5 = xr.concat(expected5list, dim="y") + expected5.name = "x" + + with raise_if_dask_computes(max_computes=max_computes): + result5 = ar0.idxmin(dim="x", fill_value=-1.1) + assert_identical(result5, expected5) + + # Integer fill_value + nan_mult_6 = np.array([-1 if x else 1 for x in hasna])[:, None] + expected6list = [ + (coordarr1 * nan_mult_6).isel(y=yi).isel(x=indi, drop=True) + for yi, indi in enumerate(minindex0) + ] + expected6 = xr.concat(expected6list, dim="y") + expected6.name = "x" + + with raise_if_dask_computes(max_computes=max_computes): + result6 = ar0.idxmin(dim="x", fill_value=-1) + assert_identical(result6, expected6) + + # Complex fill_value + nan_mult_7 = np.array([-5j if x else 1 for x in hasna])[:, None] + expected7list = [ + (coordarr1 * nan_mult_7).isel(y=yi).isel(x=indi, drop=True) + for yi, indi in enumerate(minindex0) + ] + expected7 = xr.concat(expected7list, dim="y") + expected7.name = "x" + + with raise_if_dask_computes(max_computes=max_computes): + result7 = ar0.idxmin(dim="x", fill_value=-5j) + assert_identical(result7, expected7) + + @pytest.mark.parametrize( + "use_dask", [pytest.param(True, id="dask"), pytest.param(False, id="nodask")] + ) + def test_idxmax( + self, + x: np.ndarray, + minindex: list[int | float], + maxindex: list[int | float], + nanindex: list[int | None], + use_dask: bool, + ) -> None: + if use_dask and not has_dask: + pytest.skip("requires dask") + if use_dask and x.dtype.kind == "M": + pytest.xfail("dask operation 'argmax' breaks when dtype is datetime64 (M)") + + if x.dtype.kind == "O": + # TODO: nanops._nan_argminmax_object computes once to check for all-NaN slices. + max_computes = 1 + else: + max_computes = 0 + + ar0_raw = xr.DataArray( + x, + dims=["y", "x"], + coords={"x": np.arange(x.shape[1]) * 4, "y": 1 - np.arange(x.shape[0])}, + attrs=self.attrs, + ) + + if use_dask: + ar0 = ar0_raw.chunk({}) + else: + ar0 = ar0_raw + + # No dimension specified + with pytest.raises(ValueError): + ar0.idxmax() + + # dim doesn't exist + with pytest.raises(KeyError): + ar0.idxmax(dim="Y") + + ar1 = ar0.copy() + del ar1.coords["y"] + with pytest.raises(KeyError): + ar1.idxmax(dim="y") + + coordarr0 = xr.DataArray( + np.tile(ar0.coords["x"], [x.shape[0], 1]), dims=ar0.dims, coords=ar0.coords + ) + + hasna = [np.isnan(x) for x in maxindex] + coordarr1 = coordarr0.copy() + coordarr1[hasna, :] = 1 + maxindex0 = [x if not np.isnan(x) else 0 for x in maxindex] + + nan_mult_0 = np.array([np.nan if x else 1 for x in hasna])[:, None] + expected0list = [ + (coordarr1 * nan_mult_0).isel(y=yi).isel(x=indi, drop=True) + for yi, indi in enumerate(maxindex0) + ] + expected0 = xr.concat(expected0list, dim="y") + expected0.name = "x" + + # Default fill value (NaN) + with raise_if_dask_computes(max_computes=max_computes): + result0 = ar0.idxmax(dim="x") + assert_identical(result0, expected0) + + # Manually specify NaN fill_value + with raise_if_dask_computes(max_computes=max_computes): + result1 = ar0.idxmax(dim="x", fill_value=np.nan) + assert_identical(result1, expected0) + + # keep_attrs + with raise_if_dask_computes(max_computes=max_computes): + result2 = ar0.idxmax(dim="x", keep_attrs=True) + expected2 = expected0.copy() + expected2.attrs = self.attrs + assert_identical(result2, expected2) + + # skipna=False + maxindex3 = [ + x if y is None or ar0.dtype.kind == "O" else y + for x, y in zip(maxindex0, nanindex) + ] + expected3list = [ + coordarr0.isel(y=yi).isel(x=indi, drop=True) + for yi, indi in enumerate(maxindex3) + ] + expected3 = xr.concat(expected3list, dim="y") + expected3.name = "x" + expected3.attrs = {} + + with raise_if_dask_computes(max_computes=max_computes): + result3 = ar0.idxmax(dim="x", skipna=False) + assert_identical(result3, expected3) + + # fill_value should be ignored with skipna=False + with raise_if_dask_computes(max_computes=max_computes): + result4 = ar0.idxmax(dim="x", skipna=False, fill_value=-100j) + assert_identical(result4, expected3) + + # Float fill_value + nan_mult_5 = np.array([-1.1 if x else 1 for x in hasna])[:, None] + expected5list = [ + (coordarr1 * nan_mult_5).isel(y=yi).isel(x=indi, drop=True) + for yi, indi in enumerate(maxindex0) + ] + expected5 = xr.concat(expected5list, dim="y") + expected5.name = "x" + + with raise_if_dask_computes(max_computes=max_computes): + result5 = ar0.idxmax(dim="x", fill_value=-1.1) + assert_identical(result5, expected5) + + # Integer fill_value + nan_mult_6 = np.array([-1 if x else 1 for x in hasna])[:, None] + expected6list = [ + (coordarr1 * nan_mult_6).isel(y=yi).isel(x=indi, drop=True) + for yi, indi in enumerate(maxindex0) + ] + expected6 = xr.concat(expected6list, dim="y") + expected6.name = "x" + + with raise_if_dask_computes(max_computes=max_computes): + result6 = ar0.idxmax(dim="x", fill_value=-1) + assert_identical(result6, expected6) + + # Complex fill_value + nan_mult_7 = np.array([-5j if x else 1 for x in hasna])[:, None] + expected7list = [ + (coordarr1 * nan_mult_7).isel(y=yi).isel(x=indi, drop=True) + for yi, indi in enumerate(maxindex0) + ] + expected7 = xr.concat(expected7list, dim="y") + expected7.name = "x" + + with raise_if_dask_computes(max_computes=max_computes): + result7 = ar0.idxmax(dim="x", fill_value=-5j) + assert_identical(result7, expected7) + + @pytest.mark.filterwarnings( + "ignore:Behaviour of argmin/argmax with neither dim nor :DeprecationWarning" + ) + def test_argmin_dim( + self, + x: np.ndarray, + minindex: list[int | float], + maxindex: list[int | float], + nanindex: list[int | None], + ) -> None: + ar = xr.DataArray( + x, + dims=["y", "x"], + coords={"x": np.arange(x.shape[1]) * 4, "y": 1 - np.arange(x.shape[0])}, + attrs=self.attrs, + ) + indarrnp = np.tile(np.arange(x.shape[1], dtype=np.intp), [x.shape[0], 1]) + indarr = xr.DataArray(indarrnp, dims=ar.dims, coords=ar.coords) + + if np.isnan(minindex).any(): + with pytest.raises(ValueError): + ar.argmin(dim="x") + return + + expected0list = [ + indarr.isel(y=yi).isel(x=indi, drop=True) + for yi, indi in enumerate(minindex) + ] + expected0 = {"x": xr.concat(expected0list, dim="y")} + + result0 = ar.argmin(dim=["x"]) + for key in expected0: + assert_identical(result0[key], expected0[key]) + + result1 = ar.argmin(dim=["x"], keep_attrs=True) + expected1 = deepcopy(expected0) + expected1["x"].attrs = self.attrs + for key in expected1: + assert_identical(result1[key], expected1[key]) + + minindex = [ + x if y is None or ar.dtype.kind == "O" else y + for x, y in zip(minindex, nanindex) + ] + expected2list = [ + indarr.isel(y=yi).isel(x=indi, drop=True) + for yi, indi in enumerate(minindex) + ] + expected2 = {"x": xr.concat(expected2list, dim="y")} + expected2["x"].attrs = {} + + result2 = ar.argmin(dim=["x"], skipna=False) + + for key in expected2: + assert_identical(result2[key], expected2[key]) + + result3 = ar.argmin(...) + # TODO: remove cast once argmin typing is overloaded + min_xind = cast(DataArray, ar.isel(expected0).argmin()) + expected3 = { + "y": DataArray(min_xind), + "x": DataArray(minindex[min_xind.item()]), + } + + for key in expected3: + assert_identical(result3[key], expected3[key]) + + @pytest.mark.filterwarnings( + "ignore:Behaviour of argmin/argmax with neither dim nor :DeprecationWarning" + ) + def test_argmax_dim( + self, + x: np.ndarray, + minindex: list[int | float], + maxindex: list[int | float], + nanindex: list[int | None], + ) -> None: + ar = xr.DataArray( + x, + dims=["y", "x"], + coords={"x": np.arange(x.shape[1]) * 4, "y": 1 - np.arange(x.shape[0])}, + attrs=self.attrs, + ) + indarrnp = np.tile(np.arange(x.shape[1], dtype=np.intp), [x.shape[0], 1]) + indarr = xr.DataArray(indarrnp, dims=ar.dims, coords=ar.coords) + + if np.isnan(maxindex).any(): + with pytest.raises(ValueError): + ar.argmax(dim="x") + return + + expected0list = [ + indarr.isel(y=yi).isel(x=indi, drop=True) + for yi, indi in enumerate(maxindex) + ] + expected0 = {"x": xr.concat(expected0list, dim="y")} + + result0 = ar.argmax(dim=["x"]) + for key in expected0: + assert_identical(result0[key], expected0[key]) + + result1 = ar.argmax(dim=["x"], keep_attrs=True) + expected1 = deepcopy(expected0) + expected1["x"].attrs = self.attrs + for key in expected1: + assert_identical(result1[key], expected1[key]) + + maxindex = [ + x if y is None or ar.dtype.kind == "O" else y + for x, y in zip(maxindex, nanindex) + ] + expected2list = [ + indarr.isel(y=yi).isel(x=indi, drop=True) + for yi, indi in enumerate(maxindex) + ] + expected2 = {"x": xr.concat(expected2list, dim="y")} + expected2["x"].attrs = {} + + result2 = ar.argmax(dim=["x"], skipna=False) + + for key in expected2: + assert_identical(result2[key], expected2[key]) + + result3 = ar.argmax(...) + # TODO: remove cast once argmax typing is overloaded + max_xind = cast(DataArray, ar.isel(expected0).argmax()) + expected3 = { + "y": DataArray(max_xind), + "x": DataArray(maxindex[max_xind.item()]), + } + + for key in expected3: + assert_identical(result3[key], expected3[key]) + + +@pytest.mark.parametrize( + "x, minindices_x, minindices_y, minindices_z, minindices_xy, " + "minindices_xz, minindices_yz, minindices_xyz, maxindices_x, " + "maxindices_y, maxindices_z, maxindices_xy, maxindices_xz, maxindices_yz, " + "maxindices_xyz, nanindices_x, nanindices_y, nanindices_z, nanindices_xy, " + "nanindices_xz, nanindices_yz, nanindices_xyz", + [ + pytest.param( + np.array( + [ + [[0, 1, 2, 0], [-2, -4, 2, 0]], + [[1, 1, 1, 1], [1, 1, 1, 1]], + [[0, 0, -10, 5], [20, 0, 0, 0]], + ] + ), + {"x": np.array([[0, 2, 2, 0], [0, 0, 2, 0]])}, + {"y": np.array([[1, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 1]])}, + {"z": np.array([[0, 1], [0, 0], [2, 1]])}, + {"x": np.array([0, 0, 2, 0]), "y": np.array([1, 1, 0, 0])}, + {"x": np.array([2, 0]), "z": np.array([2, 1])}, + {"y": np.array([1, 0, 0]), "z": np.array([1, 0, 2])}, + {"x": np.array(2), "y": np.array(0), "z": np.array(2)}, + {"x": np.array([[1, 0, 0, 2], [2, 1, 0, 1]])}, + {"y": np.array([[0, 0, 0, 0], [0, 0, 0, 0], [1, 0, 1, 0]])}, + {"z": np.array([[2, 2], [0, 0], [3, 0]])}, + {"x": np.array([2, 0, 0, 2]), "y": np.array([1, 0, 0, 0])}, + {"x": np.array([2, 2]), "z": np.array([3, 0])}, + {"y": np.array([0, 0, 1]), "z": np.array([2, 0, 0])}, + {"x": np.array(2), "y": np.array(1), "z": np.array(0)}, + {"x": np.array([[None, None, None, None], [None, None, None, None]])}, + { + "y": np.array( + [ + [None, None, None, None], + [None, None, None, None], + [None, None, None, None], + ] + ) + }, + {"z": np.array([[None, None], [None, None], [None, None]])}, + { + "x": np.array([None, None, None, None]), + "y": np.array([None, None, None, None]), + }, + {"x": np.array([None, None]), "z": np.array([None, None])}, + {"y": np.array([None, None, None]), "z": np.array([None, None, None])}, + {"x": np.array(None), "y": np.array(None), "z": np.array(None)}, + id="int", + ), + pytest.param( + np.array( + [ + [[2.0, 1.0, 2.0, 0.0], [-2.0, -4.0, 2.0, 0.0]], + [[-4.0, np.nan, 2.0, np.nan], [-2.0, -4.0, 2.0, 0.0]], + [[np.nan] * 4, [np.nan] * 4], + ] + ), + {"x": np.array([[1, 0, 0, 0], [0, 0, 0, 0]])}, + { + "y": np.array( + [[1, 1, 0, 0], [0, 1, 0, 1], [np.nan, np.nan, np.nan, np.nan]] + ) + }, + {"z": np.array([[3, 1], [0, 1], [np.nan, np.nan]])}, + {"x": np.array([1, 0, 0, 0]), "y": np.array([0, 1, 0, 0])}, + {"x": np.array([1, 0]), "z": np.array([0, 1])}, + {"y": np.array([1, 0, np.nan]), "z": np.array([1, 0, np.nan])}, + {"x": np.array(0), "y": np.array(1), "z": np.array(1)}, + {"x": np.array([[0, 0, 0, 0], [0, 0, 0, 0]])}, + { + "y": np.array( + [[0, 0, 0, 0], [1, 1, 0, 1], [np.nan, np.nan, np.nan, np.nan]] + ) + }, + {"z": np.array([[0, 2], [2, 2], [np.nan, np.nan]])}, + {"x": np.array([0, 0, 0, 0]), "y": np.array([0, 0, 0, 0])}, + {"x": np.array([0, 0]), "z": np.array([2, 2])}, + {"y": np.array([0, 0, np.nan]), "z": np.array([0, 2, np.nan])}, + {"x": np.array(0), "y": np.array(0), "z": np.array(0)}, + {"x": np.array([[2, 1, 2, 1], [2, 2, 2, 2]])}, + { + "y": np.array( + [[None, None, None, None], [None, 0, None, 0], [0, 0, 0, 0]] + ) + }, + {"z": np.array([[None, None], [1, None], [0, 0]])}, + {"x": np.array([2, 1, 2, 1]), "y": np.array([0, 0, 0, 0])}, + {"x": np.array([1, 2]), "z": np.array([1, 0])}, + {"y": np.array([None, 0, 0]), "z": np.array([None, 1, 0])}, + {"x": np.array(1), "y": np.array(0), "z": np.array(1)}, + id="nan", + ), + pytest.param( + np.array( + [ + [[2.0, 1.0, 2.0, 0.0], [-2.0, -4.0, 2.0, 0.0]], + [[-4.0, np.nan, 2.0, np.nan], [-2.0, -4.0, 2.0, 0.0]], + [[np.nan] * 4, [np.nan] * 4], + ] + ).astype("object"), + {"x": np.array([[1, 0, 0, 0], [0, 0, 0, 0]])}, + { + "y": np.array( + [[1, 1, 0, 0], [0, 1, 0, 1], [np.nan, np.nan, np.nan, np.nan]] + ) + }, + {"z": np.array([[3, 1], [0, 1], [np.nan, np.nan]])}, + {"x": np.array([1, 0, 0, 0]), "y": np.array([0, 1, 0, 0])}, + {"x": np.array([1, 0]), "z": np.array([0, 1])}, + {"y": np.array([1, 0, np.nan]), "z": np.array([1, 0, np.nan])}, + {"x": np.array(0), "y": np.array(1), "z": np.array(1)}, + {"x": np.array([[0, 0, 0, 0], [0, 0, 0, 0]])}, + { + "y": np.array( + [[0, 0, 0, 0], [1, 1, 0, 1], [np.nan, np.nan, np.nan, np.nan]] + ) + }, + {"z": np.array([[0, 2], [2, 2], [np.nan, np.nan]])}, + {"x": np.array([0, 0, 0, 0]), "y": np.array([0, 0, 0, 0])}, + {"x": np.array([0, 0]), "z": np.array([2, 2])}, + {"y": np.array([0, 0, np.nan]), "z": np.array([0, 2, np.nan])}, + {"x": np.array(0), "y": np.array(0), "z": np.array(0)}, + {"x": np.array([[2, 1, 2, 1], [2, 2, 2, 2]])}, + { + "y": np.array( + [[None, None, None, None], [None, 0, None, 0], [0, 0, 0, 0]] + ) + }, + {"z": np.array([[None, None], [1, None], [0, 0]])}, + {"x": np.array([2, 1, 2, 1]), "y": np.array([0, 0, 0, 0])}, + {"x": np.array([1, 2]), "z": np.array([1, 0])}, + {"y": np.array([None, 0, 0]), "z": np.array([None, 1, 0])}, + {"x": np.array(1), "y": np.array(0), "z": np.array(1)}, + id="obj", + ), + pytest.param( + np.array( + [ + [["2015-12-31", "2020-01-02"], ["2020-01-01", "2016-01-01"]], + [["2020-01-02", "2020-01-02"], ["2020-01-02", "2020-01-02"]], + [["1900-01-01", "1-02-03"], ["1900-01-02", "1-02-03"]], + ], + dtype="datetime64[ns]", + ), + {"x": np.array([[2, 2], [2, 2]])}, + {"y": np.array([[0, 1], [0, 0], [0, 0]])}, + {"z": np.array([[0, 1], [0, 0], [1, 1]])}, + {"x": np.array([2, 2]), "y": np.array([0, 0])}, + {"x": np.array([2, 2]), "z": np.array([1, 1])}, + {"y": np.array([0, 0, 0]), "z": np.array([0, 0, 1])}, + {"x": np.array(2), "y": np.array(0), "z": np.array(1)}, + {"x": np.array([[1, 0], [1, 1]])}, + {"y": np.array([[1, 0], [0, 0], [1, 0]])}, + {"z": np.array([[1, 0], [0, 0], [0, 0]])}, + {"x": np.array([1, 0]), "y": np.array([0, 0])}, + {"x": np.array([0, 1]), "z": np.array([1, 0])}, + {"y": np.array([0, 0, 1]), "z": np.array([1, 0, 0])}, + {"x": np.array(0), "y": np.array(0), "z": np.array(1)}, + {"x": np.array([[None, None], [None, None]])}, + {"y": np.array([[None, None], [None, None], [None, None]])}, + {"z": np.array([[None, None], [None, None], [None, None]])}, + {"x": np.array([None, None]), "y": np.array([None, None])}, + {"x": np.array([None, None]), "z": np.array([None, None])}, + {"y": np.array([None, None, None]), "z": np.array([None, None, None])}, + {"x": np.array(None), "y": np.array(None), "z": np.array(None)}, + id="datetime", + ), + ], +) +class TestReduce3D(TestReduce): + def test_argmin_dim( + self, + x: np.ndarray, + minindices_x: dict[str, np.ndarray], + minindices_y: dict[str, np.ndarray], + minindices_z: dict[str, np.ndarray], + minindices_xy: dict[str, np.ndarray], + minindices_xz: dict[str, np.ndarray], + minindices_yz: dict[str, np.ndarray], + minindices_xyz: dict[str, np.ndarray], + maxindices_x: dict[str, np.ndarray], + maxindices_y: dict[str, np.ndarray], + maxindices_z: dict[str, np.ndarray], + maxindices_xy: dict[str, np.ndarray], + maxindices_xz: dict[str, np.ndarray], + maxindices_yz: dict[str, np.ndarray], + maxindices_xyz: dict[str, np.ndarray], + nanindices_x: dict[str, np.ndarray], + nanindices_y: dict[str, np.ndarray], + nanindices_z: dict[str, np.ndarray], + nanindices_xy: dict[str, np.ndarray], + nanindices_xz: dict[str, np.ndarray], + nanindices_yz: dict[str, np.ndarray], + nanindices_xyz: dict[str, np.ndarray], + ) -> None: + ar = xr.DataArray( + x, + dims=["x", "y", "z"], + coords={ + "x": np.arange(x.shape[0]) * 4, + "y": 1 - np.arange(x.shape[1]), + "z": 2 + 3 * np.arange(x.shape[2]), + }, + attrs=self.attrs, + ) + + for inds in [ + minindices_x, + minindices_y, + minindices_z, + minindices_xy, + minindices_xz, + minindices_yz, + minindices_xyz, + ]: + if np.array([np.isnan(i) for i in inds.values()]).any(): + with pytest.raises(ValueError): + ar.argmin(dim=[d for d in inds]) + return + + result0 = ar.argmin(dim=["x"]) + assert isinstance(result0, dict) + expected0 = { + key: xr.DataArray(value, dims=("y", "z")) + for key, value in minindices_x.items() + } + for key in expected0: + assert_identical(result0[key].drop_vars(["y", "z"]), expected0[key]) + + result1 = ar.argmin(dim=["y"]) + assert isinstance(result1, dict) + expected1 = { + key: xr.DataArray(value, dims=("x", "z")) + for key, value in minindices_y.items() + } + for key in expected1: + assert_identical(result1[key].drop_vars(["x", "z"]), expected1[key]) + + result2 = ar.argmin(dim=["z"]) + assert isinstance(result2, dict) + expected2 = { + key: xr.DataArray(value, dims=("x", "y")) + for key, value in minindices_z.items() + } + for key in expected2: + assert_identical(result2[key].drop_vars(["x", "y"]), expected2[key]) + + result3 = ar.argmin(dim=("x", "y")) + assert isinstance(result3, dict) + expected3 = { + key: xr.DataArray(value, dims=("z")) for key, value in minindices_xy.items() + } + for key in expected3: + assert_identical(result3[key].drop_vars("z"), expected3[key]) + + result4 = ar.argmin(dim=("x", "z")) + assert isinstance(result4, dict) + expected4 = { + key: xr.DataArray(value, dims=("y")) for key, value in minindices_xz.items() + } + for key in expected4: + assert_identical(result4[key].drop_vars("y"), expected4[key]) + + result5 = ar.argmin(dim=("y", "z")) + assert isinstance(result5, dict) + expected5 = { + key: xr.DataArray(value, dims=("x")) for key, value in minindices_yz.items() + } + for key in expected5: + assert_identical(result5[key].drop_vars("x"), expected5[key]) + + result6 = ar.argmin(...) + assert isinstance(result6, dict) + expected6 = {key: xr.DataArray(value) for key, value in minindices_xyz.items()} + for key in expected6: + assert_identical(result6[key], expected6[key]) + + minindices_x = { + key: xr.where( + nanindices_x[key] == None, # noqa: E711 + minindices_x[key], + nanindices_x[key], + ) + for key in minindices_x + } + expected7 = { + key: xr.DataArray(value, dims=("y", "z")) + for key, value in minindices_x.items() + } + + result7 = ar.argmin(dim=["x"], skipna=False) + assert isinstance(result7, dict) + for key in expected7: + assert_identical(result7[key].drop_vars(["y", "z"]), expected7[key]) + + minindices_y = { + key: xr.where( + nanindices_y[key] == None, # noqa: E711 + minindices_y[key], + nanindices_y[key], + ) + for key in minindices_y + } + expected8 = { + key: xr.DataArray(value, dims=("x", "z")) + for key, value in minindices_y.items() + } + + result8 = ar.argmin(dim=["y"], skipna=False) + assert isinstance(result8, dict) + for key in expected8: + assert_identical(result8[key].drop_vars(["x", "z"]), expected8[key]) + + minindices_z = { + key: xr.where( + nanindices_z[key] == None, # noqa: E711 + minindices_z[key], + nanindices_z[key], + ) + for key in minindices_z + } + expected9 = { + key: xr.DataArray(value, dims=("x", "y")) + for key, value in minindices_z.items() + } + + result9 = ar.argmin(dim=["z"], skipna=False) + assert isinstance(result9, dict) + for key in expected9: + assert_identical(result9[key].drop_vars(["x", "y"]), expected9[key]) + + minindices_xy = { + key: xr.where( + nanindices_xy[key] == None, # noqa: E711 + minindices_xy[key], + nanindices_xy[key], + ) + for key in minindices_xy + } + expected10 = { + key: xr.DataArray(value, dims="z") for key, value in minindices_xy.items() + } + + result10 = ar.argmin(dim=("x", "y"), skipna=False) + assert isinstance(result10, dict) + for key in expected10: + assert_identical(result10[key].drop_vars("z"), expected10[key]) + + minindices_xz = { + key: xr.where( + nanindices_xz[key] == None, # noqa: E711 + minindices_xz[key], + nanindices_xz[key], + ) + for key in minindices_xz + } + expected11 = { + key: xr.DataArray(value, dims="y") for key, value in minindices_xz.items() + } + + result11 = ar.argmin(dim=("x", "z"), skipna=False) + assert isinstance(result11, dict) + for key in expected11: + assert_identical(result11[key].drop_vars("y"), expected11[key]) + + minindices_yz = { + key: xr.where( + nanindices_yz[key] == None, # noqa: E711 + minindices_yz[key], + nanindices_yz[key], + ) + for key in minindices_yz + } + expected12 = { + key: xr.DataArray(value, dims="x") for key, value in minindices_yz.items() + } + + result12 = ar.argmin(dim=("y", "z"), skipna=False) + assert isinstance(result12, dict) + for key in expected12: + assert_identical(result12[key].drop_vars("x"), expected12[key]) + + minindices_xyz = { + key: xr.where( + nanindices_xyz[key] == None, # noqa: E711 + minindices_xyz[key], + nanindices_xyz[key], + ) + for key in minindices_xyz + } + expected13 = {key: xr.DataArray(value) for key, value in minindices_xyz.items()} + + result13 = ar.argmin(..., skipna=False) + assert isinstance(result13, dict) + for key in expected13: + assert_identical(result13[key], expected13[key]) + + def test_argmax_dim( + self, + x: np.ndarray, + minindices_x: dict[str, np.ndarray], + minindices_y: dict[str, np.ndarray], + minindices_z: dict[str, np.ndarray], + minindices_xy: dict[str, np.ndarray], + minindices_xz: dict[str, np.ndarray], + minindices_yz: dict[str, np.ndarray], + minindices_xyz: dict[str, np.ndarray], + maxindices_x: dict[str, np.ndarray], + maxindices_y: dict[str, np.ndarray], + maxindices_z: dict[str, np.ndarray], + maxindices_xy: dict[str, np.ndarray], + maxindices_xz: dict[str, np.ndarray], + maxindices_yz: dict[str, np.ndarray], + maxindices_xyz: dict[str, np.ndarray], + nanindices_x: dict[str, np.ndarray], + nanindices_y: dict[str, np.ndarray], + nanindices_z: dict[str, np.ndarray], + nanindices_xy: dict[str, np.ndarray], + nanindices_xz: dict[str, np.ndarray], + nanindices_yz: dict[str, np.ndarray], + nanindices_xyz: dict[str, np.ndarray], + ) -> None: + ar = xr.DataArray( + x, + dims=["x", "y", "z"], + coords={ + "x": np.arange(x.shape[0]) * 4, + "y": 1 - np.arange(x.shape[1]), + "z": 2 + 3 * np.arange(x.shape[2]), + }, + attrs=self.attrs, + ) + + for inds in [ + maxindices_x, + maxindices_y, + maxindices_z, + maxindices_xy, + maxindices_xz, + maxindices_yz, + maxindices_xyz, + ]: + if np.array([np.isnan(i) for i in inds.values()]).any(): + with pytest.raises(ValueError): + ar.argmax(dim=[d for d in inds]) + return + + result0 = ar.argmax(dim=["x"]) + assert isinstance(result0, dict) + expected0 = { + key: xr.DataArray(value, dims=("y", "z")) + for key, value in maxindices_x.items() + } + for key in expected0: + assert_identical(result0[key].drop_vars(["y", "z"]), expected0[key]) + + result1 = ar.argmax(dim=["y"]) + assert isinstance(result1, dict) + expected1 = { + key: xr.DataArray(value, dims=("x", "z")) + for key, value in maxindices_y.items() + } + for key in expected1: + assert_identical(result1[key].drop_vars(["x", "z"]), expected1[key]) + + result2 = ar.argmax(dim=["z"]) + assert isinstance(result2, dict) + expected2 = { + key: xr.DataArray(value, dims=("x", "y")) + for key, value in maxindices_z.items() + } + for key in expected2: + assert_identical(result2[key].drop_vars(["x", "y"]), expected2[key]) + + result3 = ar.argmax(dim=("x", "y")) + assert isinstance(result3, dict) + expected3 = { + key: xr.DataArray(value, dims=("z")) for key, value in maxindices_xy.items() + } + for key in expected3: + assert_identical(result3[key].drop_vars("z"), expected3[key]) + + result4 = ar.argmax(dim=("x", "z")) + assert isinstance(result4, dict) + expected4 = { + key: xr.DataArray(value, dims=("y")) for key, value in maxindices_xz.items() + } + for key in expected4: + assert_identical(result4[key].drop_vars("y"), expected4[key]) + + result5 = ar.argmax(dim=("y", "z")) + assert isinstance(result5, dict) + expected5 = { + key: xr.DataArray(value, dims=("x")) for key, value in maxindices_yz.items() + } + for key in expected5: + assert_identical(result5[key].drop_vars("x"), expected5[key]) + + result6 = ar.argmax(...) + assert isinstance(result6, dict) + expected6 = {key: xr.DataArray(value) for key, value in maxindices_xyz.items()} + for key in expected6: + assert_identical(result6[key], expected6[key]) + + maxindices_x = { + key: xr.where( + nanindices_x[key] == None, # noqa: E711 + maxindices_x[key], + nanindices_x[key], + ) + for key in maxindices_x + } + expected7 = { + key: xr.DataArray(value, dims=("y", "z")) + for key, value in maxindices_x.items() + } + + result7 = ar.argmax(dim=["x"], skipna=False) + assert isinstance(result7, dict) + for key in expected7: + assert_identical(result7[key].drop_vars(["y", "z"]), expected7[key]) + + maxindices_y = { + key: xr.where( + nanindices_y[key] == None, # noqa: E711 + maxindices_y[key], + nanindices_y[key], + ) + for key in maxindices_y + } + expected8 = { + key: xr.DataArray(value, dims=("x", "z")) + for key, value in maxindices_y.items() + } + + result8 = ar.argmax(dim=["y"], skipna=False) + assert isinstance(result8, dict) + for key in expected8: + assert_identical(result8[key].drop_vars(["x", "z"]), expected8[key]) + + maxindices_z = { + key: xr.where( + nanindices_z[key] == None, # noqa: E711 + maxindices_z[key], + nanindices_z[key], + ) + for key in maxindices_z + } + expected9 = { + key: xr.DataArray(value, dims=("x", "y")) + for key, value in maxindices_z.items() + } + + result9 = ar.argmax(dim=["z"], skipna=False) + assert isinstance(result9, dict) + for key in expected9: + assert_identical(result9[key].drop_vars(["x", "y"]), expected9[key]) + + maxindices_xy = { + key: xr.where( + nanindices_xy[key] == None, # noqa: E711 + maxindices_xy[key], + nanindices_xy[key], + ) + for key in maxindices_xy + } + expected10 = { + key: xr.DataArray(value, dims="z") for key, value in maxindices_xy.items() + } + + result10 = ar.argmax(dim=("x", "y"), skipna=False) + assert isinstance(result10, dict) + for key in expected10: + assert_identical(result10[key].drop_vars("z"), expected10[key]) + + maxindices_xz = { + key: xr.where( + nanindices_xz[key] == None, # noqa: E711 + maxindices_xz[key], + nanindices_xz[key], + ) + for key in maxindices_xz + } + expected11 = { + key: xr.DataArray(value, dims="y") for key, value in maxindices_xz.items() + } + + result11 = ar.argmax(dim=("x", "z"), skipna=False) + assert isinstance(result11, dict) + for key in expected11: + assert_identical(result11[key].drop_vars("y"), expected11[key]) + + maxindices_yz = { + key: xr.where( + nanindices_yz[key] == None, # noqa: E711 + maxindices_yz[key], + nanindices_yz[key], + ) + for key in maxindices_yz + } + expected12 = { + key: xr.DataArray(value, dims="x") for key, value in maxindices_yz.items() + } + + result12 = ar.argmax(dim=("y", "z"), skipna=False) + assert isinstance(result12, dict) + for key in expected12: + assert_identical(result12[key].drop_vars("x"), expected12[key]) + + maxindices_xyz = { + key: xr.where( + nanindices_xyz[key] == None, # noqa: E711 + maxindices_xyz[key], + nanindices_xyz[key], + ) + for key in maxindices_xyz + } + expected13 = {key: xr.DataArray(value) for key, value in maxindices_xyz.items()} + + result13 = ar.argmax(..., skipna=False) + assert isinstance(result13, dict) + for key in expected13: + assert_identical(result13[key], expected13[key]) + + +class TestReduceND(TestReduce): + @pytest.mark.parametrize("op", ["idxmin", "idxmax"]) + @pytest.mark.parametrize("ndim", [3, 5]) + def test_idxminmax_dask(self, op: str, ndim: int) -> None: + if not has_dask: + pytest.skip("requires dask") + + ar0_raw = xr.DataArray( + np.random.random_sample(size=[10] * ndim), + dims=[i for i in "abcdefghij"[: ndim - 1]] + ["x"], + coords={"x": np.arange(10)}, + attrs=self.attrs, + ) + + ar0_dsk = ar0_raw.chunk({}) + # Assert idx is the same with dask and without + assert_equal(getattr(ar0_dsk, op)(dim="x"), getattr(ar0_raw, op)(dim="x")) + + +@pytest.mark.parametrize("da", ("repeating_ints",), indirect=True) +def test_isin(da) -> None: + expected = DataArray( + np.asarray([[0, 0, 0], [1, 0, 0]]), + dims=list("yx"), + coords={"x": list("abc"), "y": list("de")}, + ).astype("bool") + + result = da.isin([3]).sel(y=list("de"), z=0) + assert_equal(result, expected) + + expected = DataArray( + np.asarray([[0, 0, 1], [1, 0, 0]]), + dims=list("yx"), + coords={"x": list("abc"), "y": list("de")}, + ).astype("bool") + result = da.isin([2, 3]).sel(y=list("de"), z=0) + assert_equal(result, expected) + + +def test_raise_no_warning_for_nan_in_binary_ops() -> None: + with assert_no_warnings(): + xr.DataArray([1, 2, np.nan]) > 0 + + +@pytest.mark.filterwarnings("error") +def test_no_warning_for_all_nan() -> None: + _ = xr.DataArray([np.nan, np.nan]).mean() + + +def test_name_in_masking() -> None: + name = "RingoStarr" + da = xr.DataArray(range(10), coords=[("x", range(10))], name=name) + assert da.where(da > 5).name == name + assert da.where((da > 5).rename("YokoOno")).name == name + assert da.where(da > 5, drop=True).name == name + assert da.where((da > 5).rename("YokoOno"), drop=True).name == name + + +class TestIrisConversion: + @requires_iris + def test_to_and_from_iris(self) -> None: + import cf_units # iris requirement + import iris + + # to iris + coord_dict: dict[Hashable, Any] = {} + coord_dict["distance"] = ("distance", [-2, 2], {"units": "meters"}) + coord_dict["time"] = ("time", pd.date_range("2000-01-01", periods=3)) + coord_dict["height"] = 10 + coord_dict["distance2"] = ("distance", [0, 1], {"foo": "bar"}) + coord_dict["time2"] = (("distance", "time"), [[0, 1, 2], [2, 3, 4]]) + + original = DataArray( + np.arange(6, dtype="float").reshape(2, 3), + coord_dict, + name="Temperature", + attrs={ + "baz": 123, + "units": "Kelvin", + "standard_name": "fire_temperature", + "long_name": "Fire Temperature", + }, + dims=("distance", "time"), + ) + + # Set a bad value to test the masking logic + original.data[0, 2] = np.nan + + original.attrs["cell_methods"] = "height: mean (comment: A cell method)" + actual = original.to_iris() + assert_array_equal(actual.data, original.data) + assert actual.var_name == original.name + assert tuple(d.var_name for d in actual.dim_coords) == original.dims + assert actual.cell_methods == ( + iris.coords.CellMethod( + method="mean", + coords=("height",), + intervals=(), + comments=("A cell method",), + ), + ) + + for coord, orginal_key in zip((actual.coords()), original.coords): + original_coord = original.coords[orginal_key] + assert coord.var_name == original_coord.name + assert_array_equal( + coord.points, CFDatetimeCoder().encode(original_coord.variable).values + ) + assert actual.coord_dims(coord) == original.get_axis_num( + original.coords[coord.var_name].dims + ) + + assert ( + actual.coord("distance2").attributes["foo"] + == original.coords["distance2"].attrs["foo"] + ) + assert actual.coord("distance").units == cf_units.Unit( + original.coords["distance"].units + ) + assert actual.attributes["baz"] == original.attrs["baz"] + assert actual.standard_name == original.attrs["standard_name"] + + roundtripped = DataArray.from_iris(actual) + assert_identical(original, roundtripped) + + actual.remove_coord("time") + auto_time_dimension = DataArray.from_iris(actual) + assert auto_time_dimension.dims == ("distance", "dim_1") + + @requires_iris + @requires_dask + def test_to_and_from_iris_dask(self) -> None: + import cf_units # iris requirement + import dask.array as da + import iris + + coord_dict: dict[Hashable, Any] = {} + coord_dict["distance"] = ("distance", [-2, 2], {"units": "meters"}) + coord_dict["time"] = ("time", pd.date_range("2000-01-01", periods=3)) + coord_dict["height"] = 10 + coord_dict["distance2"] = ("distance", [0, 1], {"foo": "bar"}) + coord_dict["time2"] = (("distance", "time"), [[0, 1, 2], [2, 3, 4]]) + + original = DataArray( + da.from_array(np.arange(-1, 5, dtype="float").reshape(2, 3), 3), + coord_dict, + name="Temperature", + attrs=dict( + baz=123, + units="Kelvin", + standard_name="fire_temperature", + long_name="Fire Temperature", + ), + dims=("distance", "time"), + ) + + # Set a bad value to test the masking logic + original.data = da.ma.masked_less(original.data, 0) + + original.attrs["cell_methods"] = "height: mean (comment: A cell method)" + actual = original.to_iris() + + # Be careful not to trigger the loading of the iris data + actual_data = ( + actual.core_data() if hasattr(actual, "core_data") else actual.data + ) + assert_array_equal(actual_data, original.data) + assert actual.var_name == original.name + assert tuple(d.var_name for d in actual.dim_coords) == original.dims + assert actual.cell_methods == ( + iris.coords.CellMethod( + method="mean", + coords=("height",), + intervals=(), + comments=("A cell method",), + ), + ) + + for coord, orginal_key in zip((actual.coords()), original.coords): + original_coord = original.coords[orginal_key] + assert coord.var_name == original_coord.name + assert_array_equal( + coord.points, CFDatetimeCoder().encode(original_coord.variable).values + ) + assert actual.coord_dims(coord) == original.get_axis_num( + original.coords[coord.var_name].dims + ) + + assert ( + actual.coord("distance2").attributes["foo"] + == original.coords["distance2"].attrs["foo"] + ) + assert actual.coord("distance").units == cf_units.Unit( + original.coords["distance"].units + ) + assert actual.attributes["baz"] == original.attrs["baz"] + assert actual.standard_name == original.attrs["standard_name"] + + roundtripped = DataArray.from_iris(actual) + assert_identical(original, roundtripped) + + # If the Iris version supports it then we should have a dask array + # at each stage of the conversion + if hasattr(actual, "core_data"): + assert isinstance(original.data, type(actual.core_data())) + assert isinstance(original.data, type(roundtripped.data)) + + actual.remove_coord("time") + auto_time_dimension = DataArray.from_iris(actual) + assert auto_time_dimension.dims == ("distance", "dim_1") + + @requires_iris + @pytest.mark.parametrize( + "var_name, std_name, long_name, name, attrs", + [ + ( + "var_name", + "height", + "Height", + "var_name", + {"standard_name": "height", "long_name": "Height"}, + ), + ( + None, + "height", + "Height", + "height", + {"standard_name": "height", "long_name": "Height"}, + ), + (None, None, "Height", "Height", {"long_name": "Height"}), + (None, None, None, None, {}), + ], + ) + def test_da_name_from_cube( + self, std_name, long_name, var_name, name, attrs + ) -> None: + from iris.cube import Cube + + cube = Cube([], var_name=var_name, standard_name=std_name, long_name=long_name) + result = xr.DataArray.from_iris(cube) + expected = xr.DataArray([], name=name, attrs=attrs) + xr.testing.assert_identical(result, expected) + + @requires_iris + @pytest.mark.parametrize( + "var_name, std_name, long_name, name, attrs", + [ + ( + "var_name", + "height", + "Height", + "var_name", + {"standard_name": "height", "long_name": "Height"}, + ), + ( + None, + "height", + "Height", + "height", + {"standard_name": "height", "long_name": "Height"}, + ), + (None, None, "Height", "Height", {"long_name": "Height"}), + (None, None, None, "unknown", {}), + ], + ) + def test_da_coord_name_from_cube( + self, std_name, long_name, var_name, name, attrs + ) -> None: + from iris.coords import DimCoord + from iris.cube import Cube + + latitude = DimCoord( + [-90, 0, 90], standard_name=std_name, var_name=var_name, long_name=long_name + ) + data = [0, 0, 0] + cube = Cube(data, dim_coords_and_dims=[(latitude, 0)]) + result = xr.DataArray.from_iris(cube) + expected = xr.DataArray(data, coords=[(name, [-90, 0, 90], attrs)]) + xr.testing.assert_identical(result, expected) + + @requires_iris + def test_prevent_duplicate_coord_names(self) -> None: + from iris.coords import DimCoord + from iris.cube import Cube + + # Iris enforces unique coordinate names. Because we use a different + # name resolution order a valid iris Cube with coords that have the + # same var_name would lead to duplicate dimension names in the + # DataArray + longitude = DimCoord([0, 360], standard_name="longitude", var_name="duplicate") + latitude = DimCoord( + [-90, 0, 90], standard_name="latitude", var_name="duplicate" + ) + data = [[0, 0, 0], [0, 0, 0]] + cube = Cube(data, dim_coords_and_dims=[(longitude, 0), (latitude, 1)]) + with pytest.raises(ValueError): + xr.DataArray.from_iris(cube) + + @requires_iris + @pytest.mark.parametrize( + "coord_values", + [["IA", "IL", "IN"], [0, 2, 1]], # non-numeric values # non-monotonic values + ) + def test_fallback_to_iris_AuxCoord(self, coord_values) -> None: + from iris.coords import AuxCoord + from iris.cube import Cube + + data = [0, 0, 0] + da = xr.DataArray(data, coords=[coord_values], dims=["space"]) + result = xr.DataArray.to_iris(da) + expected = Cube( + data, aux_coords_and_dims=[(AuxCoord(coord_values, var_name="space"), 0)] + ) + assert result == expected + + +def test_no_dict() -> None: + d = DataArray() + with pytest.raises(AttributeError): + d.__dict__ + + +def test_subclass_slots() -> None: + """Test that DataArray subclasses must explicitly define ``__slots__``. + + .. note:: + As of 0.13.0, this is actually mitigated into a FutureWarning for any class + defined outside of the xarray package. + """ + with pytest.raises(AttributeError) as e: + + class MyArray(DataArray): + pass + + assert str(e.value) == "MyArray must explicitly define __slots__" + + +def test_weakref() -> None: + """Classes with __slots__ are incompatible with the weakref module unless they + explicitly state __weakref__ among their slots + """ + from weakref import ref + + a = DataArray(1) + r = ref(a) + assert r() is a + + +def test_delete_coords() -> None: + """Make sure that deleting a coordinate doesn't corrupt the DataArray. + See issue #3899. + + Also test that deleting succeeds and produces the expected output. + """ + a0 = DataArray( + np.array([[1, 2, 3], [4, 5, 6]]), + dims=["y", "x"], + coords={"x": ["a", "b", "c"], "y": [-1, 1]}, + ) + assert_identical(a0, a0) + + a1 = a0.copy() + del a1.coords["y"] + + # This test will detect certain sorts of corruption in the DataArray + assert_identical(a0, a0) + + assert a0.dims == ("y", "x") + assert a1.dims == ("y", "x") + assert set(a0.coords.keys()) == {"x", "y"} + assert set(a1.coords.keys()) == {"x"} + + +def test_deepcopy_nested_attrs() -> None: + """Check attrs deep copy, see :issue:`2835`""" + da1 = xr.DataArray([[1, 2], [3, 4]], dims=("x", "y"), coords={"x": [10, 20]}) + da1.attrs["flat"] = "0" + da1.attrs["nested"] = {"level1a": "1", "level1b": "1"} + + da2 = da1.copy(deep=True) + + da2.attrs["new"] = "2" + da2.attrs.update({"new2": "2"}) + da2.attrs["flat"] = "2" + da2.attrs["nested"]["level1a"] = "2" + da2.attrs["nested"].update({"level1b": "2"}) + + # Coarse test + assert not da1.identical(da2) + + # Check attrs levels + assert da1.attrs["flat"] != da2.attrs["flat"] + assert da1.attrs["nested"] != da2.attrs["nested"] + assert "new" not in da1.attrs + assert "new2" not in da1.attrs + + +def test_deepcopy_obj_array() -> None: + x0 = DataArray(np.array([object()])) + x1 = deepcopy(x0) + assert x0.values[0] is not x1.values[0] + + +def test_deepcopy_recursive() -> None: + # GH:issue:7111 + + # direct recursion + da = xr.DataArray([1, 2], dims=["x"]) + da.attrs["other"] = da + + # TODO: cannot use assert_identical on recursive Vars yet... + # lets just ensure that deep copy works without RecursionError + da.copy(deep=True) + + # indirect recursion + da2 = xr.DataArray([5, 6], dims=["y"]) + da.attrs["other"] = da2 + da2.attrs["other"] = da + + # TODO: cannot use assert_identical on recursive Vars yet... + # lets just ensure that deep copy works without RecursionError + da.copy(deep=True) + da2.copy(deep=True) + + +def test_clip(da: DataArray) -> None: + with raise_if_dask_computes(): + result = da.clip(min=0.5) + assert result.min() >= 0.5 + + result = da.clip(max=0.5) + assert result.max() <= 0.5 + + result = da.clip(min=0.25, max=0.75) + assert result.min() >= 0.25 + assert result.max() <= 0.75 + + with raise_if_dask_computes(): + result = da.clip(min=da.mean("x"), max=da.mean("a")) + assert result.dims == da.dims + assert_array_equal( + result.data, + np.clip(da.data, da.mean("x").data[:, :, np.newaxis], da.mean("a").data), + ) + + with_nans = da.isel(time=[0, 1]).reindex_like(da) + with raise_if_dask_computes(): + result = da.clip(min=da.mean("x"), max=da.mean("a")) + result = da.clip(with_nans) + # The values should be the same where there were NaNs. + assert_array_equal(result.isel(time=[0, 1]), with_nans.isel(time=[0, 1])) + + # Unclear whether we want this work, OK to adjust the test when we have decided. + with pytest.raises(ValueError, match="cannot reindex or align along dimension.*"): + result = da.clip(min=da.mean("x"), max=da.mean("a").isel(x=[0, 1])) + + +class TestDropDuplicates: + @pytest.mark.parametrize("keep", ["first", "last", False]) + def test_drop_duplicates_1d(self, keep) -> None: + da = xr.DataArray( + [0, 5, 6, 7], dims="time", coords={"time": [0, 0, 1, 2]}, name="test" + ) + + if keep == "first": + data = [0, 6, 7] + time = [0, 1, 2] + elif keep == "last": + data = [5, 6, 7] + time = [0, 1, 2] + else: + data = [6, 7] + time = [1, 2] + + expected = xr.DataArray(data, dims="time", coords={"time": time}, name="test") + result = da.drop_duplicates("time", keep=keep) + assert_equal(expected, result) + + with pytest.raises( + ValueError, + match=re.escape( + "Dimensions ('space',) not found in data dimensions ('time',)" + ), + ): + da.drop_duplicates("space", keep=keep) + + def test_drop_duplicates_2d(self) -> None: + da = xr.DataArray( + [[0, 5, 6, 7], [2, 1, 3, 4]], + dims=["space", "time"], + coords={"space": [10, 10], "time": [0, 0, 1, 2]}, + name="test", + ) + + expected = xr.DataArray( + [[0, 6, 7]], + dims=["space", "time"], + coords={"time": ("time", [0, 1, 2]), "space": ("space", [10])}, + name="test", + ) + + result = da.drop_duplicates(["time", "space"], keep="first") + assert_equal(expected, result) + + result = da.drop_duplicates(..., keep="first") + assert_equal(expected, result) + + +class TestNumpyCoercion: + # TODO once flexible indexes refactor complete also test coercion of dimension coords + def test_from_numpy(self) -> None: + da = xr.DataArray([1, 2, 3], dims="x", coords={"lat": ("x", [4, 5, 6])}) + + assert_identical(da.as_numpy(), da) + np.testing.assert_equal(da.to_numpy(), np.array([1, 2, 3])) + np.testing.assert_equal(da["lat"].to_numpy(), np.array([4, 5, 6])) + + @requires_dask + def test_from_dask(self) -> None: + da = xr.DataArray([1, 2, 3], dims="x", coords={"lat": ("x", [4, 5, 6])}) + da_chunked = da.chunk(1) + + assert_identical(da_chunked.as_numpy(), da.compute()) + np.testing.assert_equal(da.to_numpy(), np.array([1, 2, 3])) + np.testing.assert_equal(da["lat"].to_numpy(), np.array([4, 5, 6])) + + @requires_pint + def test_from_pint(self) -> None: + from pint import Quantity + + arr = np.array([1, 2, 3]) + da = xr.DataArray( + Quantity(arr, units="Pa"), + dims="x", + coords={"lat": ("x", Quantity(arr + 3, units="m"))}, + ) + + expected = xr.DataArray(arr, dims="x", coords={"lat": ("x", arr + 3)}) + assert_identical(da.as_numpy(), expected) + np.testing.assert_equal(da.to_numpy(), arr) + np.testing.assert_equal(da["lat"].to_numpy(), arr + 3) + + @requires_sparse + def test_from_sparse(self) -> None: + import sparse + + arr = np.diagflat([1, 2, 3]) + sparr = sparse.COO.from_numpy(arr) + da = xr.DataArray( + sparr, dims=["x", "y"], coords={"elev": (("x", "y"), sparr + 3)} + ) + + expected = xr.DataArray( + arr, dims=["x", "y"], coords={"elev": (("x", "y"), arr + 3)} + ) + assert_identical(da.as_numpy(), expected) + np.testing.assert_equal(da.to_numpy(), arr) + + @requires_cupy + def test_from_cupy(self) -> None: + import cupy as cp + + arr = np.array([1, 2, 3]) + da = xr.DataArray( + cp.array(arr), dims="x", coords={"lat": ("x", cp.array(arr + 3))} + ) + + expected = xr.DataArray(arr, dims="x", coords={"lat": ("x", arr + 3)}) + assert_identical(da.as_numpy(), expected) + np.testing.assert_equal(da.to_numpy(), arr) + + @requires_dask + @requires_pint + def test_from_pint_wrapping_dask(self) -> None: + import dask + from pint import Quantity + + arr = np.array([1, 2, 3]) + d = dask.array.from_array(arr) + da = xr.DataArray( + Quantity(d, units="Pa"), + dims="x", + coords={"lat": ("x", Quantity(d, units="m") * 2)}, + ) + + result = da.as_numpy() + result.name = None # remove dask-assigned name + expected = xr.DataArray(arr, dims="x", coords={"lat": ("x", arr * 2)}) + assert_identical(result, expected) + np.testing.assert_equal(da.to_numpy(), arr) + + +class TestStackEllipsis: + # https://github.com/pydata/xarray/issues/6051 + def test_result_as_expected(self) -> None: + da = DataArray([[1, 2], [1, 2]], dims=("x", "y")) + result = da.stack(flat=[...]) + expected = da.stack(flat=da.dims) + assert_identical(result, expected) + + def test_error_on_ellipsis_without_list(self) -> None: + da = DataArray([[1, 2], [1, 2]], dims=("x", "y")) + with pytest.raises(ValueError): + da.stack(flat=...) + + +def test_nD_coord_dataarray() -> None: + # should succeed + da = DataArray( + np.ones((2, 4)), + dims=("x", "y"), + coords={ + "x": (("x", "y"), np.arange(8).reshape((2, 4))), + "y": ("y", np.arange(4)), + }, + ) + _assert_internal_invariants(da, check_default_indexes=True) + + da2 = DataArray(np.ones(4), dims=("y"), coords={"y": ("y", np.arange(4))}) + da3 = DataArray(np.ones(4), dims=("z")) + + _, actual = xr.align(da, da2) + assert_identical(da2, actual) + + expected = da.drop_vars("x") + _, actual = xr.broadcast(da, da2) + assert_identical(expected, actual) + + actual, _ = xr.broadcast(da, da3) + expected = da.expand_dims(z=4, axis=-1) + assert_identical(actual, expected) + + da4 = DataArray(np.ones((2, 4)), coords={"x": 0}, dims=["x", "y"]) + _assert_internal_invariants(da4, check_default_indexes=True) + assert "x" not in da4.xindexes + assert "x" in da4.coords + + +def test_lazy_data_variable_not_loaded(): + # GH8753 + array = InaccessibleArray(np.array([1, 2, 3])) + v = Variable(data=array, dims="x") + # No data needs to be accessed, so no error should be raised + da = xr.DataArray(v) + # No data needs to be accessed, so no error should be raised + xr.DataArray(da) diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_dataset.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_dataset.py new file mode 100644 index 0000000..81b27da --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_dataset.py @@ -0,0 +1,7435 @@ +from __future__ import annotations + +import pickle +import re +import sys +import warnings +from collections.abc import Hashable +from copy import copy, deepcopy +from io import StringIO +from textwrap import dedent +from typing import Any, Literal + +import numpy as np +import pandas as pd +import pytest +from pandas.core.indexes.datetimes import DatetimeIndex + +# remove once numpy 2.0 is the oldest supported version +try: + from numpy.exceptions import RankWarning # type: ignore[attr-defined,unused-ignore] +except ImportError: + from numpy import RankWarning + +import xarray as xr +from xarray import ( + DataArray, + Dataset, + IndexVariable, + MergeError, + Variable, + align, + backends, + broadcast, + open_dataset, + set_options, +) +from xarray.coding.cftimeindex import CFTimeIndex +from xarray.core import dtypes, indexing, utils +from xarray.core.common import duck_array_ops, full_like +from xarray.core.coordinates import Coordinates, DatasetCoordinates +from xarray.core.indexes import Index, PandasIndex +from xarray.core.utils import is_scalar +from xarray.namedarray.pycompat import array_type, integer_types +from xarray.testing import _assert_internal_invariants +from xarray.tests import ( + DuckArrayWrapper, + InaccessibleArray, + UnexpectedDataAccess, + assert_allclose, + assert_array_equal, + assert_equal, + assert_identical, + assert_no_warnings, + assert_writeable, + create_test_data, + has_cftime, + has_dask, + raise_if_dask_computes, + requires_bottleneck, + requires_cftime, + requires_cupy, + requires_dask, + requires_numexpr, + requires_pint, + requires_scipy, + requires_sparse, + source_ndarray, +) + +try: + from pandas.errors import UndefinedVariableError +except ImportError: + # TODO: remove once we stop supporting pandas<1.4.3 + from pandas.core.computation.ops import UndefinedVariableError + + +try: + import dask.array as da +except ImportError: + pass + +# from numpy version 2.0 trapz is deprecated and renamed to trapezoid +# remove once numpy 2.0 is the oldest supported version +try: + from numpy import trapezoid # type: ignore[attr-defined,unused-ignore] +except ImportError: + from numpy import trapz as trapezoid + +sparse_array_type = array_type("sparse") + +pytestmark = [ + pytest.mark.filterwarnings("error:Mean of empty slice"), + pytest.mark.filterwarnings("error:All-NaN (slice|axis) encountered"), +] + + +def create_append_test_data(seed=None) -> tuple[Dataset, Dataset, Dataset]: + rs = np.random.RandomState(seed) + + lat = [2, 1, 0] + lon = [0, 1, 2] + nt1 = 3 + nt2 = 2 + time1 = pd.date_range("2000-01-01", periods=nt1) + time2 = pd.date_range("2000-02-01", periods=nt2) + string_var = np.array(["a", "bc", "def"], dtype=object) + string_var_to_append = np.array(["asdf", "asdfg"], dtype=object) + string_var_fixed_length = np.array(["aa", "bb", "cc"], dtype="|S2") + string_var_fixed_length_to_append = np.array(["dd", "ee"], dtype="|S2") + unicode_var = np.array(["áó", "áó", "áó"]) + datetime_var = np.array( + ["2019-01-01", "2019-01-02", "2019-01-03"], dtype="datetime64[s]" + ) + datetime_var_to_append = np.array( + ["2019-01-04", "2019-01-05"], dtype="datetime64[s]" + ) + bool_var = np.array([True, False, True], dtype=bool) + bool_var_to_append = np.array([False, True], dtype=bool) + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "Converting non-nanosecond") + ds = xr.Dataset( + data_vars={ + "da": xr.DataArray( + rs.rand(3, 3, nt1), + coords=[lat, lon, time1], + dims=["lat", "lon", "time"], + ), + "string_var": ("time", string_var), + "string_var_fixed_length": ("time", string_var_fixed_length), + "unicode_var": ("time", unicode_var), + "datetime_var": ("time", datetime_var), + "bool_var": ("time", bool_var), + } + ) + + ds_to_append = xr.Dataset( + data_vars={ + "da": xr.DataArray( + rs.rand(3, 3, nt2), + coords=[lat, lon, time2], + dims=["lat", "lon", "time"], + ), + "string_var": ("time", string_var_to_append), + "string_var_fixed_length": ("time", string_var_fixed_length_to_append), + "unicode_var": ("time", unicode_var[:nt2]), + "datetime_var": ("time", datetime_var_to_append), + "bool_var": ("time", bool_var_to_append), + } + ) + + ds_with_new_var = xr.Dataset( + data_vars={ + "new_var": xr.DataArray( + rs.rand(3, 3, nt1 + nt2), + coords=[lat, lon, time1.append(time2)], + dims=["lat", "lon", "time"], + ) + } + ) + + assert_writeable(ds) + assert_writeable(ds_to_append) + assert_writeable(ds_with_new_var) + return ds, ds_to_append, ds_with_new_var + + +def create_append_string_length_mismatch_test_data(dtype) -> tuple[Dataset, Dataset]: + def make_datasets(data, data_to_append) -> tuple[Dataset, Dataset]: + ds = xr.Dataset( + {"temperature": (["time"], data)}, + coords={"time": [0, 1, 2]}, + ) + ds_to_append = xr.Dataset( + {"temperature": (["time"], data_to_append)}, coords={"time": [0, 1, 2]} + ) + assert_writeable(ds) + assert_writeable(ds_to_append) + return ds, ds_to_append + + u2_strings = ["ab", "cd", "ef"] + u5_strings = ["abc", "def", "ghijk"] + + s2_strings = np.array(["aa", "bb", "cc"], dtype="|S2") + s3_strings = np.array(["aaa", "bbb", "ccc"], dtype="|S3") + + if dtype == "U": + return make_datasets(u2_strings, u5_strings) + elif dtype == "S": + return make_datasets(s2_strings, s3_strings) + else: + raise ValueError(f"unsupported dtype {dtype}.") + + +def create_test_multiindex() -> Dataset: + mindex = pd.MultiIndex.from_product( + [["a", "b"], [1, 2]], names=("level_1", "level_2") + ) + return Dataset({}, Coordinates.from_pandas_multiindex(mindex, "x")) + + +def create_test_stacked_array() -> tuple[DataArray, DataArray]: + x = DataArray(pd.Index(np.r_[:10], name="x")) + y = DataArray(pd.Index(np.r_[:20], name="y")) + a = x * y + b = x * y * y + return a, b + + +class InaccessibleVariableDataStore(backends.InMemoryDataStore): + """ + Store that does not allow any data access. + """ + + def __init__(self): + super().__init__() + self._indexvars = set() + + def store(self, variables, *args, **kwargs) -> None: + super().store(variables, *args, **kwargs) + for k, v in variables.items(): + if isinstance(v, IndexVariable): + self._indexvars.add(k) + + def get_variables(self): + def lazy_inaccessible(k, v): + if k in self._indexvars: + return v + data = indexing.LazilyIndexedArray(InaccessibleArray(v.values)) + return Variable(v.dims, data, v.attrs) + + return {k: lazy_inaccessible(k, v) for k, v in self._variables.items()} + + +class DuckBackendArrayWrapper(backends.common.BackendArray): + """Mimic a BackendArray wrapper around DuckArrayWrapper""" + + def __init__(self, array): + self.array = DuckArrayWrapper(array) + self.shape = array.shape + self.dtype = array.dtype + + def get_array(self): + return self.array + + def __getitem__(self, key): + return self.array[key.tuple] + + +class AccessibleAsDuckArrayDataStore(backends.InMemoryDataStore): + """ + Store that returns a duck array, not convertible to numpy array, + on read. Modeled after nVIDIA's kvikio. + """ + + def __init__(self): + super().__init__() + self._indexvars = set() + + def store(self, variables, *args, **kwargs) -> None: + super().store(variables, *args, **kwargs) + for k, v in variables.items(): + if isinstance(v, IndexVariable): + self._indexvars.add(k) + + def get_variables(self) -> dict[Any, xr.Variable]: + def lazy_accessible(k, v) -> xr.Variable: + if k in self._indexvars: + return v + data = indexing.LazilyIndexedArray(DuckBackendArrayWrapper(v.values)) + return Variable(v.dims, data, v.attrs) + + return {k: lazy_accessible(k, v) for k, v in self._variables.items()} + + +class TestDataset: + def test_repr(self) -> None: + data = create_test_data(seed=123) + data.attrs["foo"] = "bar" + # need to insert str dtype at runtime to handle different endianness + expected = dedent( + """\ + Size: 2kB + Dimensions: (dim2: 9, dim3: 10, time: 20, dim1: 8) + Coordinates: + * dim2 (dim2) float64 72B 0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0 + * dim3 (dim3) {} 40B 'a' 'b' 'c' 'd' 'e' 'f' 'g' 'h' 'i' 'j' + * time (time) datetime64[ns] 160B 2000-01-01 2000-01-02 ... 2000-01-20 + numbers (dim3) int64 80B 0 1 2 0 0 1 1 2 2 3 + Dimensions without coordinates: dim1 + Data variables: + var1 (dim1, dim2) float64 576B -1.086 0.9973 0.283 ... 0.4684 -0.8312 + var2 (dim1, dim2) float64 576B 1.162 -1.097 -2.123 ... 1.267 0.3328 + var3 (dim3, dim1) float64 640B 0.5565 -0.2121 0.4563 ... -0.2452 -0.3616 + Attributes: + foo: bar""".format( + data["dim3"].dtype + ) + ) + actual = "\n".join(x.rstrip() for x in repr(data).split("\n")) + print(actual) + assert expected == actual + + with set_options(display_width=100): + max_len = max(map(len, repr(data).split("\n"))) + assert 90 < max_len < 100 + + expected = dedent( + """\ + Size: 0B + Dimensions: () + Data variables: + *empty*""" + ) + actual = "\n".join(x.rstrip() for x in repr(Dataset()).split("\n")) + print(actual) + assert expected == actual + + # verify that ... doesn't appear for scalar coordinates + data = Dataset({"foo": ("x", np.ones(10))}).mean() + expected = dedent( + """\ + Size: 8B + Dimensions: () + Data variables: + foo float64 8B 1.0""" + ) + actual = "\n".join(x.rstrip() for x in repr(data).split("\n")) + print(actual) + assert expected == actual + + # verify long attributes are truncated + data = Dataset(attrs={"foo": "bar" * 1000}) + assert len(repr(data)) < 1000 + + def test_repr_multiindex(self) -> None: + data = create_test_multiindex() + expected = dedent( + """\ + Size: 96B + Dimensions: (x: 4) + Coordinates: + * x (x) object 32B MultiIndex + * level_1 (x) object 32B 'a' 'a' 'b' 'b' + * level_2 (x) int64 32B 1 2 1 2 + Data variables: + *empty*""" + ) + actual = "\n".join(x.rstrip() for x in repr(data).split("\n")) + print(actual) + assert expected == actual + + # verify that long level names are not truncated + midx = pd.MultiIndex.from_product( + [["a", "b"], [1, 2]], names=("a_quite_long_level_name", "level_2") + ) + midx_coords = Coordinates.from_pandas_multiindex(midx, "x") + data = Dataset({}, midx_coords) + expected = dedent( + """\ + Size: 96B + Dimensions: (x: 4) + Coordinates: + * x (x) object 32B MultiIndex + * a_quite_long_level_name (x) object 32B 'a' 'a' 'b' 'b' + * level_2 (x) int64 32B 1 2 1 2 + Data variables: + *empty*""" + ) + actual = "\n".join(x.rstrip() for x in repr(data).split("\n")) + print(actual) + assert expected == actual + + def test_repr_period_index(self) -> None: + data = create_test_data(seed=456) + data.coords["time"] = pd.period_range("2000-01-01", periods=20, freq="D") + + # check that creating the repr doesn't raise an error #GH645 + repr(data) + + def test_unicode_data(self) -> None: + # regression test for GH834 + data = Dataset({"foø": ["ba®"]}, attrs={"å": "∑"}) + repr(data) # should not raise + + byteorder = "<" if sys.byteorder == "little" else ">" + expected = dedent( + """\ + Size: 12B + Dimensions: (foø: 1) + Coordinates: + * foø (foø) %cU3 12B %r + Data variables: + *empty* + Attributes: + å: ∑""" + % (byteorder, "ba®") + ) + actual = str(data) + assert expected == actual + + def test_repr_nep18(self) -> None: + class Array: + def __init__(self): + self.shape = (2,) + self.ndim = 1 + self.dtype = np.dtype(np.float64) + + def __array_function__(self, *args, **kwargs): + return NotImplemented + + def __array_ufunc__(self, *args, **kwargs): + return NotImplemented + + def __repr__(self): + return "Custom\nArray" + + dataset = Dataset({"foo": ("x", Array())}) + expected = dedent( + """\ + Size: 16B + Dimensions: (x: 2) + Dimensions without coordinates: x + Data variables: + foo (x) float64 16B Custom Array""" + ) + assert expected == repr(dataset) + + def test_info(self) -> None: + ds = create_test_data(seed=123) + ds = ds.drop_vars("dim3") # string type prints differently in PY2 vs PY3 + ds.attrs["unicode_attr"] = "ba®" + ds.attrs["string_attr"] = "bar" + + buf = StringIO() + ds.info(buf=buf) + + expected = dedent( + """\ + xarray.Dataset { + dimensions: + \tdim2 = 9 ; + \ttime = 20 ; + \tdim1 = 8 ; + \tdim3 = 10 ; + + variables: + \tfloat64 dim2(dim2) ; + \tdatetime64[ns] time(time) ; + \tfloat64 var1(dim1, dim2) ; + \t\tvar1:foo = variable ; + \tfloat64 var2(dim1, dim2) ; + \t\tvar2:foo = variable ; + \tfloat64 var3(dim3, dim1) ; + \t\tvar3:foo = variable ; + \tint64 numbers(dim3) ; + + // global attributes: + \t:unicode_attr = ba® ; + \t:string_attr = bar ; + }""" + ) + actual = buf.getvalue() + assert expected == actual + buf.close() + + def test_constructor(self) -> None: + x1 = ("x", 2 * np.arange(100)) + x2 = ("x", np.arange(1000)) + z = (["x", "y"], np.arange(1000).reshape(100, 10)) + + with pytest.raises(ValueError, match=r"conflicting sizes"): + Dataset({"a": x1, "b": x2}) + with pytest.raises(TypeError, match=r"tuple of form"): + Dataset({"x": (1, 2, 3, 4, 5, 6, 7)}) + with pytest.raises(ValueError, match=r"already exists as a scalar"): + Dataset({"x": 0, "y": ("x", [1, 2, 3])}) + + # nD coordinate variable "x" sharing name with dimension + actual = Dataset({"a": x1, "x": z}) + assert "x" not in actual.xindexes + _assert_internal_invariants(actual, check_default_indexes=True) + + # verify handling of DataArrays + expected = Dataset({"x": x1, "z": z}) + actual = Dataset({"z": expected["z"]}) + assert_identical(expected, actual) + + def test_constructor_1d(self) -> None: + expected = Dataset({"x": (["x"], 5.0 + np.arange(5))}) + actual = Dataset({"x": 5.0 + np.arange(5)}) + assert_identical(expected, actual) + + actual = Dataset({"x": [5, 6, 7, 8, 9]}) + assert_identical(expected, actual) + + @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") + def test_constructor_0d(self) -> None: + expected = Dataset({"x": ([], 1)}) + for arg in [1, np.array(1), expected["x"]]: + actual = Dataset({"x": arg}) + assert_identical(expected, actual) + + class Arbitrary: + pass + + d = pd.Timestamp("2000-01-01T12") + args = [ + True, + None, + 3.4, + np.nan, + "hello", + b"raw", + np.datetime64("2000-01-01"), + d, + d.to_pydatetime(), + Arbitrary(), + ] + for arg in args: + print(arg) + expected = Dataset({"x": ([], arg)}) + actual = Dataset({"x": arg}) + assert_identical(expected, actual) + + def test_constructor_auto_align(self) -> None: + a = DataArray([1, 2], [("x", [0, 1])]) + b = DataArray([3, 4], [("x", [1, 2])]) + + # verify align uses outer join + expected = Dataset( + {"a": ("x", [1, 2, np.nan]), "b": ("x", [np.nan, 3, 4])}, {"x": [0, 1, 2]} + ) + actual = Dataset({"a": a, "b": b}) + assert_identical(expected, actual) + + # regression test for GH346 + assert isinstance(actual.variables["x"], IndexVariable) + + # variable with different dimensions + c = ("y", [3, 4]) + expected2 = expected.merge({"c": c}) + actual = Dataset({"a": a, "b": b, "c": c}) + assert_identical(expected2, actual) + + # variable that is only aligned against the aligned variables + d = ("x", [3, 2, 1]) + expected3 = expected.merge({"d": d}) + actual = Dataset({"a": a, "b": b, "d": d}) + assert_identical(expected3, actual) + + e = ("x", [0, 0]) + with pytest.raises(ValueError, match=r"conflicting sizes"): + Dataset({"a": a, "b": b, "e": e}) + + def test_constructor_pandas_sequence(self) -> None: + ds = self.make_example_math_dataset() + pandas_objs = { + var_name: ds[var_name].to_pandas() for var_name in ["foo", "bar"] + } + ds_based_on_pandas = Dataset(pandas_objs, ds.coords, attrs=ds.attrs) + del ds_based_on_pandas["x"] + assert_equal(ds, ds_based_on_pandas) + + # reindex pandas obj, check align works + rearranged_index = reversed(pandas_objs["foo"].index) + pandas_objs["foo"] = pandas_objs["foo"].reindex(rearranged_index) + ds_based_on_pandas = Dataset(pandas_objs, ds.coords, attrs=ds.attrs) + del ds_based_on_pandas["x"] + assert_equal(ds, ds_based_on_pandas) + + def test_constructor_pandas_single(self) -> None: + das = [ + DataArray(np.random.rand(4), dims=["a"]), # series + DataArray(np.random.rand(4, 3), dims=["a", "b"]), # df + ] + + for a in das: + pandas_obj = a.to_pandas() + ds_based_on_pandas = Dataset(pandas_obj) # type: ignore # TODO: improve typing of __init__ + for dim in ds_based_on_pandas.data_vars: + assert_array_equal(ds_based_on_pandas[dim], pandas_obj[dim]) + + def test_constructor_compat(self) -> None: + data = {"x": DataArray(0, coords={"y": 1}), "y": ("z", [1, 1, 1])} + expected = Dataset({"x": 0}, {"y": ("z", [1, 1, 1])}) + actual = Dataset(data) + assert_identical(expected, actual) + + data = {"y": ("z", [1, 1, 1]), "x": DataArray(0, coords={"y": 1})} + actual = Dataset(data) + assert_identical(expected, actual) + + original = Dataset( + {"a": (("x", "y"), np.ones((2, 3)))}, + {"c": (("x", "y"), np.zeros((2, 3))), "x": [0, 1]}, + ) + expected = Dataset( + {"a": ("x", np.ones(2)), "b": ("y", np.ones(3))}, + {"c": (("x", "y"), np.zeros((2, 3))), "x": [0, 1]}, + ) + + actual = Dataset( + {"a": original["a"][:, 0], "b": original["a"][0].drop_vars("x")} + ) + assert_identical(expected, actual) + + data = {"x": DataArray(0, coords={"y": 3}), "y": ("z", [1, 1, 1])} + with pytest.raises(MergeError): + Dataset(data) + + data = {"x": DataArray(0, coords={"y": 1}), "y": [1, 1]} + actual = Dataset(data) + expected = Dataset({"x": 0}, {"y": [1, 1]}) + assert_identical(expected, actual) + + def test_constructor_with_coords(self) -> None: + with pytest.raises(ValueError, match=r"found in both data_vars and"): + Dataset({"a": ("x", [1])}, {"a": ("x", [1])}) + + ds = Dataset({}, {"a": ("x", [1])}) + assert not ds.data_vars + assert list(ds.coords.keys()) == ["a"] + + mindex = pd.MultiIndex.from_product( + [["a", "b"], [1, 2]], names=("level_1", "level_2") + ) + with pytest.raises(ValueError, match=r"conflicting MultiIndex"): + with pytest.warns( + FutureWarning, + match=".*`pandas.MultiIndex`.*no longer be implicitly promoted.*", + ): + Dataset({}, {"x": mindex, "y": mindex}) + Dataset({}, {"x": mindex, "level_1": range(4)}) + + def test_constructor_no_default_index(self) -> None: + # explicitly passing a Coordinates object skips the creation of default index + ds = Dataset(coords=Coordinates({"x": [1, 2, 3]}, indexes={})) + assert "x" in ds + assert "x" not in ds.xindexes + + def test_constructor_multiindex(self) -> None: + midx = pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")) + coords = Coordinates.from_pandas_multiindex(midx, "x") + + ds = Dataset(coords=coords) + assert_identical(ds, coords.to_dataset()) + + with pytest.warns( + FutureWarning, + match=".*`pandas.MultiIndex`.*no longer be implicitly promoted.*", + ): + Dataset(data_vars={"x": midx}) + + with pytest.warns( + FutureWarning, + match=".*`pandas.MultiIndex`.*no longer be implicitly promoted.*", + ): + Dataset(coords={"x": midx}) + + def test_constructor_custom_index(self) -> None: + class CustomIndex(Index): ... + + coords = Coordinates( + coords={"x": ("x", [1, 2, 3])}, indexes={"x": CustomIndex()} + ) + ds = Dataset(coords=coords) + assert isinstance(ds.xindexes["x"], CustomIndex) + + # test coordinate variables copied + assert ds.variables["x"] is not coords.variables["x"] + + @pytest.mark.filterwarnings("ignore:return type") + def test_properties(self) -> None: + ds = create_test_data() + + # dims / sizes + # These exact types aren't public API, but this makes sure we don't + # change them inadvertently: + assert isinstance(ds.dims, utils.Frozen) + # TODO change after deprecation cycle in GH #8500 is complete + assert isinstance(ds.dims.mapping, dict) + assert type(ds.dims.mapping) is dict # noqa: E721 + with pytest.warns( + FutureWarning, + match=" To access a mapping from dimension names to lengths, please use `Dataset.sizes`", + ): + assert ds.dims == ds.sizes + assert ds.sizes == {"dim1": 8, "dim2": 9, "dim3": 10, "time": 20} + + # dtypes + assert isinstance(ds.dtypes, utils.Frozen) + assert isinstance(ds.dtypes.mapping, dict) + assert ds.dtypes == { + "var1": np.dtype("float64"), + "var2": np.dtype("float64"), + "var3": np.dtype("float64"), + } + + # data_vars + assert list(ds) == list(ds.data_vars) + assert list(ds.keys()) == list(ds.data_vars) + assert "aasldfjalskdfj" not in ds.variables + assert "dim1" in repr(ds.variables) + assert len(ds) == 3 + assert bool(ds) + + assert list(ds.data_vars) == ["var1", "var2", "var3"] + assert list(ds.data_vars.keys()) == ["var1", "var2", "var3"] + assert "var1" in ds.data_vars + assert "dim1" not in ds.data_vars + assert "numbers" not in ds.data_vars + assert len(ds.data_vars) == 3 + + # xindexes + assert set(ds.xindexes) == {"dim2", "dim3", "time"} + assert len(ds.xindexes) == 3 + assert "dim2" in repr(ds.xindexes) + assert all([isinstance(idx, Index) for idx in ds.xindexes.values()]) + + # indexes + assert set(ds.indexes) == {"dim2", "dim3", "time"} + assert len(ds.indexes) == 3 + assert "dim2" in repr(ds.indexes) + assert all([isinstance(idx, pd.Index) for idx in ds.indexes.values()]) + + # coords + assert list(ds.coords) == ["dim2", "dim3", "time", "numbers"] + assert "dim2" in ds.coords + assert "numbers" in ds.coords + assert "var1" not in ds.coords + assert "dim1" not in ds.coords + assert len(ds.coords) == 4 + + # nbytes + assert ( + Dataset({"x": np.int64(1), "y": np.array([1, 2], dtype=np.float32)}).nbytes + == 16 + ) + + def test_warn_ds_dims_deprecation(self) -> None: + # TODO remove after deprecation cycle in GH #8500 is complete + ds = create_test_data() + + with pytest.warns(FutureWarning, match="return type"): + ds.dims["dim1"] + + with pytest.warns(FutureWarning, match="return type"): + ds.dims.keys() + + with pytest.warns(FutureWarning, match="return type"): + ds.dims.values() + + with pytest.warns(FutureWarning, match="return type"): + ds.dims.items() + + with assert_no_warnings(): + len(ds.dims) + ds.dims.__iter__() + "dim1" in ds.dims + + def test_asarray(self) -> None: + ds = Dataset({"x": 0}) + with pytest.raises(TypeError, match=r"cannot directly convert"): + np.asarray(ds) + + def test_get_index(self) -> None: + ds = Dataset({"foo": (("x", "y"), np.zeros((2, 3)))}, coords={"x": ["a", "b"]}) + assert ds.get_index("x").equals(pd.Index(["a", "b"])) + assert ds.get_index("y").equals(pd.Index([0, 1, 2])) + with pytest.raises(KeyError): + ds.get_index("z") + + def test_attr_access(self) -> None: + ds = Dataset( + {"tmin": ("x", [42], {"units": "Celsius"})}, attrs={"title": "My test data"} + ) + assert_identical(ds.tmin, ds["tmin"]) + assert_identical(ds.tmin.x, ds.x) + + assert ds.title == ds.attrs["title"] + assert ds.tmin.units == ds["tmin"].attrs["units"] + + assert {"tmin", "title"} <= set(dir(ds)) + assert "units" in set(dir(ds.tmin)) + + # should defer to variable of same name + ds.attrs["tmin"] = -999 + assert ds.attrs["tmin"] == -999 + assert_identical(ds.tmin, ds["tmin"]) + + def test_variable(self) -> None: + a = Dataset() + d = np.random.random((10, 3)) + a["foo"] = (("time", "x"), d) + assert "foo" in a.variables + assert "foo" in a + a["bar"] = (("time", "x"), d) + # order of creation is preserved + assert list(a.variables) == ["foo", "bar"] + assert_array_equal(a["foo"].values, d) + # try to add variable with dim (10,3) with data that's (3,10) + with pytest.raises(ValueError): + a["qux"] = (("time", "x"), d.T) + + def test_modify_inplace(self) -> None: + a = Dataset() + vec = np.random.random((10,)) + attributes = {"foo": "bar"} + a["x"] = ("x", vec, attributes) + assert "x" in a.coords + assert isinstance(a.coords["x"].to_index(), pd.Index) + assert_identical(a.coords["x"].variable, a.variables["x"]) + b = Dataset() + b["x"] = ("x", vec, attributes) + assert_identical(a["x"], b["x"]) + assert a.sizes == b.sizes + # this should work + a["x"] = ("x", vec[:5]) + a["z"] = ("x", np.arange(5)) + with pytest.raises(ValueError): + # now it shouldn't, since there is a conflicting length + a["x"] = ("x", vec[:4]) + arr = np.random.random((10, 1)) + scal = np.array(0) + with pytest.raises(ValueError): + a["y"] = ("y", arr) + with pytest.raises(ValueError): + a["y"] = ("y", scal) + assert "y" not in a.dims + + def test_coords_properties(self) -> None: + # use int64 for repr consistency on windows + data = Dataset( + { + "x": ("x", np.array([-1, -2], "int64")), + "y": ("y", np.array([0, 1, 2], "int64")), + "foo": (["x", "y"], np.random.randn(2, 3)), + }, + {"a": ("x", np.array([4, 5], "int64")), "b": np.int64(-10)}, + ) + + coords = data.coords + assert isinstance(coords, DatasetCoordinates) + + # len + assert len(coords) == 4 + + # iter + assert list(coords) == ["x", "y", "a", "b"] + + assert_identical(coords["x"].variable, data["x"].variable) + assert_identical(coords["y"].variable, data["y"].variable) + + assert "x" in coords + assert "a" in coords + assert 0 not in coords + assert "foo" not in coords + + with pytest.raises(KeyError): + coords["foo"] + with pytest.raises(KeyError): + coords[0] + + # repr + expected = dedent( + """\ + Coordinates: + * x (x) int64 16B -1 -2 + * y (y) int64 24B 0 1 2 + a (x) int64 16B 4 5 + b int64 8B -10""" + ) + actual = repr(coords) + assert expected == actual + + # dims + assert coords.sizes == {"x": 2, "y": 3} + + # dtypes + assert coords.dtypes == { + "x": np.dtype("int64"), + "y": np.dtype("int64"), + "a": np.dtype("int64"), + "b": np.dtype("int64"), + } + + def test_coords_modify(self) -> None: + data = Dataset( + { + "x": ("x", [-1, -2]), + "y": ("y", [0, 1, 2]), + "foo": (["x", "y"], np.random.randn(2, 3)), + }, + {"a": ("x", [4, 5]), "b": -10}, + ) + + actual = data.copy(deep=True) + actual.coords["x"] = ("x", ["a", "b"]) + assert_array_equal(actual["x"], ["a", "b"]) + + actual = data.copy(deep=True) + actual.coords["z"] = ("z", ["a", "b"]) + assert_array_equal(actual["z"], ["a", "b"]) + + actual = data.copy(deep=True) + with pytest.raises(ValueError, match=r"conflicting dimension sizes"): + actual.coords["x"] = ("x", [-1]) + assert_identical(actual, data) # should not be modified + + actual = data.copy() + del actual.coords["b"] + expected = data.reset_coords("b", drop=True) + assert_identical(expected, actual) + + with pytest.raises(KeyError): + del data.coords["not_found"] + + with pytest.raises(KeyError): + del data.coords["foo"] + + actual = data.copy(deep=True) + actual.coords.update({"c": 11}) + expected = data.merge({"c": 11}).set_coords("c") + assert_identical(expected, actual) + + # regression test for GH3746 + del actual.coords["x"] + assert "x" not in actual.xindexes + + def test_update_index(self) -> None: + actual = Dataset(coords={"x": [1, 2, 3]}) + actual["x"] = ["a", "b", "c"] + assert actual.xindexes["x"].to_pandas_index().equals(pd.Index(["a", "b", "c"])) + + def test_coords_setitem_with_new_dimension(self) -> None: + actual = Dataset() + actual.coords["foo"] = ("x", [1, 2, 3]) + expected = Dataset(coords={"foo": ("x", [1, 2, 3])}) + assert_identical(expected, actual) + + def test_coords_setitem_multiindex(self) -> None: + data = create_test_multiindex() + with pytest.raises(ValueError, match=r"cannot drop or update.*corrupt.*index "): + data.coords["level_1"] = range(4) + + def test_coords_set(self) -> None: + one_coord = Dataset({"x": ("x", [0]), "yy": ("x", [1]), "zzz": ("x", [2])}) + two_coords = Dataset({"zzz": ("x", [2])}, {"x": ("x", [0]), "yy": ("x", [1])}) + all_coords = Dataset( + coords={"x": ("x", [0]), "yy": ("x", [1]), "zzz": ("x", [2])} + ) + + actual = one_coord.set_coords("x") + assert_identical(one_coord, actual) + actual = one_coord.set_coords(["x"]) + assert_identical(one_coord, actual) + + actual = one_coord.set_coords("yy") + assert_identical(two_coords, actual) + + actual = one_coord.set_coords(["yy", "zzz"]) + assert_identical(all_coords, actual) + + actual = one_coord.reset_coords() + assert_identical(one_coord, actual) + actual = two_coords.reset_coords() + assert_identical(one_coord, actual) + actual = all_coords.reset_coords() + assert_identical(one_coord, actual) + + actual = all_coords.reset_coords(["yy", "zzz"]) + assert_identical(one_coord, actual) + actual = all_coords.reset_coords("zzz") + assert_identical(two_coords, actual) + + with pytest.raises(ValueError, match=r"cannot remove index"): + one_coord.reset_coords("x") + + actual = all_coords.reset_coords("zzz", drop=True) + expected = all_coords.drop_vars("zzz") + assert_identical(expected, actual) + expected = two_coords.drop_vars("zzz") + assert_identical(expected, actual) + + def test_coords_to_dataset(self) -> None: + orig = Dataset({"foo": ("y", [-1, 0, 1])}, {"x": 10, "y": [2, 3, 4]}) + expected = Dataset(coords={"x": 10, "y": [2, 3, 4]}) + actual = orig.coords.to_dataset() + assert_identical(expected, actual) + + def test_coords_merge(self) -> None: + orig_coords = Dataset(coords={"a": ("x", [1, 2]), "x": [0, 1]}).coords + other_coords = Dataset(coords={"b": ("x", ["a", "b"]), "x": [0, 1]}).coords + expected = Dataset( + coords={"a": ("x", [1, 2]), "b": ("x", ["a", "b"]), "x": [0, 1]} + ) + actual = orig_coords.merge(other_coords) + assert_identical(expected, actual) + actual = other_coords.merge(orig_coords) + assert_identical(expected, actual) + + other_coords = Dataset(coords={"x": ("x", ["a"])}).coords + with pytest.raises(MergeError): + orig_coords.merge(other_coords) + other_coords = Dataset(coords={"x": ("x", ["a", "b"])}).coords + with pytest.raises(MergeError): + orig_coords.merge(other_coords) + other_coords = Dataset(coords={"x": ("x", ["a", "b", "c"])}).coords + with pytest.raises(MergeError): + orig_coords.merge(other_coords) + + other_coords = Dataset(coords={"a": ("x", [8, 9])}).coords + expected = Dataset(coords={"x": range(2)}) + actual = orig_coords.merge(other_coords) + assert_identical(expected, actual) + actual = other_coords.merge(orig_coords) + assert_identical(expected, actual) + + other_coords = Dataset(coords={"x": np.nan}).coords + actual = orig_coords.merge(other_coords) + assert_identical(orig_coords.to_dataset(), actual) + actual = other_coords.merge(orig_coords) + assert_identical(orig_coords.to_dataset(), actual) + + def test_coords_merge_mismatched_shape(self) -> None: + orig_coords = Dataset(coords={"a": ("x", [1, 1])}).coords + other_coords = Dataset(coords={"a": 1}).coords + expected = orig_coords.to_dataset() + actual = orig_coords.merge(other_coords) + assert_identical(expected, actual) + + other_coords = Dataset(coords={"a": ("y", [1])}).coords + expected = Dataset(coords={"a": (["x", "y"], [[1], [1]])}) + actual = orig_coords.merge(other_coords) + assert_identical(expected, actual) + + actual = other_coords.merge(orig_coords) + assert_identical(expected.transpose(), actual) + + orig_coords = Dataset(coords={"a": ("x", [np.nan])}).coords + other_coords = Dataset(coords={"a": np.nan}).coords + expected = orig_coords.to_dataset() + actual = orig_coords.merge(other_coords) + assert_identical(expected, actual) + + def test_data_vars_properties(self) -> None: + ds = Dataset() + ds["foo"] = (("x",), [1.0]) + ds["bar"] = 2.0 + + # iter + assert set(ds.data_vars) == {"foo", "bar"} + assert "foo" in ds.data_vars + assert "x" not in ds.data_vars + assert_identical(ds["foo"], ds.data_vars["foo"]) + + # repr + expected = dedent( + """\ + Data variables: + foo (x) float64 8B 1.0 + bar float64 8B 2.0""" + ) + actual = repr(ds.data_vars) + assert expected == actual + + # dtypes + assert ds.data_vars.dtypes == { + "foo": np.dtype("float64"), + "bar": np.dtype("float64"), + } + + # len + ds.coords["x"] = [1] + assert len(ds.data_vars) == 2 + + # https://github.com/pydata/xarray/issues/7588 + with pytest.raises( + AssertionError, match="something is wrong with Dataset._coord_names" + ): + ds._coord_names = {"w", "x", "y", "z"} + len(ds.data_vars) + + def test_equals_and_identical(self) -> None: + data = create_test_data(seed=42) + assert data.equals(data) + assert data.identical(data) + + data2 = create_test_data(seed=42) + data2.attrs["foobar"] = "baz" + assert data.equals(data2) + assert not data.identical(data2) + + del data2["time"] + assert not data.equals(data2) + + data = create_test_data(seed=42).rename({"var1": None}) + assert data.equals(data) + assert data.identical(data) + + data2 = data.reset_coords() + assert not data2.equals(data) + assert not data2.identical(data) + + def test_equals_failures(self) -> None: + data = create_test_data() + assert not data.equals("foo") # type: ignore[arg-type] + assert not data.identical(123) # type: ignore[arg-type] + assert not data.broadcast_equals({1: 2}) # type: ignore[arg-type] + + def test_broadcast_equals(self) -> None: + data1 = Dataset(coords={"x": 0}) + data2 = Dataset(coords={"x": [0]}) + assert data1.broadcast_equals(data2) + assert not data1.equals(data2) + assert not data1.identical(data2) + + def test_attrs(self) -> None: + data = create_test_data(seed=42) + data.attrs = {"foobar": "baz"} + assert data.attrs["foobar"], "baz" + assert isinstance(data.attrs, dict) + + def test_chunks_does_not_load_data(self) -> None: + # regression test for GH6538 + store = InaccessibleVariableDataStore() + create_test_data().dump_to_store(store) + ds = open_dataset(store) + assert ds.chunks == {} + + @requires_dask + def test_chunk(self) -> None: + data = create_test_data() + for v in data.variables.values(): + assert isinstance(v.data, np.ndarray) + assert data.chunks == {} + + reblocked = data.chunk() + for k, v in reblocked.variables.items(): + if k in reblocked.dims: + assert isinstance(v.data, np.ndarray) + else: + assert isinstance(v.data, da.Array) + + expected_chunks: dict[Hashable, tuple[int, ...]] = { + "dim1": (8,), + "dim2": (9,), + "dim3": (10,), + } + assert reblocked.chunks == expected_chunks + + # test kwargs form of chunks + assert data.chunk(expected_chunks).chunks == expected_chunks + + def get_dask_names(ds): + return {k: v.data.name for k, v in ds.items()} + + orig_dask_names = get_dask_names(reblocked) + + reblocked = data.chunk({"time": 5, "dim1": 5, "dim2": 5, "dim3": 5}) + # time is not a dim in any of the data_vars, so it + # doesn't get chunked + expected_chunks = {"dim1": (5, 3), "dim2": (5, 4), "dim3": (5, 5)} + assert reblocked.chunks == expected_chunks + + # make sure dask names change when rechunking by different amounts + # regression test for GH3350 + new_dask_names = get_dask_names(reblocked) + for k, v in new_dask_names.items(): + assert v != orig_dask_names[k] + + reblocked = data.chunk(expected_chunks) + assert reblocked.chunks == expected_chunks + + # reblock on already blocked data + orig_dask_names = get_dask_names(reblocked) + reblocked = reblocked.chunk(expected_chunks) + new_dask_names = get_dask_names(reblocked) + assert reblocked.chunks == expected_chunks + assert_identical(reblocked, data) + # rechunking with same chunk sizes should not change names + for k, v in new_dask_names.items(): + assert v == orig_dask_names[k] + + with pytest.raises( + ValueError, + match=re.escape( + "chunks keys ('foo',) not found in data dimensions ('dim2', 'dim3', 'time', 'dim1')" + ), + ): + data.chunk({"foo": 10}) + + @requires_dask + def test_dask_is_lazy(self) -> None: + store = InaccessibleVariableDataStore() + create_test_data().dump_to_store(store) + ds = open_dataset(store).chunk() + + with pytest.raises(UnexpectedDataAccess): + ds.load() + with pytest.raises(UnexpectedDataAccess): + ds["var1"].values + + # these should not raise UnexpectedDataAccess: + ds.var1.data + ds.isel(time=10) + ds.isel(time=slice(10), dim1=[0]).isel(dim1=0, dim2=-1) + ds.transpose() + ds.mean() + ds.fillna(0) + ds.rename({"dim1": "foobar"}) + ds.set_coords("var1") + ds.drop_vars("var1") + + def test_isel(self) -> None: + data = create_test_data() + slicers: dict[Hashable, slice] = { + "dim1": slice(None, None, 2), + "dim2": slice(0, 2), + } + ret = data.isel(slicers) + + # Verify that only the specified dimension was altered + assert list(data.dims) == list(ret.dims) + for d in data.dims: + if d in slicers: + assert ret.sizes[d] == np.arange(data.sizes[d])[slicers[d]].size + else: + assert data.sizes[d] == ret.sizes[d] + # Verify that the data is what we expect + for v in data.variables: + assert data[v].dims == ret[v].dims + assert data[v].attrs == ret[v].attrs + slice_list = [slice(None)] * data[v].values.ndim + for d, s in slicers.items(): + if d in data[v].dims: + inds = np.nonzero(np.array(data[v].dims) == d)[0] + for ind in inds: + slice_list[ind] = s + expected = data[v].values[tuple(slice_list)] + actual = ret[v].values + np.testing.assert_array_equal(expected, actual) + + with pytest.raises(ValueError): + data.isel(not_a_dim=slice(0, 2)) + with pytest.raises( + ValueError, + match=r"Dimensions {'not_a_dim'} do not exist. Expected " + r"one or more of " + r"[\w\W]*'dim\d'[\w\W]*'dim\d'[\w\W]*'time'[\w\W]*'dim\d'[\w\W]*", + ): + data.isel(not_a_dim=slice(0, 2)) + with pytest.warns( + UserWarning, + match=r"Dimensions {'not_a_dim'} do not exist. " + r"Expected one or more of " + r"[\w\W]*'dim\d'[\w\W]*'dim\d'[\w\W]*'time'[\w\W]*'dim\d'[\w\W]*", + ): + data.isel(not_a_dim=slice(0, 2), missing_dims="warn") + assert_identical(data, data.isel(not_a_dim=slice(0, 2), missing_dims="ignore")) + + ret = data.isel(dim1=0) + assert {"time": 20, "dim2": 9, "dim3": 10} == ret.sizes + assert set(data.data_vars) == set(ret.data_vars) + assert set(data.coords) == set(ret.coords) + assert set(data.xindexes) == set(ret.xindexes) + + ret = data.isel(time=slice(2), dim1=0, dim2=slice(5)) + assert {"time": 2, "dim2": 5, "dim3": 10} == ret.sizes + assert set(data.data_vars) == set(ret.data_vars) + assert set(data.coords) == set(ret.coords) + assert set(data.xindexes) == set(ret.xindexes) + + ret = data.isel(time=0, dim1=0, dim2=slice(5)) + assert {"dim2": 5, "dim3": 10} == ret.sizes + assert set(data.data_vars) == set(ret.data_vars) + assert set(data.coords) == set(ret.coords) + assert set(data.xindexes) == set(list(ret.xindexes) + ["time"]) + + def test_isel_fancy(self) -> None: + # isel with fancy indexing. + data = create_test_data() + + pdim1 = [1, 2, 3] + pdim2 = [4, 5, 1] + pdim3 = [1, 2, 3] + actual = data.isel( + dim1=(("test_coord",), pdim1), + dim2=(("test_coord",), pdim2), + dim3=(("test_coord",), pdim3), + ) + assert "test_coord" in actual.dims + assert actual.coords["test_coord"].shape == (len(pdim1),) + + # Should work with DataArray + actual = data.isel( + dim1=DataArray(pdim1, dims="test_coord"), + dim2=(("test_coord",), pdim2), + dim3=(("test_coord",), pdim3), + ) + assert "test_coord" in actual.dims + assert actual.coords["test_coord"].shape == (len(pdim1),) + expected = data.isel( + dim1=(("test_coord",), pdim1), + dim2=(("test_coord",), pdim2), + dim3=(("test_coord",), pdim3), + ) + assert_identical(actual, expected) + + # DataArray with coordinate + idx1 = DataArray(pdim1, dims=["a"], coords={"a": np.random.randn(3)}) + idx2 = DataArray(pdim2, dims=["b"], coords={"b": np.random.randn(3)}) + idx3 = DataArray(pdim3, dims=["c"], coords={"c": np.random.randn(3)}) + # Should work with DataArray + actual = data.isel(dim1=idx1, dim2=idx2, dim3=idx3) + assert "a" in actual.dims + assert "b" in actual.dims + assert "c" in actual.dims + assert "time" in actual.coords + assert "dim2" in actual.coords + assert "dim3" in actual.coords + expected = data.isel( + dim1=(("a",), pdim1), dim2=(("b",), pdim2), dim3=(("c",), pdim3) + ) + expected = expected.assign_coords(a=idx1["a"], b=idx2["b"], c=idx3["c"]) + assert_identical(actual, expected) + + idx1 = DataArray(pdim1, dims=["a"], coords={"a": np.random.randn(3)}) + idx2 = DataArray(pdim2, dims=["a"]) + idx3 = DataArray(pdim3, dims=["a"]) + # Should work with DataArray + actual = data.isel(dim1=idx1, dim2=idx2, dim3=idx3) + assert "a" in actual.dims + assert "time" in actual.coords + assert "dim2" in actual.coords + assert "dim3" in actual.coords + expected = data.isel( + dim1=(("a",), pdim1), dim2=(("a",), pdim2), dim3=(("a",), pdim3) + ) + expected = expected.assign_coords(a=idx1["a"]) + assert_identical(actual, expected) + + actual = data.isel(dim1=(("points",), pdim1), dim2=(("points",), pdim2)) + assert "points" in actual.dims + assert "dim3" in actual.dims + assert "dim3" not in actual.data_vars + np.testing.assert_array_equal(data["dim2"][pdim2], actual["dim2"]) + + # test that the order of the indexers doesn't matter + assert_identical( + data.isel(dim1=(("points",), pdim1), dim2=(("points",), pdim2)), + data.isel(dim2=(("points",), pdim2), dim1=(("points",), pdim1)), + ) + # make sure we're raising errors in the right places + with pytest.raises(IndexError, match=r"Dimensions of indexers mismatch"): + data.isel(dim1=(("points",), [1, 2]), dim2=(("points",), [1, 2, 3])) + with pytest.raises(TypeError, match=r"cannot use a Dataset"): + data.isel(dim1=Dataset({"points": [1, 2]})) + + # test to be sure we keep around variables that were not indexed + ds = Dataset({"x": [1, 2, 3, 4], "y": 0}) + actual = ds.isel(x=(("points",), [0, 1, 2])) + assert_identical(ds["y"], actual["y"]) + + # tests using index or DataArray as indexers + stations = Dataset() + stations["station"] = (("station",), ["A", "B", "C"]) + stations["dim1s"] = (("station",), [1, 2, 3]) + stations["dim2s"] = (("station",), [4, 5, 1]) + + actual = data.isel(dim1=stations["dim1s"], dim2=stations["dim2s"]) + assert "station" in actual.coords + assert "station" in actual.dims + assert_identical(actual["station"].drop_vars(["dim2"]), stations["station"]) + + with pytest.raises(ValueError, match=r"conflicting values/indexes on "): + data.isel( + dim1=DataArray( + [0, 1, 2], dims="station", coords={"station": [0, 1, 2]} + ), + dim2=DataArray( + [0, 1, 2], dims="station", coords={"station": [0, 1, 3]} + ), + ) + + # multi-dimensional selection + stations = Dataset() + stations["a"] = (("a",), ["A", "B", "C"]) + stations["b"] = (("b",), [0, 1]) + stations["dim1s"] = (("a", "b"), [[1, 2], [2, 3], [3, 4]]) + stations["dim2s"] = (("a",), [4, 5, 1]) + actual = data.isel(dim1=stations["dim1s"], dim2=stations["dim2s"]) + assert "a" in actual.coords + assert "a" in actual.dims + assert "b" in actual.coords + assert "b" in actual.dims + assert "dim2" in actual.coords + assert "a" in actual["dim2"].dims + + assert_identical(actual["a"].drop_vars(["dim2"]), stations["a"]) + assert_identical(actual["b"], stations["b"]) + expected_var1 = data["var1"].variable[ + stations["dim1s"].variable, stations["dim2s"].variable + ] + expected_var2 = data["var2"].variable[ + stations["dim1s"].variable, stations["dim2s"].variable + ] + expected_var3 = data["var3"].variable[slice(None), stations["dim1s"].variable] + assert_equal(actual["a"].drop_vars("dim2"), stations["a"]) + assert_array_equal(actual["var1"], expected_var1) + assert_array_equal(actual["var2"], expected_var2) + assert_array_equal(actual["var3"], expected_var3) + + # test that drop works + ds = xr.Dataset({"a": (("x",), [1, 2, 3])}, coords={"b": (("x",), [5, 6, 7])}) + + actual = ds.isel({"x": 1}, drop=False) + expected = xr.Dataset({"a": 2}, coords={"b": 6}) + assert_identical(actual, expected) + + actual = ds.isel({"x": 1}, drop=True) + expected = xr.Dataset({"a": 2}) + assert_identical(actual, expected) + + actual = ds.isel({"x": DataArray(1)}, drop=False) + expected = xr.Dataset({"a": 2}, coords={"b": 6}) + assert_identical(actual, expected) + + actual = ds.isel({"x": DataArray(1)}, drop=True) + expected = xr.Dataset({"a": 2}) + assert_identical(actual, expected) + + def test_isel_dataarray(self) -> None: + """Test for indexing by DataArray""" + data = create_test_data() + # indexing with DataArray with same-name coordinates. + indexing_da = DataArray( + np.arange(1, 4), dims=["dim1"], coords={"dim1": np.random.randn(3)} + ) + actual = data.isel(dim1=indexing_da) + assert_identical(indexing_da["dim1"], actual["dim1"]) + assert_identical(data["dim2"], actual["dim2"]) + + # Conflict in the dimension coordinate + indexing_da = DataArray( + np.arange(1, 4), dims=["dim2"], coords={"dim2": np.random.randn(3)} + ) + with pytest.raises(IndexError, match=r"dimension coordinate 'dim2'"): + data.isel(dim2=indexing_da) + # Also the case for DataArray + with pytest.raises(IndexError, match=r"dimension coordinate 'dim2'"): + data["var2"].isel(dim2=indexing_da) + with pytest.raises(IndexError, match=r"dimension coordinate 'dim2'"): + data["dim2"].isel(dim2=indexing_da) + + # same name coordinate which does not conflict + indexing_da = DataArray( + np.arange(1, 4), dims=["dim2"], coords={"dim2": data["dim2"].values[1:4]} + ) + actual = data.isel(dim2=indexing_da) + assert_identical(actual["dim2"], indexing_da["dim2"]) + + # Silently drop conflicted (non-dimensional) coordinate of indexer + indexing_da = DataArray( + np.arange(1, 4), + dims=["dim2"], + coords={ + "dim2": data["dim2"].values[1:4], + "numbers": ("dim2", np.arange(2, 5)), + }, + ) + actual = data.isel(dim2=indexing_da) + assert_identical(actual["numbers"], data["numbers"]) + + # boolean data array with coordinate with the same name + indexing_da = DataArray( + np.arange(1, 10), dims=["dim2"], coords={"dim2": data["dim2"].values} + ) + indexing_da = indexing_da < 3 + actual = data.isel(dim2=indexing_da) + assert_identical(actual["dim2"], data["dim2"][:2]) + + # boolean data array with non-dimensioncoordinate + indexing_da = DataArray( + np.arange(1, 10), + dims=["dim2"], + coords={ + "dim2": data["dim2"].values, + "non_dim": (("dim2",), np.random.randn(9)), + "non_dim2": 0, + }, + ) + indexing_da = indexing_da < 3 + actual = data.isel(dim2=indexing_da) + assert_identical( + actual["dim2"].drop_vars("non_dim").drop_vars("non_dim2"), data["dim2"][:2] + ) + assert_identical(actual["non_dim"], indexing_da["non_dim"][:2]) + assert_identical(actual["non_dim2"], indexing_da["non_dim2"]) + + # non-dimension coordinate will be also attached + indexing_da = DataArray( + np.arange(1, 4), + dims=["dim2"], + coords={"non_dim": (("dim2",), np.random.randn(3))}, + ) + actual = data.isel(dim2=indexing_da) + assert "non_dim" in actual + assert "non_dim" in actual.coords + + # Index by a scalar DataArray + indexing_da = DataArray(3, dims=[], coords={"station": 2}) + actual = data.isel(dim2=indexing_da) + assert "station" in actual + actual = data.isel(dim2=indexing_da["station"]) + assert "station" in actual + + # indexer generated from coordinates + indexing_ds = Dataset({}, coords={"dim2": [0, 1, 2]}) + with pytest.raises(IndexError, match=r"dimension coordinate 'dim2'"): + actual = data.isel(dim2=indexing_ds["dim2"]) + + def test_isel_fancy_convert_index_variable(self) -> None: + # select index variable "x" with a DataArray of dim "z" + # -> drop index and convert index variable to base variable + ds = xr.Dataset({"foo": ("x", [1, 2, 3])}, coords={"x": [0, 1, 2]}) + idxr = xr.DataArray([1], dims="z", name="x") + actual = ds.isel(x=idxr) + assert "x" not in actual.xindexes + assert not isinstance(actual.x.variable, IndexVariable) + + def test_sel(self) -> None: + data = create_test_data() + int_slicers = {"dim1": slice(None, None, 2), "dim2": slice(2), "dim3": slice(3)} + loc_slicers = { + "dim1": slice(None, None, 2), + "dim2": slice(0, 0.5), + "dim3": slice("a", "c"), + } + assert_equal(data.isel(int_slicers), data.sel(loc_slicers)) + data["time"] = ("time", pd.date_range("2000-01-01", periods=20)) + assert_equal(data.isel(time=0), data.sel(time="2000-01-01")) + assert_equal( + data.isel(time=slice(10)), data.sel(time=slice("2000-01-01", "2000-01-10")) + ) + assert_equal(data, data.sel(time=slice("1999", "2005"))) + times = pd.date_range("2000-01-01", periods=3) + assert_equal(data.isel(time=slice(3)), data.sel(time=times)) + assert_equal( + data.isel(time=slice(3)), data.sel(time=(data["time.dayofyear"] <= 3)) + ) + + td = pd.to_timedelta(np.arange(3), unit="days") + data = Dataset({"x": ("td", np.arange(3)), "td": td}) + assert_equal(data, data.sel(td=td)) + assert_equal(data, data.sel(td=slice("3 days"))) + assert_equal(data.isel(td=0), data.sel(td=pd.Timedelta("0 days"))) + assert_equal(data.isel(td=0), data.sel(td=pd.Timedelta("0h"))) + assert_equal(data.isel(td=slice(1, 3)), data.sel(td=slice("1 days", "2 days"))) + + def test_sel_dataarray(self) -> None: + data = create_test_data() + + ind = DataArray([0.0, 0.5, 1.0], dims=["dim2"]) + actual = data.sel(dim2=ind) + assert_equal(actual, data.isel(dim2=[0, 1, 2])) + + # with different dimension + ind = DataArray([0.0, 0.5, 1.0], dims=["new_dim"]) + actual = data.sel(dim2=ind) + expected = data.isel(dim2=Variable("new_dim", [0, 1, 2])) + assert "new_dim" in actual.dims + assert_equal(actual, expected) + + # Multi-dimensional + ind = DataArray([[0.0], [0.5], [1.0]], dims=["new_dim", "new_dim2"]) + actual = data.sel(dim2=ind) + expected = data.isel(dim2=Variable(("new_dim", "new_dim2"), [[0], [1], [2]])) + assert "new_dim" in actual.dims + assert "new_dim2" in actual.dims + assert_equal(actual, expected) + + # with coordinate + ind = DataArray( + [0.0, 0.5, 1.0], dims=["new_dim"], coords={"new_dim": ["a", "b", "c"]} + ) + actual = data.sel(dim2=ind) + expected = data.isel(dim2=[0, 1, 2]).rename({"dim2": "new_dim"}) + assert "new_dim" in actual.dims + assert "new_dim" in actual.coords + assert_equal( + actual.drop_vars("new_dim").drop_vars("dim2"), expected.drop_vars("new_dim") + ) + assert_equal(actual["new_dim"].drop_vars("dim2"), ind["new_dim"]) + + # with conflicted coordinate (silently ignored) + ind = DataArray( + [0.0, 0.5, 1.0], dims=["dim2"], coords={"dim2": ["a", "b", "c"]} + ) + actual = data.sel(dim2=ind) + expected = data.isel(dim2=[0, 1, 2]) + assert_equal(actual, expected) + + # with conflicted coordinate (silently ignored) + ind = DataArray( + [0.0, 0.5, 1.0], + dims=["new_dim"], + coords={"new_dim": ["a", "b", "c"], "dim2": 3}, + ) + actual = data.sel(dim2=ind) + assert_equal( + actual["new_dim"].drop_vars("dim2"), ind["new_dim"].drop_vars("dim2") + ) + expected = data.isel(dim2=[0, 1, 2]) + expected["dim2"] = (("new_dim"), expected["dim2"].values) + assert_equal(actual["dim2"].drop_vars("new_dim"), expected["dim2"]) + assert actual["var1"].dims == ("dim1", "new_dim") + + # with non-dimensional coordinate + ind = DataArray( + [0.0, 0.5, 1.0], + dims=["dim2"], + coords={ + "dim2": ["a", "b", "c"], + "numbers": ("dim2", [0, 1, 2]), + "new_dim": ("dim2", [1.1, 1.2, 1.3]), + }, + ) + actual = data.sel(dim2=ind) + expected = data.isel(dim2=[0, 1, 2]) + assert_equal(actual.drop_vars("new_dim"), expected) + assert np.allclose(actual["new_dim"].values, ind["new_dim"].values) + + def test_sel_dataarray_mindex(self) -> None: + midx = pd.MultiIndex.from_product([list("abc"), [0, 1]], names=("one", "two")) + midx_coords = Coordinates.from_pandas_multiindex(midx, "x") + midx_coords["y"] = range(3) + + mds = xr.Dataset( + {"var": (("x", "y"), np.random.rand(6, 3))}, coords=midx_coords + ) + + actual_isel = mds.isel(x=xr.DataArray(np.arange(3), dims="x")) + actual_sel = mds.sel(x=DataArray(midx[:3], dims="x")) + assert actual_isel["x"].dims == ("x",) + assert actual_sel["x"].dims == ("x",) + assert_identical(actual_isel, actual_sel) + + actual_isel = mds.isel(x=xr.DataArray(np.arange(3), dims="z")) + actual_sel = mds.sel(x=Variable("z", midx[:3])) + assert actual_isel["x"].dims == ("z",) + assert actual_sel["x"].dims == ("z",) + assert_identical(actual_isel, actual_sel) + + # with coordinate + actual_isel = mds.isel( + x=xr.DataArray(np.arange(3), dims="z", coords={"z": [0, 1, 2]}) + ) + actual_sel = mds.sel( + x=xr.DataArray(midx[:3], dims="z", coords={"z": [0, 1, 2]}) + ) + assert actual_isel["x"].dims == ("z",) + assert actual_sel["x"].dims == ("z",) + assert_identical(actual_isel, actual_sel) + + # Vectorized indexing with level-variables raises an error + with pytest.raises(ValueError, match=r"Vectorized selection is "): + mds.sel(one=["a", "b"]) + + with pytest.raises( + ValueError, + match=r"Vectorized selection is not available along coordinate 'x' with a multi-index", + ): + mds.sel( + x=xr.DataArray( + [np.array(midx[:2]), np.array(midx[-2:])], dims=["a", "b"] + ) + ) + + def test_sel_categorical(self) -> None: + ind = pd.Series(["foo", "bar"], dtype="category") + df = pd.DataFrame({"ind": ind, "values": [1, 2]}) + ds = df.set_index("ind").to_xarray() + actual = ds.sel(ind="bar") + expected = ds.isel(ind=1) + assert_identical(expected, actual) + + def test_sel_categorical_error(self) -> None: + ind = pd.Series(["foo", "bar"], dtype="category") + df = pd.DataFrame({"ind": ind, "values": [1, 2]}) + ds = df.set_index("ind").to_xarray() + with pytest.raises(ValueError): + ds.sel(ind="bar", method="nearest") + with pytest.raises(ValueError): + ds.sel(ind="bar", tolerance="nearest") + + def test_categorical_index(self) -> None: + cat = pd.CategoricalIndex( + ["foo", "bar", "foo"], + categories=["foo", "bar", "baz", "qux", "quux", "corge"], + ) + ds = xr.Dataset( + {"var": ("cat", np.arange(3))}, + coords={"cat": ("cat", cat), "c": ("cat", [0, 1, 1])}, + ) + # test slice + actual1 = ds.sel(cat="foo") + expected1 = ds.isel(cat=[0, 2]) + assert_identical(expected1, actual1) + # make sure the conversion to the array works + actual2 = ds.sel(cat="foo")["cat"].values + assert (actual2 == np.array(["foo", "foo"])).all() + + ds = ds.set_index(index=["cat", "c"]) + actual3 = ds.unstack("index") + assert actual3["var"].shape == (2, 2) + + def test_categorical_reindex(self) -> None: + cat = pd.CategoricalIndex( + ["foo", "bar", "baz"], + categories=["foo", "bar", "baz", "qux", "quux", "corge"], + ) + ds = xr.Dataset( + {"var": ("cat", np.arange(3))}, + coords={"cat": ("cat", cat), "c": ("cat", [0, 1, 2])}, + ) + actual = ds.reindex(cat=["foo"])["cat"].values + assert (actual == np.array(["foo"])).all() + + def test_categorical_multiindex(self) -> None: + i1 = pd.Series([0, 0]) + cat = pd.CategoricalDtype(categories=["foo", "baz", "bar"]) + i2 = pd.Series(["baz", "bar"], dtype=cat) + + df = pd.DataFrame({"i1": i1, "i2": i2, "values": [1, 2]}).set_index( + ["i1", "i2"] + ) + actual = df.to_xarray() + assert actual["values"].shape == (1, 2) + + def test_sel_drop(self) -> None: + data = Dataset({"foo": ("x", [1, 2, 3])}, {"x": [0, 1, 2]}) + expected = Dataset({"foo": 1}) + selected = data.sel(x=0, drop=True) + assert_identical(expected, selected) + + expected = Dataset({"foo": 1}, {"x": 0}) + selected = data.sel(x=0, drop=False) + assert_identical(expected, selected) + + data = Dataset({"foo": ("x", [1, 2, 3])}) + expected = Dataset({"foo": 1}) + selected = data.sel(x=0, drop=True) + assert_identical(expected, selected) + + def test_sel_drop_mindex(self) -> None: + midx = pd.MultiIndex.from_arrays([["a", "a"], [1, 2]], names=("foo", "bar")) + midx_coords = Coordinates.from_pandas_multiindex(midx, "x") + data = Dataset(coords=midx_coords) + + actual = data.sel(foo="a", drop=True) + assert "foo" not in actual.coords + + actual = data.sel(foo="a", drop=False) + assert_equal(actual.foo, DataArray("a", coords={"foo": "a"})) + + def test_isel_drop(self) -> None: + data = Dataset({"foo": ("x", [1, 2, 3])}, {"x": [0, 1, 2]}) + expected = Dataset({"foo": 1}) + selected = data.isel(x=0, drop=True) + assert_identical(expected, selected) + + expected = Dataset({"foo": 1}, {"x": 0}) + selected = data.isel(x=0, drop=False) + assert_identical(expected, selected) + + def test_head(self) -> None: + data = create_test_data() + + expected = data.isel(time=slice(5), dim2=slice(6)) + actual = data.head(time=5, dim2=6) + assert_equal(expected, actual) + + expected = data.isel(time=slice(0)) + actual = data.head(time=0) + assert_equal(expected, actual) + + expected = data.isel({dim: slice(6) for dim in data.dims}) + actual = data.head(6) + assert_equal(expected, actual) + + expected = data.isel({dim: slice(5) for dim in data.dims}) + actual = data.head() + assert_equal(expected, actual) + + with pytest.raises(TypeError, match=r"either dict-like or a single int"): + data.head([3]) # type: ignore[arg-type] + with pytest.raises(TypeError, match=r"expected integer type"): + data.head(dim2=3.1) + with pytest.raises(ValueError, match=r"expected positive int"): + data.head(time=-3) + + def test_tail(self) -> None: + data = create_test_data() + + expected = data.isel(time=slice(-5, None), dim2=slice(-6, None)) + actual = data.tail(time=5, dim2=6) + assert_equal(expected, actual) + + expected = data.isel(dim1=slice(0)) + actual = data.tail(dim1=0) + assert_equal(expected, actual) + + expected = data.isel({dim: slice(-6, None) for dim in data.dims}) + actual = data.tail(6) + assert_equal(expected, actual) + + expected = data.isel({dim: slice(-5, None) for dim in data.dims}) + actual = data.tail() + assert_equal(expected, actual) + + with pytest.raises(TypeError, match=r"either dict-like or a single int"): + data.tail([3]) # type: ignore[arg-type] + with pytest.raises(TypeError, match=r"expected integer type"): + data.tail(dim2=3.1) + with pytest.raises(ValueError, match=r"expected positive int"): + data.tail(time=-3) + + def test_thin(self) -> None: + data = create_test_data() + + expected = data.isel(time=slice(None, None, 5), dim2=slice(None, None, 6)) + actual = data.thin(time=5, dim2=6) + assert_equal(expected, actual) + + expected = data.isel({dim: slice(None, None, 6) for dim in data.dims}) + actual = data.thin(6) + assert_equal(expected, actual) + + with pytest.raises(TypeError, match=r"either dict-like or a single int"): + data.thin([3]) # type: ignore[arg-type] + with pytest.raises(TypeError, match=r"expected integer type"): + data.thin(dim2=3.1) + with pytest.raises(ValueError, match=r"cannot be zero"): + data.thin(time=0) + with pytest.raises(ValueError, match=r"expected positive int"): + data.thin(time=-3) + + @pytest.mark.filterwarnings("ignore::DeprecationWarning") + def test_sel_fancy(self) -> None: + data = create_test_data() + + # add in a range() index + data["dim1"] = data.dim1 + + pdim1 = [1, 2, 3] + pdim2 = [4, 5, 1] + pdim3 = [1, 2, 3] + expected = data.isel( + dim1=Variable(("test_coord",), pdim1), + dim2=Variable(("test_coord",), pdim2), + dim3=Variable(("test_coord"), pdim3), + ) + actual = data.sel( + dim1=Variable(("test_coord",), data.dim1[pdim1]), + dim2=Variable(("test_coord",), data.dim2[pdim2]), + dim3=Variable(("test_coord",), data.dim3[pdim3]), + ) + assert_identical(expected, actual) + + # DataArray Indexer + idx_t = DataArray( + data["time"][[3, 2, 1]].values, dims=["a"], coords={"a": ["a", "b", "c"]} + ) + idx_2 = DataArray( + data["dim2"][[3, 2, 1]].values, dims=["a"], coords={"a": ["a", "b", "c"]} + ) + idx_3 = DataArray( + data["dim3"][[3, 2, 1]].values, dims=["a"], coords={"a": ["a", "b", "c"]} + ) + actual = data.sel(time=idx_t, dim2=idx_2, dim3=idx_3) + expected = data.isel( + time=Variable(("a",), [3, 2, 1]), + dim2=Variable(("a",), [3, 2, 1]), + dim3=Variable(("a",), [3, 2, 1]), + ) + expected = expected.assign_coords(a=idx_t["a"]) + assert_identical(expected, actual) + + idx_t = DataArray( + data["time"][[3, 2, 1]].values, dims=["a"], coords={"a": ["a", "b", "c"]} + ) + idx_2 = DataArray( + data["dim2"][[2, 1, 3]].values, dims=["b"], coords={"b": [0, 1, 2]} + ) + idx_3 = DataArray( + data["dim3"][[1, 2, 1]].values, dims=["c"], coords={"c": [0.0, 1.1, 2.2]} + ) + actual = data.sel(time=idx_t, dim2=idx_2, dim3=idx_3) + expected = data.isel( + time=Variable(("a",), [3, 2, 1]), + dim2=Variable(("b",), [2, 1, 3]), + dim3=Variable(("c",), [1, 2, 1]), + ) + expected = expected.assign_coords(a=idx_t["a"], b=idx_2["b"], c=idx_3["c"]) + assert_identical(expected, actual) + + # test from sel_points + data = Dataset({"foo": (("x", "y"), np.arange(9).reshape(3, 3))}) + data.coords.update({"x": [0, 1, 2], "y": [0, 1, 2]}) + + expected = Dataset( + {"foo": ("points", [0, 4, 8])}, + coords={ + "x": Variable(("points",), [0, 1, 2]), + "y": Variable(("points",), [0, 1, 2]), + }, + ) + actual = data.sel( + x=Variable(("points",), [0, 1, 2]), y=Variable(("points",), [0, 1, 2]) + ) + assert_identical(expected, actual) + + expected.coords.update({"x": ("points", [0, 1, 2]), "y": ("points", [0, 1, 2])}) + actual = data.sel( + x=Variable(("points",), [0.1, 1.1, 2.5]), + y=Variable(("points",), [0, 1.2, 2.0]), + method="pad", + ) + assert_identical(expected, actual) + + idx_x = DataArray([0, 1, 2], dims=["a"], coords={"a": ["a", "b", "c"]}) + idx_y = DataArray([0, 2, 1], dims=["b"], coords={"b": [0, 3, 6]}) + expected_ary = data["foo"][[0, 1, 2], [0, 2, 1]] + actual = data.sel(x=idx_x, y=idx_y) + assert_array_equal(expected_ary, actual["foo"]) + assert_identical(actual["a"].drop_vars("x"), idx_x["a"]) + assert_identical(actual["b"].drop_vars("y"), idx_y["b"]) + + with pytest.raises(KeyError): + data.sel(x=[2.5], y=[2.0], method="pad", tolerance=1e-3) + + def test_sel_method(self) -> None: + data = create_test_data() + + expected = data.sel(dim2=1) + actual = data.sel(dim2=0.95, method="nearest") + assert_identical(expected, actual) + + actual = data.sel(dim2=0.95, method="nearest", tolerance=1) + assert_identical(expected, actual) + + with pytest.raises(KeyError): + actual = data.sel(dim2=np.pi, method="nearest", tolerance=0) + + expected = data.sel(dim2=[1.5]) + actual = data.sel(dim2=[1.45], method="backfill") + assert_identical(expected, actual) + + with pytest.raises(NotImplementedError, match=r"slice objects"): + data.sel(dim2=slice(1, 3), method="ffill") + + with pytest.raises(TypeError, match=r"``method``"): + # this should not pass silently + data.sel(dim2=1, method=data) # type: ignore[arg-type] + + # cannot pass method if there is no associated coordinate + with pytest.raises(ValueError, match=r"cannot supply"): + data.sel(dim1=0, method="nearest") + + def test_loc(self) -> None: + data = create_test_data() + expected = data.sel(dim3="a") + actual = data.loc[dict(dim3="a")] + assert_identical(expected, actual) + with pytest.raises(TypeError, match=r"can only lookup dict"): + data.loc["a"] # type: ignore[index] + + def test_selection_multiindex(self) -> None: + midx = pd.MultiIndex.from_product( + [["a", "b"], [1, 2], [-1, -2]], names=("one", "two", "three") + ) + midx_coords = Coordinates.from_pandas_multiindex(midx, "x") + mdata = Dataset(data_vars={"var": ("x", range(8))}, coords=midx_coords) + + def test_sel( + lab_indexer, pos_indexer, replaced_idx=False, renamed_dim=None + ) -> None: + ds = mdata.sel(x=lab_indexer) + expected_ds = mdata.isel(x=pos_indexer) + if not replaced_idx: + assert_identical(ds, expected_ds) + else: + if renamed_dim: + assert ds["var"].dims[0] == renamed_dim + ds = ds.rename({renamed_dim: "x"}) + assert_identical(ds["var"].variable, expected_ds["var"].variable) + assert not ds["x"].equals(expected_ds["x"]) + + test_sel(("a", 1, -1), 0) + test_sel(("b", 2, -2), -1) + test_sel(("a", 1), [0, 1], replaced_idx=True, renamed_dim="three") + test_sel(("a",), range(4), replaced_idx=True) + test_sel("a", range(4), replaced_idx=True) + test_sel([("a", 1, -1), ("b", 2, -2)], [0, 7]) + test_sel(slice("a", "b"), range(8)) + test_sel(slice(("a", 1), ("b", 1)), range(6)) + test_sel({"one": "a", "two": 1, "three": -1}, 0) + test_sel({"one": "a", "two": 1}, [0, 1], replaced_idx=True, renamed_dim="three") + test_sel({"one": "a"}, range(4), replaced_idx=True) + + assert_identical(mdata.loc[{"x": {"one": "a"}}], mdata.sel(x={"one": "a"})) + assert_identical(mdata.loc[{"x": "a"}], mdata.sel(x="a")) + assert_identical(mdata.loc[{"x": ("a", 1)}], mdata.sel(x=("a", 1))) + assert_identical(mdata.loc[{"x": ("a", 1, -1)}], mdata.sel(x=("a", 1, -1))) + + assert_identical(mdata.sel(x={"one": "a", "two": 1}), mdata.sel(one="a", two=1)) + + def test_broadcast_like(self) -> None: + original1 = DataArray( + np.random.randn(5), [("x", range(5))], name="a" + ).to_dataset() + + original2 = DataArray(np.random.randn(6), [("y", range(6))], name="b") + + expected1, expected2 = broadcast(original1, original2) + + assert_identical( + original1.broadcast_like(original2), expected1.transpose("y", "x") + ) + + assert_identical(original2.broadcast_like(original1), expected2) + + def test_to_pandas(self) -> None: + # 0D -> series + actual = Dataset({"a": 1, "b": 2}).to_pandas() + expected = pd.Series([1, 2], ["a", "b"]) + assert_array_equal(actual, expected) + + # 1D -> dataframe + x = np.random.randn(10) + y = np.random.randn(10) + t = list("abcdefghij") + ds = Dataset({"a": ("t", x), "b": ("t", y), "t": ("t", t)}) + actual = ds.to_pandas() + expected = ds.to_dataframe() + assert expected.equals(actual), (expected, actual) + + # 2D -> error + x2d = np.random.randn(10, 10) + y2d = np.random.randn(10, 10) + with pytest.raises(ValueError, match=r"cannot convert Datasets"): + Dataset({"a": (["t", "r"], x2d), "b": (["t", "r"], y2d)}).to_pandas() + + def test_reindex_like(self) -> None: + data = create_test_data() + data["letters"] = ("dim3", 10 * ["a"]) + + expected = data.isel(dim1=slice(10), time=slice(13)) + actual = data.reindex_like(expected) + assert_identical(actual, expected) + + expected = data.copy(deep=True) + expected["dim3"] = ("dim3", list("cdefghijkl")) + expected["var3"][:-2] = expected["var3"][2:].values + expected["var3"][-2:] = np.nan + expected["letters"] = expected["letters"].astype(object) + expected["letters"][-2:] = np.nan + expected["numbers"] = expected["numbers"].astype(float) + expected["numbers"][:-2] = expected["numbers"][2:].values + expected["numbers"][-2:] = np.nan + actual = data.reindex_like(expected) + assert_identical(actual, expected) + + def test_reindex(self) -> None: + data = create_test_data() + assert_identical(data, data.reindex()) + + expected = data.assign_coords(dim1=data["dim1"]) + actual = data.reindex(dim1=data["dim1"]) + assert_identical(actual, expected) + + actual = data.reindex(dim1=data["dim1"].values) + assert_identical(actual, expected) + + actual = data.reindex(dim1=data["dim1"].to_index()) + assert_identical(actual, expected) + + with pytest.raises( + ValueError, match=r"cannot reindex or align along dimension" + ): + data.reindex(dim1=data["dim1"][:5]) + + expected = data.isel(dim2=slice(5)) + actual = data.reindex(dim2=data["dim2"][:5]) + assert_identical(actual, expected) + + # test dict-like argument + actual = data.reindex({"dim2": data["dim2"]}) + expected = data + assert_identical(actual, expected) + with pytest.raises(ValueError, match=r"cannot specify both"): + data.reindex({"x": 0}, x=0) + with pytest.raises(ValueError, match=r"dictionary"): + data.reindex("foo") # type: ignore[arg-type] + + # invalid dimension + # TODO: (benbovy - explicit indexes): uncomment? + # --> from reindex docstrings: "any mis-matched dimension is simply ignored" + # with pytest.raises(ValueError, match=r"indexer keys.*not correspond.*"): + # data.reindex(invalid=0) + + # out of order + expected = data.sel(dim2=data["dim2"][:5:-1]) + actual = data.reindex(dim2=data["dim2"][:5:-1]) + assert_identical(actual, expected) + + # multiple fill values + expected = data.reindex(dim2=[0.1, 2.1, 3.1, 4.1]).assign( + var1=lambda ds: ds.var1.copy(data=[[-10, -10, -10, -10]] * len(ds.dim1)), + var2=lambda ds: ds.var2.copy(data=[[-20, -20, -20, -20]] * len(ds.dim1)), + ) + actual = data.reindex( + dim2=[0.1, 2.1, 3.1, 4.1], fill_value={"var1": -10, "var2": -20} + ) + assert_identical(actual, expected) + # use the default value + expected = data.reindex(dim2=[0.1, 2.1, 3.1, 4.1]).assign( + var1=lambda ds: ds.var1.copy(data=[[-10, -10, -10, -10]] * len(ds.dim1)), + var2=lambda ds: ds.var2.copy( + data=[[np.nan, np.nan, np.nan, np.nan]] * len(ds.dim1) + ), + ) + actual = data.reindex(dim2=[0.1, 2.1, 3.1, 4.1], fill_value={"var1": -10}) + assert_identical(actual, expected) + + # regression test for #279 + expected = Dataset({"x": ("time", np.random.randn(5))}, {"time": range(5)}) + time2 = DataArray(np.arange(5), dims="time2") + with pytest.raises(ValueError): + actual = expected.reindex(time=time2) + + # another regression test + ds = Dataset( + {"foo": (["x", "y"], np.zeros((3, 4)))}, {"x": range(3), "y": range(4)} + ) + expected = Dataset( + {"foo": (["x", "y"], np.zeros((3, 2)))}, {"x": [0, 1, 3], "y": [0, 1]} + ) + expected["foo"][-1] = np.nan + actual = ds.reindex(x=[0, 1, 3], y=[0, 1]) + assert_identical(expected, actual) + + def test_reindex_attrs_encoding(self) -> None: + ds = Dataset( + {"data": ("x", [1, 2, 3])}, + {"x": ("x", [0, 1, 2], {"foo": "bar"}, {"bar": "baz"})}, + ) + actual = ds.reindex(x=[0, 1]) + expected = Dataset( + {"data": ("x", [1, 2])}, + {"x": ("x", [0, 1], {"foo": "bar"}, {"bar": "baz"})}, + ) + assert_identical(actual, expected) + assert actual.x.encoding == expected.x.encoding + + def test_reindex_warning(self) -> None: + data = create_test_data() + + with pytest.raises(ValueError): + # DataArray with different dimension raises Future warning + ind = xr.DataArray([0.0, 1.0], dims=["new_dim"], name="ind") + data.reindex(dim2=ind) + + # Should not warn + ind = xr.DataArray([0.0, 1.0], dims=["dim2"], name="ind") + with warnings.catch_warnings(record=True) as ws: + data.reindex(dim2=ind) + assert len(ws) == 0 + + def test_reindex_variables_copied(self) -> None: + data = create_test_data() + reindexed_data = data.reindex(copy=False) + for k in data.variables: + assert reindexed_data.variables[k] is not data.variables[k] + + def test_reindex_method(self) -> None: + ds = Dataset({"x": ("y", [10, 20]), "y": [0, 1]}) + y = [-0.5, 0.5, 1.5] + actual = ds.reindex(y=y, method="backfill") + expected = Dataset({"x": ("y", [10, 20, np.nan]), "y": y}) + assert_identical(expected, actual) + + actual = ds.reindex(y=y, method="backfill", tolerance=0.1) + expected = Dataset({"x": ("y", 3 * [np.nan]), "y": y}) + assert_identical(expected, actual) + + actual = ds.reindex(y=y, method="backfill", tolerance=[0.1, 0.5, 0.1]) + expected = Dataset({"x": ("y", [np.nan, 20, np.nan]), "y": y}) + assert_identical(expected, actual) + + actual = ds.reindex(y=[0.1, 0.1, 1], tolerance=[0, 0.1, 0], method="nearest") + expected = Dataset({"x": ("y", [np.nan, 10, 20]), "y": [0.1, 0.1, 1]}) + assert_identical(expected, actual) + + actual = ds.reindex(y=y, method="pad") + expected = Dataset({"x": ("y", [np.nan, 10, 20]), "y": y}) + assert_identical(expected, actual) + + alt = Dataset({"y": y}) + actual = ds.reindex_like(alt, method="pad") + assert_identical(expected, actual) + + @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0, {"x": 2, "z": 1}]) + def test_reindex_fill_value(self, fill_value) -> None: + ds = Dataset({"x": ("y", [10, 20]), "z": ("y", [-20, -10]), "y": [0, 1]}) + y = [0, 1, 2] + actual = ds.reindex(y=y, fill_value=fill_value) + if fill_value == dtypes.NA: + # if we supply the default, we expect the missing value for a + # float array + fill_value_x = fill_value_z = np.nan + elif isinstance(fill_value, dict): + fill_value_x = fill_value["x"] + fill_value_z = fill_value["z"] + else: + fill_value_x = fill_value_z = fill_value + expected = Dataset( + { + "x": ("y", [10, 20, fill_value_x]), + "z": ("y", [-20, -10, fill_value_z]), + "y": y, + } + ) + assert_identical(expected, actual) + + @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0, {"x": 2, "z": 1}]) + def test_reindex_like_fill_value(self, fill_value) -> None: + ds = Dataset({"x": ("y", [10, 20]), "z": ("y", [-20, -10]), "y": [0, 1]}) + y = [0, 1, 2] + alt = Dataset({"y": y}) + actual = ds.reindex_like(alt, fill_value=fill_value) + if fill_value == dtypes.NA: + # if we supply the default, we expect the missing value for a + # float array + fill_value_x = fill_value_z = np.nan + elif isinstance(fill_value, dict): + fill_value_x = fill_value["x"] + fill_value_z = fill_value["z"] + else: + fill_value_x = fill_value_z = fill_value + expected = Dataset( + { + "x": ("y", [10, 20, fill_value_x]), + "z": ("y", [-20, -10, fill_value_z]), + "y": y, + } + ) + assert_identical(expected, actual) + + @pytest.mark.parametrize("dtype", [str, bytes]) + def test_reindex_str_dtype(self, dtype) -> None: + data = Dataset({"data": ("x", [1, 2]), "x": np.array(["a", "b"], dtype=dtype)}) + + actual = data.reindex(x=data.x) + expected = data + + assert_identical(expected, actual) + assert actual.x.dtype == expected.x.dtype + + @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0, {"foo": 2, "bar": 1}]) + def test_align_fill_value(self, fill_value) -> None: + x = Dataset({"foo": DataArray([1, 2], dims=["x"], coords={"x": [1, 2]})}) + y = Dataset({"bar": DataArray([1, 2], dims=["x"], coords={"x": [1, 3]})}) + x2, y2 = align(x, y, join="outer", fill_value=fill_value) + if fill_value == dtypes.NA: + # if we supply the default, we expect the missing value for a + # float array + fill_value_foo = fill_value_bar = np.nan + elif isinstance(fill_value, dict): + fill_value_foo = fill_value["foo"] + fill_value_bar = fill_value["bar"] + else: + fill_value_foo = fill_value_bar = fill_value + + expected_x2 = Dataset( + { + "foo": DataArray( + [1, 2, fill_value_foo], dims=["x"], coords={"x": [1, 2, 3]} + ) + } + ) + expected_y2 = Dataset( + { + "bar": DataArray( + [1, fill_value_bar, 2], dims=["x"], coords={"x": [1, 2, 3]} + ) + } + ) + assert_identical(expected_x2, x2) + assert_identical(expected_y2, y2) + + def test_align(self) -> None: + left = create_test_data() + right = left.copy(deep=True) + right["dim3"] = ("dim3", list("cdefghijkl")) + right["var3"][:-2] = right["var3"][2:].values + right["var3"][-2:] = np.random.randn(*right["var3"][-2:].shape) + right["numbers"][:-2] = right["numbers"][2:].values + right["numbers"][-2:] = -10 + + intersection = list("cdefghij") + union = list("abcdefghijkl") + + left2, right2 = align(left, right, join="inner") + assert_array_equal(left2["dim3"], intersection) + assert_identical(left2, right2) + + left2, right2 = align(left, right, join="outer") + + assert_array_equal(left2["dim3"], union) + assert_equal(left2["dim3"].variable, right2["dim3"].variable) + + assert_identical(left2.sel(dim3=intersection), right2.sel(dim3=intersection)) + assert np.isnan(left2["var3"][-2:]).all() + assert np.isnan(right2["var3"][:2]).all() + + left2, right2 = align(left, right, join="left") + assert_equal(left2["dim3"].variable, right2["dim3"].variable) + assert_equal(left2["dim3"].variable, left["dim3"].variable) + + assert_identical(left2.sel(dim3=intersection), right2.sel(dim3=intersection)) + assert np.isnan(right2["var3"][:2]).all() + + left2, right2 = align(left, right, join="right") + assert_equal(left2["dim3"].variable, right2["dim3"].variable) + assert_equal(left2["dim3"].variable, right["dim3"].variable) + + assert_identical(left2.sel(dim3=intersection), right2.sel(dim3=intersection)) + + assert np.isnan(left2["var3"][-2:]).all() + + with pytest.raises(ValueError, match=r"invalid value for join"): + align(left, right, join="foobar") # type: ignore[call-overload] + with pytest.raises(TypeError): + align(left, right, foo="bar") # type: ignore[call-overload] + + def test_align_exact(self) -> None: + left = xr.Dataset(coords={"x": [0, 1]}) + right = xr.Dataset(coords={"x": [1, 2]}) + + left1, left2 = xr.align(left, left, join="exact") + assert_identical(left1, left) + assert_identical(left2, left) + + with pytest.raises(ValueError, match=r"cannot align.*join.*exact.*not equal.*"): + xr.align(left, right, join="exact") + + def test_align_override(self) -> None: + left = xr.Dataset(coords={"x": [0, 1, 2]}) + right = xr.Dataset(coords={"x": [0.1, 1.1, 2.1], "y": [1, 2, 3]}) + expected_right = xr.Dataset(coords={"x": [0, 1, 2], "y": [1, 2, 3]}) + + new_left, new_right = xr.align(left, right, join="override") + assert_identical(left, new_left) + assert_identical(new_right, expected_right) + + new_left, new_right = xr.align(left, right, exclude="x", join="override") + assert_identical(left, new_left) + assert_identical(right, new_right) + + new_left, new_right = xr.align( + left.isel(x=0, drop=True), right, exclude="x", join="override" + ) + assert_identical(left.isel(x=0, drop=True), new_left) + assert_identical(right, new_right) + + with pytest.raises( + ValueError, match=r"cannot align.*join.*override.*same size" + ): + xr.align(left.isel(x=0).expand_dims("x"), right, join="override") + + def test_align_exclude(self) -> None: + x = Dataset( + { + "foo": DataArray( + [[1, 2], [3, 4]], dims=["x", "y"], coords={"x": [1, 2], "y": [3, 4]} + ) + } + ) + y = Dataset( + { + "bar": DataArray( + [[1, 2], [3, 4]], dims=["x", "y"], coords={"x": [1, 3], "y": [5, 6]} + ) + } + ) + x2, y2 = align(x, y, exclude=["y"], join="outer") + + expected_x2 = Dataset( + { + "foo": DataArray( + [[1, 2], [3, 4], [np.nan, np.nan]], + dims=["x", "y"], + coords={"x": [1, 2, 3], "y": [3, 4]}, + ) + } + ) + expected_y2 = Dataset( + { + "bar": DataArray( + [[1, 2], [np.nan, np.nan], [3, 4]], + dims=["x", "y"], + coords={"x": [1, 2, 3], "y": [5, 6]}, + ) + } + ) + assert_identical(expected_x2, x2) + assert_identical(expected_y2, y2) + + def test_align_nocopy(self) -> None: + x = Dataset({"foo": DataArray([1, 2, 3], coords=[("x", [1, 2, 3])])}) + y = Dataset({"foo": DataArray([1, 2], coords=[("x", [1, 2])])}) + expected_x2 = x + expected_y2 = Dataset( + {"foo": DataArray([1, 2, np.nan], coords=[("x", [1, 2, 3])])} + ) + + x2, y2 = align(x, y, copy=False, join="outer") + assert_identical(expected_x2, x2) + assert_identical(expected_y2, y2) + assert source_ndarray(x["foo"].data) is source_ndarray(x2["foo"].data) + + x2, y2 = align(x, y, copy=True, join="outer") + assert source_ndarray(x["foo"].data) is not source_ndarray(x2["foo"].data) + assert_identical(expected_x2, x2) + assert_identical(expected_y2, y2) + + def test_align_indexes(self) -> None: + x = Dataset({"foo": DataArray([1, 2, 3], dims="x", coords=[("x", [1, 2, 3])])}) + (x2,) = align(x, indexes={"x": [2, 3, 1]}) + expected_x2 = Dataset( + {"foo": DataArray([2, 3, 1], dims="x", coords={"x": [2, 3, 1]})} + ) + + assert_identical(expected_x2, x2) + + def test_align_non_unique(self) -> None: + x = Dataset({"foo": ("x", [3, 4, 5]), "x": [0, 0, 1]}) + x1, x2 = align(x, x) + assert_identical(x1, x) + assert_identical(x2, x) + + y = Dataset({"bar": ("x", [6, 7]), "x": [0, 1]}) + with pytest.raises(ValueError, match=r"cannot reindex or align"): + align(x, y) + + def test_align_str_dtype(self) -> None: + a = Dataset({"foo": ("x", [0, 1])}, coords={"x": ["a", "b"]}) + b = Dataset({"foo": ("x", [1, 2])}, coords={"x": ["b", "c"]}) + + expected_a = Dataset( + {"foo": ("x", [0, 1, np.nan])}, coords={"x": ["a", "b", "c"]} + ) + expected_b = Dataset( + {"foo": ("x", [np.nan, 1, 2])}, coords={"x": ["a", "b", "c"]} + ) + + actual_a, actual_b = xr.align(a, b, join="outer") + + assert_identical(expected_a, actual_a) + assert expected_a.x.dtype == actual_a.x.dtype + + assert_identical(expected_b, actual_b) + assert expected_b.x.dtype == actual_b.x.dtype + + @pytest.mark.parametrize("join", ["left", "override"]) + def test_align_index_var_attrs(self, join) -> None: + # regression test https://github.com/pydata/xarray/issues/6852 + # aligning two objects should have no side effect on their index variable + # metadata. + + ds = Dataset(coords={"x": ("x", [1, 2, 3], {"units": "m"})}) + ds_noattr = Dataset(coords={"x": ("x", [1, 2, 3])}) + + xr.align(ds_noattr, ds, join=join) + + assert ds.x.attrs == {"units": "m"} + assert ds_noattr.x.attrs == {} + + def test_broadcast(self) -> None: + ds = Dataset( + {"foo": 0, "bar": ("x", [1]), "baz": ("y", [2, 3])}, {"c": ("x", [4])} + ) + expected = Dataset( + { + "foo": (("x", "y"), [[0, 0]]), + "bar": (("x", "y"), [[1, 1]]), + "baz": (("x", "y"), [[2, 3]]), + }, + {"c": ("x", [4])}, + ) + (actual,) = broadcast(ds) + assert_identical(expected, actual) + + ds_x = Dataset({"foo": ("x", [1])}) + ds_y = Dataset({"bar": ("y", [2, 3])}) + expected_x = Dataset({"foo": (("x", "y"), [[1, 1]])}) + expected_y = Dataset({"bar": (("x", "y"), [[2, 3]])}) + actual_x, actual_y = broadcast(ds_x, ds_y) + assert_identical(expected_x, actual_x) + assert_identical(expected_y, actual_y) + + array_y = ds_y["bar"] + expected_y2 = expected_y["bar"] + actual_x2, actual_y2 = broadcast(ds_x, array_y) + assert_identical(expected_x, actual_x2) + assert_identical(expected_y2, actual_y2) + + def test_broadcast_nocopy(self) -> None: + # Test that data is not copied if not needed + x = Dataset({"foo": (("x", "y"), [[1, 1]])}) + y = Dataset({"bar": ("y", [2, 3])}) + + (actual_x,) = broadcast(x) + assert_identical(x, actual_x) + assert source_ndarray(actual_x["foo"].data) is source_ndarray(x["foo"].data) + + actual_x, actual_y = broadcast(x, y) + assert_identical(x, actual_x) + assert source_ndarray(actual_x["foo"].data) is source_ndarray(x["foo"].data) + + def test_broadcast_exclude(self) -> None: + x = Dataset( + { + "foo": DataArray( + [[1, 2], [3, 4]], dims=["x", "y"], coords={"x": [1, 2], "y": [3, 4]} + ), + "bar": DataArray(5), + } + ) + y = Dataset( + { + "foo": DataArray( + [[1, 2]], dims=["z", "y"], coords={"z": [1], "y": [5, 6]} + ) + } + ) + x2, y2 = broadcast(x, y, exclude=["y"]) + + expected_x2 = Dataset( + { + "foo": DataArray( + [[[1, 2]], [[3, 4]]], + dims=["x", "z", "y"], + coords={"z": [1], "x": [1, 2], "y": [3, 4]}, + ), + "bar": DataArray( + [[5], [5]], dims=["x", "z"], coords={"x": [1, 2], "z": [1]} + ), + } + ) + expected_y2 = Dataset( + { + "foo": DataArray( + [[[1, 2]], [[1, 2]]], + dims=["x", "z", "y"], + coords={"z": [1], "x": [1, 2], "y": [5, 6]}, + ) + } + ) + assert_identical(expected_x2, x2) + assert_identical(expected_y2, y2) + + def test_broadcast_misaligned(self) -> None: + x = Dataset({"foo": DataArray([1, 2, 3], coords=[("x", [-1, -2, -3])])}) + y = Dataset( + { + "bar": DataArray( + [[1, 2], [3, 4]], + dims=["y", "x"], + coords={"y": [1, 2], "x": [10, -3]}, + ) + } + ) + x2, y2 = broadcast(x, y) + expected_x2 = Dataset( + { + "foo": DataArray( + [[3, 3], [2, 2], [1, 1], [np.nan, np.nan]], + dims=["x", "y"], + coords={"y": [1, 2], "x": [-3, -2, -1, 10]}, + ) + } + ) + expected_y2 = Dataset( + { + "bar": DataArray( + [[2, 4], [np.nan, np.nan], [np.nan, np.nan], [1, 3]], + dims=["x", "y"], + coords={"y": [1, 2], "x": [-3, -2, -1, 10]}, + ) + } + ) + assert_identical(expected_x2, x2) + assert_identical(expected_y2, y2) + + def test_broadcast_multi_index(self) -> None: + # GH6430 + ds = Dataset( + {"foo": (("x", "y", "z"), np.ones((3, 4, 2)))}, + {"x": ["a", "b", "c"], "y": [1, 2, 3, 4]}, + ) + stacked = ds.stack(space=["x", "y"]) + broadcasted, _ = broadcast(stacked, stacked.space) + + assert broadcasted.xindexes["x"] is broadcasted.xindexes["space"] + assert broadcasted.xindexes["y"] is broadcasted.xindexes["space"] + + def test_variable_indexing(self) -> None: + data = create_test_data() + v = data["var1"] + d1 = data["dim1"] + d2 = data["dim2"] + assert_equal(v, v[d1.values]) + assert_equal(v, v[d1]) + assert_equal(v[:3], v[d1 < 3]) + assert_equal(v[:, 3:], v[:, d2 >= 1.5]) + assert_equal(v[:3, 3:], v[d1 < 3, d2 >= 1.5]) + assert_equal(v[:3, :2], v[range(3), range(2)]) + assert_equal(v[:3, :2], v.loc[d1[:3], d2[:2]]) + + def test_drop_variables(self) -> None: + data = create_test_data() + + assert_identical(data, data.drop_vars([])) + + expected = Dataset({k: data[k] for k in data.variables if k != "time"}) + actual = data.drop_vars("time") + assert_identical(expected, actual) + actual = data.drop_vars(["time"]) + assert_identical(expected, actual) + + with pytest.raises( + ValueError, + match=re.escape( + "These variables cannot be found in this dataset: ['not_found_here']" + ), + ): + data.drop_vars("not_found_here") + + actual = data.drop_vars("not_found_here", errors="ignore") + assert_identical(data, actual) + + actual = data.drop_vars(["not_found_here"], errors="ignore") + assert_identical(data, actual) + + actual = data.drop_vars(["time", "not_found_here"], errors="ignore") + assert_identical(expected, actual) + + # deprecated approach with `drop` works (straight copy paste from above) + + with pytest.warns(DeprecationWarning): + actual = data.drop("not_found_here", errors="ignore") + assert_identical(data, actual) + + with pytest.warns(DeprecationWarning): + actual = data.drop(["not_found_here"], errors="ignore") + assert_identical(data, actual) + + with pytest.warns(DeprecationWarning): + actual = data.drop(["time", "not_found_here"], errors="ignore") + assert_identical(expected, actual) + + with pytest.warns(DeprecationWarning): + actual = data.drop({"time", "not_found_here"}, errors="ignore") + assert_identical(expected, actual) + + def test_drop_multiindex_level(self) -> None: + data = create_test_multiindex() + expected = data.drop_vars(["x", "level_1", "level_2"]) + with pytest.warns(DeprecationWarning): + actual = data.drop_vars("level_1") + assert_identical(expected, actual) + + def test_drop_index_labels(self) -> None: + data = Dataset({"A": (["x", "y"], np.random.randn(2, 3)), "x": ["a", "b"]}) + + with pytest.warns(DeprecationWarning): + actual = data.drop(["a"], dim="x") + expected = data.isel(x=[1]) + assert_identical(expected, actual) + + with pytest.warns(DeprecationWarning): + actual = data.drop(["a", "b"], dim="x") + expected = data.isel(x=slice(0, 0)) + assert_identical(expected, actual) + + with pytest.raises(KeyError): + # not contained in axis + with pytest.warns(DeprecationWarning): + data.drop(["c"], dim="x") + + with pytest.warns(DeprecationWarning): + actual = data.drop(["c"], dim="x", errors="ignore") + assert_identical(data, actual) + + with pytest.raises(ValueError): + data.drop(["c"], dim="x", errors="wrong_value") # type: ignore[arg-type] + + with pytest.warns(DeprecationWarning): + actual = data.drop(["a", "b", "c"], "x", errors="ignore") + expected = data.isel(x=slice(0, 0)) + assert_identical(expected, actual) + + # DataArrays as labels are a nasty corner case as they are not + # Iterable[Hashable] - DataArray.__iter__ yields scalar DataArrays. + actual = data.drop_sel(x=DataArray(["a", "b", "c"]), errors="ignore") + expected = data.isel(x=slice(0, 0)) + assert_identical(expected, actual) + with pytest.warns(DeprecationWarning): + data.drop(DataArray(["a", "b", "c"]), dim="x", errors="ignore") + assert_identical(expected, actual) + + actual = data.drop_sel(y=[1]) + expected = data.isel(y=[0, 2]) + assert_identical(expected, actual) + + with pytest.raises(KeyError, match=r"not found in axis"): + data.drop_sel(x=0) + + def test_drop_labels_by_keyword(self) -> None: + data = Dataset( + {"A": (["x", "y"], np.random.randn(2, 6)), "x": ["a", "b"], "y": range(6)} + ) + # Basic functionality. + assert len(data.coords["x"]) == 2 + + with pytest.warns(DeprecationWarning): + ds1 = data.drop(["a"], dim="x") + ds2 = data.drop_sel(x="a") + ds3 = data.drop_sel(x=["a"]) + ds4 = data.drop_sel(x=["a", "b"]) + ds5 = data.drop_sel(x=["a", "b"], y=range(0, 6, 2)) + + arr = DataArray(range(3), dims=["c"]) + with pytest.warns(DeprecationWarning): + data.drop(arr.coords) + with pytest.warns(DeprecationWarning): + data.drop(arr.xindexes) + + assert_array_equal(ds1.coords["x"], ["b"]) + assert_array_equal(ds2.coords["x"], ["b"]) + assert_array_equal(ds3.coords["x"], ["b"]) + assert ds4.coords["x"].size == 0 + assert ds5.coords["x"].size == 0 + assert_array_equal(ds5.coords["y"], [1, 3, 5]) + + # Error handling if user tries both approaches. + with pytest.raises(ValueError): + data.drop(labels=["a"], x="a") + with pytest.raises(ValueError): + data.drop(labels=["a"], dim="x", x="a") + warnings.filterwarnings("ignore", r"\W*drop") + with pytest.raises(ValueError): + data.drop(dim="x", x="a") + + def test_drop_labels_by_position(self) -> None: + data = Dataset( + {"A": (["x", "y"], np.random.randn(2, 6)), "x": ["a", "b"], "y": range(6)} + ) + # Basic functionality. + assert len(data.coords["x"]) == 2 + + actual = data.drop_isel(x=0) + expected = data.drop_sel(x="a") + assert_identical(expected, actual) + + actual = data.drop_isel(x=[0]) + expected = data.drop_sel(x=["a"]) + assert_identical(expected, actual) + + actual = data.drop_isel(x=[0, 1]) + expected = data.drop_sel(x=["a", "b"]) + assert_identical(expected, actual) + assert actual.coords["x"].size == 0 + + actual = data.drop_isel(x=[0, 1], y=range(0, 6, 2)) + expected = data.drop_sel(x=["a", "b"], y=range(0, 6, 2)) + assert_identical(expected, actual) + assert actual.coords["x"].size == 0 + + with pytest.raises(KeyError): + data.drop_isel(z=1) + + def test_drop_indexes(self) -> None: + ds = Dataset( + coords={ + "x": ("x", [0, 1, 2]), + "y": ("y", [3, 4, 5]), + "foo": ("x", ["a", "a", "b"]), + } + ) + + actual = ds.drop_indexes("x") + assert "x" not in actual.xindexes + assert type(actual.x.variable) is Variable + + actual = ds.drop_indexes(["x", "y"]) + assert "x" not in actual.xindexes + assert "y" not in actual.xindexes + assert type(actual.x.variable) is Variable + assert type(actual.y.variable) is Variable + + with pytest.raises( + ValueError, + match=r"The coordinates \('not_a_coord',\) are not found in the dataset coordinates", + ): + ds.drop_indexes("not_a_coord") + + with pytest.raises(ValueError, match="those coordinates do not have an index"): + ds.drop_indexes("foo") + + actual = ds.drop_indexes(["foo", "not_a_coord"], errors="ignore") + assert_identical(actual, ds) + + # test index corrupted + midx = pd.MultiIndex.from_tuples([([1, 2]), ([3, 4])], names=["a", "b"]) + midx_coords = Coordinates.from_pandas_multiindex(midx, "x") + ds = Dataset(coords=midx_coords) + + with pytest.raises(ValueError, match=".*would corrupt the following index.*"): + ds.drop_indexes("a") + + def test_drop_dims(self) -> None: + data = xr.Dataset( + { + "A": (["x", "y"], np.random.randn(2, 3)), + "B": ("x", np.random.randn(2)), + "x": ["a", "b"], + "z": np.pi, + } + ) + + actual = data.drop_dims("x") + expected = data.drop_vars(["A", "B", "x"]) + assert_identical(expected, actual) + + actual = data.drop_dims("y") + expected = data.drop_vars("A") + assert_identical(expected, actual) + + actual = data.drop_dims(["x", "y"]) + expected = data.drop_vars(["A", "B", "x"]) + assert_identical(expected, actual) + + with pytest.raises((ValueError, KeyError)): + data.drop_dims("z") # not a dimension + + with pytest.raises((ValueError, KeyError)): + data.drop_dims(None) # type:ignore[arg-type] + + actual = data.drop_dims("z", errors="ignore") + assert_identical(data, actual) + + # should this be allowed? + actual = data.drop_dims(None, errors="ignore") # type:ignore[arg-type] + assert_identical(data, actual) + + with pytest.raises(ValueError): + actual = data.drop_dims("z", errors="wrong_value") # type: ignore[arg-type] + + actual = data.drop_dims(["x", "y", "z"], errors="ignore") + expected = data.drop_vars(["A", "B", "x"]) + assert_identical(expected, actual) + + def test_copy(self) -> None: + data = create_test_data() + data.attrs["Test"] = [1, 2, 3] + + for copied in [data.copy(deep=False), copy(data)]: + assert_identical(data, copied) + assert data.encoding == copied.encoding + # Note: IndexVariable objects with string dtype are always + # copied because of xarray.core.indexes.safe_cast_to_index. + # Limiting the test to data variables. + for k in data.data_vars: + v0 = data.variables[k] + v1 = copied.variables[k] + assert source_ndarray(v0.data) is source_ndarray(v1.data) + copied["foo"] = ("z", np.arange(5)) + assert "foo" not in data + + copied.attrs["foo"] = "bar" + assert "foo" not in data.attrs + assert data.attrs["Test"] is copied.attrs["Test"] + + for copied in [data.copy(deep=True), deepcopy(data)]: + assert_identical(data, copied) + for k, v0 in data.variables.items(): + v1 = copied.variables[k] + assert v0 is not v1 + + assert data.attrs["Test"] is not copied.attrs["Test"] + + def test_copy_with_data(self) -> None: + orig = create_test_data() + new_data = {k: np.random.randn(*v.shape) for k, v in orig.data_vars.items()} + actual = orig.copy(data=new_data) + + expected = orig.copy() + for k, v in new_data.items(): + expected[k].data = v + assert_identical(expected, actual) + + @pytest.mark.xfail(raises=AssertionError) + @pytest.mark.parametrize( + "deep, expected_orig", + [ + [ + True, + xr.DataArray( + xr.IndexVariable("a", np.array([1, 2])), + coords={"a": [1, 2]}, + dims=["a"], + ), + ], + [ + False, + xr.DataArray( + xr.IndexVariable("a", np.array([999, 2])), + coords={"a": [999, 2]}, + dims=["a"], + ), + ], + ], + ) + def test_copy_coords(self, deep, expected_orig) -> None: + """The test fails for the shallow copy, and apparently only on Windows + for some reason. In windows coords seem to be immutable unless it's one + dataset deep copied from another.""" + ds = xr.DataArray( + np.ones([2, 2, 2]), + coords={"a": [1, 2], "b": ["x", "y"], "c": [0, 1]}, + dims=["a", "b", "c"], + name="value", + ).to_dataset() + ds_cp = ds.copy(deep=deep) + new_a = np.array([999, 2]) + ds_cp.coords["a"] = ds_cp.a.copy(data=new_a) + + expected_cp = xr.DataArray( + xr.IndexVariable("a", new_a), + coords={"a": [999, 2]}, + dims=["a"], + ) + assert_identical(ds_cp.coords["a"], expected_cp) + + assert_identical(ds.coords["a"], expected_orig) + + def test_copy_with_data_errors(self) -> None: + orig = create_test_data() + new_var1 = np.arange(orig["var1"].size).reshape(orig["var1"].shape) + with pytest.raises(ValueError, match=r"Data must be dict-like"): + orig.copy(data=new_var1) # type: ignore[arg-type] + with pytest.raises(ValueError, match=r"only contain variables in original"): + orig.copy(data={"not_in_original": new_var1}) + with pytest.raises(ValueError, match=r"contain all variables in original"): + orig.copy(data={"var1": new_var1}) + + def test_drop_encoding(self) -> None: + orig = create_test_data() + vencoding = {"scale_factor": 10} + orig.encoding = {"foo": "bar"} + + for k, v in orig.variables.items(): + orig[k].encoding = vencoding + + actual = orig.drop_encoding() + assert actual.encoding == {} + for k, v in actual.variables.items(): + assert v.encoding == {} + + assert_equal(actual, orig) + + def test_rename(self) -> None: + data = create_test_data() + newnames = { + "var1": "renamed_var1", + "dim2": "renamed_dim2", + } + renamed = data.rename(newnames) + + variables = dict(data.variables) + for nk, nv in newnames.items(): + variables[nv] = variables.pop(nk) + + for k, v in variables.items(): + dims = list(v.dims) + for name, newname in newnames.items(): + if name in dims: + dims[dims.index(name)] = newname + + assert_equal( + Variable(dims, v.values, v.attrs), + renamed[k].variable.to_base_variable(), + ) + assert v.encoding == renamed[k].encoding + assert type(v) is type(renamed.variables[k]) # noqa: E721 + + assert "var1" not in renamed + assert "dim2" not in renamed + + with pytest.raises(ValueError, match=r"cannot rename 'not_a_var'"): + data.rename({"not_a_var": "nada"}) + + with pytest.raises(ValueError, match=r"'var1' conflicts"): + data.rename({"var2": "var1"}) + + # verify that we can rename a variable without accessing the data + var1 = data["var1"] + data["var1"] = (var1.dims, InaccessibleArray(var1.values)) + renamed = data.rename(newnames) + with pytest.raises(UnexpectedDataAccess): + renamed["renamed_var1"].values + + # https://github.com/python/mypy/issues/10008 + renamed_kwargs = data.rename(**newnames) # type: ignore[arg-type] + assert_identical(renamed, renamed_kwargs) + + def test_rename_old_name(self) -> None: + # regtest for GH1477 + data = create_test_data() + + with pytest.raises(ValueError, match=r"'samecol' conflicts"): + data.rename({"var1": "samecol", "var2": "samecol"}) + + # This shouldn't cause any problems. + data.rename({"var1": "var2", "var2": "var1"}) + + def test_rename_same_name(self) -> None: + data = create_test_data() + newnames = {"var1": "var1", "dim2": "dim2"} + renamed = data.rename(newnames) + assert_identical(renamed, data) + + def test_rename_dims(self) -> None: + original = Dataset({"x": ("x", [0, 1, 2]), "y": ("x", [10, 11, 12]), "z": 42}) + expected = Dataset( + {"x": ("x_new", [0, 1, 2]), "y": ("x_new", [10, 11, 12]), "z": 42} + ) + # TODO: (benbovy - explicit indexes) update when set_index supports + # setting index for non-dimension variables + expected = expected.set_coords("x") + actual = original.rename_dims({"x": "x_new"}) + assert_identical(expected, actual, check_default_indexes=False) + actual_2 = original.rename_dims(x="x_new") + assert_identical(expected, actual_2, check_default_indexes=False) + + # Test to raise ValueError + dims_dict_bad = {"x_bad": "x_new"} + with pytest.raises(ValueError): + original.rename_dims(dims_dict_bad) + + with pytest.raises(ValueError): + original.rename_dims({"x": "z"}) + + def test_rename_vars(self) -> None: + original = Dataset({"x": ("x", [0, 1, 2]), "y": ("x", [10, 11, 12]), "z": 42}) + expected = Dataset( + {"x_new": ("x", [0, 1, 2]), "y": ("x", [10, 11, 12]), "z": 42} + ) + # TODO: (benbovy - explicit indexes) update when set_index supports + # setting index for non-dimension variables + expected = expected.set_coords("x_new") + actual = original.rename_vars({"x": "x_new"}) + assert_identical(expected, actual, check_default_indexes=False) + actual_2 = original.rename_vars(x="x_new") + assert_identical(expected, actual_2, check_default_indexes=False) + + # Test to raise ValueError + names_dict_bad = {"x_bad": "x_new"} + with pytest.raises(ValueError): + original.rename_vars(names_dict_bad) + + def test_rename_dimension_coord(self) -> None: + # rename a dimension corodinate to a non-dimension coordinate + # should preserve index + original = Dataset(coords={"x": ("x", [0, 1, 2])}) + + actual = original.rename_vars({"x": "x_new"}) + assert "x_new" in actual.xindexes + + actual_2 = original.rename_dims({"x": "x_new"}) + assert "x" in actual_2.xindexes + + def test_rename_dimension_coord_warnings(self) -> None: + # create a dimension coordinate by renaming a dimension or coordinate + # should raise a warning (no index created) + ds = Dataset(coords={"x": ("y", [0, 1])}) + + with pytest.warns( + UserWarning, match="rename 'x' to 'y' does not create an index.*" + ): + ds.rename(x="y") + + ds = Dataset(coords={"y": ("x", [0, 1])}) + + with pytest.warns( + UserWarning, match="rename 'x' to 'y' does not create an index.*" + ): + ds.rename(x="y") + + # No operation should not raise a warning + ds = Dataset( + data_vars={"data": (("x", "y"), np.ones((2, 3)))}, + coords={"x": range(2), "y": range(3), "a": ("x", [3, 4])}, + ) + with warnings.catch_warnings(): + warnings.simplefilter("error") + ds.rename(x="x") + + def test_rename_multiindex(self) -> None: + midx = pd.MultiIndex.from_tuples([([1, 2]), ([3, 4])], names=["a", "b"]) + midx_coords = Coordinates.from_pandas_multiindex(midx, "x") + original = Dataset({}, midx_coords) + + midx_renamed = midx.rename(["a", "c"]) + midx_coords_renamed = Coordinates.from_pandas_multiindex(midx_renamed, "x") + expected = Dataset({}, midx_coords_renamed) + + actual = original.rename({"b": "c"}) + assert_identical(expected, actual) + + with pytest.raises(ValueError, match=r"'a' conflicts"): + with pytest.warns(UserWarning, match="does not create an index anymore"): + original.rename({"x": "a"}) + + with pytest.raises(ValueError, match=r"'x' conflicts"): + with pytest.warns(UserWarning, match="does not create an index anymore"): + original.rename({"a": "x"}) + + with pytest.raises(ValueError, match=r"'b' conflicts"): + original.rename({"a": "b"}) + + def test_rename_perserve_attrs_encoding(self) -> None: + # test propagate attrs/encoding to new variable(s) created from Index object + original = Dataset(coords={"x": ("x", [0, 1, 2])}) + expected = Dataset(coords={"y": ("y", [0, 1, 2])}) + for ds, dim in zip([original, expected], ["x", "y"]): + ds[dim].attrs = {"foo": "bar"} + ds[dim].encoding = {"foo": "bar"} + + actual = original.rename({"x": "y"}) + assert_identical(actual, expected) + + @requires_cftime + def test_rename_does_not_change_CFTimeIndex_type(self) -> None: + # make sure CFTimeIndex is not converted to DatetimeIndex #3522 + + time = xr.cftime_range(start="2000", periods=6, freq="2MS", calendar="noleap") + orig = Dataset(coords={"time": time}) + + renamed = orig.rename(time="time_new") + assert "time_new" in renamed.xindexes + # TODO: benbovy - flexible indexes: update when CFTimeIndex + # inherits from xarray.Index + assert isinstance(renamed.xindexes["time_new"].to_pandas_index(), CFTimeIndex) + assert renamed.xindexes["time_new"].to_pandas_index().name == "time_new" + + # check original has not changed + assert "time" in orig.xindexes + assert isinstance(orig.xindexes["time"].to_pandas_index(), CFTimeIndex) + assert orig.xindexes["time"].to_pandas_index().name == "time" + + # note: rename_dims(time="time_new") drops "ds.indexes" + renamed = orig.rename_dims() + assert isinstance(renamed.xindexes["time"].to_pandas_index(), CFTimeIndex) + + renamed = orig.rename_vars() + assert isinstance(renamed.xindexes["time"].to_pandas_index(), CFTimeIndex) + + def test_rename_does_not_change_DatetimeIndex_type(self) -> None: + # make sure DatetimeIndex is conderved on rename + + time = pd.date_range(start="2000", periods=6, freq="2MS") + orig = Dataset(coords={"time": time}) + + renamed = orig.rename(time="time_new") + assert "time_new" in renamed.xindexes + # TODO: benbovy - flexible indexes: update when DatetimeIndex + # inherits from xarray.Index? + assert isinstance(renamed.xindexes["time_new"].to_pandas_index(), DatetimeIndex) + assert renamed.xindexes["time_new"].to_pandas_index().name == "time_new" + + # check original has not changed + assert "time" in orig.xindexes + assert isinstance(orig.xindexes["time"].to_pandas_index(), DatetimeIndex) + assert orig.xindexes["time"].to_pandas_index().name == "time" + + # note: rename_dims(time="time_new") drops "ds.indexes" + renamed = orig.rename_dims() + assert isinstance(renamed.xindexes["time"].to_pandas_index(), DatetimeIndex) + + renamed = orig.rename_vars() + assert isinstance(renamed.xindexes["time"].to_pandas_index(), DatetimeIndex) + + def test_swap_dims(self) -> None: + original = Dataset({"x": [1, 2, 3], "y": ("x", list("abc")), "z": 42}) + expected = Dataset({"z": 42}, {"x": ("y", [1, 2, 3]), "y": list("abc")}) + actual = original.swap_dims({"x": "y"}) + assert_identical(expected, actual) + assert isinstance(actual.variables["y"], IndexVariable) + assert isinstance(actual.variables["x"], Variable) + assert actual.xindexes["y"].equals(expected.xindexes["y"]) + + roundtripped = actual.swap_dims({"y": "x"}) + assert_identical(original.set_coords("y"), roundtripped) + + with pytest.raises(ValueError, match=r"cannot swap"): + original.swap_dims({"y": "x"}) + with pytest.raises(ValueError, match=r"replacement dimension"): + original.swap_dims({"x": "z"}) + + expected = Dataset( + {"y": ("u", list("abc")), "z": 42}, coords={"x": ("u", [1, 2, 3])} + ) + actual = original.swap_dims({"x": "u"}) + assert_identical(expected, actual) + + # as kwargs + expected = Dataset( + {"y": ("u", list("abc")), "z": 42}, coords={"x": ("u", [1, 2, 3])} + ) + actual = original.swap_dims(x="u") + assert_identical(expected, actual) + + # handle multiindex case + midx = pd.MultiIndex.from_arrays([list("aab"), list("yzz")], names=["y1", "y2"]) + + original = Dataset({"x": [1, 2, 3], "y": ("x", midx), "z": 42}) + + midx_coords = Coordinates.from_pandas_multiindex(midx, "y") + midx_coords["x"] = ("y", [1, 2, 3]) + expected = Dataset({"z": 42}, midx_coords) + + actual = original.swap_dims({"x": "y"}) + assert_identical(expected, actual) + assert isinstance(actual.variables["y"], IndexVariable) + assert isinstance(actual.variables["x"], Variable) + assert actual.xindexes["y"].equals(expected.xindexes["y"]) + + def test_expand_dims_error(self) -> None: + original = Dataset( + { + "x": ("a", np.random.randn(3)), + "y": (["b", "a"], np.random.randn(4, 3)), + "z": ("a", np.random.randn(3)), + }, + coords={ + "a": np.linspace(0, 1, 3), + "b": np.linspace(0, 1, 4), + "c": np.linspace(0, 1, 5), + }, + attrs={"key": "entry"}, + ) + + with pytest.raises(ValueError, match=r"already exists"): + original.expand_dims(dim=["x"]) + + # Make sure it raises true error also for non-dimensional coordinates + # which has dimension. + original = original.set_coords("z") + with pytest.raises(ValueError, match=r"already exists"): + original.expand_dims(dim=["z"]) + + original = Dataset( + { + "x": ("a", np.random.randn(3)), + "y": (["b", "a"], np.random.randn(4, 3)), + "z": ("a", np.random.randn(3)), + }, + coords={ + "a": np.linspace(0, 1, 3), + "b": np.linspace(0, 1, 4), + "c": np.linspace(0, 1, 5), + }, + attrs={"key": "entry"}, + ) + with pytest.raises(TypeError, match=r"value of new dimension"): + original.expand_dims({"d": 3.2}) + with pytest.raises(ValueError, match=r"both keyword and positional"): + original.expand_dims({"d": 4}, e=4) + + def test_expand_dims_int(self) -> None: + original = Dataset( + {"x": ("a", np.random.randn(3)), "y": (["b", "a"], np.random.randn(4, 3))}, + coords={ + "a": np.linspace(0, 1, 3), + "b": np.linspace(0, 1, 4), + "c": np.linspace(0, 1, 5), + }, + attrs={"key": "entry"}, + ) + + actual = original.expand_dims(["z"], [1]) + expected = Dataset( + { + "x": original["x"].expand_dims("z", 1), + "y": original["y"].expand_dims("z", 1), + }, + coords={ + "a": np.linspace(0, 1, 3), + "b": np.linspace(0, 1, 4), + "c": np.linspace(0, 1, 5), + }, + attrs={"key": "entry"}, + ) + assert_identical(expected, actual) + # make sure squeeze restores the original data set. + roundtripped = actual.squeeze("z") + assert_identical(original, roundtripped) + + # another test with a negative axis + actual = original.expand_dims(["z"], [-1]) + expected = Dataset( + { + "x": original["x"].expand_dims("z", -1), + "y": original["y"].expand_dims("z", -1), + }, + coords={ + "a": np.linspace(0, 1, 3), + "b": np.linspace(0, 1, 4), + "c": np.linspace(0, 1, 5), + }, + attrs={"key": "entry"}, + ) + assert_identical(expected, actual) + # make sure squeeze restores the original data set. + roundtripped = actual.squeeze("z") + assert_identical(original, roundtripped) + + def test_expand_dims_coords(self) -> None: + original = Dataset({"x": ("a", np.array([1, 2, 3]))}) + expected = Dataset( + {"x": (("b", "a"), np.array([[1, 2, 3], [1, 2, 3]]))}, coords={"b": [1, 2]} + ) + actual = original.expand_dims(dict(b=[1, 2])) + assert_identical(expected, actual) + assert "b" not in original._coord_names + + def test_expand_dims_existing_scalar_coord(self) -> None: + original = Dataset({"x": 1}, {"a": 2}) + expected = Dataset({"x": (("a",), [1])}, {"a": [2]}) + actual = original.expand_dims("a") + assert_identical(expected, actual) + + def test_isel_expand_dims_roundtrip(self) -> None: + original = Dataset({"x": (("a",), [1])}, {"a": [2]}) + actual = original.isel(a=0).expand_dims("a") + assert_identical(actual, original) + + def test_expand_dims_mixed_int_and_coords(self) -> None: + # Test expanding one dimension to have size > 1 that doesn't have + # coordinates, and also expanding another dimension to have size > 1 + # that DOES have coordinates. + original = Dataset( + {"x": ("a", np.random.randn(3)), "y": (["b", "a"], np.random.randn(4, 3))}, + coords={ + "a": np.linspace(0, 1, 3), + "b": np.linspace(0, 1, 4), + "c": np.linspace(0, 1, 5), + }, + ) + + actual = original.expand_dims({"d": 4, "e": ["l", "m", "n"]}) + + expected = Dataset( + { + "x": xr.DataArray( + original["x"].values * np.ones([4, 3, 3]), + coords=dict(d=range(4), e=["l", "m", "n"], a=np.linspace(0, 1, 3)), + dims=["d", "e", "a"], + ).drop_vars("d"), + "y": xr.DataArray( + original["y"].values * np.ones([4, 3, 4, 3]), + coords=dict( + d=range(4), + e=["l", "m", "n"], + b=np.linspace(0, 1, 4), + a=np.linspace(0, 1, 3), + ), + dims=["d", "e", "b", "a"], + ).drop_vars("d"), + }, + coords={"c": np.linspace(0, 1, 5)}, + ) + assert_identical(actual, expected) + + def test_expand_dims_kwargs_python36plus(self) -> None: + original = Dataset( + {"x": ("a", np.random.randn(3)), "y": (["b", "a"], np.random.randn(4, 3))}, + coords={ + "a": np.linspace(0, 1, 3), + "b": np.linspace(0, 1, 4), + "c": np.linspace(0, 1, 5), + }, + attrs={"key": "entry"}, + ) + other_way = original.expand_dims(e=["l", "m", "n"]) + other_way_expected = Dataset( + { + "x": xr.DataArray( + original["x"].values * np.ones([3, 3]), + coords=dict(e=["l", "m", "n"], a=np.linspace(0, 1, 3)), + dims=["e", "a"], + ), + "y": xr.DataArray( + original["y"].values * np.ones([3, 4, 3]), + coords=dict( + e=["l", "m", "n"], + b=np.linspace(0, 1, 4), + a=np.linspace(0, 1, 3), + ), + dims=["e", "b", "a"], + ), + }, + coords={"c": np.linspace(0, 1, 5)}, + attrs={"key": "entry"}, + ) + assert_identical(other_way_expected, other_way) + + @pytest.mark.parametrize("create_index_for_new_dim_flag", [True, False]) + def test_expand_dims_create_index_data_variable( + self, create_index_for_new_dim_flag + ): + # data variables should not gain an index ever + ds = Dataset({"x": 0}) + + if create_index_for_new_dim_flag: + with pytest.warns(UserWarning, match="No index created"): + expanded = ds.expand_dims( + "x", create_index_for_new_dim=create_index_for_new_dim_flag + ) + else: + expanded = ds.expand_dims( + "x", create_index_for_new_dim=create_index_for_new_dim_flag + ) + + # TODO Can't just create the expected dataset directly using constructor because of GH issue 8959 + expected = Dataset({"x": ("x", [0])}).drop_indexes("x").reset_coords("x") + + assert_identical(expanded, expected, check_default_indexes=False) + assert expanded.indexes == {} + + def test_expand_dims_create_index_coordinate_variable(self): + # coordinate variables should gain an index only if create_index_for_new_dim is True (the default) + ds = Dataset(coords={"x": 0}) + expanded = ds.expand_dims("x") + expected = Dataset({"x": ("x", [0])}) + assert_identical(expanded, expected) + + expanded_no_index = ds.expand_dims("x", create_index_for_new_dim=False) + + # TODO Can't just create the expected dataset directly using constructor because of GH issue 8959 + expected = Dataset(coords={"x": ("x", [0])}).drop_indexes("x") + + assert_identical(expanded_no_index, expected, check_default_indexes=False) + assert expanded_no_index.indexes == {} + + def test_expand_dims_create_index_from_iterable(self): + ds = Dataset(coords={"x": 0}) + expanded = ds.expand_dims(x=[0, 1]) + expected = Dataset({"x": ("x", [0, 1])}) + assert_identical(expanded, expected) + + expanded_no_index = ds.expand_dims(x=[0, 1], create_index_for_new_dim=False) + + # TODO Can't just create the expected dataset directly using constructor because of GH issue 8959 + expected = Dataset(coords={"x": ("x", [0, 1])}).drop_indexes("x") + + assert_identical(expanded, expected, check_default_indexes=False) + assert expanded_no_index.indexes == {} + + def test_expand_dims_non_nanosecond_conversion(self) -> None: + # Regression test for https://github.com/pydata/xarray/issues/7493#issuecomment-1953091000 + with pytest.warns(UserWarning, match="non-nanosecond precision"): + ds = Dataset().expand_dims({"time": [np.datetime64("2018-01-01", "s")]}) + assert ds.time.dtype == np.dtype("datetime64[ns]") + + def test_set_index(self) -> None: + expected = create_test_multiindex() + mindex = expected["x"].to_index() + indexes = [mindex.get_level_values(n) for n in mindex.names] + coords = {idx.name: ("x", idx) for idx in indexes} + ds = Dataset({}, coords=coords) + + obj = ds.set_index(x=mindex.names) + assert_identical(obj, expected) + + # ensure pre-existing indexes involved are removed + # (level_2 should be a coordinate with no index) + ds = create_test_multiindex() + coords = {"x": coords["level_1"], "level_2": coords["level_2"]} + expected = Dataset({}, coords=coords) + + obj = ds.set_index(x="level_1") + assert_identical(obj, expected) + + # ensure set_index with no existing index and a single data var given + # doesn't return multi-index + ds = Dataset(data_vars={"x_var": ("x", [0, 1, 2])}) + expected = Dataset(coords={"x": [0, 1, 2]}) + assert_identical(ds.set_index(x="x_var"), expected) + + with pytest.raises(ValueError, match=r"bar variable\(s\) do not exist"): + ds.set_index(foo="bar") + + with pytest.raises(ValueError, match=r"dimension mismatch.*"): + ds.set_index(y="x_var") + + ds = Dataset(coords={"x": 1}) + with pytest.raises( + ValueError, match=r".*cannot set a PandasIndex.*scalar variable.*" + ): + ds.set_index(x="x") + + def test_set_index_deindexed_coords(self) -> None: + # test de-indexed coordinates are converted to base variable + # https://github.com/pydata/xarray/issues/6969 + one = ["a", "a", "b", "b"] + two = [1, 2, 1, 2] + three = ["c", "c", "d", "d"] + four = [3, 4, 3, 4] + + midx_12 = pd.MultiIndex.from_arrays([one, two], names=["one", "two"]) + midx_34 = pd.MultiIndex.from_arrays([three, four], names=["three", "four"]) + + coords = Coordinates.from_pandas_multiindex(midx_12, "x") + coords["three"] = ("x", three) + coords["four"] = ("x", four) + ds = xr.Dataset(coords=coords) + actual = ds.set_index(x=["three", "four"]) + + coords_expected = Coordinates.from_pandas_multiindex(midx_34, "x") + coords_expected["one"] = ("x", one) + coords_expected["two"] = ("x", two) + expected = xr.Dataset(coords=coords_expected) + + assert_identical(actual, expected) + + def test_reset_index(self) -> None: + ds = create_test_multiindex() + mindex = ds["x"].to_index() + indexes = [mindex.get_level_values(n) for n in mindex.names] + coords = {idx.name: ("x", idx) for idx in indexes} + expected = Dataset({}, coords=coords) + + obj = ds.reset_index("x") + assert_identical(obj, expected, check_default_indexes=False) + assert len(obj.xindexes) == 0 + + ds = Dataset(coords={"y": ("x", [1, 2, 3])}) + with pytest.raises(ValueError, match=r".*not coordinates with an index"): + ds.reset_index("y") + + def test_reset_index_keep_attrs(self) -> None: + coord_1 = DataArray([1, 2], dims=["coord_1"], attrs={"attrs": True}) + ds = Dataset({}, {"coord_1": coord_1}) + obj = ds.reset_index("coord_1") + assert ds.coord_1.attrs == obj.coord_1.attrs + assert len(obj.xindexes) == 0 + + def test_reset_index_drop_dims(self) -> None: + ds = Dataset(coords={"x": [1, 2]}) + reset = ds.reset_index("x", drop=True) + assert len(reset.dims) == 0 + + @pytest.mark.parametrize( + ["arg", "drop", "dropped", "converted", "renamed"], + [ + ("foo", False, [], [], {"bar": "x"}), + ("foo", True, ["foo"], [], {"bar": "x"}), + ("x", False, ["x"], ["foo", "bar"], {}), + ("x", True, ["x", "foo", "bar"], [], {}), + (["foo", "bar"], False, ["x"], ["foo", "bar"], {}), + (["foo", "bar"], True, ["x", "foo", "bar"], [], {}), + (["x", "foo"], False, ["x"], ["foo", "bar"], {}), + (["foo", "x"], True, ["x", "foo", "bar"], [], {}), + ], + ) + def test_reset_index_drop_convert( + self, + arg: str | list[str], + drop: bool, + dropped: list[str], + converted: list[str], + renamed: dict[str, str], + ) -> None: + # regressions https://github.com/pydata/xarray/issues/6946 and + # https://github.com/pydata/xarray/issues/6989 + # check that multi-index dimension or level coordinates are dropped, converted + # from IndexVariable to Variable or renamed to dimension as expected + midx = pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("foo", "bar")) + midx_coords = Coordinates.from_pandas_multiindex(midx, "x") + ds = xr.Dataset(coords=midx_coords) + reset = ds.reset_index(arg, drop=drop) + + for name in dropped: + assert name not in reset.variables + for name in converted: + assert_identical(reset[name].variable, ds[name].variable.to_base_variable()) + for old_name, new_name in renamed.items(): + assert_identical(ds[old_name].variable, reset[new_name].variable) + + def test_reorder_levels(self) -> None: + ds = create_test_multiindex() + mindex = ds["x"].to_index() + midx = mindex.reorder_levels(["level_2", "level_1"]) + midx_coords = Coordinates.from_pandas_multiindex(midx, "x") + expected = Dataset({}, coords=midx_coords) + + # check attrs propagated + ds["level_1"].attrs["foo"] = "bar" + expected["level_1"].attrs["foo"] = "bar" + + reindexed = ds.reorder_levels(x=["level_2", "level_1"]) + assert_identical(reindexed, expected) + + ds = Dataset({}, coords={"x": [1, 2]}) + with pytest.raises(ValueError, match=r"has no MultiIndex"): + ds.reorder_levels(x=["level_1", "level_2"]) + + def test_set_xindex(self) -> None: + ds = Dataset( + coords={"foo": ("x", ["a", "a", "b", "b"]), "bar": ("x", [0, 1, 2, 3])} + ) + + actual = ds.set_xindex("foo") + expected = ds.set_index(x="foo").rename_vars(x="foo") + assert_identical(actual, expected, check_default_indexes=False) + + actual_mindex = ds.set_xindex(["foo", "bar"]) + expected_mindex = ds.set_index(x=["foo", "bar"]) + assert_identical(actual_mindex, expected_mindex) + + class NotAnIndex: ... + + with pytest.raises(TypeError, match=".*not a subclass of xarray.Index"): + ds.set_xindex("foo", NotAnIndex) # type: ignore + + with pytest.raises(ValueError, match="those variables don't exist"): + ds.set_xindex("not_a_coordinate", PandasIndex) + + ds["data_var"] = ("x", [1, 2, 3, 4]) + + with pytest.raises(ValueError, match="those variables are data variables"): + ds.set_xindex("data_var", PandasIndex) + + ds2 = Dataset(coords={"x": ("x", [0, 1, 2, 3])}) + + with pytest.raises(ValueError, match="those coordinates already have an index"): + ds2.set_xindex("x", PandasIndex) + + def test_set_xindex_options(self) -> None: + ds = Dataset(coords={"foo": ("x", ["a", "a", "b", "b"])}) + + class IndexWithOptions(Index): + def __init__(self, opt): + self.opt = opt + + @classmethod + def from_variables(cls, variables, options): + return cls(options["opt"]) + + indexed = ds.set_xindex("foo", IndexWithOptions, opt=1) + assert getattr(indexed.xindexes["foo"], "opt") == 1 + + def test_stack(self) -> None: + ds = Dataset( + data_vars={"b": (("x", "y"), [[0, 1], [2, 3]])}, + coords={"x": ("x", [0, 1]), "y": ["a", "b"]}, + ) + + midx_expected = pd.MultiIndex.from_product( + [[0, 1], ["a", "b"]], names=["x", "y"] + ) + midx_coords_expected = Coordinates.from_pandas_multiindex(midx_expected, "z") + expected = Dataset( + data_vars={"b": ("z", [0, 1, 2, 3])}, coords=midx_coords_expected + ) + # check attrs propagated + ds["x"].attrs["foo"] = "bar" + expected["x"].attrs["foo"] = "bar" + + actual = ds.stack(z=["x", "y"]) + assert_identical(expected, actual) + assert list(actual.xindexes) == ["z", "x", "y"] + + actual = ds.stack(z=[...]) + assert_identical(expected, actual) + + # non list dims with ellipsis + actual = ds.stack(z=(...,)) + assert_identical(expected, actual) + + # ellipsis with given dim + actual = ds.stack(z=[..., "y"]) + assert_identical(expected, actual) + + midx_expected = pd.MultiIndex.from_product( + [["a", "b"], [0, 1]], names=["y", "x"] + ) + midx_coords_expected = Coordinates.from_pandas_multiindex(midx_expected, "z") + expected = Dataset( + data_vars={"b": ("z", [0, 2, 1, 3])}, coords=midx_coords_expected + ) + expected["x"].attrs["foo"] = "bar" + + actual = ds.stack(z=["y", "x"]) + assert_identical(expected, actual) + assert list(actual.xindexes) == ["z", "y", "x"] + + @pytest.mark.parametrize( + "create_index,expected_keys", + [ + (True, ["z", "x", "y"]), + (False, []), + (None, ["z", "x", "y"]), + ], + ) + def test_stack_create_index(self, create_index, expected_keys) -> None: + ds = Dataset( + data_vars={"b": (("x", "y"), [[0, 1], [2, 3]])}, + coords={"x": ("x", [0, 1]), "y": ["a", "b"]}, + ) + + actual = ds.stack(z=["x", "y"], create_index=create_index) + assert list(actual.xindexes) == expected_keys + + # TODO: benbovy (flexible indexes) - test error multiple indexes found + # along dimension + create_index=True + + def test_stack_multi_index(self) -> None: + # multi-index on a dimension to stack is discarded too + midx = pd.MultiIndex.from_product([["a", "b"], [0, 1]], names=("lvl1", "lvl2")) + coords = Coordinates.from_pandas_multiindex(midx, "x") + coords["y"] = [0, 1] + ds = xr.Dataset( + data_vars={"b": (("x", "y"), [[0, 1], [2, 3], [4, 5], [6, 7]])}, + coords=coords, + ) + expected = Dataset( + data_vars={"b": ("z", [0, 1, 2, 3, 4, 5, 6, 7])}, + coords={ + "x": ("z", np.repeat(midx.values, 2)), + "lvl1": ("z", np.repeat(midx.get_level_values("lvl1"), 2)), + "lvl2": ("z", np.repeat(midx.get_level_values("lvl2"), 2)), + "y": ("z", [0, 1, 0, 1] * 2), + }, + ) + actual = ds.stack(z=["x", "y"], create_index=False) + assert_identical(expected, actual) + assert len(actual.xindexes) == 0 + + with pytest.raises(ValueError, match=r"cannot create.*wraps a multi-index"): + ds.stack(z=["x", "y"], create_index=True) + + def test_stack_non_dim_coords(self) -> None: + ds = Dataset( + data_vars={"b": (("x", "y"), [[0, 1], [2, 3]])}, + coords={"x": ("x", [0, 1]), "y": ["a", "b"]}, + ).rename_vars(x="xx") + + exp_index = pd.MultiIndex.from_product([[0, 1], ["a", "b"]], names=["xx", "y"]) + exp_coords = Coordinates.from_pandas_multiindex(exp_index, "z") + expected = Dataset(data_vars={"b": ("z", [0, 1, 2, 3])}, coords=exp_coords) + + actual = ds.stack(z=["x", "y"]) + assert_identical(expected, actual) + assert list(actual.xindexes) == ["z", "xx", "y"] + + def test_unstack(self) -> None: + index = pd.MultiIndex.from_product([[0, 1], ["a", "b"]], names=["x", "y"]) + coords = Coordinates.from_pandas_multiindex(index, "z") + ds = Dataset(data_vars={"b": ("z", [0, 1, 2, 3])}, coords=coords) + expected = Dataset( + {"b": (("x", "y"), [[0, 1], [2, 3]]), "x": [0, 1], "y": ["a", "b"]} + ) + + # check attrs propagated + ds["x"].attrs["foo"] = "bar" + expected["x"].attrs["foo"] = "bar" + + for dim in ["z", ["z"], None]: + actual = ds.unstack(dim) + assert_identical(actual, expected) + + def test_unstack_errors(self) -> None: + ds = Dataset({"x": [1, 2, 3]}) + with pytest.raises( + ValueError, + match=re.escape("Dimensions ('foo',) not found in data dimensions ('x',)"), + ): + ds.unstack("foo") + with pytest.raises(ValueError, match=r".*do not have exactly one multi-index"): + ds.unstack("x") + + ds = Dataset({"da": [1, 2]}, coords={"y": ("x", [1, 1]), "z": ("x", [0, 0])}) + ds = ds.set_index(x=("y", "z")) + + with pytest.raises( + ValueError, match="Cannot unstack MultiIndex containing duplicates" + ): + ds.unstack("x") + + def test_unstack_fill_value(self) -> None: + ds = xr.Dataset( + {"var": (("x",), np.arange(6)), "other_var": (("x",), np.arange(3, 9))}, + coords={"x": [0, 1, 2] * 2, "y": (("x",), ["a"] * 3 + ["b"] * 3)}, + ) + # make ds incomplete + ds = ds.isel(x=[0, 2, 3, 4]).set_index(index=["x", "y"]) + # test fill_value + actual1 = ds.unstack("index", fill_value=-1) + expected1 = ds.unstack("index").fillna(-1).astype(int) + assert actual1["var"].dtype == int + assert_equal(actual1, expected1) + + actual2 = ds["var"].unstack("index", fill_value=-1) + expected2 = ds["var"].unstack("index").fillna(-1).astype(int) + assert_equal(actual2, expected2) + + actual3 = ds.unstack("index", fill_value={"var": -1, "other_var": 1}) + expected3 = ds.unstack("index").fillna({"var": -1, "other_var": 1}).astype(int) + assert_equal(actual3, expected3) + + @requires_sparse + def test_unstack_sparse(self) -> None: + ds = xr.Dataset( + {"var": (("x",), np.arange(6))}, + coords={"x": [0, 1, 2] * 2, "y": (("x",), ["a"] * 3 + ["b"] * 3)}, + ) + # make ds incomplete + ds = ds.isel(x=[0, 2, 3, 4]).set_index(index=["x", "y"]) + # test fill_value + actual1 = ds.unstack("index", sparse=True) + expected1 = ds.unstack("index") + assert isinstance(actual1["var"].data, sparse_array_type) + assert actual1["var"].variable._to_dense().equals(expected1["var"].variable) + assert actual1["var"].data.density < 1.0 + + actual2 = ds["var"].unstack("index", sparse=True) + expected2 = ds["var"].unstack("index") + assert isinstance(actual2.data, sparse_array_type) + assert actual2.variable._to_dense().equals(expected2.variable) + assert actual2.data.density < 1.0 + + midx = pd.MultiIndex.from_arrays([np.arange(3), np.arange(3)], names=["a", "b"]) + coords = Coordinates.from_pandas_multiindex(midx, "z") + coords["foo"] = np.arange(4) + coords["bar"] = np.arange(5) + ds_eye = Dataset( + {"var": (("z", "foo", "bar"), np.ones((3, 4, 5)))}, coords=coords + ) + actual3 = ds_eye.unstack(sparse=True, fill_value=0) + assert isinstance(actual3["var"].data, sparse_array_type) + expected3 = xr.Dataset( + { + "var": ( + ("foo", "bar", "a", "b"), + np.broadcast_to(np.eye(3, 3), (4, 5, 3, 3)), + ) + }, + coords={ + "foo": np.arange(4), + "bar": np.arange(5), + "a": np.arange(3), + "b": np.arange(3), + }, + ) + actual3["var"].data = actual3["var"].data.todense() + assert_equal(expected3, actual3) + + def test_stack_unstack_fast(self) -> None: + ds = Dataset( + { + "a": ("x", [0, 1]), + "b": (("x", "y"), [[0, 1], [2, 3]]), + "x": [0, 1], + "y": ["a", "b"], + } + ) + actual = ds.stack(z=["x", "y"]).unstack("z") + assert actual.broadcast_equals(ds) + + actual = ds[["b"]].stack(z=["x", "y"]).unstack("z") + assert actual.identical(ds[["b"]]) + + def test_stack_unstack_slow(self) -> None: + ds = Dataset( + data_vars={ + "a": ("x", [0, 1]), + "b": (("x", "y"), [[0, 1], [2, 3]]), + }, + coords={"x": [0, 1], "y": ["a", "b"]}, + ) + stacked = ds.stack(z=["x", "y"]) + actual = stacked.isel(z=slice(None, None, -1)).unstack("z") + assert actual.broadcast_equals(ds) + + stacked = ds[["b"]].stack(z=["x", "y"]) + actual = stacked.isel(z=slice(None, None, -1)).unstack("z") + assert actual.identical(ds[["b"]]) + + def test_to_stacked_array_invalid_sample_dims(self) -> None: + data = xr.Dataset( + data_vars={"a": (("x", "y"), [[0, 1, 2], [3, 4, 5]]), "b": ("x", [6, 7])}, + coords={"y": ["u", "v", "w"]}, + ) + with pytest.raises( + ValueError, + match=r"Variables in the dataset must contain all ``sample_dims`` \(\['y'\]\) but 'b' misses \['y'\]", + ): + data.to_stacked_array("features", sample_dims=["y"]) + + def test_to_stacked_array_name(self) -> None: + name = "adf9d" + + # make a two dimensional dataset + a, b = create_test_stacked_array() + D = xr.Dataset({"a": a, "b": b}) + sample_dims = ["x"] + + y = D.to_stacked_array("features", sample_dims, name=name) + assert y.name == name + + def test_to_stacked_array_dtype_dims(self) -> None: + # make a two dimensional dataset + a, b = create_test_stacked_array() + D = xr.Dataset({"a": a, "b": b}) + sample_dims = ["x"] + y = D.to_stacked_array("features", sample_dims) + assert y.xindexes["features"].to_pandas_index().levels[1].dtype == D.y.dtype + assert y.dims == ("x", "features") + + def test_to_stacked_array_to_unstacked_dataset(self) -> None: + # single dimension: regression test for GH4049 + arr = xr.DataArray(np.arange(3), coords=[("x", [0, 1, 2])]) + data = xr.Dataset({"a": arr, "b": arr}) + stacked = data.to_stacked_array("y", sample_dims=["x"]) + unstacked = stacked.to_unstacked_dataset("y") + assert_identical(unstacked, data) + + # make a two dimensional dataset + a, b = create_test_stacked_array() + D = xr.Dataset({"a": a, "b": b}) + sample_dims = ["x"] + y = D.to_stacked_array("features", sample_dims).transpose("x", "features") + + x = y.to_unstacked_dataset("features") + assert_identical(D, x) + + # test on just one sample + x0 = y[0].to_unstacked_dataset("features") + d0 = D.isel(x=0) + assert_identical(d0, x0) + + def test_to_stacked_array_to_unstacked_dataset_different_dimension(self) -> None: + # test when variables have different dimensionality + a, b = create_test_stacked_array() + sample_dims = ["x"] + D = xr.Dataset({"a": a, "b": b.isel(y=0)}) + + y = D.to_stacked_array("features", sample_dims) + x = y.to_unstacked_dataset("features") + assert_identical(D, x) + + def test_to_stacked_array_preserves_dtype(self) -> None: + # regression test for bug found in https://github.com/pydata/xarray/pull/8872#issuecomment-2081218616 + ds = xr.Dataset( + data_vars={ + "a": (("x", "y"), [[0, 1, 2], [3, 4, 5]]), + "b": ("x", [6, 7]), + }, + coords={"y": ["u", "v", "w"]}, + ) + stacked = ds.to_stacked_array("z", sample_dims=["x"]) + + # coordinate created from variables names should be of string dtype + data = np.array(["a", "a", "a", "b"], dtype=" None: + data = create_test_data(seed=0) + expected = data.copy() + var2 = Variable("dim1", np.arange(8)) + actual = data + actual.update({"var2": var2}) + expected["var2"] = var2 + assert_identical(expected, actual) + + actual = data.copy() + actual.update(data) + assert_identical(expected, actual) + + other = Dataset(attrs={"new": "attr"}) + actual = data.copy() + actual.update(other) + assert_identical(expected, actual) + + def test_update_overwrite_coords(self) -> None: + data = Dataset({"a": ("x", [1, 2])}, {"b": 3}) + data.update(Dataset(coords={"b": 4})) + expected = Dataset({"a": ("x", [1, 2])}, {"b": 4}) + assert_identical(data, expected) + + data = Dataset({"a": ("x", [1, 2])}, {"b": 3}) + data.update(Dataset({"c": 5}, coords={"b": 4})) + expected = Dataset({"a": ("x", [1, 2]), "c": 5}, {"b": 4}) + assert_identical(data, expected) + + data = Dataset({"a": ("x", [1, 2])}, {"b": 3}) + data.update({"c": DataArray(5, coords={"b": 4})}) + expected = Dataset({"a": ("x", [1, 2]), "c": 5}, {"b": 3}) + assert_identical(data, expected) + + def test_update_multiindex_level(self) -> None: + data = create_test_multiindex() + + with pytest.raises( + ValueError, match=r"cannot set or update variable.*corrupt.*index " + ): + data.update({"level_1": range(4)}) + + def test_update_auto_align(self) -> None: + ds = Dataset({"x": ("t", [3, 4])}, {"t": [0, 1]}) + + expected1 = Dataset( + {"x": ("t", [3, 4]), "y": ("t", [np.nan, 5])}, {"t": [0, 1]} + ) + actual1 = ds.copy() + other1 = {"y": ("t", [5]), "t": [1]} + with pytest.raises(ValueError, match=r"conflicting sizes"): + actual1.update(other1) + actual1.update(Dataset(other1)) + assert_identical(expected1, actual1) + + actual2 = ds.copy() + other2 = Dataset({"y": ("t", [5]), "t": [100]}) + actual2.update(other2) + expected2 = Dataset( + {"x": ("t", [3, 4]), "y": ("t", [np.nan] * 2)}, {"t": [0, 1]} + ) + assert_identical(expected2, actual2) + + def test_getitem(self) -> None: + data = create_test_data() + assert isinstance(data["var1"], DataArray) + assert_equal(data["var1"].variable, data.variables["var1"]) + with pytest.raises(KeyError): + data["notfound"] + with pytest.raises(KeyError): + data[["var1", "notfound"]] + + actual1 = data[["var1", "var2"]] + expected1 = Dataset({"var1": data["var1"], "var2": data["var2"]}) + assert_equal(expected1, actual1) + + actual2 = data["numbers"] + expected2 = DataArray( + data["numbers"].variable, + {"dim3": data["dim3"], "numbers": data["numbers"]}, + dims="dim3", + name="numbers", + ) + assert_identical(expected2, actual2) + + actual3 = data[dict(dim1=0)] + expected3 = data.isel(dim1=0) + assert_identical(expected3, actual3) + + def test_getitem_hashable(self) -> None: + data = create_test_data() + data[(3, 4)] = data["var1"] + 1 + expected = data["var1"] + 1 + expected.name = (3, 4) + assert_identical(expected, data[(3, 4)]) + with pytest.raises(KeyError, match=r"('var1', 'var2')"): + data[("var1", "var2")] + + def test_getitem_multiple_dtype(self) -> None: + keys = ["foo", 1] + dataset = Dataset({key: ("dim0", range(1)) for key in keys}) + assert_identical(dataset, dataset[keys]) + + def test_virtual_variables_default_coords(self) -> None: + dataset = Dataset({"foo": ("x", range(10))}) + expected1 = DataArray(range(10), dims="x", name="x") + actual1 = dataset["x"] + assert_identical(expected1, actual1) + assert isinstance(actual1.variable, IndexVariable) + + actual2 = dataset[["x", "foo"]] + expected2 = dataset.assign_coords(x=range(10)) + assert_identical(expected2, actual2) + + def test_virtual_variables_time(self) -> None: + # access virtual variables + data = create_test_data() + assert_array_equal( + data["time.month"].values, data.variables["time"].to_index().month + ) + assert_array_equal(data["time.season"].values, "DJF") + # test virtual variable math + assert_array_equal(data["time.dayofyear"] + 1, 2 + np.arange(20)) + assert_array_equal(np.sin(data["time.dayofyear"]), np.sin(1 + np.arange(20))) + # ensure they become coordinates + expected = Dataset({}, {"dayofyear": data["time.dayofyear"]}) + actual = data[["time.dayofyear"]] + assert_equal(expected, actual) + # non-coordinate variables + ds = Dataset({"t": ("x", pd.date_range("2000-01-01", periods=3))}) + assert (ds["t.year"] == 2000).all() + + def test_virtual_variable_same_name(self) -> None: + # regression test for GH367 + times = pd.date_range("2000-01-01", freq="h", periods=5) + data = Dataset({"time": times}) + actual = data["time.time"] + expected = DataArray(times.time, [("time", times)], name="time") + assert_identical(actual, expected) + + def test_time_season(self) -> None: + time = xr.date_range("2000-01-01", periods=12, freq="ME", use_cftime=False) + ds = Dataset({"t": time}) + seas = ["DJF"] * 2 + ["MAM"] * 3 + ["JJA"] * 3 + ["SON"] * 3 + ["DJF"] + assert_array_equal(seas, ds["t.season"]) + + def test_slice_virtual_variable(self) -> None: + data = create_test_data() + assert_equal( + data["time.dayofyear"][:10].variable, Variable(["time"], 1 + np.arange(10)) + ) + assert_equal(data["time.dayofyear"][0].variable, Variable([], 1)) + + def test_setitem(self) -> None: + # assign a variable + var = Variable(["dim1"], np.random.randn(8)) + data1 = create_test_data() + data1["A"] = var + data2 = data1.copy() + data2["A"] = var + assert_identical(data1, data2) + # assign a dataset array + dv = 2 * data2["A"] + data1["B"] = dv.variable + data2["B"] = dv + assert_identical(data1, data2) + # can't assign an ND array without dimensions + with pytest.raises(ValueError, match=r"without explicit dimension names"): + data2["C"] = var.values.reshape(2, 4) + # but can assign a 1D array + data1["C"] = var.values + data2["C"] = ("C", var.values) + assert_identical(data1, data2) + # can assign a scalar + data1["scalar"] = 0 + data2["scalar"] = ([], 0) + assert_identical(data1, data2) + # can't use the same dimension name as a scalar var + with pytest.raises(ValueError, match=r"already exists as a scalar"): + data1["newvar"] = ("scalar", [3, 4, 5]) + # can't resize a used dimension + with pytest.raises(ValueError, match=r"conflicting dimension sizes"): + data1["dim1"] = data1["dim1"][:5] + # override an existing value + data1["A"] = 3 * data2["A"] + assert_equal(data1["A"], 3 * data2["A"]) + # can't assign a dataset to a single key + with pytest.raises(TypeError, match="Cannot assign a Dataset to a single key"): + data1["D"] = xr.Dataset() + + # test assignment with positional and label-based indexing + data3 = data1[["var1", "var2"]] + data3["var3"] = data3.var1.isel(dim1=0) + data4 = data3.copy() + err_msg = ( + "can only set locations defined by dictionaries from Dataset.loc. Got: a" + ) + with pytest.raises(TypeError, match=err_msg): + data1.loc["a"] = 0 + err_msg = r"Variables \['A', 'B', 'scalar'\] in new values not available in original dataset:" + with pytest.raises(ValueError, match=err_msg): + data4[{"dim2": 1}] = data1[{"dim2": 2}] + err_msg = "Variable 'var3': indexer {'dim2': 0} not available" + with pytest.raises(ValueError, match=err_msg): + data1[{"dim2": 0}] = 0.0 + err_msg = "Variable 'var1': indexer {'dim2': 10} not available" + with pytest.raises(ValueError, match=err_msg): + data4[{"dim2": 10}] = data3[{"dim2": 2}] + err_msg = "Variable 'var1': dimension 'dim2' appears in new values" + with pytest.raises(KeyError, match=err_msg): + data4[{"dim2": 2}] = data3[{"dim2": [2]}] + err_msg = ( + "Variable 'var2': dimension order differs between original and new data" + ) + data3["var2"] = data3["var2"].T + with pytest.raises(ValueError, match=err_msg): + data4[{"dim2": [2, 3]}] = data3[{"dim2": [2, 3]}] + data3["var2"] = data3["var2"].T + err_msg = r"cannot align objects.*not equal along these coordinates.*" + with pytest.raises(ValueError, match=err_msg): + data4[{"dim2": [2, 3]}] = data3[{"dim2": [2, 3, 4]}] + err_msg = "Dataset assignment only accepts DataArrays, Datasets, and scalars." + with pytest.raises(TypeError, match=err_msg): + data4[{"dim2": [2, 3]}] = data3["var1"][{"dim2": [3, 4]}].values + data5 = data4.astype(str) + data5["var4"] = data4["var1"] + # convert to `np.str_('a')` once `numpy<2.0` has been dropped + err_msg = "could not convert string to float: .*'a'.*" + with pytest.raises(ValueError, match=err_msg): + data5[{"dim2": 1}] = "a" + + data4[{"dim2": 0}] = 0.0 + data4[{"dim2": 1}] = data3[{"dim2": 2}] + data4.loc[{"dim2": 1.5}] = 1.0 + data4.loc[{"dim2": 2.0}] = data3.loc[{"dim2": 2.5}] + for v, dat3 in data3.items(): + dat4 = data4[v] + assert_array_equal(dat4[{"dim2": 0}], 0.0) + assert_array_equal(dat4[{"dim2": 1}], dat3[{"dim2": 2}]) + assert_array_equal(dat4.loc[{"dim2": 1.5}], 1.0) + assert_array_equal(dat4.loc[{"dim2": 2.0}], dat3.loc[{"dim2": 2.5}]) + unchanged = [1.0, 2.5, 3.0, 3.5, 4.0] + assert_identical( + dat4.loc[{"dim2": unchanged}], dat3.loc[{"dim2": unchanged}] + ) + + def test_setitem_pandas(self) -> None: + ds = self.make_example_math_dataset() + ds["x"] = np.arange(3) + ds_copy = ds.copy() + ds_copy["bar"] = ds["bar"].to_pandas() + + assert_equal(ds, ds_copy) + + def test_setitem_auto_align(self) -> None: + ds = Dataset() + ds["x"] = ("y", range(3)) + ds["y"] = 1 + np.arange(3) + expected = Dataset({"x": ("y", range(3)), "y": 1 + np.arange(3)}) + assert_identical(ds, expected) + + ds["y"] = DataArray(range(3), dims="y") + expected = Dataset({"x": ("y", range(3))}, {"y": range(3)}) + assert_identical(ds, expected) + + ds["x"] = DataArray([1, 2], coords=[("y", [0, 1])]) + expected = Dataset({"x": ("y", [1, 2, np.nan])}, {"y": range(3)}) + assert_identical(ds, expected) + + ds["x"] = 42 + expected = Dataset({"x": 42, "y": range(3)}) + assert_identical(ds, expected) + + ds["x"] = DataArray([4, 5, 6, 7], coords=[("y", [0, 1, 2, 3])]) + expected = Dataset({"x": ("y", [4, 5, 6])}, {"y": range(3)}) + assert_identical(ds, expected) + + def test_setitem_dimension_override(self) -> None: + # regression test for GH-3377 + ds = xr.Dataset({"x": [0, 1, 2]}) + ds["x"] = ds["x"][:2] + expected = Dataset({"x": [0, 1]}) + assert_identical(ds, expected) + + ds = xr.Dataset({"x": [0, 1, 2]}) + ds["x"] = np.array([0, 1]) + assert_identical(ds, expected) + + ds = xr.Dataset({"x": [0, 1, 2]}) + ds.coords["x"] = [0, 1] + assert_identical(ds, expected) + + def test_setitem_with_coords(self) -> None: + # Regression test for GH:2068 + ds = create_test_data() + + other = DataArray( + np.arange(10), dims="dim3", coords={"numbers": ("dim3", np.arange(10))} + ) + expected = ds.copy() + expected["var3"] = other.drop_vars("numbers") + actual = ds.copy() + actual["var3"] = other + assert_identical(expected, actual) + assert "numbers" in other.coords # should not change other + + # with alignment + other = ds["var3"].isel(dim3=slice(1, -1)) + other["numbers"] = ("dim3", np.arange(8)) + actual = ds.copy() + actual["var3"] = other + assert "numbers" in other.coords # should not change other + expected = ds.copy() + expected["var3"] = ds["var3"].isel(dim3=slice(1, -1)) + assert_identical(expected, actual) + + # with non-duplicate coords + other = ds["var3"].isel(dim3=slice(1, -1)) + other["numbers"] = ("dim3", np.arange(8)) + other["position"] = ("dim3", np.arange(8)) + actual = ds.copy() + actual["var3"] = other + assert "position" in actual + assert "position" in other.coords + + # assigning a coordinate-only dataarray + actual = ds.copy() + other = actual["numbers"] + other[0] = 10 + actual["numbers"] = other + assert actual["numbers"][0] == 10 + + # GH: 2099 + ds = Dataset( + {"var": ("x", [1, 2, 3])}, + coords={"x": [0, 1, 2], "z1": ("x", [1, 2, 3]), "z2": ("x", [1, 2, 3])}, + ) + ds["var"] = ds["var"] * 2 + assert np.allclose(ds["var"], [2, 4, 6]) + + def test_setitem_align_new_indexes(self) -> None: + ds = Dataset({"foo": ("x", [1, 2, 3])}, {"x": [0, 1, 2]}) + ds["bar"] = DataArray([2, 3, 4], [("x", [1, 2, 3])]) + expected = Dataset( + {"foo": ("x", [1, 2, 3]), "bar": ("x", [np.nan, 2, 3])}, {"x": [0, 1, 2]} + ) + assert_identical(ds, expected) + + def test_setitem_vectorized(self) -> None: + # Regression test for GH:7030 + # Positional indexing + da = xr.DataArray(np.r_[:120].reshape(2, 3, 4, 5), dims=["a", "b", "c", "d"]) + ds = xr.Dataset({"da": da}) + b = xr.DataArray([[0, 0], [1, 0]], dims=["u", "v"]) + c = xr.DataArray([[0, 1], [2, 3]], dims=["u", "v"]) + w = xr.DataArray([-1, -2], dims=["u"]) + index = dict(b=b, c=c) + ds[index] = xr.Dataset({"da": w}) + assert (ds[index]["da"] == w).all() + + # Indexing with coordinates + da = xr.DataArray(np.r_[:120].reshape(2, 3, 4, 5), dims=["a", "b", "c", "d"]) + ds = xr.Dataset({"da": da}) + ds.coords["b"] = [2, 4, 6] + b = xr.DataArray([[2, 2], [4, 2]], dims=["u", "v"]) + c = xr.DataArray([[0, 1], [2, 3]], dims=["u", "v"]) + w = xr.DataArray([-1, -2], dims=["u"]) + index = dict(b=b, c=c) + ds.loc[index] = xr.Dataset({"da": w}, coords={"b": ds.coords["b"]}) + assert (ds.loc[index]["da"] == w).all() + + @pytest.mark.parametrize("dtype", [str, bytes]) + def test_setitem_str_dtype(self, dtype) -> None: + ds = xr.Dataset(coords={"x": np.array(["x", "y"], dtype=dtype)}) + # test Dataset update + ds["foo"] = xr.DataArray(np.array([0, 0]), dims=["x"]) + + assert np.issubdtype(ds.x.dtype, dtype) + + def test_setitem_using_list(self) -> None: + # assign a list of variables + var1 = Variable(["dim1"], np.random.randn(8)) + var2 = Variable(["dim1"], np.random.randn(8)) + actual = create_test_data() + expected = actual.copy() + expected["A"] = var1 + expected["B"] = var2 + actual[["A", "B"]] = [var1, var2] + assert_identical(actual, expected) + # assign a list of dataset arrays + dv = 2 * expected[["A", "B"]] + actual[["C", "D"]] = [d.variable for d in dv.data_vars.values()] + expected[["C", "D"]] = dv + assert_identical(actual, expected) + + @pytest.mark.parametrize( + "var_list, data, error_regex", + [ + ( + ["A", "B"], + [Variable(["dim1"], np.random.randn(8))], + r"Different lengths", + ), + ([], [Variable(["dim1"], np.random.randn(8))], r"Empty list of variables"), + (["A", "B"], xr.DataArray([1, 2]), r"assign single DataArray"), + ], + ) + def test_setitem_using_list_errors(self, var_list, data, error_regex) -> None: + actual = create_test_data() + with pytest.raises(ValueError, match=error_regex): + actual[var_list] = data + + def test_assign(self) -> None: + ds = Dataset() + actual = ds.assign(x=[0, 1, 2], y=2) + expected = Dataset({"x": [0, 1, 2], "y": 2}) + assert_identical(actual, expected) + assert list(actual.variables) == ["x", "y"] + assert_identical(ds, Dataset()) + + actual = actual.assign(y=lambda ds: ds.x**2) + expected = Dataset({"y": ("x", [0, 1, 4]), "x": [0, 1, 2]}) + assert_identical(actual, expected) + + actual = actual.assign_coords(z=2) + expected = Dataset({"y": ("x", [0, 1, 4])}, {"z": 2, "x": [0, 1, 2]}) + assert_identical(actual, expected) + + def test_assign_coords(self) -> None: + ds = Dataset() + + actual = ds.assign(x=[0, 1, 2], y=2) + actual = actual.assign_coords(x=list("abc")) + expected = Dataset({"x": list("abc"), "y": 2}) + assert_identical(actual, expected) + + actual = ds.assign(x=[0, 1, 2], y=[2, 3]) + actual = actual.assign_coords({"y": [2.0, 3.0]}) + expected = ds.assign(x=[0, 1, 2], y=[2.0, 3.0]) + assert_identical(actual, expected) + + def test_assign_attrs(self) -> None: + expected = Dataset(attrs=dict(a=1, b=2)) + new = Dataset() + actual = new.assign_attrs(a=1, b=2) + assert_identical(actual, expected) + assert new.attrs == {} + + expected.attrs["c"] = 3 + new_actual = actual.assign_attrs({"c": 3}) + assert_identical(new_actual, expected) + assert actual.attrs == dict(a=1, b=2) + + def test_assign_multiindex_level(self) -> None: + data = create_test_multiindex() + with pytest.raises(ValueError, match=r"cannot drop or update.*corrupt.*index "): + data.assign(level_1=range(4)) + data.assign_coords(level_1=range(4)) + + def test_assign_new_multiindex(self) -> None: + midx = pd.MultiIndex.from_arrays([["a", "a", "b", "b"], [0, 1, 0, 1]]) + midx_coords = Coordinates.from_pandas_multiindex(midx, "x") + + ds = Dataset(coords={"x": [1, 2]}) + expected = Dataset(coords=midx_coords) + + with pytest.warns( + FutureWarning, + match=".*`pandas.MultiIndex`.*no longer be implicitly promoted.*", + ): + actual = ds.assign(x=midx) + assert_identical(actual, expected) + + @pytest.mark.parametrize("orig_coords", [{}, {"x": range(4)}]) + def test_assign_coords_new_multiindex(self, orig_coords) -> None: + ds = Dataset(coords=orig_coords) + midx = pd.MultiIndex.from_arrays( + [["a", "a", "b", "b"], [0, 1, 0, 1]], names=("one", "two") + ) + midx_coords = Coordinates.from_pandas_multiindex(midx, "x") + + expected = Dataset(coords=midx_coords) + + with pytest.warns( + FutureWarning, + match=".*`pandas.MultiIndex`.*no longer be implicitly promoted.*", + ): + actual = ds.assign_coords({"x": midx}) + assert_identical(actual, expected) + + actual = ds.assign_coords(midx_coords) + assert_identical(actual, expected) + + def test_assign_coords_existing_multiindex(self) -> None: + data = create_test_multiindex() + with pytest.warns( + FutureWarning, match=r"updating coordinate.*MultiIndex.*inconsistent" + ): + updated = data.assign_coords(x=range(4)) + # https://github.com/pydata/xarray/issues/7097 (coord names updated) + assert len(updated.coords) == 1 + + with pytest.warns( + FutureWarning, match=r"updating coordinate.*MultiIndex.*inconsistent" + ): + updated = data.assign(x=range(4)) + # https://github.com/pydata/xarray/issues/7097 (coord names updated) + assert len(updated.coords) == 1 + + def test_assign_all_multiindex_coords(self) -> None: + data = create_test_multiindex() + actual = data.assign(x=range(4), level_1=range(4), level_2=range(4)) + # no error but multi-index dropped in favor of single indexes for each level + assert ( + actual.xindexes["x"] + is not actual.xindexes["level_1"] + is not actual.xindexes["level_2"] + ) + + def test_assign_coords_custom_index_side_effect(self) -> None: + # test that assigning new coordinates do not reset other dimension coord indexes + # to default (pandas) index (https://github.com/pydata/xarray/issues/7346) + class CustomIndex(PandasIndex): + pass + + ds = ( + Dataset(coords={"x": [1, 2, 3]}) + .drop_indexes("x") + .set_xindex("x", CustomIndex) + ) + actual = ds.assign_coords(y=[4, 5, 6]) + assert isinstance(actual.xindexes["x"], CustomIndex) + + def test_assign_coords_custom_index(self) -> None: + class CustomIndex(Index): + pass + + coords = Coordinates( + coords={"x": ("x", [1, 2, 3])}, indexes={"x": CustomIndex()} + ) + ds = Dataset() + actual = ds.assign_coords(coords) + assert isinstance(actual.xindexes["x"], CustomIndex) + + def test_assign_coords_no_default_index(self) -> None: + coords = Coordinates({"y": [1, 2, 3]}, indexes={}) + ds = Dataset() + actual = ds.assign_coords(coords) + expected = coords.to_dataset() + assert_identical(expected, actual, check_default_indexes=False) + assert "y" not in actual.xindexes + + def test_merge_multiindex_level(self) -> None: + data = create_test_multiindex() + + other = Dataset({"level_1": ("x", [0, 1])}) + with pytest.raises(ValueError, match=r".*conflicting dimension sizes.*"): + data.merge(other) + + other = Dataset({"level_1": ("x", range(4))}) + with pytest.raises( + ValueError, match=r"unable to determine.*coordinates or not.*" + ): + data.merge(other) + + # `other` Dataset coordinates are ignored (bug or feature?) + other = Dataset(coords={"level_1": ("x", range(4))}) + assert_identical(data.merge(other), data) + + def test_setitem_original_non_unique_index(self) -> None: + # regression test for GH943 + original = Dataset({"data": ("x", np.arange(5))}, coords={"x": [0, 1, 2, 0, 1]}) + expected = Dataset({"data": ("x", np.arange(5))}, {"x": range(5)}) + + actual = original.copy() + actual["x"] = list(range(5)) + assert_identical(actual, expected) + + actual = original.copy() + actual["x"] = ("x", list(range(5))) + assert_identical(actual, expected) + + actual = original.copy() + actual.coords["x"] = list(range(5)) + assert_identical(actual, expected) + + def test_setitem_both_non_unique_index(self) -> None: + # regression test for GH956 + names = ["joaquin", "manolo", "joaquin"] + values = np.random.randint(0, 256, (3, 4, 4)) + array = DataArray( + values, dims=["name", "row", "column"], coords=[names, range(4), range(4)] + ) + expected = Dataset({"first": array, "second": array}) + actual = array.rename("first").to_dataset() + actual["second"] = array + assert_identical(expected, actual) + + def test_setitem_multiindex_level(self) -> None: + data = create_test_multiindex() + with pytest.raises( + ValueError, match=r"cannot set or update variable.*corrupt.*index " + ): + data["level_1"] = range(4) + + def test_delitem(self) -> None: + data = create_test_data() + all_items = set(data.variables) + assert set(data.variables) == all_items + del data["var1"] + assert set(data.variables) == all_items - {"var1"} + del data["numbers"] + assert set(data.variables) == all_items - {"var1", "numbers"} + assert "numbers" not in data.coords + + expected = Dataset() + actual = Dataset({"y": ("x", [1, 2])}) + del actual["y"] + assert_identical(expected, actual) + + def test_delitem_multiindex_level(self) -> None: + data = create_test_multiindex() + with pytest.raises( + ValueError, match=r"cannot remove coordinate.*corrupt.*index " + ): + del data["level_1"] + + def test_squeeze(self) -> None: + data = Dataset({"foo": (["x", "y", "z"], [[[1], [2]]])}) + test_args: list[list] = [[], [["x"]], [["x", "z"]]] + for args in test_args: + + def get_args(v): + return [set(args[0]) & set(v.dims)] if args else [] + + expected = Dataset( + {k: v.squeeze(*get_args(v)) for k, v in data.variables.items()} + ) + expected = expected.set_coords(data.coords) + assert_identical(expected, data.squeeze(*args)) + # invalid squeeze + with pytest.raises(ValueError, match=r"cannot select a dimension"): + data.squeeze("y") + + def test_squeeze_drop(self) -> None: + data = Dataset({"foo": ("x", [1])}, {"x": [0]}) + expected = Dataset({"foo": 1}) + selected = data.squeeze(drop=True) + assert_identical(expected, selected) + + expected = Dataset({"foo": 1}, {"x": 0}) + selected = data.squeeze(drop=False) + assert_identical(expected, selected) + + data = Dataset({"foo": (("x", "y"), [[1]])}, {"x": [0], "y": [0]}) + expected = Dataset({"foo": 1}) + selected = data.squeeze(drop=True) + assert_identical(expected, selected) + + expected = Dataset({"foo": ("x", [1])}, {"x": [0]}) + selected = data.squeeze(dim="y", drop=True) + assert_identical(expected, selected) + + data = Dataset({"foo": (("x",), [])}, {"x": []}) + selected = data.squeeze(drop=True) + assert_identical(data, selected) + + def test_to_dataarray(self) -> None: + ds = Dataset( + {"a": 1, "b": ("x", [1, 2, 3])}, + coords={"c": 42}, + attrs={"Conventions": "None"}, + ) + data = [[1, 1, 1], [1, 2, 3]] + coords = {"c": 42, "variable": ["a", "b"]} + dims = ("variable", "x") + expected = DataArray(data, coords, dims, attrs=ds.attrs) + actual = ds.to_dataarray() + assert_identical(expected, actual) + + actual = ds.to_dataarray("abc", name="foo") + expected = expected.rename({"variable": "abc"}).rename("foo") + assert_identical(expected, actual) + + def test_to_and_from_dataframe(self) -> None: + x = np.random.randn(10) + y = np.random.randn(10) + t = list("abcdefghij") + cat = pd.Categorical(["a", "b"] * 5) + ds = Dataset({"a": ("t", x), "b": ("t", y), "t": ("t", t), "cat": ("t", cat)}) + expected = pd.DataFrame( + np.array([x, y]).T, columns=["a", "b"], index=pd.Index(t, name="t") + ) + expected["cat"] = cat + actual = ds.to_dataframe() + # use the .equals method to check all DataFrame metadata + assert expected.equals(actual), (expected, actual) + + # verify coords are included + actual = ds.set_coords("b").to_dataframe() + assert expected.equals(actual), (expected, actual) + + # check roundtrip + assert_identical(ds, Dataset.from_dataframe(actual)) + assert isinstance(ds["cat"].variable.data.dtype, pd.CategoricalDtype) + # test a case with a MultiIndex + w = np.random.randn(2, 3) + cat = pd.Categorical(["a", "a", "c"]) + ds = Dataset({"w": (("x", "y"), w), "cat": ("y", cat)}) + ds["y"] = ("y", list("abc")) + exp_index = pd.MultiIndex.from_arrays( + [[0, 0, 0, 1, 1, 1], ["a", "b", "c", "a", "b", "c"]], names=["x", "y"] + ) + expected = pd.DataFrame( + {"w": w.reshape(-1), "cat": pd.Categorical(["a", "a", "c", "a", "a", "c"])}, + index=exp_index, + ) + actual = ds.to_dataframe() + assert expected.equals(actual) + + # check roundtrip + # from_dataframe attempts to broadcast across because it doesn't know better, so cat must be converted + ds["cat"] = (("x", "y"), np.stack((ds["cat"].to_numpy(), ds["cat"].to_numpy()))) + assert_identical(ds.assign_coords(x=[0, 1]), Dataset.from_dataframe(actual)) + + # Check multiindex reordering + new_order = ["x", "y"] + # revert broadcasting fix above for 1d arrays + ds["cat"] = ("y", cat) + actual = ds.to_dataframe(dim_order=new_order) + assert expected.equals(actual) + + new_order = ["y", "x"] + exp_index = pd.MultiIndex.from_arrays( + [["a", "a", "b", "b", "c", "c"], [0, 1, 0, 1, 0, 1]], names=["y", "x"] + ) + expected = pd.DataFrame( + { + "w": w.transpose().reshape(-1), + "cat": pd.Categorical(["a", "a", "a", "a", "c", "c"]), + }, + index=exp_index, + ) + actual = ds.to_dataframe(dim_order=new_order) + assert expected.equals(actual) + + invalid_order = ["x"] + with pytest.raises( + ValueError, match="does not match the set of dimensions of this" + ): + ds.to_dataframe(dim_order=invalid_order) + + invalid_order = ["x", "z"] + with pytest.raises( + ValueError, match="does not match the set of dimensions of this" + ): + ds.to_dataframe(dim_order=invalid_order) + + # check pathological cases + df = pd.DataFrame([1]) + actual = Dataset.from_dataframe(df) + expected = Dataset({0: ("index", [1])}, {"index": [0]}) + assert_identical(expected, actual) + + df = pd.DataFrame() + actual = Dataset.from_dataframe(df) + expected = Dataset(coords={"index": []}) + assert_identical(expected, actual) + + # GH697 + df = pd.DataFrame({"A": []}) + actual = Dataset.from_dataframe(df) + expected = Dataset({"A": DataArray([], dims=("index",))}, {"index": []}) + assert_identical(expected, actual) + + # regression test for GH278 + # use int64 to ensure consistent results for the pandas .equals method + # on windows (which requires the same dtype) + ds = Dataset({"x": pd.Index(["bar"]), "a": ("y", np.array([1], "int64"))}).isel( + x=0 + ) + # use .loc to ensure consistent results on Python 3 + actual = ds.to_dataframe().loc[:, ["a", "x"]] + expected = pd.DataFrame( + [[1, "bar"]], index=pd.Index([0], name="y"), columns=["a", "x"] + ) + assert expected.equals(actual), (expected, actual) + + ds = Dataset({"x": np.array([0], "int64"), "y": np.array([1], "int64")}) + actual = ds.to_dataframe() + idx = pd.MultiIndex.from_arrays([[0], [1]], names=["x", "y"]) + expected = pd.DataFrame([[]], index=idx) + assert expected.equals(actual), (expected, actual) + + def test_from_dataframe_categorical_index(self) -> None: + cat = pd.CategoricalDtype( + categories=["foo", "bar", "baz", "qux", "quux", "corge"] + ) + i1 = pd.Series(["foo", "bar", "foo"], dtype=cat) + i2 = pd.Series(["bar", "bar", "baz"], dtype=cat) + + df = pd.DataFrame({"i1": i1, "i2": i2, "values": [1, 2, 3]}) + ds = df.set_index("i1").to_xarray() + assert len(ds["i1"]) == 3 + + ds = df.set_index(["i1", "i2"]).to_xarray() + assert len(ds["i1"]) == 2 + assert len(ds["i2"]) == 2 + + def test_from_dataframe_categorical_index_string_categories(self) -> None: + cat = pd.CategoricalIndex( + pd.Categorical.from_codes( + np.array([1, 1, 0, 2]), + categories=pd.Index(["foo", "bar", "baz"], dtype="string"), + ) + ) + ser = pd.Series(1, index=cat) + ds = ser.to_xarray() + assert ds.coords.dtypes["index"] == np.dtype("O") + + @requires_sparse + def test_from_dataframe_sparse(self) -> None: + import sparse + + df_base = pd.DataFrame( + {"x": range(10), "y": list("abcdefghij"), "z": np.arange(0, 100, 10)} + ) + + ds_sparse = Dataset.from_dataframe(df_base.set_index("x"), sparse=True) + ds_dense = Dataset.from_dataframe(df_base.set_index("x"), sparse=False) + assert isinstance(ds_sparse["y"].data, sparse.COO) + assert isinstance(ds_sparse["z"].data, sparse.COO) + ds_sparse["y"].data = ds_sparse["y"].data.todense() + ds_sparse["z"].data = ds_sparse["z"].data.todense() + assert_identical(ds_dense, ds_sparse) + + ds_sparse = Dataset.from_dataframe(df_base.set_index(["x", "y"]), sparse=True) + ds_dense = Dataset.from_dataframe(df_base.set_index(["x", "y"]), sparse=False) + assert isinstance(ds_sparse["z"].data, sparse.COO) + ds_sparse["z"].data = ds_sparse["z"].data.todense() + assert_identical(ds_dense, ds_sparse) + + def test_to_and_from_empty_dataframe(self) -> None: + # GH697 + expected = pd.DataFrame({"foo": []}) + ds = Dataset.from_dataframe(expected) + assert len(ds["foo"]) == 0 + actual = ds.to_dataframe() + assert len(actual) == 0 + assert expected.equals(actual) + + def test_from_dataframe_multiindex(self) -> None: + index = pd.MultiIndex.from_product([["a", "b"], [1, 2, 3]], names=["x", "y"]) + df = pd.DataFrame({"z": np.arange(6)}, index=index) + + expected = Dataset( + {"z": (("x", "y"), [[0, 1, 2], [3, 4, 5]])}, + coords={"x": ["a", "b"], "y": [1, 2, 3]}, + ) + actual = Dataset.from_dataframe(df) + assert_identical(actual, expected) + + df2 = df.iloc[[3, 2, 1, 0, 4, 5], :] + actual = Dataset.from_dataframe(df2) + assert_identical(actual, expected) + + df3 = df.iloc[:4, :] + expected3 = Dataset( + {"z": (("x", "y"), [[0, 1, 2], [3, np.nan, np.nan]])}, + coords={"x": ["a", "b"], "y": [1, 2, 3]}, + ) + actual = Dataset.from_dataframe(df3) + assert_identical(actual, expected3) + + df_nonunique = df.iloc[[0, 0], :] + with pytest.raises(ValueError, match=r"non-unique MultiIndex"): + Dataset.from_dataframe(df_nonunique) + + def test_from_dataframe_unsorted_levels(self) -> None: + # regression test for GH-4186 + index = pd.MultiIndex( + levels=[["b", "a"], ["foo"]], codes=[[0, 1], [0, 0]], names=["lev1", "lev2"] + ) + df = pd.DataFrame({"c1": [0, 2], "c2": [1, 3]}, index=index) + expected = Dataset( + { + "c1": (("lev1", "lev2"), [[0], [2]]), + "c2": (("lev1", "lev2"), [[1], [3]]), + }, + coords={"lev1": ["b", "a"], "lev2": ["foo"]}, + ) + actual = Dataset.from_dataframe(df) + assert_identical(actual, expected) + + def test_from_dataframe_non_unique_columns(self) -> None: + # regression test for GH449 + df = pd.DataFrame(np.zeros((2, 2))) + df.columns = ["foo", "foo"] + with pytest.raises(ValueError, match=r"non-unique columns"): + Dataset.from_dataframe(df) + + def test_convert_dataframe_with_many_types_and_multiindex(self) -> None: + # regression test for GH737 + df = pd.DataFrame( + { + "a": list("abc"), + "b": list(range(1, 4)), + "c": np.arange(3, 6).astype("u1"), + "d": np.arange(4.0, 7.0, dtype="float64"), + "e": [True, False, True], + "f": pd.Categorical(list("abc")), + "g": pd.date_range("20130101", periods=3), + "h": pd.date_range("20130101", periods=3, tz="America/New_York"), + } + ) + df.index = pd.MultiIndex.from_product([["a"], range(3)], names=["one", "two"]) + roundtripped = Dataset.from_dataframe(df).to_dataframe() + # we can't do perfectly, but we should be at least as faithful as + # np.asarray + expected = df.apply(np.asarray) + assert roundtripped.equals(expected) + + @pytest.mark.parametrize("encoding", [True, False]) + @pytest.mark.parametrize("data", [True, "list", "array"]) + def test_to_and_from_dict( + self, encoding: bool, data: bool | Literal["list", "array"] + ) -> None: + # + # Dimensions: (t: 10) + # Coordinates: + # * t (t) U1" + expected_no_data["coords"]["t"].update({"dtype": endiantype, "shape": (10,)}) + expected_no_data["data_vars"]["a"].update({"dtype": "float64", "shape": (10,)}) + expected_no_data["data_vars"]["b"].update({"dtype": "float64", "shape": (10,)}) + actual_no_data = ds.to_dict(data=False, encoding=encoding) + assert expected_no_data == actual_no_data + + # verify coords are included roundtrip + expected_ds = ds.set_coords("b") + actual2 = Dataset.from_dict(expected_ds.to_dict(data=data, encoding=encoding)) + + assert_identical(expected_ds, actual2) + if encoding: + assert set(expected_ds.variables) == set(actual2.variables) + for vv in ds.variables: + np.testing.assert_equal(expected_ds[vv].encoding, actual2[vv].encoding) + + # test some incomplete dicts: + # this one has no attrs field, the dims are strings, and x, y are + # np.arrays + + d = { + "coords": {"t": {"dims": "t", "data": t}}, + "dims": "t", + "data_vars": {"a": {"dims": "t", "data": x}, "b": {"dims": "t", "data": y}}, + } + assert_identical(ds, Dataset.from_dict(d)) + + # this is kind of a flattened version with no coords, or data_vars + d = { + "a": {"dims": "t", "data": x}, + "t": {"data": t, "dims": "t"}, + "b": {"dims": "t", "data": y}, + } + assert_identical(ds, Dataset.from_dict(d)) + + # this one is missing some necessary information + d = { + "a": {"data": x}, + "t": {"data": t, "dims": "t"}, + "b": {"dims": "t", "data": y}, + } + with pytest.raises( + ValueError, match=r"cannot convert dict without the key 'dims'" + ): + Dataset.from_dict(d) + + def test_to_and_from_dict_with_time_dim(self) -> None: + x = np.random.randn(10, 3) + y = np.random.randn(10, 3) + t = pd.date_range("20130101", periods=10) + lat = [77.7, 83.2, 76] + ds = Dataset( + { + "a": (["t", "lat"], x), + "b": (["t", "lat"], y), + "t": ("t", t), + "lat": ("lat", lat), + } + ) + roundtripped = Dataset.from_dict(ds.to_dict()) + assert_identical(ds, roundtripped) + + @pytest.mark.parametrize("data", [True, "list", "array"]) + def test_to_and_from_dict_with_nan_nat( + self, data: bool | Literal["list", "array"] + ) -> None: + x = np.random.randn(10, 3) + y = np.random.randn(10, 3) + y[2] = np.nan + t = pd.Series(pd.date_range("20130101", periods=10)) + t[2] = np.nan + + lat = [77.7, 83.2, 76] + ds = Dataset( + { + "a": (["t", "lat"], x), + "b": (["t", "lat"], y), + "t": ("t", t), + "lat": ("lat", lat), + } + ) + roundtripped = Dataset.from_dict(ds.to_dict(data=data)) + assert_identical(ds, roundtripped) + + def test_to_dict_with_numpy_attrs(self) -> None: + # this doesn't need to roundtrip + x = np.random.randn(10) + y = np.random.randn(10) + t = list("abcdefghij") + attrs = { + "created": np.float64(1998), + "coords": np.array([37, -110.1, 100]), + "maintainer": "bar", + } + ds = Dataset({"a": ("t", x, attrs), "b": ("t", y, attrs), "t": ("t", t)}) + expected_attrs = { + "created": attrs["created"].item(), # type: ignore[attr-defined] + "coords": attrs["coords"].tolist(), # type: ignore[attr-defined] + "maintainer": "bar", + } + actual = ds.to_dict() + + # check that they are identical + assert expected_attrs == actual["data_vars"]["a"]["attrs"] + + def test_pickle(self) -> None: + data = create_test_data() + roundtripped = pickle.loads(pickle.dumps(data)) + assert_identical(data, roundtripped) + # regression test for #167: + assert data.sizes == roundtripped.sizes + + def test_lazy_load(self) -> None: + store = InaccessibleVariableDataStore() + create_test_data().dump_to_store(store) + + for decode_cf in [True, False]: + ds = open_dataset(store, decode_cf=decode_cf) + with pytest.raises(UnexpectedDataAccess): + ds.load() + with pytest.raises(UnexpectedDataAccess): + ds["var1"].values + + # these should not raise UnexpectedDataAccess: + ds.isel(time=10) + ds.isel(time=slice(10), dim1=[0]).isel(dim1=0, dim2=-1) + + def test_lazy_load_duck_array(self) -> None: + store = AccessibleAsDuckArrayDataStore() + create_test_data().dump_to_store(store) + + for decode_cf in [True, False]: + ds = open_dataset(store, decode_cf=decode_cf) + with pytest.raises(UnexpectedDataAccess): + ds["var1"].values + + # these should not raise UnexpectedDataAccess: + ds.var1.data + ds.isel(time=10) + ds.isel(time=slice(10), dim1=[0]).isel(dim1=0, dim2=-1) + repr(ds) + + # preserve the duck array type and don't cast to array + assert isinstance(ds["var1"].load().data, DuckArrayWrapper) + assert isinstance( + ds["var1"].isel(dim2=0, dim1=0).load().data, DuckArrayWrapper + ) + + ds.close() + + def test_dropna(self) -> None: + x = np.random.randn(4, 4) + x[::2, 0] = np.nan + y = np.random.randn(4) + y[-1] = np.nan + ds = Dataset({"foo": (("a", "b"), x), "bar": (("b", y))}) + + expected = ds.isel(a=slice(1, None, 2)) + actual = ds.dropna("a") + assert_identical(actual, expected) + + expected = ds.isel(b=slice(1, 3)) + actual = ds.dropna("b") + assert_identical(actual, expected) + + actual = ds.dropna("b", subset=["foo", "bar"]) + assert_identical(actual, expected) + + expected = ds.isel(b=slice(1, None)) + actual = ds.dropna("b", subset=["foo"]) + assert_identical(actual, expected) + + expected = ds.isel(b=slice(3)) + actual = ds.dropna("b", subset=["bar"]) + assert_identical(actual, expected) + + actual = ds.dropna("a", subset=[]) + assert_identical(actual, ds) + + actual = ds.dropna("a", subset=["bar"]) + assert_identical(actual, ds) + + actual = ds.dropna("a", how="all") + assert_identical(actual, ds) + + actual = ds.dropna("b", how="all", subset=["bar"]) + expected = ds.isel(b=[0, 1, 2]) + assert_identical(actual, expected) + + actual = ds.dropna("b", thresh=1, subset=["bar"]) + assert_identical(actual, expected) + + actual = ds.dropna("b", thresh=2) + assert_identical(actual, ds) + + actual = ds.dropna("b", thresh=4) + expected = ds.isel(b=[1, 2, 3]) + assert_identical(actual, expected) + + actual = ds.dropna("a", thresh=3) + expected = ds.isel(a=[1, 3]) + assert_identical(actual, ds) + + with pytest.raises( + ValueError, + match=r"'foo' not found in data dimensions \('a', 'b'\)", + ): + ds.dropna("foo") + with pytest.raises(ValueError, match=r"invalid how"): + ds.dropna("a", how="somehow") # type: ignore[arg-type] + with pytest.raises(TypeError, match=r"must specify how or thresh"): + ds.dropna("a", how=None) # type: ignore[arg-type] + + def test_fillna(self) -> None: + ds = Dataset({"a": ("x", [np.nan, 1, np.nan, 3])}, {"x": [0, 1, 2, 3]}) + + # fill with -1 + actual1 = ds.fillna(-1) + expected = Dataset({"a": ("x", [-1, 1, -1, 3])}, {"x": [0, 1, 2, 3]}) + assert_identical(expected, actual1) + + actual2 = ds.fillna({"a": -1}) + assert_identical(expected, actual2) + + other = Dataset({"a": -1}) + actual3 = ds.fillna(other) + assert_identical(expected, actual3) + + actual4 = ds.fillna({"a": other.a}) + assert_identical(expected, actual4) + + # fill with range(4) + b = DataArray(range(4), coords=[("x", range(4))]) + actual5 = ds.fillna(b) + expected = b.rename("a").to_dataset() + assert_identical(expected, actual5) + + actual6 = ds.fillna(expected) + assert_identical(expected, actual6) + + actual7 = ds.fillna(np.arange(4)) + assert_identical(expected, actual7) + + actual8 = ds.fillna(b[:3]) + assert_identical(expected, actual8) + + # okay to only include some data variables + ds["b"] = np.nan + actual9 = ds.fillna({"a": -1}) + expected = Dataset( + {"a": ("x", [-1, 1, -1, 3]), "b": np.nan}, {"x": [0, 1, 2, 3]} + ) + assert_identical(expected, actual9) + + # but new data variables is not okay + with pytest.raises(ValueError, match=r"must be contained"): + ds.fillna({"x": 0}) + + # empty argument should be OK + result1 = ds.fillna({}) + assert_identical(ds, result1) + + result2 = ds.fillna(Dataset(coords={"c": 42})) + expected = ds.assign_coords(c=42) + assert_identical(expected, result2) + + da = DataArray(range(5), name="a", attrs={"attr": "da"}) + actual10 = da.fillna(1) + assert actual10.name == "a" + assert actual10.attrs == da.attrs + + ds = Dataset({"a": da}, attrs={"attr": "ds"}) + actual11 = ds.fillna({"a": 1}) + assert actual11.attrs == ds.attrs + assert actual11.a.name == "a" + assert actual11.a.attrs == ds.a.attrs + + @pytest.mark.parametrize( + "func", [lambda x: x.clip(0, 1), lambda x: np.float64(1.0) * x, np.abs, abs] + ) + def test_propagate_attrs(self, func) -> None: + da = DataArray(range(5), name="a", attrs={"attr": "da"}) + ds = Dataset({"a": da}, attrs={"attr": "ds"}) + + # test defaults + assert func(ds).attrs == ds.attrs + with set_options(keep_attrs=False): + assert func(ds).attrs != ds.attrs + assert func(ds).a.attrs != ds.a.attrs + + with set_options(keep_attrs=False): + assert func(ds).attrs != ds.attrs + assert func(ds).a.attrs != ds.a.attrs + + with set_options(keep_attrs=True): + assert func(ds).attrs == ds.attrs + assert func(ds).a.attrs == ds.a.attrs + + def test_where(self) -> None: + ds = Dataset({"a": ("x", range(5))}) + expected1 = Dataset({"a": ("x", [np.nan, np.nan, 2, 3, 4])}) + actual1 = ds.where(ds > 1) + assert_identical(expected1, actual1) + + actual2 = ds.where(ds.a > 1) + assert_identical(expected1, actual2) + + actual3 = ds.where(ds.a.values > 1) + assert_identical(expected1, actual3) + + actual4 = ds.where(True) + assert_identical(ds, actual4) + + expected5 = ds.copy(deep=True) + expected5["a"].values = np.array([np.nan] * 5) + actual5 = ds.where(False) + assert_identical(expected5, actual5) + + # 2d + ds = Dataset({"a": (("x", "y"), [[0, 1], [2, 3]])}) + expected6 = Dataset({"a": (("x", "y"), [[np.nan, 1], [2, 3]])}) + actual6 = ds.where(ds > 0) + assert_identical(expected6, actual6) + + # attrs + da = DataArray(range(5), name="a", attrs={"attr": "da"}) + actual7 = da.where(da.values > 1) + assert actual7.name == "a" + assert actual7.attrs == da.attrs + + ds = Dataset({"a": da}, attrs={"attr": "ds"}) + actual8 = ds.where(ds > 0) + assert actual8.attrs == ds.attrs + assert actual8.a.name == "a" + assert actual8.a.attrs == ds.a.attrs + + # lambda + ds = Dataset({"a": ("x", range(5))}) + expected9 = Dataset({"a": ("x", [np.nan, np.nan, 2, 3, 4])}) + actual9 = ds.where(lambda x: x > 1) + assert_identical(expected9, actual9) + + def test_where_other(self) -> None: + ds = Dataset({"a": ("x", range(5))}, {"x": range(5)}) + expected = Dataset({"a": ("x", [-1, -1, 2, 3, 4])}, {"x": range(5)}) + actual = ds.where(ds > 1, -1) + assert_equal(expected, actual) + assert actual.a.dtype == int + + actual = ds.where(lambda x: x > 1, -1) + assert_equal(expected, actual) + + actual = ds.where(ds > 1, other=-1, drop=True) + expected_nodrop = ds.where(ds > 1, -1) + _, expected = xr.align(actual, expected_nodrop, join="left") + assert_equal(actual, expected) + assert actual.a.dtype == int + + with pytest.raises(ValueError, match=r"cannot align .* are not equal"): + ds.where(ds > 1, ds.isel(x=slice(3))) + + with pytest.raises(ValueError, match=r"exact match required"): + ds.where(ds > 1, ds.assign(b=2)) + + def test_where_drop(self) -> None: + # if drop=True + + # 1d + # data array case + array = DataArray(range(5), coords=[range(5)], dims=["x"]) + expected1 = DataArray(range(5)[2:], coords=[range(5)[2:]], dims=["x"]) + actual1 = array.where(array > 1, drop=True) + assert_identical(expected1, actual1) + + # dataset case + ds = Dataset({"a": array}) + expected2 = Dataset({"a": expected1}) + + actual2 = ds.where(ds > 1, drop=True) + assert_identical(expected2, actual2) + + actual3 = ds.where(ds.a > 1, drop=True) + assert_identical(expected2, actual3) + + with pytest.raises(TypeError, match=r"must be a"): + ds.where(np.arange(5) > 1, drop=True) + + # 1d with odd coordinates + array = DataArray( + np.array([2, 7, 1, 8, 3]), coords=[np.array([3, 1, 4, 5, 9])], dims=["x"] + ) + expected4 = DataArray( + np.array([7, 8, 3]), coords=[np.array([1, 5, 9])], dims=["x"] + ) + actual4 = array.where(array > 2, drop=True) + assert_identical(expected4, actual4) + + # 1d multiple variables + ds = Dataset({"a": (("x"), [0, 1, 2, 3]), "b": (("x"), [4, 5, 6, 7])}) + expected5 = Dataset( + {"a": (("x"), [np.nan, 1, 2, 3]), "b": (("x"), [4, 5, 6, np.nan])} + ) + actual5 = ds.where((ds > 0) & (ds < 7), drop=True) + assert_identical(expected5, actual5) + + # 2d + ds = Dataset({"a": (("x", "y"), [[0, 1], [2, 3]])}) + expected6 = Dataset({"a": (("x", "y"), [[np.nan, 1], [2, 3]])}) + actual6 = ds.where(ds > 0, drop=True) + assert_identical(expected6, actual6) + + # 2d with odd coordinates + ds = Dataset( + {"a": (("x", "y"), [[0, 1], [2, 3]])}, + coords={ + "x": [4, 3], + "y": [1, 2], + "z": (["x", "y"], [[np.e, np.pi], [np.pi * np.e, np.pi * 3]]), + }, + ) + expected7 = Dataset( + {"a": (("x", "y"), [[3]])}, + coords={"x": [3], "y": [2], "z": (["x", "y"], [[np.pi * 3]])}, + ) + actual7 = ds.where(ds > 2, drop=True) + assert_identical(expected7, actual7) + + # 2d multiple variables + ds = Dataset( + {"a": (("x", "y"), [[0, 1], [2, 3]]), "b": (("x", "y"), [[4, 5], [6, 7]])} + ) + expected8 = Dataset( + { + "a": (("x", "y"), [[np.nan, 1], [2, 3]]), + "b": (("x", "y"), [[4, 5], [6, 7]]), + } + ) + actual8 = ds.where(ds > 0, drop=True) + assert_identical(expected8, actual8) + + # mixed dimensions: PR#6690, Issue#6227 + ds = xr.Dataset( + { + "a": ("x", [1, 2, 3]), + "b": ("y", [2, 3, 4]), + "c": (("x", "y"), np.arange(9).reshape((3, 3))), + } + ) + expected9 = xr.Dataset( + { + "a": ("x", [np.nan, 3]), + "b": ("y", [np.nan, 3, 4]), + "c": (("x", "y"), np.arange(3.0, 9.0).reshape((2, 3))), + } + ) + actual9 = ds.where(ds > 2, drop=True) + assert actual9.sizes["x"] == 2 + assert_identical(expected9, actual9) + + def test_where_drop_empty(self) -> None: + # regression test for GH1341 + array = DataArray(np.random.rand(100, 10), dims=["nCells", "nVertLevels"]) + mask = DataArray(np.zeros((100,), dtype="bool"), dims="nCells") + actual = array.where(mask, drop=True) + expected = DataArray(np.zeros((0, 10)), dims=["nCells", "nVertLevels"]) + assert_identical(expected, actual) + + def test_where_drop_no_indexes(self) -> None: + ds = Dataset({"foo": ("x", [0.0, 1.0])}) + expected = Dataset({"foo": ("x", [1.0])}) + actual = ds.where(ds == 1, drop=True) + assert_identical(expected, actual) + + def test_reduce(self) -> None: + data = create_test_data() + + assert len(data.mean().coords) == 0 + + actual = data.max() + expected = Dataset({k: v.max() for k, v in data.data_vars.items()}) + assert_equal(expected, actual) + + assert_equal(data.min(dim=["dim1"]), data.min(dim="dim1")) + + for reduct, expected_dims in [ + ("dim2", ["dim3", "time", "dim1"]), + (["dim2", "time"], ["dim3", "dim1"]), + (("dim2", "time"), ["dim3", "dim1"]), + ((), ["dim2", "dim3", "time", "dim1"]), + ]: + actual_dims = list(data.min(dim=reduct).dims) + assert actual_dims == expected_dims + + assert_equal(data.mean(dim=[]), data) + + with pytest.raises(ValueError): + data.mean(axis=0) + + def test_reduce_coords(self) -> None: + # regression test for GH1470 + data = xr.Dataset({"a": ("x", [1, 2, 3])}, coords={"b": 4}) + expected = xr.Dataset({"a": 2}, coords={"b": 4}) + actual = data.mean("x") + assert_identical(actual, expected) + + # should be consistent + actual = data["a"].mean("x").to_dataset() + assert_identical(actual, expected) + + def test_mean_uint_dtype(self) -> None: + data = xr.Dataset( + { + "a": (("x", "y"), np.arange(6).reshape(3, 2).astype("uint")), + "b": (("x",), np.array([0.1, 0.2, np.nan])), + } + ) + actual = data.mean("x", skipna=True) + expected = xr.Dataset( + {"a": data["a"].mean("x"), "b": data["b"].mean("x", skipna=True)} + ) + assert_identical(actual, expected) + + def test_reduce_bad_dim(self) -> None: + data = create_test_data() + with pytest.raises( + ValueError, + match=r"Dimensions \('bad_dim',\) not found in data dimensions", + ): + data.mean(dim="bad_dim") + + def test_reduce_cumsum(self) -> None: + data = xr.Dataset( + {"a": 1, "b": ("x", [1, 2]), "c": (("x", "y"), [[np.nan, 3], [0, 4]])} + ) + assert_identical(data.fillna(0), data.cumsum("y")) + + expected = xr.Dataset( + {"a": 1, "b": ("x", [1, 3]), "c": (("x", "y"), [[0, 3], [0, 7]])} + ) + assert_identical(expected, data.cumsum()) + + @pytest.mark.parametrize( + "reduct, expected", + [ + ("dim1", ["dim2", "dim3", "time", "dim1"]), + ("dim2", ["dim3", "time", "dim1", "dim2"]), + ("dim3", ["dim2", "time", "dim1", "dim3"]), + ("time", ["dim2", "dim3", "dim1"]), + ], + ) + @pytest.mark.parametrize("func", ["cumsum", "cumprod"]) + def test_reduce_cumsum_test_dims(self, reduct, expected, func) -> None: + data = create_test_data() + with pytest.raises( + ValueError, + match=r"Dimensions \('bad_dim',\) not found in data dimensions", + ): + getattr(data, func)(dim="bad_dim") + + # ensure dimensions are correct + actual = getattr(data, func)(dim=reduct).dims + assert list(actual) == expected + + def test_reduce_non_numeric(self) -> None: + data1 = create_test_data(seed=44, use_extension_array=True) + data2 = create_test_data(seed=44) + add_vars = {"var5": ["dim1", "dim2"], "var6": ["dim1"]} + for v, dims in sorted(add_vars.items()): + size = tuple(data1.sizes[d] for d in dims) + data = np.random.randint(0, 100, size=size).astype(np.str_) + data1[v] = (dims, data, {"foo": "variable"}) + # var4 is extension array categorical and should be dropped + assert ( + "var4" not in data1.mean() + and "var5" not in data1.mean() + and "var6" not in data1.mean() + ) + assert_equal(data1.mean(), data2.mean()) + assert_equal(data1.mean(dim="dim1"), data2.mean(dim="dim1")) + assert "var5" not in data1.mean(dim="dim2") and "var6" in data1.mean(dim="dim2") + + @pytest.mark.filterwarnings( + "ignore:Once the behaviour of DataArray:DeprecationWarning" + ) + def test_reduce_strings(self) -> None: + expected = Dataset({"x": "a"}) + ds = Dataset({"x": ("y", ["a", "b"])}) + ds.coords["y"] = [-10, 10] + actual = ds.min() + assert_identical(expected, actual) + + expected = Dataset({"x": "b"}) + actual = ds.max() + assert_identical(expected, actual) + + expected = Dataset({"x": 0}) + actual = ds.argmin() + assert_identical(expected, actual) + + expected = Dataset({"x": 1}) + actual = ds.argmax() + assert_identical(expected, actual) + + expected = Dataset({"x": -10}) + actual = ds.idxmin() + assert_identical(expected, actual) + + expected = Dataset({"x": 10}) + actual = ds.idxmax() + assert_identical(expected, actual) + + expected = Dataset({"x": b"a"}) + ds = Dataset({"x": ("y", np.array(["a", "b"], "S1"))}) + actual = ds.min() + assert_identical(expected, actual) + + expected = Dataset({"x": "a"}) + ds = Dataset({"x": ("y", np.array(["a", "b"], "U1"))}) + actual = ds.min() + assert_identical(expected, actual) + + def test_reduce_dtypes(self) -> None: + # regression test for GH342 + expected = Dataset({"x": 1}) + actual = Dataset({"x": True}).sum() + assert_identical(expected, actual) + + # regression test for GH505 + expected = Dataset({"x": 3}) + actual = Dataset({"x": ("y", np.array([1, 2], "uint16"))}).sum() + assert_identical(expected, actual) + + expected = Dataset({"x": 1 + 1j}) + actual = Dataset({"x": ("y", [1, 1j])}).sum() + assert_identical(expected, actual) + + def test_reduce_keep_attrs(self) -> None: + data = create_test_data() + _attrs = {"attr1": "value1", "attr2": 2929} + + attrs = dict(_attrs) + data.attrs = attrs + + # Test dropped attrs + ds = data.mean() + assert ds.attrs == {} + for v in ds.data_vars.values(): + assert v.attrs == {} + + # Test kept attrs + ds = data.mean(keep_attrs=True) + assert ds.attrs == attrs + for k, v in ds.data_vars.items(): + assert v.attrs == data[k].attrs + + @pytest.mark.filterwarnings( + "ignore:Once the behaviour of DataArray:DeprecationWarning" + ) + def test_reduce_argmin(self) -> None: + # regression test for #205 + ds = Dataset({"a": ("x", [0, 1])}) + expected = Dataset({"a": ([], 0)}) + actual = ds.argmin() + assert_identical(expected, actual) + + actual = ds.argmin("x") + assert_identical(expected, actual) + + def test_reduce_scalars(self) -> None: + ds = Dataset({"x": ("a", [2, 2]), "y": 2, "z": ("b", [2])}) + expected = Dataset({"x": 0, "y": 0, "z": 0}) + actual = ds.var() + assert_identical(expected, actual) + + expected = Dataset({"x": 0, "y": 0, "z": ("b", [0])}) + actual = ds.var("a") + assert_identical(expected, actual) + + def test_reduce_only_one_axis(self) -> None: + def mean_only_one_axis(x, axis): + if not isinstance(axis, integer_types): + raise TypeError("non-integer axis") + return x.mean(axis) + + ds = Dataset({"a": (["x", "y"], [[0, 1, 2, 3, 4]])}) + expected = Dataset({"a": ("x", [2])}) + actual = ds.reduce(mean_only_one_axis, "y") + assert_identical(expected, actual) + + with pytest.raises( + TypeError, match=r"missing 1 required positional argument: 'axis'" + ): + ds.reduce(mean_only_one_axis) + + def test_reduce_no_axis(self) -> None: + def total_sum(x): + return np.sum(x.flatten()) + + ds = Dataset({"a": (["x", "y"], [[0, 1, 2, 3, 4]])}) + expected = Dataset({"a": ((), 10)}) + actual = ds.reduce(total_sum) + assert_identical(expected, actual) + + with pytest.raises(TypeError, match=r"unexpected keyword argument 'axis'"): + ds.reduce(total_sum, dim="x") + + def test_reduce_keepdims(self) -> None: + ds = Dataset( + {"a": (["x", "y"], [[0, 1, 2, 3, 4]])}, + coords={ + "y": [0, 1, 2, 3, 4], + "x": [0], + "lat": (["x", "y"], [[0, 1, 2, 3, 4]]), + "c": -999.0, + }, + ) + + # Shape should match behaviour of numpy reductions with keepdims=True + # Coordinates involved in the reduction should be removed + actual = ds.mean(keepdims=True) + expected = Dataset( + {"a": (["x", "y"], np.mean(ds.a, keepdims=True).data)}, coords={"c": ds.c} + ) + assert_identical(expected, actual) + + actual = ds.mean("x", keepdims=True) + expected = Dataset( + {"a": (["x", "y"], np.mean(ds.a, axis=0, keepdims=True).data)}, + coords={"y": ds.y, "c": ds.c}, + ) + assert_identical(expected, actual) + + @pytest.mark.parametrize("compute_backend", ["numbagg", None], indirect=True) + @pytest.mark.parametrize("skipna", [True, False, None]) + @pytest.mark.parametrize("q", [0.25, [0.50], [0.25, 0.75]]) + def test_quantile(self, q, skipna, compute_backend) -> None: + ds = create_test_data(seed=123) + ds.var1.data[0, 0] = np.nan + + for dim in [None, "dim1", ["dim1"]]: + ds_quantile = ds.quantile(q, dim=dim, skipna=skipna) + if is_scalar(q): + assert "quantile" not in ds_quantile.dims + else: + assert "quantile" in ds_quantile.dims + + for var, dar in ds.data_vars.items(): + assert var in ds_quantile + assert_identical( + ds_quantile[var], dar.quantile(q, dim=dim, skipna=skipna) + ) + dim = ["dim1", "dim2"] + ds_quantile = ds.quantile(q, dim=dim, skipna=skipna) + assert "dim3" in ds_quantile.dims + assert all(d not in ds_quantile.dims for d in dim) + + @pytest.mark.parametrize("compute_backend", ["numbagg", None], indirect=True) + @pytest.mark.parametrize("skipna", [True, False]) + def test_quantile_skipna(self, skipna, compute_backend) -> None: + q = 0.1 + dim = "time" + ds = Dataset({"a": ([dim], np.arange(0, 11))}) + ds = ds.where(ds >= 1) + + result = ds.quantile(q=q, dim=dim, skipna=skipna) + + value = 1.9 if skipna else np.nan + expected = Dataset({"a": value}, coords={"quantile": q}) + assert_identical(result, expected) + + @pytest.mark.parametrize("method", ["midpoint", "lower"]) + def test_quantile_method(self, method) -> None: + ds = create_test_data(seed=123) + q = [0.25, 0.5, 0.75] + + result = ds.quantile(q, method=method) + + assert_identical(result.var1, ds.var1.quantile(q, method=method)) + assert_identical(result.var2, ds.var2.quantile(q, method=method)) + assert_identical(result.var3, ds.var3.quantile(q, method=method)) + + @pytest.mark.parametrize("method", ["midpoint", "lower"]) + def test_quantile_interpolation_deprecated(self, method) -> None: + ds = create_test_data(seed=123) + q = [0.25, 0.5, 0.75] + + with warnings.catch_warnings(record=True) as w: + ds.quantile(q, interpolation=method) + + # ensure the warning is only raised once + assert len(w) == 1 + + with warnings.catch_warnings(record=True): + with pytest.raises(TypeError, match="interpolation and method keywords"): + ds.quantile(q, method=method, interpolation=method) + + @requires_bottleneck + def test_rank(self) -> None: + ds = create_test_data(seed=1234) + # only ds.var3 depends on dim3 + z = ds.rank("dim3") + assert ["var3"] == list(z.data_vars) + # same as dataarray version + x = z.var3 + y = ds.var3.rank("dim3") + assert_equal(x, y) + # coordinates stick + assert list(z.coords) == list(ds.coords) + assert list(x.coords) == list(y.coords) + # invalid dim + with pytest.raises( + ValueError, + match=re.escape( + "Dimension 'invalid_dim' not found in data dimensions ('dim3', 'dim1')" + ), + ): + x.rank("invalid_dim") + + def test_rank_use_bottleneck(self) -> None: + ds = Dataset({"a": ("x", [0, np.nan, 2]), "b": ("y", [4, 6, 3, 4])}) + with xr.set_options(use_bottleneck=False): + with pytest.raises(RuntimeError): + ds.rank("x") + + def test_count(self) -> None: + ds = Dataset({"x": ("a", [np.nan, 1]), "y": 0, "z": np.nan}) + expected = Dataset({"x": 1, "y": 1, "z": 0}) + actual = ds.count() + assert_identical(expected, actual) + + def test_map(self) -> None: + data = create_test_data() + data.attrs["foo"] = "bar" + + assert_identical(data.map(np.mean), data.mean()) + + expected = data.mean(keep_attrs=True) + actual = data.map(lambda x: x.mean(keep_attrs=True), keep_attrs=True) + assert_identical(expected, actual) + + assert_identical(data.map(lambda x: x, keep_attrs=True), data.drop_vars("time")) + + def scale(x, multiple=1): + return multiple * x + + actual = data.map(scale, multiple=2) + assert_equal(actual["var1"], 2 * data["var1"]) + assert_identical(actual["numbers"], data["numbers"]) + + actual = data.map(np.asarray) + expected = data.drop_vars("time") # time is not used on a data var + assert_equal(expected, actual) + + def test_apply_pending_deprecated_map(self) -> None: + data = create_test_data() + data.attrs["foo"] = "bar" + + with pytest.warns(PendingDeprecationWarning): + assert_identical(data.apply(np.mean), data.mean()) + + def make_example_math_dataset(self): + variables = { + "bar": ("x", np.arange(100, 400, 100)), + "foo": (("x", "y"), 1.0 * np.arange(12).reshape(3, 4)), + } + coords = {"abc": ("x", ["a", "b", "c"]), "y": 10 * np.arange(4)} + ds = Dataset(variables, coords) + ds["foo"][0, 0] = np.nan + return ds + + def test_dataset_number_math(self) -> None: + ds = self.make_example_math_dataset() + + assert_identical(ds, +ds) + assert_identical(ds, ds + 0) + assert_identical(ds, 0 + ds) + assert_identical(ds, ds + np.array(0)) + assert_identical(ds, np.array(0) + ds) + + actual = ds.copy(deep=True) + actual += 0 + assert_identical(ds, actual) + + def test_unary_ops(self) -> None: + ds = self.make_example_math_dataset() + + assert_identical(ds.map(abs), abs(ds)) + assert_identical(ds.map(lambda x: x + 4), ds + 4) + + for func in [ + lambda x: x.isnull(), + lambda x: x.round(), + lambda x: x.astype(int), + ]: + assert_identical(ds.map(func), func(ds)) + + assert_identical(ds.isnull(), ~ds.notnull()) + + # don't actually patch these methods in + with pytest.raises(AttributeError): + ds.item + with pytest.raises(AttributeError): + ds.searchsorted + + def test_dataset_array_math(self) -> None: + ds = self.make_example_math_dataset() + + expected = ds.map(lambda x: x - ds["foo"]) + assert_identical(expected, ds - ds["foo"]) + assert_identical(expected, -ds["foo"] + ds) + assert_identical(expected, ds - ds["foo"].variable) + assert_identical(expected, -ds["foo"].variable + ds) + actual = ds.copy(deep=True) + actual -= ds["foo"] + assert_identical(expected, actual) + + expected = ds.map(lambda x: x + ds["bar"]) + assert_identical(expected, ds + ds["bar"]) + actual = ds.copy(deep=True) + actual += ds["bar"] + assert_identical(expected, actual) + + expected = Dataset({"bar": ds["bar"] + np.arange(3)}) + assert_identical(expected, ds[["bar"]] + np.arange(3)) + assert_identical(expected, np.arange(3) + ds[["bar"]]) + + def test_dataset_dataset_math(self) -> None: + ds = self.make_example_math_dataset() + + assert_identical(ds, ds + 0 * ds) + assert_identical(ds, ds + {"foo": 0, "bar": 0}) + + expected = ds.map(lambda x: 2 * x) + assert_identical(expected, 2 * ds) + assert_identical(expected, ds + ds) + assert_identical(expected, ds + ds.data_vars) + assert_identical(expected, ds + dict(ds.data_vars)) + + actual = ds.copy(deep=True) + expected_id = id(actual) + actual += ds + assert_identical(expected, actual) + assert expected_id == id(actual) + + assert_identical(ds == ds, ds.notnull()) + + subsampled = ds.isel(y=slice(2)) + expected = 2 * subsampled + assert_identical(expected, subsampled + ds) + assert_identical(expected, ds + subsampled) + + def test_dataset_math_auto_align(self) -> None: + ds = self.make_example_math_dataset() + subset = ds.isel(y=[1, 3]) + expected = 2 * subset + actual = ds + subset + assert_identical(expected, actual) + + actual = ds.isel(y=slice(1)) + ds.isel(y=slice(1, None)) + expected = 2 * ds.drop_sel(y=ds.y) + assert_equal(actual, expected) + + actual = ds + ds[["bar"]] + expected = (2 * ds[["bar"]]).merge(ds.coords) + assert_identical(expected, actual) + + assert_identical(ds + Dataset(), ds.coords.to_dataset()) + assert_identical(Dataset() + Dataset(), Dataset()) + + ds2 = Dataset(coords={"bar": 42}) + assert_identical(ds + ds2, ds.coords.merge(ds2)) + + # maybe unary arithmetic with empty datasets should raise instead? + assert_identical(Dataset() + 1, Dataset()) + + actual = ds.copy(deep=True) + other = ds.isel(y=slice(2)) + actual += other + expected = ds + other.reindex_like(ds) + assert_identical(expected, actual) + + @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") + def test_dataset_math_errors(self) -> None: + ds = self.make_example_math_dataset() + + with pytest.raises(TypeError): + ds["foo"] += ds + with pytest.raises(TypeError): + ds["foo"].variable += ds + with pytest.raises(ValueError, match=r"must have the same"): + ds += ds[["bar"]] + + # verify we can rollback in-place operations if something goes wrong + # nb. inplace datetime64 math actually will work with an integer array + # but not floats thanks to numpy's inconsistent handling + other = DataArray(np.datetime64("2000-01-01"), coords={"c": 2}) + actual = ds.copy(deep=True) + with pytest.raises(TypeError): + actual += other + assert_identical(actual, ds) + + def test_dataset_transpose(self) -> None: + ds = Dataset( + { + "a": (("x", "y"), np.random.randn(3, 4)), + "b": (("y", "x"), np.random.randn(4, 3)), + }, + coords={ + "x": range(3), + "y": range(4), + "xy": (("x", "y"), np.random.randn(3, 4)), + }, + ) + + actual = ds.transpose() + expected = Dataset( + {"a": (("y", "x"), ds.a.values.T), "b": (("x", "y"), ds.b.values.T)}, + coords={ + "x": ds.x.values, + "y": ds.y.values, + "xy": (("y", "x"), ds.xy.values.T), + }, + ) + assert_identical(expected, actual) + + actual = ds.transpose(...) + expected = ds + assert_identical(expected, actual) + + actual = ds.transpose("x", "y") + expected = ds.map(lambda x: x.transpose("x", "y", transpose_coords=True)) + assert_identical(expected, actual) + + ds = create_test_data() + actual = ds.transpose() + for k in ds.variables: + assert actual[k].dims[::-1] == ds[k].dims + + new_order = ("dim2", "dim3", "dim1", "time") + actual = ds.transpose(*new_order) + for k in ds.variables: + expected_dims = tuple(d for d in new_order if d in ds[k].dims) + assert actual[k].dims == expected_dims + + # same as above but with ellipsis + new_order = ("dim2", "dim3", "dim1", "time") + actual = ds.transpose("dim2", "dim3", ...) + for k in ds.variables: + expected_dims = tuple(d for d in new_order if d in ds[k].dims) + assert actual[k].dims == expected_dims + + # test missing dimension, raise error + with pytest.raises(ValueError): + ds.transpose(..., "not_a_dim") + + # test missing dimension, ignore error + actual = ds.transpose(..., "not_a_dim", missing_dims="ignore") + expected_ell = ds.transpose(...) + assert_identical(expected_ell, actual) + + # test missing dimension, raise warning + with pytest.warns(UserWarning): + actual = ds.transpose(..., "not_a_dim", missing_dims="warn") + assert_identical(expected_ell, actual) + + assert "T" not in dir(ds) + + def test_dataset_ellipsis_transpose_different_ordered_vars(self) -> None: + # https://github.com/pydata/xarray/issues/1081#issuecomment-544350457 + ds = Dataset( + dict( + a=(("w", "x", "y", "z"), np.ones((2, 3, 4, 5))), + b=(("x", "w", "y", "z"), np.zeros((3, 2, 4, 5))), + ) + ) + result = ds.transpose(..., "z", "y") + assert list(result["a"].dims) == list("wxzy") + assert list(result["b"].dims) == list("xwzy") + + def test_dataset_retains_period_index_on_transpose(self) -> None: + ds = create_test_data() + ds["time"] = pd.period_range("2000-01-01", periods=20) + + transposed = ds.transpose() + + assert isinstance(transposed.time.to_index(), pd.PeriodIndex) + + def test_dataset_diff_n1_simple(self) -> None: + ds = Dataset({"foo": ("x", [5, 5, 6, 6])}) + actual = ds.diff("x") + expected = Dataset({"foo": ("x", [0, 1, 0])}) + assert_equal(expected, actual) + + def test_dataset_diff_n1_label(self) -> None: + ds = Dataset({"foo": ("x", [5, 5, 6, 6])}, {"x": [0, 1, 2, 3]}) + actual = ds.diff("x", label="lower") + expected = Dataset({"foo": ("x", [0, 1, 0])}, {"x": [0, 1, 2]}) + assert_equal(expected, actual) + + actual = ds.diff("x", label="upper") + expected = Dataset({"foo": ("x", [0, 1, 0])}, {"x": [1, 2, 3]}) + assert_equal(expected, actual) + + def test_dataset_diff_n1(self) -> None: + ds = create_test_data(seed=1) + actual = ds.diff("dim2") + expected_dict = {} + expected_dict["var1"] = DataArray( + np.diff(ds["var1"].values, axis=1), + {"dim2": ds["dim2"].values[1:]}, + ["dim1", "dim2"], + ) + expected_dict["var2"] = DataArray( + np.diff(ds["var2"].values, axis=1), + {"dim2": ds["dim2"].values[1:]}, + ["dim1", "dim2"], + ) + expected_dict["var3"] = ds["var3"] + expected = Dataset(expected_dict, coords={"time": ds["time"].values}) + expected.coords["numbers"] = ("dim3", ds["numbers"].values) + assert_equal(expected, actual) + + def test_dataset_diff_n2(self) -> None: + ds = create_test_data(seed=1) + actual = ds.diff("dim2", n=2) + expected_dict = {} + expected_dict["var1"] = DataArray( + np.diff(ds["var1"].values, axis=1, n=2), + {"dim2": ds["dim2"].values[2:]}, + ["dim1", "dim2"], + ) + expected_dict["var2"] = DataArray( + np.diff(ds["var2"].values, axis=1, n=2), + {"dim2": ds["dim2"].values[2:]}, + ["dim1", "dim2"], + ) + expected_dict["var3"] = ds["var3"] + expected = Dataset(expected_dict, coords={"time": ds["time"].values}) + expected.coords["numbers"] = ("dim3", ds["numbers"].values) + assert_equal(expected, actual) + + def test_dataset_diff_exception_n_neg(self) -> None: + ds = create_test_data(seed=1) + with pytest.raises(ValueError, match=r"must be non-negative"): + ds.diff("dim2", n=-1) + + def test_dataset_diff_exception_label_str(self) -> None: + ds = create_test_data(seed=1) + with pytest.raises(ValueError, match=r"'label' argument has to"): + ds.diff("dim2", label="raise_me") # type: ignore[arg-type] + + @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0, {"foo": -10}]) + def test_shift(self, fill_value) -> None: + coords = {"bar": ("x", list("abc")), "x": [-4, 3, 2]} + attrs = {"meta": "data"} + ds = Dataset({"foo": ("x", [1, 2, 3])}, coords, attrs) + actual = ds.shift(x=1, fill_value=fill_value) + if fill_value == dtypes.NA: + # if we supply the default, we expect the missing value for a + # float array + fill_value = np.nan + elif isinstance(fill_value, dict): + fill_value = fill_value.get("foo", np.nan) + expected = Dataset({"foo": ("x", [fill_value, 1, 2])}, coords, attrs) + assert_identical(expected, actual) + + with pytest.raises(ValueError, match=r"dimensions"): + ds.shift(foo=123) + + def test_roll_coords(self) -> None: + coords = {"bar": ("x", list("abc")), "x": [-4, 3, 2]} + attrs = {"meta": "data"} + ds = Dataset({"foo": ("x", [1, 2, 3])}, coords, attrs) + actual = ds.roll(x=1, roll_coords=True) + + ex_coords = {"bar": ("x", list("cab")), "x": [2, -4, 3]} + expected = Dataset({"foo": ("x", [3, 1, 2])}, ex_coords, attrs) + assert_identical(expected, actual) + + with pytest.raises(ValueError, match=r"dimensions"): + ds.roll(foo=123, roll_coords=True) + + def test_roll_no_coords(self) -> None: + coords = {"bar": ("x", list("abc")), "x": [-4, 3, 2]} + attrs = {"meta": "data"} + ds = Dataset({"foo": ("x", [1, 2, 3])}, coords, attrs) + actual = ds.roll(x=1) + + expected = Dataset({"foo": ("x", [3, 1, 2])}, coords, attrs) + assert_identical(expected, actual) + + with pytest.raises(ValueError, match=r"dimensions"): + ds.roll(abc=321) + + def test_roll_multidim(self) -> None: + # regression test for 2445 + arr = xr.DataArray( + [[1, 2, 3], [4, 5, 6]], + coords={"x": range(3), "y": range(2)}, + dims=("y", "x"), + ) + actual = arr.roll(x=1, roll_coords=True) + expected = xr.DataArray( + [[3, 1, 2], [6, 4, 5]], coords=[("y", [0, 1]), ("x", [2, 0, 1])] + ) + assert_identical(expected, actual) + + def test_real_and_imag(self) -> None: + attrs = {"foo": "bar"} + ds = Dataset({"x": ((), 1 + 2j, attrs)}, attrs=attrs) + + expected_re = Dataset({"x": ((), 1, attrs)}, attrs=attrs) + assert_identical(ds.real, expected_re) + + expected_im = Dataset({"x": ((), 2, attrs)}, attrs=attrs) + assert_identical(ds.imag, expected_im) + + def test_setattr_raises(self) -> None: + ds = Dataset({}, coords={"scalar": 1}, attrs={"foo": "bar"}) + with pytest.raises(AttributeError, match=r"cannot set attr"): + ds.scalar = 2 + with pytest.raises(AttributeError, match=r"cannot set attr"): + ds.foo = 2 + with pytest.raises(AttributeError, match=r"cannot set attr"): + ds.other = 2 + + def test_filter_by_attrs(self) -> None: + precip = dict(standard_name="convective_precipitation_flux") + temp0 = dict(standard_name="air_potential_temperature", height="0 m") + temp10 = dict(standard_name="air_potential_temperature", height="10 m") + ds = Dataset( + { + "temperature_0": (["t"], [0], temp0), + "temperature_10": (["t"], [0], temp10), + "precipitation": (["t"], [0], precip), + }, + coords={"time": (["t"], [0], dict(axis="T", long_name="time_in_seconds"))}, + ) + + # Test return empty Dataset. + ds.filter_by_attrs(standard_name="invalid_standard_name") + new_ds = ds.filter_by_attrs(standard_name="invalid_standard_name") + assert not bool(new_ds.data_vars) + + # Test return one DataArray. + new_ds = ds.filter_by_attrs(standard_name="convective_precipitation_flux") + assert new_ds["precipitation"].standard_name == "convective_precipitation_flux" + + assert_equal(new_ds["precipitation"], ds["precipitation"]) + + # Test filter coordinates + new_ds = ds.filter_by_attrs(long_name="time_in_seconds") + assert new_ds["time"].long_name == "time_in_seconds" + assert not bool(new_ds.data_vars) + + # Test return more than one DataArray. + new_ds = ds.filter_by_attrs(standard_name="air_potential_temperature") + assert len(new_ds.data_vars) == 2 + for var in new_ds.data_vars: + assert new_ds[var].standard_name == "air_potential_temperature" + + # Test callable. + new_ds = ds.filter_by_attrs(height=lambda v: v is not None) + assert len(new_ds.data_vars) == 2 + for var in new_ds.data_vars: + assert new_ds[var].standard_name == "air_potential_temperature" + + new_ds = ds.filter_by_attrs(height="10 m") + assert len(new_ds.data_vars) == 1 + for var in new_ds.data_vars: + assert new_ds[var].height == "10 m" + + # Test return empty Dataset due to conflicting filters + new_ds = ds.filter_by_attrs( + standard_name="convective_precipitation_flux", height="0 m" + ) + assert not bool(new_ds.data_vars) + + # Test return one DataArray with two filter conditions + new_ds = ds.filter_by_attrs( + standard_name="air_potential_temperature", height="0 m" + ) + for var in new_ds.data_vars: + assert new_ds[var].standard_name == "air_potential_temperature" + assert new_ds[var].height == "0 m" + assert new_ds[var].height != "10 m" + + # Test return empty Dataset due to conflicting callables + new_ds = ds.filter_by_attrs( + standard_name=lambda v: False, height=lambda v: True + ) + assert not bool(new_ds.data_vars) + + def test_binary_op_propagate_indexes(self) -> None: + ds = Dataset( + {"d1": DataArray([1, 2, 3], dims=["x"], coords={"x": [10, 20, 30]})} + ) + expected = ds.xindexes["x"] + actual = (ds * 2).xindexes["x"] + assert expected is actual + + def test_binary_op_join_setting(self) -> None: + # arithmetic_join applies to data array coordinates + missing_2 = xr.Dataset({"x": [0, 1]}) + missing_0 = xr.Dataset({"x": [1, 2]}) + with xr.set_options(arithmetic_join="outer"): + actual = missing_2 + missing_0 + expected = xr.Dataset({"x": [0, 1, 2]}) + assert_equal(actual, expected) + + # arithmetic join also applies to data_vars + ds1 = xr.Dataset({"foo": 1, "bar": 2}) + ds2 = xr.Dataset({"bar": 2, "baz": 3}) + expected = xr.Dataset({"bar": 4}) # default is inner joining + actual = ds1 + ds2 + assert_equal(actual, expected) + + with xr.set_options(arithmetic_join="outer"): + expected = xr.Dataset({"foo": np.nan, "bar": 4, "baz": np.nan}) + actual = ds1 + ds2 + assert_equal(actual, expected) + + with xr.set_options(arithmetic_join="left"): + expected = xr.Dataset({"foo": np.nan, "bar": 4}) + actual = ds1 + ds2 + assert_equal(actual, expected) + + with xr.set_options(arithmetic_join="right"): + expected = xr.Dataset({"bar": 4, "baz": np.nan}) + actual = ds1 + ds2 + assert_equal(actual, expected) + + @pytest.mark.parametrize( + ["keep_attrs", "expected"], + ( + pytest.param(False, {}, id="False"), + pytest.param(True, {"foo": "a", "bar": "b"}, id="True"), + ), + ) + def test_binary_ops_keep_attrs(self, keep_attrs, expected) -> None: + ds1 = xr.Dataset({"a": 1}, attrs={"foo": "a", "bar": "b"}) + ds2 = xr.Dataset({"a": 1}, attrs={"foo": "a", "baz": "c"}) + with xr.set_options(keep_attrs=keep_attrs): + ds_result = ds1 + ds2 + + assert ds_result.attrs == expected + + def test_full_like(self) -> None: + # For more thorough tests, see test_variable.py + # Note: testing data_vars with mismatched dtypes + ds = Dataset( + { + "d1": DataArray([1, 2, 3], dims=["x"], coords={"x": [10, 20, 30]}), + "d2": DataArray([1.1, 2.2, 3.3], dims=["y"]), + }, + attrs={"foo": "bar"}, + ) + actual = full_like(ds, 2) + + expected = ds.copy(deep=True) + # https://github.com/python/mypy/issues/3004 + expected["d1"].values = [2, 2, 2] # type: ignore + expected["d2"].values = [2.0, 2.0, 2.0] # type: ignore + assert expected["d1"].dtype == int + assert expected["d2"].dtype == float + assert_identical(expected, actual) + + # override dtype + actual = full_like(ds, fill_value=True, dtype=bool) + expected = ds.copy(deep=True) + expected["d1"].values = [True, True, True] # type: ignore + expected["d2"].values = [True, True, True] # type: ignore + assert expected["d1"].dtype == bool + assert expected["d2"].dtype == bool + assert_identical(expected, actual) + + # with multiple fill values + actual = full_like(ds, {"d1": 1, "d2": 2.3}) + expected = ds.assign(d1=("x", [1, 1, 1]), d2=("y", [2.3, 2.3, 2.3])) + assert expected["d1"].dtype == int + assert expected["d2"].dtype == float + assert_identical(expected, actual) + + # override multiple dtypes + actual = full_like(ds, fill_value={"d1": 1, "d2": 2.3}, dtype={"d1": bool}) + expected = ds.assign(d1=("x", [True, True, True]), d2=("y", [2.3, 2.3, 2.3])) + assert expected["d1"].dtype == bool + assert expected["d2"].dtype == float + assert_identical(expected, actual) + + def test_combine_first(self) -> None: + dsx0 = DataArray([0, 0], [("x", ["a", "b"])]).to_dataset(name="dsx0") + dsx1 = DataArray([1, 1], [("x", ["b", "c"])]).to_dataset(name="dsx1") + + actual = dsx0.combine_first(dsx1) + expected = Dataset( + {"dsx0": ("x", [0, 0, np.nan]), "dsx1": ("x", [np.nan, 1, 1])}, + coords={"x": ["a", "b", "c"]}, + ) + assert_equal(actual, expected) + assert_equal(actual, xr.merge([dsx0, dsx1])) + + # works just like xr.merge([self, other]) + dsy2 = DataArray([2, 2, 2], [("x", ["b", "c", "d"])]).to_dataset(name="dsy2") + actual = dsx0.combine_first(dsy2) + expected = xr.merge([dsy2, dsx0]) + assert_equal(actual, expected) + + def test_sortby(self) -> None: + ds = Dataset( + { + "A": DataArray( + [[1, 2], [3, 4], [5, 6]], [("x", ["c", "b", "a"]), ("y", [1, 0])] + ), + "B": DataArray([[5, 6], [7, 8], [9, 10]], dims=["x", "y"]), + } + ) + + sorted1d = Dataset( + { + "A": DataArray( + [[5, 6], [3, 4], [1, 2]], [("x", ["a", "b", "c"]), ("y", [1, 0])] + ), + "B": DataArray([[9, 10], [7, 8], [5, 6]], dims=["x", "y"]), + } + ) + + sorted2d = Dataset( + { + "A": DataArray( + [[6, 5], [4, 3], [2, 1]], [("x", ["a", "b", "c"]), ("y", [0, 1])] + ), + "B": DataArray([[10, 9], [8, 7], [6, 5]], dims=["x", "y"]), + } + ) + + expected = sorted1d + dax = DataArray([100, 99, 98], [("x", ["c", "b", "a"])]) + actual = ds.sortby(dax) + assert_equal(actual, expected) + + # test descending order sort + actual = ds.sortby(dax, ascending=False) + assert_equal(actual, ds) + + # test alignment (fills in nan for 'c') + dax_short = DataArray([98, 97], [("x", ["b", "a"])]) + actual = ds.sortby(dax_short) + assert_equal(actual, expected) + + # test 1-D lexsort + # dax0 is sorted first to give indices of [1, 2, 0] + # and then dax1 would be used to move index 2 ahead of 1 + dax0 = DataArray([100, 95, 95], [("x", ["c", "b", "a"])]) + dax1 = DataArray([0, 1, 0], [("x", ["c", "b", "a"])]) + actual = ds.sortby([dax0, dax1]) # lexsort underneath gives [2, 1, 0] + assert_equal(actual, expected) + + expected = sorted2d + # test multi-dim sort by 1D dataarray values + day = DataArray([90, 80], [("y", [1, 0])]) + actual = ds.sortby([day, dax]) + assert_equal(actual, expected) + + # test exception-raising + with pytest.raises(KeyError): + actual = ds.sortby("z") + + with pytest.raises(ValueError) as excinfo: + actual = ds.sortby(ds["A"]) + assert "DataArray is not 1-D" in str(excinfo.value) + + expected = sorted1d + actual = ds.sortby("x") + assert_equal(actual, expected) + + # test pandas.MultiIndex + indices = (("b", 1), ("b", 0), ("a", 1), ("a", 0)) + midx = pd.MultiIndex.from_tuples(indices, names=["one", "two"]) + ds_midx = Dataset( + { + "A": DataArray( + [[1, 2], [3, 4], [5, 6], [7, 8]], [("x", midx), ("y", [1, 0])] + ), + "B": DataArray([[5, 6], [7, 8], [9, 10], [11, 12]], dims=["x", "y"]), + } + ) + actual = ds_midx.sortby("x") + midx_reversed = pd.MultiIndex.from_tuples( + tuple(reversed(indices)), names=["one", "two"] + ) + expected = Dataset( + { + "A": DataArray( + [[7, 8], [5, 6], [3, 4], [1, 2]], + [("x", midx_reversed), ("y", [1, 0])], + ), + "B": DataArray([[11, 12], [9, 10], [7, 8], [5, 6]], dims=["x", "y"]), + } + ) + assert_equal(actual, expected) + + # multi-dim sort by coordinate objects + expected = sorted2d + actual = ds.sortby(["x", "y"]) + assert_equal(actual, expected) + + # test descending order sort + actual = ds.sortby(["x", "y"], ascending=False) + assert_equal(actual, ds) + + def test_attribute_access(self) -> None: + ds = create_test_data(seed=1) + for key in ["var1", "var2", "var3", "time", "dim1", "dim2", "dim3", "numbers"]: + assert_equal(ds[key], getattr(ds, key)) + assert key in dir(ds) + + for key in ["dim3", "dim1", "numbers"]: + assert_equal(ds["var3"][key], getattr(ds.var3, key)) + assert key in dir(ds["var3"]) + # attrs + assert ds["var3"].attrs["foo"] == ds.var3.foo + assert "foo" in dir(ds["var3"]) + + def test_ipython_key_completion(self) -> None: + ds = create_test_data(seed=1) + actual = ds._ipython_key_completions_() + expected = ["var1", "var2", "var3", "time", "dim1", "dim2", "dim3", "numbers"] + for item in actual: + ds[item] # should not raise + assert sorted(actual) == sorted(expected) + + # for dataarray + actual = ds["var3"]._ipython_key_completions_() + expected = ["dim3", "dim1", "numbers"] + for item in actual: + ds["var3"][item] # should not raise + assert sorted(actual) == sorted(expected) + + # MultiIndex + ds_midx = ds.stack(dim12=["dim2", "dim3"]) + actual = ds_midx._ipython_key_completions_() + expected = [ + "var1", + "var2", + "var3", + "time", + "dim1", + "dim2", + "dim3", + "numbers", + "dim12", + ] + for item in actual: + ds_midx[item] # should not raise + assert sorted(actual) == sorted(expected) + + # coords + actual = ds.coords._ipython_key_completions_() + expected = ["time", "dim1", "dim2", "dim3", "numbers"] + for item in actual: + ds.coords[item] # should not raise + assert sorted(actual) == sorted(expected) + + actual = ds["var3"].coords._ipython_key_completions_() + expected = ["dim1", "dim3", "numbers"] + for item in actual: + ds["var3"].coords[item] # should not raise + assert sorted(actual) == sorted(expected) + + coords = Coordinates(ds.coords) + actual = coords._ipython_key_completions_() + expected = ["time", "dim2", "dim3", "numbers"] + for item in actual: + coords[item] # should not raise + assert sorted(actual) == sorted(expected) + + # data_vars + actual = ds.data_vars._ipython_key_completions_() + expected = ["var1", "var2", "var3", "dim1"] + for item in actual: + ds.data_vars[item] # should not raise + assert sorted(actual) == sorted(expected) + + def test_polyfit_output(self) -> None: + ds = create_test_data(seed=1) + + out = ds.polyfit("dim2", 2, full=False) + assert "var1_polyfit_coefficients" in out + + out = ds.polyfit("dim1", 2, full=True) + assert "var1_polyfit_coefficients" in out + assert "dim1_matrix_rank" in out + + out = ds.polyfit("time", 2) + assert len(out.data_vars) == 0 + + def test_polyfit_weighted(self) -> None: + # Make sure weighted polyfit does not change the original object (issue #5644) + ds = create_test_data(seed=1) + ds_copy = ds.copy(deep=True) + + ds.polyfit("dim2", 2, w=np.arange(ds.sizes["dim2"])) + xr.testing.assert_identical(ds, ds_copy) + + def test_polyfit_warnings(self) -> None: + ds = create_test_data(seed=1) + + with warnings.catch_warnings(record=True) as ws: + ds.var1.polyfit("dim2", 10, full=False) + assert len(ws) == 1 + assert ws[0].category == RankWarning + ds.var1.polyfit("dim2", 10, full=True) + assert len(ws) == 1 + + def test_pad(self) -> None: + ds = create_test_data(seed=1) + padded = ds.pad(dim2=(1, 1), constant_values=42) + + assert padded["dim2"].shape == (11,) + assert padded["var1"].shape == (8, 11) + assert padded["var2"].shape == (8, 11) + assert padded["var3"].shape == (10, 8) + assert dict(padded.sizes) == {"dim1": 8, "dim2": 11, "dim3": 10, "time": 20} + + np.testing.assert_equal(padded["var1"].isel(dim2=[0, -1]).data, 42) + np.testing.assert_equal(padded["dim2"][[0, -1]].data, np.nan) + + @pytest.mark.parametrize( + ["keep_attrs", "attrs", "expected"], + [ + pytest.param(None, {"a": 1, "b": 2}, {"a": 1, "b": 2}, id="default"), + pytest.param(False, {"a": 1, "b": 2}, {}, id="False"), + pytest.param(True, {"a": 1, "b": 2}, {"a": 1, "b": 2}, id="True"), + ], + ) + def test_pad_keep_attrs(self, keep_attrs, attrs, expected) -> None: + ds = xr.Dataset( + {"a": ("x", [1, 2], attrs), "b": ("y", [1, 2], attrs)}, + coords={"c": ("x", [-1, 1], attrs), "d": ("y", [-1, 1], attrs)}, + attrs=attrs, + ) + expected = xr.Dataset( + {"a": ("x", [0, 1, 2, 0], expected), "b": ("y", [1, 2], attrs)}, + coords={ + "c": ("x", [np.nan, -1, 1, np.nan], expected), + "d": ("y", [-1, 1], attrs), + }, + attrs=expected, + ) + + keep_attrs_ = "default" if keep_attrs is None else keep_attrs + + with set_options(keep_attrs=keep_attrs_): + actual = ds.pad({"x": (1, 1)}, mode="constant", constant_values=0) + xr.testing.assert_identical(actual, expected) + + actual = ds.pad( + {"x": (1, 1)}, mode="constant", constant_values=0, keep_attrs=keep_attrs + ) + xr.testing.assert_identical(actual, expected) + + def test_astype_attrs(self) -> None: + data = create_test_data(seed=123) + data.attrs["foo"] = "bar" + + assert data.attrs == data.astype(float).attrs + assert data.var1.attrs == data.astype(float).var1.attrs + assert not data.astype(float, keep_attrs=False).attrs + assert not data.astype(float, keep_attrs=False).var1.attrs + + @pytest.mark.parametrize("parser", ["pandas", "python"]) + @pytest.mark.parametrize( + "engine", ["python", None, pytest.param("numexpr", marks=[requires_numexpr])] + ) + @pytest.mark.parametrize( + "backend", ["numpy", pytest.param("dask", marks=[requires_dask])] + ) + def test_query(self, backend, engine, parser) -> None: + """Test querying a dataset.""" + + # setup test data + np.random.seed(42) + a = np.arange(0, 10, 1) + b = np.random.randint(0, 100, size=10) + c = np.linspace(0, 1, 20) + d = np.random.choice(["foo", "bar", "baz"], size=30, replace=True).astype( + object + ) + e = np.arange(0, 10 * 20).reshape(10, 20) + f = np.random.normal(0, 1, size=(10, 20, 30)) + if backend == "numpy": + ds = Dataset( + { + "a": ("x", a), + "b": ("x", b), + "c": ("y", c), + "d": ("z", d), + "e": (("x", "y"), e), + "f": (("x", "y", "z"), f), + }, + coords={ + "a2": ("x", a), + "b2": ("x", b), + "c2": ("y", c), + "d2": ("z", d), + "e2": (("x", "y"), e), + "f2": (("x", "y", "z"), f), + }, + ) + elif backend == "dask": + ds = Dataset( + { + "a": ("x", da.from_array(a, chunks=3)), + "b": ("x", da.from_array(b, chunks=3)), + "c": ("y", da.from_array(c, chunks=7)), + "d": ("z", da.from_array(d, chunks=12)), + "e": (("x", "y"), da.from_array(e, chunks=(3, 7))), + "f": (("x", "y", "z"), da.from_array(f, chunks=(3, 7, 12))), + }, + coords={ + "a2": ("x", a), + "b2": ("x", b), + "c2": ("y", c), + "d2": ("z", d), + "e2": (("x", "y"), e), + "f2": (("x", "y", "z"), f), + }, + ) + + # query single dim, single variable + with raise_if_dask_computes(): + actual = ds.query(x="a2 > 5", engine=engine, parser=parser) + expect = ds.isel(x=(a > 5)) + assert_identical(expect, actual) + + # query single dim, single variable, via dict + with raise_if_dask_computes(): + actual = ds.query(dict(x="a2 > 5"), engine=engine, parser=parser) + expect = ds.isel(dict(x=(a > 5))) + assert_identical(expect, actual) + + # query single dim, single variable + with raise_if_dask_computes(): + actual = ds.query(x="b2 > 50", engine=engine, parser=parser) + expect = ds.isel(x=(b > 50)) + assert_identical(expect, actual) + + # query single dim, single variable + with raise_if_dask_computes(): + actual = ds.query(y="c2 < .5", engine=engine, parser=parser) + expect = ds.isel(y=(c < 0.5)) + assert_identical(expect, actual) + + # query single dim, single string variable + if parser == "pandas": + # N.B., this query currently only works with the pandas parser + # xref https://github.com/pandas-dev/pandas/issues/40436 + with raise_if_dask_computes(): + actual = ds.query(z='d2 == "bar"', engine=engine, parser=parser) + expect = ds.isel(z=(d == "bar")) + assert_identical(expect, actual) + + # query single dim, multiple variables + with raise_if_dask_computes(): + actual = ds.query(x="(a2 > 5) & (b2 > 50)", engine=engine, parser=parser) + expect = ds.isel(x=((a > 5) & (b > 50))) + assert_identical(expect, actual) + + # query single dim, multiple variables with computation + with raise_if_dask_computes(): + actual = ds.query(x="(a2 * b2) > 250", engine=engine, parser=parser) + expect = ds.isel(x=(a * b) > 250) + assert_identical(expect, actual) + + # check pandas query syntax is supported + if parser == "pandas": + with raise_if_dask_computes(): + actual = ds.query( + x="(a2 > 5) and (b2 > 50)", engine=engine, parser=parser + ) + expect = ds.isel(x=((a > 5) & (b > 50))) + assert_identical(expect, actual) + + # query multiple dims via kwargs + with raise_if_dask_computes(): + actual = ds.query(x="a2 > 5", y="c2 < .5", engine=engine, parser=parser) + expect = ds.isel(x=(a > 5), y=(c < 0.5)) + assert_identical(expect, actual) + + # query multiple dims via kwargs + if parser == "pandas": + with raise_if_dask_computes(): + actual = ds.query( + x="a2 > 5", + y="c2 < .5", + z="d2 == 'bar'", + engine=engine, + parser=parser, + ) + expect = ds.isel(x=(a > 5), y=(c < 0.5), z=(d == "bar")) + assert_identical(expect, actual) + + # query multiple dims via dict + with raise_if_dask_computes(): + actual = ds.query( + dict(x="a2 > 5", y="c2 < .5"), engine=engine, parser=parser + ) + expect = ds.isel(dict(x=(a > 5), y=(c < 0.5))) + assert_identical(expect, actual) + + # query multiple dims via dict + if parser == "pandas": + with raise_if_dask_computes(): + actual = ds.query( + dict(x="a2 > 5", y="c2 < .5", z="d2 == 'bar'"), + engine=engine, + parser=parser, + ) + expect = ds.isel(dict(x=(a > 5), y=(c < 0.5), z=(d == "bar"))) + assert_identical(expect, actual) + + # test error handling + with pytest.raises(ValueError): + ds.query("a > 5") # type: ignore # must be dict or kwargs + with pytest.raises(ValueError): + ds.query(x=(a > 5)) + with pytest.raises(IndexError): + ds.query(y="a > 5") # wrong length dimension + with pytest.raises(IndexError): + ds.query(x="c < .5") # wrong length dimension + with pytest.raises(IndexError): + ds.query(x="e > 100") # wrong number of dimensions + with pytest.raises(UndefinedVariableError): + ds.query(x="spam > 50") # name not present + + +# pytest tests — new tests should go here, rather than in the class. + + +@pytest.mark.parametrize("parser", ["pandas", "python"]) +def test_eval(ds, parser) -> None: + """Currently much more minimal testing that `query` above, and much of the setup + isn't used. But the risks are fairly low — `query` shares much of the code, and + the method is currently experimental.""" + + actual = ds.eval("z1 + 5", parser=parser) + expect = ds["z1"] + 5 + assert_identical(expect, actual) + + # check pandas query syntax is supported + if parser == "pandas": + actual = ds.eval("(z1 > 5) and (z2 > 0)", parser=parser) + expect = (ds["z1"] > 5) & (ds["z2"] > 0) + assert_identical(expect, actual) + + +@pytest.mark.parametrize("test_elements", ([1, 2], np.array([1, 2]), DataArray([1, 2]))) +def test_isin(test_elements, backend) -> None: + expected = Dataset( + data_vars={ + "var1": (("dim1",), [0, 1]), + "var2": (("dim1",), [1, 1]), + "var3": (("dim1",), [0, 1]), + } + ).astype("bool") + + if backend == "dask": + expected = expected.chunk() + + result = Dataset( + data_vars={ + "var1": (("dim1",), [0, 1]), + "var2": (("dim1",), [1, 2]), + "var3": (("dim1",), [0, 1]), + } + ).isin(test_elements) + + assert_equal(result, expected) + + +def test_isin_dataset() -> None: + ds = Dataset({"x": [1, 2]}) + with pytest.raises(TypeError): + ds.isin(ds) + + +@pytest.mark.parametrize( + "unaligned_coords", + ( + {"x": [2, 1, 0]}, + {"x": (["x"], np.asarray([2, 1, 0]))}, + {"x": (["x"], np.asarray([1, 2, 0]))}, + {"x": pd.Index([2, 1, 0])}, + {"x": Variable(dims="x", data=[0, 2, 1])}, + {"x": IndexVariable(dims="x", data=[0, 1, 2])}, + {"y": 42}, + {"y": ("x", [2, 1, 0])}, + {"y": ("x", np.asarray([2, 1, 0]))}, + {"y": (["x"], np.asarray([2, 1, 0]))}, + ), +) +@pytest.mark.parametrize("coords", ({"x": ("x", [0, 1, 2])}, {"x": [0, 1, 2]})) +def test_dataset_constructor_aligns_to_explicit_coords( + unaligned_coords, coords +) -> None: + a = xr.DataArray([1, 2, 3], dims=["x"], coords=unaligned_coords) + + expected = xr.Dataset(coords=coords) + expected["a"] = a + + result = xr.Dataset({"a": a}, coords=coords) + + assert_equal(expected, result) + + +def test_error_message_on_set_supplied() -> None: + with pytest.raises(TypeError, match="has invalid type "): + xr.Dataset(dict(date=[1, 2, 3], sec={4})) + + +@pytest.mark.parametrize("unaligned_coords", ({"y": ("b", np.asarray([2, 1, 0]))},)) +def test_constructor_raises_with_invalid_coords(unaligned_coords) -> None: + with pytest.raises(ValueError, match="not a subset of the DataArray dimensions"): + xr.DataArray([1, 2, 3], dims=["x"], coords=unaligned_coords) + + +@pytest.mark.parametrize("ds", [3], indirect=True) +def test_dir_expected_attrs(ds) -> None: + some_expected_attrs = {"pipe", "mean", "isnull", "var1", "dim2", "numbers"} + result = dir(ds) + assert set(result) >= some_expected_attrs + + +def test_dir_non_string(ds) -> None: + # add a numbered key to ensure this doesn't break dir + ds[5] = "foo" + result = dir(ds) + assert 5 not in result + + # GH2172 + sample_data = np.random.uniform(size=[2, 2000, 10000]) + x = xr.Dataset({"sample_data": (sample_data.shape, sample_data)}) + x2 = x["sample_data"] + dir(x2) + + +def test_dir_unicode(ds) -> None: + ds["unicode"] = "uni" + result = dir(ds) + assert "unicode" in result + + +def test_raise_no_warning_for_nan_in_binary_ops() -> None: + with assert_no_warnings(): + Dataset(data_vars={"x": ("y", [1, 2, np.nan])}) > 0 + + +@pytest.mark.filterwarnings("error") +@pytest.mark.parametrize("ds", (2,), indirect=True) +def test_raise_no_warning_assert_close(ds) -> None: + assert_allclose(ds, ds) + + +@pytest.mark.parametrize("dask", [True, False]) +@pytest.mark.parametrize("edge_order", [1, 2]) +def test_differentiate(dask, edge_order) -> None: + rs = np.random.RandomState(42) + coord = [0.2, 0.35, 0.4, 0.6, 0.7, 0.75, 0.76, 0.8] + + da = xr.DataArray( + rs.randn(8, 6), + dims=["x", "y"], + coords={"x": coord, "z": 3, "x2d": (("x", "y"), rs.randn(8, 6))}, + ) + if dask and has_dask: + da = da.chunk({"x": 4}) + + ds = xr.Dataset({"var": da}) + + # along x + actual = da.differentiate("x", edge_order) + expected_x = xr.DataArray( + np.gradient(da, da["x"], axis=0, edge_order=edge_order), + dims=da.dims, + coords=da.coords, + ) + assert_equal(expected_x, actual) + assert_equal( + ds["var"].differentiate("x", edge_order=edge_order), + ds.differentiate("x", edge_order=edge_order)["var"], + ) + # coordinate should not change + assert_equal(da["x"], actual["x"]) + + # along y + actual = da.differentiate("y", edge_order) + expected_y = xr.DataArray( + np.gradient(da, da["y"], axis=1, edge_order=edge_order), + dims=da.dims, + coords=da.coords, + ) + assert_equal(expected_y, actual) + assert_equal(actual, ds.differentiate("y", edge_order=edge_order)["var"]) + assert_equal( + ds["var"].differentiate("y", edge_order=edge_order), + ds.differentiate("y", edge_order=edge_order)["var"], + ) + + with pytest.raises(ValueError): + da.differentiate("x2d") + + +@pytest.mark.filterwarnings("ignore:Converting non-nanosecond") +@pytest.mark.parametrize("dask", [True, False]) +def test_differentiate_datetime(dask) -> None: + rs = np.random.RandomState(42) + coord = np.array( + [ + "2004-07-13", + "2006-01-13", + "2010-08-13", + "2010-09-13", + "2010-10-11", + "2010-12-13", + "2011-02-13", + "2012-08-13", + ], + dtype="datetime64", + ) + + da = xr.DataArray( + rs.randn(8, 6), + dims=["x", "y"], + coords={"x": coord, "z": 3, "x2d": (("x", "y"), rs.randn(8, 6))}, + ) + if dask and has_dask: + da = da.chunk({"x": 4}) + + # along x + actual = da.differentiate("x", edge_order=1, datetime_unit="D") + expected_x = xr.DataArray( + np.gradient( + da, da["x"].variable._to_numeric(datetime_unit="D"), axis=0, edge_order=1 + ), + dims=da.dims, + coords=da.coords, + ) + assert_equal(expected_x, actual) + + actual2 = da.differentiate("x", edge_order=1, datetime_unit="h") + assert np.allclose(actual, actual2 * 24) + + # for datetime variable + actual = da["x"].differentiate("x", edge_order=1, datetime_unit="D") + assert np.allclose(actual, 1.0) + + # with different date unit + da = xr.DataArray(coord.astype("datetime64[ms]"), dims=["x"], coords={"x": coord}) + actual = da.differentiate("x", edge_order=1) + assert np.allclose(actual, 1.0) + + +@pytest.mark.skipif(not has_cftime, reason="Test requires cftime.") +@pytest.mark.parametrize("dask", [True, False]) +def test_differentiate_cftime(dask) -> None: + rs = np.random.RandomState(42) + coord = xr.cftime_range("2000", periods=8, freq="2ME") + + da = xr.DataArray( + rs.randn(8, 6), + coords={"time": coord, "z": 3, "t2d": (("time", "y"), rs.randn(8, 6))}, + dims=["time", "y"], + ) + + if dask and has_dask: + da = da.chunk({"time": 4}) + + actual = da.differentiate("time", edge_order=1, datetime_unit="D") + expected_data = np.gradient( + da, da["time"].variable._to_numeric(datetime_unit="D"), axis=0, edge_order=1 + ) + expected = xr.DataArray(expected_data, coords=da.coords, dims=da.dims) + assert_equal(expected, actual) + + actual2 = da.differentiate("time", edge_order=1, datetime_unit="h") + assert_allclose(actual, actual2 * 24) + + # Test the differentiation of datetimes themselves + actual = da["time"].differentiate("time", edge_order=1, datetime_unit="D") + assert_allclose(actual, xr.ones_like(da["time"]).astype(float)) + + +@pytest.mark.parametrize("dask", [True, False]) +def test_integrate(dask) -> None: + rs = np.random.RandomState(42) + coord = [0.2, 0.35, 0.4, 0.6, 0.7, 0.75, 0.76, 0.8] + + da = xr.DataArray( + rs.randn(8, 6), + dims=["x", "y"], + coords={ + "x": coord, + "x2": (("x",), rs.randn(8)), + "z": 3, + "x2d": (("x", "y"), rs.randn(8, 6)), + }, + ) + if dask and has_dask: + da = da.chunk({"x": 4}) + + ds = xr.Dataset({"var": da}) + + # along x + actual = da.integrate("x") + # coordinate that contains x should be dropped. + expected_x = xr.DataArray( + trapezoid(da.compute(), da["x"], axis=0), + dims=["y"], + coords={k: v for k, v in da.coords.items() if "x" not in v.dims}, + ) + assert_allclose(expected_x, actual.compute()) + assert_equal(ds["var"].integrate("x"), ds.integrate("x")["var"]) + + # make sure result is also a dask array (if the source is dask array) + assert isinstance(actual.data, type(da.data)) + + # along y + actual = da.integrate("y") + expected_y = xr.DataArray( + trapezoid(da, da["y"], axis=1), + dims=["x"], + coords={k: v for k, v in da.coords.items() if "y" not in v.dims}, + ) + assert_allclose(expected_y, actual.compute()) + assert_equal(actual, ds.integrate("y")["var"]) + assert_equal(ds["var"].integrate("y"), ds.integrate("y")["var"]) + + # along x and y + actual = da.integrate(("y", "x")) + assert actual.ndim == 0 + + with pytest.raises(ValueError): + da.integrate("x2d") + + +@requires_scipy +@pytest.mark.parametrize("dask", [True, False]) +def test_cumulative_integrate(dask) -> None: + rs = np.random.RandomState(43) + coord = [0.2, 0.35, 0.4, 0.6, 0.7, 0.75, 0.76, 0.8] + + da = xr.DataArray( + rs.randn(8, 6), + dims=["x", "y"], + coords={ + "x": coord, + "x2": (("x",), rs.randn(8)), + "z": 3, + "x2d": (("x", "y"), rs.randn(8, 6)), + }, + ) + if dask and has_dask: + da = da.chunk({"x": 4}) + + ds = xr.Dataset({"var": da}) + + # along x + actual = da.cumulative_integrate("x") + + from scipy.integrate import cumulative_trapezoid + + expected_x = xr.DataArray( + cumulative_trapezoid(da.compute(), da["x"], axis=0, initial=0.0), + dims=["x", "y"], + coords=da.coords, + ) + assert_allclose(expected_x, actual.compute()) + assert_equal( + ds["var"].cumulative_integrate("x"), + ds.cumulative_integrate("x")["var"], + ) + + # make sure result is also a dask array (if the source is dask array) + assert isinstance(actual.data, type(da.data)) + + # along y + actual = da.cumulative_integrate("y") + expected_y = xr.DataArray( + cumulative_trapezoid(da, da["y"], axis=1, initial=0.0), + dims=["x", "y"], + coords=da.coords, + ) + assert_allclose(expected_y, actual.compute()) + assert_equal(actual, ds.cumulative_integrate("y")["var"]) + assert_equal( + ds["var"].cumulative_integrate("y"), + ds.cumulative_integrate("y")["var"], + ) + + # along x and y + actual = da.cumulative_integrate(("y", "x")) + assert actual.ndim == 2 + + with pytest.raises(ValueError): + da.cumulative_integrate("x2d") + + +@pytest.mark.filterwarnings("ignore:Converting non-nanosecond") +@pytest.mark.parametrize("dask", [True, False]) +@pytest.mark.parametrize("which_datetime", ["np", "cftime"]) +def test_trapezoid_datetime(dask, which_datetime) -> None: + rs = np.random.RandomState(42) + if which_datetime == "np": + coord = np.array( + [ + "2004-07-13", + "2006-01-13", + "2010-08-13", + "2010-09-13", + "2010-10-11", + "2010-12-13", + "2011-02-13", + "2012-08-13", + ], + dtype="datetime64", + ) + else: + if not has_cftime: + pytest.skip("Test requires cftime.") + coord = xr.cftime_range("2000", periods=8, freq="2D") + + da = xr.DataArray( + rs.randn(8, 6), + coords={"time": coord, "z": 3, "t2d": (("time", "y"), rs.randn(8, 6))}, + dims=["time", "y"], + ) + + if dask and has_dask: + da = da.chunk({"time": 4}) + + actual = da.integrate("time", datetime_unit="D") + expected_data = trapezoid( + da.compute().data, + duck_array_ops.datetime_to_numeric(da["time"].data, datetime_unit="D"), + axis=0, + ) + expected = xr.DataArray( + expected_data, + dims=["y"], + coords={k: v for k, v in da.coords.items() if "time" not in v.dims}, + ) + assert_allclose(expected, actual.compute()) + + # make sure result is also a dask array (if the source is dask array) + assert isinstance(actual.data, type(da.data)) + + actual2 = da.integrate("time", datetime_unit="h") + assert_allclose(actual, actual2 / 24.0) + + +def test_no_dict() -> None: + d = Dataset() + with pytest.raises(AttributeError): + d.__dict__ + + +def test_subclass_slots() -> None: + """Test that Dataset subclasses must explicitly define ``__slots__``. + + .. note:: + As of 0.13.0, this is actually mitigated into a FutureWarning for any class + defined outside of the xarray package. + """ + with pytest.raises(AttributeError) as e: + + class MyDS(Dataset): + pass + + assert str(e.value) == "MyDS must explicitly define __slots__" + + +def test_weakref() -> None: + """Classes with __slots__ are incompatible with the weakref module unless they + explicitly state __weakref__ among their slots + """ + from weakref import ref + + ds = Dataset() + r = ref(ds) + assert r() is ds + + +def test_deepcopy_obj_array() -> None: + x0 = Dataset(dict(foo=DataArray(np.array([object()])))) + x1 = deepcopy(x0) + assert x0["foo"].values[0] is not x1["foo"].values[0] + + +def test_deepcopy_recursive() -> None: + # GH:issue:7111 + + # direct recursion + ds = xr.Dataset({"a": (["x"], [1, 2])}) + ds.attrs["other"] = ds + + # TODO: cannot use assert_identical on recursive Vars yet... + # lets just ensure that deep copy works without RecursionError + ds.copy(deep=True) + + # indirect recursion + ds2 = xr.Dataset({"b": (["y"], [3, 4])}) + ds.attrs["other"] = ds2 + ds2.attrs["other"] = ds + + # TODO: cannot use assert_identical on recursive Vars yet... + # lets just ensure that deep copy works without RecursionError + ds.copy(deep=True) + ds2.copy(deep=True) + + +def test_clip(ds) -> None: + result = ds.clip(min=0.5) + assert all((result.min(...) >= 0.5).values()) + + result = ds.clip(max=0.5) + assert all((result.max(...) <= 0.5).values()) + + result = ds.clip(min=0.25, max=0.75) + assert all((result.min(...) >= 0.25).values()) + assert all((result.max(...) <= 0.75).values()) + + result = ds.clip(min=ds.mean("y"), max=ds.mean("y")) + assert result.sizes == ds.sizes + + +class TestDropDuplicates: + @pytest.mark.parametrize("keep", ["first", "last", False]) + def test_drop_duplicates_1d(self, keep) -> None: + ds = xr.Dataset( + {"a": ("time", [0, 5, 6, 7]), "b": ("time", [9, 3, 8, 2])}, + coords={"time": [0, 0, 1, 2]}, + ) + + if keep == "first": + a = [0, 6, 7] + b = [9, 8, 2] + time = [0, 1, 2] + elif keep == "last": + a = [5, 6, 7] + b = [3, 8, 2] + time = [0, 1, 2] + else: + a = [6, 7] + b = [8, 2] + time = [1, 2] + + expected = xr.Dataset( + {"a": ("time", a), "b": ("time", b)}, coords={"time": time} + ) + result = ds.drop_duplicates("time", keep=keep) + assert_equal(expected, result) + + with pytest.raises( + ValueError, + match=re.escape( + "Dimensions ('space',) not found in data dimensions ('time',)" + ), + ): + ds.drop_duplicates("space", keep=keep) + + +class TestNumpyCoercion: + def test_from_numpy(self) -> None: + ds = xr.Dataset({"a": ("x", [1, 2, 3])}, coords={"lat": ("x", [4, 5, 6])}) + + assert_identical(ds.as_numpy(), ds) + + @requires_dask + def test_from_dask(self) -> None: + ds = xr.Dataset({"a": ("x", [1, 2, 3])}, coords={"lat": ("x", [4, 5, 6])}) + ds_chunked = ds.chunk(1) + + assert_identical(ds_chunked.as_numpy(), ds.compute()) + + @requires_pint + def test_from_pint(self) -> None: + from pint import Quantity + + arr = np.array([1, 2, 3]) + ds = xr.Dataset( + {"a": ("x", Quantity(arr, units="Pa"))}, + coords={"lat": ("x", Quantity(arr + 3, units="m"))}, + ) + + expected = xr.Dataset({"a": ("x", [1, 2, 3])}, coords={"lat": ("x", arr + 3)}) + assert_identical(ds.as_numpy(), expected) + + @requires_sparse + def test_from_sparse(self) -> None: + import sparse + + arr = np.diagflat([1, 2, 3]) + sparr = sparse.COO.from_numpy(arr) + ds = xr.Dataset( + {"a": (["x", "y"], sparr)}, coords={"elev": (("x", "y"), sparr + 3)} + ) + + expected = xr.Dataset( + {"a": (["x", "y"], arr)}, coords={"elev": (("x", "y"), arr + 3)} + ) + assert_identical(ds.as_numpy(), expected) + + @requires_cupy + def test_from_cupy(self) -> None: + import cupy as cp + + arr = np.array([1, 2, 3]) + ds = xr.Dataset( + {"a": ("x", cp.array(arr))}, coords={"lat": ("x", cp.array(arr + 3))} + ) + + expected = xr.Dataset({"a": ("x", [1, 2, 3])}, coords={"lat": ("x", arr + 3)}) + assert_identical(ds.as_numpy(), expected) + + @requires_dask + @requires_pint + def test_from_pint_wrapping_dask(self) -> None: + import dask + from pint import Quantity + + arr = np.array([1, 2, 3]) + d = dask.array.from_array(arr) + ds = xr.Dataset( + {"a": ("x", Quantity(d, units="Pa"))}, + coords={"lat": ("x", Quantity(d, units="m") * 2)}, + ) + + result = ds.as_numpy() + expected = xr.Dataset({"a": ("x", arr)}, coords={"lat": ("x", arr * 2)}) + assert_identical(result, expected) + + +def test_string_keys_typing() -> None: + """Tests that string keys to `variables` are permitted by mypy""" + + da = xr.DataArray(np.arange(10), dims=["x"]) + ds = xr.Dataset(dict(x=da)) + mapping = {"y": da} + ds.assign(variables=mapping) + + +def test_transpose_error() -> None: + # Transpose dataset with list as argument + # Should raise error + ds = xr.Dataset({"foo": (("x", "y"), [[21]]), "bar": (("x", "y"), [[12]])}) + + with pytest.raises( + TypeError, + match=re.escape( + "transpose requires dim to be passed as multiple arguments. Expected `'y', 'x'`. Received `['y', 'x']` instead" + ), + ): + ds.transpose(["y", "x"]) # type: ignore diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_datatree.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_datatree.py new file mode 100644 index 0000000..58fec20 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_datatree.py @@ -0,0 +1,904 @@ +from copy import copy, deepcopy +from textwrap import dedent + +import numpy as np +import pytest + +import xarray as xr +from xarray.core.datatree import DataTree +from xarray.core.datatree_ops import _MAPPED_DOCSTRING_ADDENDUM, insert_doc_addendum +from xarray.core.treenode import NotFoundInTreeError +from xarray.testing import assert_equal, assert_identical +from xarray.tests import create_test_data, source_ndarray + + +class TestTreeCreation: + def test_empty(self): + dt: DataTree = DataTree(name="root") + assert dt.name == "root" + assert dt.parent is None + assert dt.children == {} + assert_identical(dt.to_dataset(), xr.Dataset()) + + def test_unnamed(self): + dt: DataTree = DataTree() + assert dt.name is None + + def test_bad_names(self): + with pytest.raises(TypeError): + DataTree(name=5) # type: ignore[arg-type] + + with pytest.raises(ValueError): + DataTree(name="folder/data") + + +class TestFamilyTree: + def test_setparent_unnamed_child_node_fails(self): + john: DataTree = DataTree(name="john") + with pytest.raises(ValueError, match="unnamed"): + DataTree(parent=john) + + def test_create_two_children(self): + root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])}) + set1_data = xr.Dataset({"a": 0, "b": 1}) + + root: DataTree = DataTree(data=root_data) + set1: DataTree = DataTree(name="set1", parent=root, data=set1_data) + DataTree(name="set1", parent=root) + DataTree(name="set2", parent=set1) + + def test_create_full_tree(self, simple_datatree): + root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])}) + set1_data = xr.Dataset({"a": 0, "b": 1}) + set2_data = xr.Dataset({"a": ("x", [2, 3]), "b": ("x", [0.1, 0.2])}) + + root: DataTree = DataTree(data=root_data) + set1: DataTree = DataTree(name="set1", parent=root, data=set1_data) + DataTree(name="set1", parent=set1) + DataTree(name="set2", parent=set1) + set2: DataTree = DataTree(name="set2", parent=root, data=set2_data) + DataTree(name="set1", parent=set2) + DataTree(name="set3", parent=root) + + expected = simple_datatree + assert root.identical(expected) + + +class TestNames: + def test_child_gets_named_on_attach(self): + sue: DataTree = DataTree() + mary: DataTree = DataTree(children={"Sue": sue}) # noqa + assert sue.name == "Sue" + + +class TestPaths: + def test_path_property(self): + sue: DataTree = DataTree() + mary: DataTree = DataTree(children={"Sue": sue}) + john: DataTree = DataTree(children={"Mary": mary}) + assert sue.path == "/Mary/Sue" + assert john.path == "/" + + def test_path_roundtrip(self): + sue: DataTree = DataTree() + mary: DataTree = DataTree(children={"Sue": sue}) + john: DataTree = DataTree(children={"Mary": mary}) + assert john[sue.path] is sue + + def test_same_tree(self): + mary: DataTree = DataTree() + kate: DataTree = DataTree() + john: DataTree = DataTree(children={"Mary": mary, "Kate": kate}) # noqa + assert mary.same_tree(kate) + + def test_relative_paths(self): + sue: DataTree = DataTree() + mary: DataTree = DataTree(children={"Sue": sue}) + annie: DataTree = DataTree() + john: DataTree = DataTree(children={"Mary": mary, "Annie": annie}) + + result = sue.relative_to(john) + assert result == "Mary/Sue" + assert john.relative_to(sue) == "../.." + assert annie.relative_to(sue) == "../../Annie" + assert sue.relative_to(annie) == "../Mary/Sue" + assert sue.relative_to(sue) == "." + + evil_kate: DataTree = DataTree() + with pytest.raises( + NotFoundInTreeError, match="nodes do not lie within the same tree" + ): + sue.relative_to(evil_kate) + + +class TestStoreDatasets: + def test_create_with_data(self): + dat = xr.Dataset({"a": 0}) + john: DataTree = DataTree(name="john", data=dat) + + assert_identical(john.to_dataset(), dat) + + with pytest.raises(TypeError): + DataTree(name="mary", parent=john, data="junk") # type: ignore[arg-type] + + def test_set_data(self): + john: DataTree = DataTree(name="john") + dat = xr.Dataset({"a": 0}) + john.ds = dat # type: ignore[assignment] + + assert_identical(john.to_dataset(), dat) + + with pytest.raises(TypeError): + john.ds = "junk" # type: ignore[assignment] + + def test_has_data(self): + john: DataTree = DataTree(name="john", data=xr.Dataset({"a": 0})) + assert john.has_data + + john_no_data: DataTree = DataTree(name="john", data=None) + assert not john_no_data.has_data + + def test_is_hollow(self): + john: DataTree = DataTree(data=xr.Dataset({"a": 0})) + assert john.is_hollow + + eve: DataTree = DataTree(children={"john": john}) + assert eve.is_hollow + + eve.ds = xr.Dataset({"a": 1}) # type: ignore[assignment] + assert not eve.is_hollow + + +class TestVariablesChildrenNameCollisions: + def test_parent_already_has_variable_with_childs_name(self): + dt: DataTree = DataTree(data=xr.Dataset({"a": [0], "b": 1})) + with pytest.raises(KeyError, match="already contains a data variable named a"): + DataTree(name="a", data=None, parent=dt) + + def test_assign_when_already_child_with_variables_name(self): + dt: DataTree = DataTree(data=None) + DataTree(name="a", data=None, parent=dt) + with pytest.raises(KeyError, match="names would collide"): + dt.ds = xr.Dataset({"a": 0}) # type: ignore[assignment] + + dt.ds = xr.Dataset() # type: ignore[assignment] + + new_ds = dt.to_dataset().assign(a=xr.DataArray(0)) + with pytest.raises(KeyError, match="names would collide"): + dt.ds = new_ds # type: ignore[assignment] + + +class TestGet: ... + + +class TestGetItem: + def test_getitem_node(self): + folder1: DataTree = DataTree(name="folder1") + results: DataTree = DataTree(name="results", parent=folder1) + highres: DataTree = DataTree(name="highres", parent=results) + assert folder1["results"] is results + assert folder1["results/highres"] is highres + + def test_getitem_self(self): + dt: DataTree = DataTree() + assert dt["."] is dt + + def test_getitem_single_data_variable(self): + data = xr.Dataset({"temp": [0, 50]}) + results: DataTree = DataTree(name="results", data=data) + assert_identical(results["temp"], data["temp"]) + + def test_getitem_single_data_variable_from_node(self): + data = xr.Dataset({"temp": [0, 50]}) + folder1: DataTree = DataTree(name="folder1") + results: DataTree = DataTree(name="results", parent=folder1) + DataTree(name="highres", parent=results, data=data) + assert_identical(folder1["results/highres/temp"], data["temp"]) + + def test_getitem_nonexistent_node(self): + folder1: DataTree = DataTree(name="folder1") + DataTree(name="results", parent=folder1) + with pytest.raises(KeyError): + folder1["results/highres"] + + def test_getitem_nonexistent_variable(self): + data = xr.Dataset({"temp": [0, 50]}) + results: DataTree = DataTree(name="results", data=data) + with pytest.raises(KeyError): + results["pressure"] + + @pytest.mark.xfail(reason="Should be deprecated in favour of .subset") + def test_getitem_multiple_data_variables(self): + data = xr.Dataset({"temp": [0, 50], "p": [5, 8, 7]}) + results: DataTree = DataTree(name="results", data=data) + assert_identical(results[["temp", "p"]], data[["temp", "p"]]) # type: ignore[index] + + @pytest.mark.xfail( + reason="Indexing needs to return whole tree (GH https://github.com/xarray-contrib/datatree/issues/77)" + ) + def test_getitem_dict_like_selection_access_to_dataset(self): + data = xr.Dataset({"temp": [0, 50]}) + results: DataTree = DataTree(name="results", data=data) + assert_identical(results[{"temp": 1}], data[{"temp": 1}]) # type: ignore[index] + + +class TestUpdate: + def test_update(self): + dt: DataTree = DataTree() + dt.update({"foo": xr.DataArray(0), "a": DataTree()}) + expected = DataTree.from_dict({"/": xr.Dataset({"foo": 0}), "a": None}) + print(dt) + print(dt.children) + print(dt._children) + print(dt["a"]) + print(expected) + assert_equal(dt, expected) + + def test_update_new_named_dataarray(self): + da = xr.DataArray(name="temp", data=[0, 50]) + folder1: DataTree = DataTree(name="folder1") + folder1.update({"results": da}) + expected = da.rename("results") + assert_equal(folder1["results"], expected) + + def test_update_doesnt_alter_child_name(self): + dt: DataTree = DataTree() + dt.update({"foo": xr.DataArray(0), "a": DataTree(name="b")}) + assert "a" in dt.children + child = dt["a"] + assert child.name == "a" + + def test_update_overwrite(self): + actual = DataTree.from_dict({"a": DataTree(xr.Dataset({"x": 1}))}) + actual.update({"a": DataTree(xr.Dataset({"x": 2}))}) + + expected = DataTree.from_dict({"a": DataTree(xr.Dataset({"x": 2}))}) + + print(actual) + print(expected) + + assert_equal(actual, expected) + + +class TestCopy: + def test_copy(self, create_test_datatree): + dt = create_test_datatree() + + for node in dt.root.subtree: + node.attrs["Test"] = [1, 2, 3] + + for copied in [dt.copy(deep=False), copy(dt)]: + assert_identical(dt, copied) + + for node, copied_node in zip(dt.root.subtree, copied.root.subtree): + assert node.encoding == copied_node.encoding + # Note: IndexVariable objects with string dtype are always + # copied because of xarray.core.util.safe_cast_to_index. + # Limiting the test to data variables. + for k in node.data_vars: + v0 = node.variables[k] + v1 = copied_node.variables[k] + assert source_ndarray(v0.data) is source_ndarray(v1.data) + copied_node["foo"] = xr.DataArray(data=np.arange(5), dims="z") + assert "foo" not in node + + copied_node.attrs["foo"] = "bar" + assert "foo" not in node.attrs + assert node.attrs["Test"] is copied_node.attrs["Test"] + + def test_copy_subtree(self): + dt = DataTree.from_dict({"/level1/level2/level3": xr.Dataset()}) + + actual = dt["/level1/level2"].copy() + expected = DataTree.from_dict({"/level3": xr.Dataset()}, name="level2") + + assert_identical(actual, expected) + + def test_deepcopy(self, create_test_datatree): + dt = create_test_datatree() + + for node in dt.root.subtree: + node.attrs["Test"] = [1, 2, 3] + + for copied in [dt.copy(deep=True), deepcopy(dt)]: + assert_identical(dt, copied) + + for node, copied_node in zip(dt.root.subtree, copied.root.subtree): + assert node.encoding == copied_node.encoding + # Note: IndexVariable objects with string dtype are always + # copied because of xarray.core.util.safe_cast_to_index. + # Limiting the test to data variables. + for k in node.data_vars: + v0 = node.variables[k] + v1 = copied_node.variables[k] + assert source_ndarray(v0.data) is not source_ndarray(v1.data) + copied_node["foo"] = xr.DataArray(data=np.arange(5), dims="z") + assert "foo" not in node + + copied_node.attrs["foo"] = "bar" + assert "foo" not in node.attrs + assert node.attrs["Test"] is not copied_node.attrs["Test"] + + @pytest.mark.xfail(reason="data argument not yet implemented") + def test_copy_with_data(self, create_test_datatree): + orig = create_test_datatree() + # TODO use .data_vars once that property is available + data_vars = { + k: v for k, v in orig.variables.items() if k not in orig._coord_names + } + new_data = {k: np.random.randn(*v.shape) for k, v in data_vars.items()} + actual = orig.copy(data=new_data) + + expected = orig.copy() + for k, v in new_data.items(): + expected[k].data = v + assert_identical(expected, actual) + + # TODO test parents and children? + + +class TestSetItem: + def test_setitem_new_child_node(self): + john: DataTree = DataTree(name="john") + mary: DataTree = DataTree(name="mary") + john["mary"] = mary + + grafted_mary = john["mary"] + assert grafted_mary.parent is john + assert grafted_mary.name == "mary" + + def test_setitem_unnamed_child_node_becomes_named(self): + john2: DataTree = DataTree(name="john2") + john2["sonny"] = DataTree() + assert john2["sonny"].name == "sonny" + + def test_setitem_new_grandchild_node(self): + john: DataTree = DataTree(name="john") + mary: DataTree = DataTree(name="mary", parent=john) + rose: DataTree = DataTree(name="rose") + john["mary/rose"] = rose + + grafted_rose = john["mary/rose"] + assert grafted_rose.parent is mary + assert grafted_rose.name == "rose" + + def test_grafted_subtree_retains_name(self): + subtree: DataTree = DataTree(name="original_subtree_name") + root: DataTree = DataTree(name="root") + root["new_subtree_name"] = subtree # noqa + assert subtree.name == "original_subtree_name" + + def test_setitem_new_empty_node(self): + john: DataTree = DataTree(name="john") + john["mary"] = DataTree() + mary = john["mary"] + assert isinstance(mary, DataTree) + assert_identical(mary.to_dataset(), xr.Dataset()) + + def test_setitem_overwrite_data_in_node_with_none(self): + john: DataTree = DataTree(name="john") + mary: DataTree = DataTree(name="mary", parent=john, data=xr.Dataset()) + john["mary"] = DataTree() + assert_identical(mary.to_dataset(), xr.Dataset()) + + john.ds = xr.Dataset() # type: ignore[assignment] + with pytest.raises(ValueError, match="has no name"): + john["."] = DataTree() + + @pytest.mark.xfail(reason="assigning Datasets doesn't yet create new nodes") + def test_setitem_dataset_on_this_node(self): + data = xr.Dataset({"temp": [0, 50]}) + results: DataTree = DataTree(name="results") + results["."] = data + assert_identical(results.to_dataset(), data) + + @pytest.mark.xfail(reason="assigning Datasets doesn't yet create new nodes") + def test_setitem_dataset_as_new_node(self): + data = xr.Dataset({"temp": [0, 50]}) + folder1: DataTree = DataTree(name="folder1") + folder1["results"] = data + assert_identical(folder1["results"].to_dataset(), data) + + @pytest.mark.xfail(reason="assigning Datasets doesn't yet create new nodes") + def test_setitem_dataset_as_new_node_requiring_intermediate_nodes(self): + data = xr.Dataset({"temp": [0, 50]}) + folder1: DataTree = DataTree(name="folder1") + folder1["results/highres"] = data + assert_identical(folder1["results/highres"].to_dataset(), data) + + def test_setitem_named_dataarray(self): + da = xr.DataArray(name="temp", data=[0, 50]) + folder1: DataTree = DataTree(name="folder1") + folder1["results"] = da + expected = da.rename("results") + assert_equal(folder1["results"], expected) + + def test_setitem_unnamed_dataarray(self): + data = xr.DataArray([0, 50]) + folder1: DataTree = DataTree(name="folder1") + folder1["results"] = data + assert_equal(folder1["results"], data) + + def test_setitem_variable(self): + var = xr.Variable(data=[0, 50], dims="x") + folder1: DataTree = DataTree(name="folder1") + folder1["results"] = var + assert_equal(folder1["results"], xr.DataArray(var)) + + def test_setitem_coerce_to_dataarray(self): + folder1: DataTree = DataTree(name="folder1") + folder1["results"] = 0 + assert_equal(folder1["results"], xr.DataArray(0)) + + def test_setitem_add_new_variable_to_empty_node(self): + results: DataTree = DataTree(name="results") + results["pressure"] = xr.DataArray(data=[2, 3]) + assert "pressure" in results.ds + results["temp"] = xr.Variable(data=[10, 11], dims=["x"]) + assert "temp" in results.ds + + # What if there is a path to traverse first? + results_with_path: DataTree = DataTree(name="results") + results_with_path["highres/pressure"] = xr.DataArray(data=[2, 3]) + assert "pressure" in results_with_path["highres"].ds + results_with_path["highres/temp"] = xr.Variable(data=[10, 11], dims=["x"]) + assert "temp" in results_with_path["highres"].ds + + def test_setitem_dataarray_replace_existing_node(self): + t = xr.Dataset({"temp": [0, 50]}) + results: DataTree = DataTree(name="results", data=t) + p = xr.DataArray(data=[2, 3]) + results["pressure"] = p + expected = t.assign(pressure=p) + assert_identical(results.to_dataset(), expected) + + +class TestDictionaryInterface: ... + + +class TestTreeFromDict: + def test_data_in_root(self): + dat = xr.Dataset() + dt = DataTree.from_dict({"/": dat}) + assert dt.name is None + assert dt.parent is None + assert dt.children == {} + assert_identical(dt.to_dataset(), dat) + + def test_one_layer(self): + dat1, dat2 = xr.Dataset({"a": 1}), xr.Dataset({"b": 2}) + dt = DataTree.from_dict({"run1": dat1, "run2": dat2}) + assert_identical(dt.to_dataset(), xr.Dataset()) + assert dt.name is None + assert_identical(dt["run1"].to_dataset(), dat1) + assert dt["run1"].children == {} + assert_identical(dt["run2"].to_dataset(), dat2) + assert dt["run2"].children == {} + + def test_two_layers(self): + dat1, dat2 = xr.Dataset({"a": 1}), xr.Dataset({"a": [1, 2]}) + dt = DataTree.from_dict({"highres/run": dat1, "lowres/run": dat2}) + assert "highres" in dt.children + assert "lowres" in dt.children + highres_run = dt["highres/run"] + assert_identical(highres_run.to_dataset(), dat1) + + def test_nones(self): + dt = DataTree.from_dict({"d": None, "d/e": None}) + assert [node.name for node in dt.subtree] == [None, "d", "e"] + assert [node.path for node in dt.subtree] == ["/", "/d", "/d/e"] + assert_identical(dt["d/e"].to_dataset(), xr.Dataset()) + + def test_full(self, simple_datatree): + dt = simple_datatree + paths = list(node.path for node in dt.subtree) + assert paths == [ + "/", + "/set1", + "/set2", + "/set3", + "/set1/set1", + "/set1/set2", + "/set2/set1", + ] + + def test_datatree_values(self): + dat1: DataTree = DataTree(data=xr.Dataset({"a": 1})) + expected: DataTree = DataTree() + expected["a"] = dat1 + + actual = DataTree.from_dict({"a": dat1}) + + assert_identical(actual, expected) + + def test_roundtrip(self, simple_datatree): + dt = simple_datatree + roundtrip = DataTree.from_dict(dt.to_dict()) + assert roundtrip.equals(dt) + + @pytest.mark.xfail + def test_roundtrip_unnamed_root(self, simple_datatree): + # See GH81 + + dt = simple_datatree + dt.name = "root" + roundtrip = DataTree.from_dict(dt.to_dict()) + assert roundtrip.equals(dt) + + +class TestDatasetView: + def test_view_contents(self): + ds = create_test_data() + dt: DataTree = DataTree(data=ds) + assert ds.identical( + dt.ds + ) # this only works because Dataset.identical doesn't check types + assert isinstance(dt.ds, xr.Dataset) + + def test_immutability(self): + # See issue https://github.com/xarray-contrib/datatree/issues/38 + dt: DataTree = DataTree(name="root", data=None) + DataTree(name="a", data=None, parent=dt) + + with pytest.raises( + AttributeError, match="Mutation of the DatasetView is not allowed" + ): + dt.ds["a"] = xr.DataArray(0) + + with pytest.raises( + AttributeError, match="Mutation of the DatasetView is not allowed" + ): + dt.ds.update({"a": 0}) + + # TODO are there any other ways you can normally modify state (in-place)? + # (not attribute-like assignment because that doesn't work on Dataset anyway) + + def test_methods(self): + ds = create_test_data() + dt: DataTree = DataTree(data=ds) + assert ds.mean().identical(dt.ds.mean()) + assert type(dt.ds.mean()) == xr.Dataset + + def test_arithmetic(self, create_test_datatree): + dt = create_test_datatree() + expected = create_test_datatree(modify=lambda ds: 10.0 * ds)["set1"] + result = 10.0 * dt["set1"].ds + assert result.identical(expected) + + def test_init_via_type(self): + # from datatree GH issue https://github.com/xarray-contrib/datatree/issues/188 + # xarray's .weighted is unusual because it uses type() to create a Dataset/DataArray + + a = xr.DataArray( + np.random.rand(3, 4, 10), + dims=["x", "y", "time"], + coords={"area": (["x", "y"], np.random.rand(3, 4))}, + ).to_dataset(name="data") + dt: DataTree = DataTree(data=a) + + def weighted_mean(ds): + return ds.weighted(ds.area).mean(["x", "y"]) + + weighted_mean(dt.ds) + + +class TestAccess: + def test_attribute_access(self, create_test_datatree): + dt = create_test_datatree() + + # vars / coords + for key in ["a", "set0"]: + assert_equal(dt[key], getattr(dt, key)) + assert key in dir(dt) + + # dims + assert_equal(dt["a"]["y"], getattr(dt.a, "y")) + assert "y" in dir(dt["a"]) + + # children + for key in ["set1", "set2", "set3"]: + assert_equal(dt[key], getattr(dt, key)) + assert key in dir(dt) + + # attrs + dt.attrs["meta"] = "NASA" + assert dt.attrs["meta"] == "NASA" + assert "meta" in dir(dt) + + def test_ipython_key_completions(self, create_test_datatree): + dt = create_test_datatree() + key_completions = dt._ipython_key_completions_() + + node_keys = [node.path[1:] for node in dt.subtree] + assert all(node_key in key_completions for node_key in node_keys) + + var_keys = list(dt.variables.keys()) + assert all(var_key in key_completions for var_key in var_keys) + + def test_operation_with_attrs_but_no_data(self): + # tests bug from xarray-datatree GH262 + xs = xr.Dataset({"testvar": xr.DataArray(np.ones((2, 3)))}) + dt = DataTree.from_dict({"node1": xs, "node2": xs}) + dt.attrs["test_key"] = 1 # sel works fine without this line + dt.sel(dim_0=0) + + +class TestRestructuring: + def test_drop_nodes(self): + sue = DataTree.from_dict({"Mary": None, "Kate": None, "Ashley": None}) + + # test drop just one node + dropped_one = sue.drop_nodes(names="Mary") + assert "Mary" not in dropped_one.children + + # test drop multiple nodes + dropped = sue.drop_nodes(names=["Mary", "Kate"]) + assert not set(["Mary", "Kate"]).intersection(set(dropped.children)) + assert "Ashley" in dropped.children + + # test raise + with pytest.raises(KeyError, match="nodes {'Mary'} not present"): + dropped.drop_nodes(names=["Mary", "Ashley"]) + + # test ignore + childless = dropped.drop_nodes(names=["Mary", "Ashley"], errors="ignore") + assert childless.children == {} + + def test_assign(self): + dt: DataTree = DataTree() + expected = DataTree.from_dict({"/": xr.Dataset({"foo": 0}), "/a": None}) + + # kwargs form + result = dt.assign(foo=xr.DataArray(0), a=DataTree()) + assert_equal(result, expected) + + # dict form + result = dt.assign({"foo": xr.DataArray(0), "a": DataTree()}) + assert_equal(result, expected) + + +class TestPipe: + def test_noop(self, create_test_datatree): + dt = create_test_datatree() + + actual = dt.pipe(lambda tree: tree) + assert actual.identical(dt) + + def test_params(self, create_test_datatree): + dt = create_test_datatree() + + def f(tree, **attrs): + return tree.assign(arr_with_attrs=xr.Variable("dim0", [], attrs=attrs)) + + attrs = {"x": 1, "y": 2, "z": 3} + + actual = dt.pipe(f, **attrs) + assert actual["arr_with_attrs"].attrs == attrs + + def test_named_self(self, create_test_datatree): + dt = create_test_datatree() + + def f(x, tree, y): + tree.attrs.update({"x": x, "y": y}) + return tree + + attrs = {"x": 1, "y": 2} + + actual = dt.pipe((f, "tree"), **attrs) + + assert actual is dt and actual.attrs == attrs + + +class TestSubset: + def test_match(self): + # TODO is this example going to cause problems with case sensitivity? + dt: DataTree = DataTree.from_dict( + { + "/a/A": None, + "/a/B": None, + "/b/A": None, + "/b/B": None, + } + ) + result = dt.match("*/B") + expected = DataTree.from_dict( + { + "/a/B": None, + "/b/B": None, + } + ) + assert_identical(result, expected) + + def test_filter(self): + simpsons: DataTree = DataTree.from_dict( + d={ + "/": xr.Dataset({"age": 83}), + "/Herbert": xr.Dataset({"age": 40}), + "/Homer": xr.Dataset({"age": 39}), + "/Homer/Bart": xr.Dataset({"age": 10}), + "/Homer/Lisa": xr.Dataset({"age": 8}), + "/Homer/Maggie": xr.Dataset({"age": 1}), + }, + name="Abe", + ) + expected = DataTree.from_dict( + d={ + "/": xr.Dataset({"age": 83}), + "/Herbert": xr.Dataset({"age": 40}), + "/Homer": xr.Dataset({"age": 39}), + }, + name="Abe", + ) + elders = simpsons.filter(lambda node: node["age"].item() > 18) + assert_identical(elders, expected) + + +class TestDSMethodInheritance: + def test_dataset_method(self): + ds = xr.Dataset({"a": ("x", [1, 2, 3])}) + dt: DataTree = DataTree(data=ds) + DataTree(name="results", parent=dt, data=ds) + + expected: DataTree = DataTree(data=ds.isel(x=1)) + DataTree(name="results", parent=expected, data=ds.isel(x=1)) + + result = dt.isel(x=1) + assert_equal(result, expected) + + def test_reduce_method(self): + ds = xr.Dataset({"a": ("x", [False, True, False])}) + dt: DataTree = DataTree(data=ds) + DataTree(name="results", parent=dt, data=ds) + + expected: DataTree = DataTree(data=ds.any()) + DataTree(name="results", parent=expected, data=ds.any()) + + result = dt.any() + assert_equal(result, expected) + + def test_nan_reduce_method(self): + ds = xr.Dataset({"a": ("x", [1, 2, 3])}) + dt: DataTree = DataTree(data=ds) + DataTree(name="results", parent=dt, data=ds) + + expected: DataTree = DataTree(data=ds.mean()) + DataTree(name="results", parent=expected, data=ds.mean()) + + result = dt.mean() + assert_equal(result, expected) + + def test_cum_method(self): + ds = xr.Dataset({"a": ("x", [1, 2, 3])}) + dt: DataTree = DataTree(data=ds) + DataTree(name="results", parent=dt, data=ds) + + expected: DataTree = DataTree(data=ds.cumsum()) + DataTree(name="results", parent=expected, data=ds.cumsum()) + + result = dt.cumsum() + assert_equal(result, expected) + + +class TestOps: + def test_binary_op_on_int(self): + ds1 = xr.Dataset({"a": [5], "b": [3]}) + ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]}) + dt: DataTree = DataTree(data=ds1) + DataTree(name="subnode", data=ds2, parent=dt) + + expected: DataTree = DataTree(data=ds1 * 5) + DataTree(name="subnode", data=ds2 * 5, parent=expected) + + # TODO: Remove ignore when ops.py is migrated? + result: DataTree = dt * 5 # type: ignore[assignment,operator] + assert_equal(result, expected) + + def test_binary_op_on_dataset(self): + ds1 = xr.Dataset({"a": [5], "b": [3]}) + ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]}) + dt: DataTree = DataTree(data=ds1) + DataTree(name="subnode", data=ds2, parent=dt) + other_ds = xr.Dataset({"z": ("z", [0.1, 0.2])}) + + expected: DataTree = DataTree(data=ds1 * other_ds) + DataTree(name="subnode", data=ds2 * other_ds, parent=expected) + + result = dt * other_ds + assert_equal(result, expected) + + def test_binary_op_on_datatree(self): + ds1 = xr.Dataset({"a": [5], "b": [3]}) + ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]}) + dt: DataTree = DataTree(data=ds1) + DataTree(name="subnode", data=ds2, parent=dt) + + expected: DataTree = DataTree(data=ds1 * ds1) + DataTree(name="subnode", data=ds2 * ds2, parent=expected) + + # TODO: Remove ignore when ops.py is migrated? + result: DataTree = dt * dt # type: ignore[operator] + assert_equal(result, expected) + + +class TestUFuncs: + def test_tree(self, create_test_datatree): + dt = create_test_datatree() + expected = create_test_datatree(modify=lambda ds: np.sin(ds)) + result_tree = np.sin(dt) + assert_equal(result_tree, expected) + + +class TestDocInsertion: + """Tests map_over_subtree docstring injection.""" + + def test_standard_doc(self): + + dataset_doc = dedent( + """\ + Manually trigger loading and/or computation of this dataset's data + from disk or a remote source into memory and return this dataset. + Unlike compute, the original dataset is modified and returned. + + Normally, it should not be necessary to call this method in user code, + because all xarray functions should either work on deferred data or + load data automatically. However, this method can be necessary when + working with many file objects on disk. + + Parameters + ---------- + **kwargs : dict + Additional keyword arguments passed on to ``dask.compute``. + + See Also + -------- + dask.compute""" + ) + + expected_doc = dedent( + """\ + Manually trigger loading and/or computation of this dataset's data + from disk or a remote source into memory and return this dataset. + Unlike compute, the original dataset is modified and returned. + + .. note:: + This method was copied from xarray.Dataset, but has been altered to + call the method on the Datasets stored in every node of the + subtree. See the `map_over_subtree` function for more details. + + Normally, it should not be necessary to call this method in user code, + because all xarray functions should either work on deferred data or + load data automatically. However, this method can be necessary when + working with many file objects on disk. + + Parameters + ---------- + **kwargs : dict + Additional keyword arguments passed on to ``dask.compute``. + + See Also + -------- + dask.compute""" + ) + + wrapped_doc = insert_doc_addendum(dataset_doc, _MAPPED_DOCSTRING_ADDENDUM) + + assert expected_doc == wrapped_doc + + def test_one_liner(self): + mixin_doc = "Same as abs(a)." + + expected_doc = dedent( + """\ + Same as abs(a). + + This method was copied from xarray.Dataset, but has been altered to call the + method on the Datasets stored in every node of the subtree. See the + `map_over_subtree` function for more details.""" + ) + + actual_doc = insert_doc_addendum(mixin_doc, _MAPPED_DOCSTRING_ADDENDUM) + assert expected_doc == actual_doc + + def test_none(self): + actual_doc = insert_doc_addendum(None, _MAPPED_DOCSTRING_ADDENDUM) + assert actual_doc is None diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_datatree_mapping.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_datatree_mapping.py new file mode 100644 index 0000000..b8b5561 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_datatree_mapping.py @@ -0,0 +1,347 @@ +import numpy as np +import pytest + +import xarray as xr +from xarray.core.datatree import DataTree +from xarray.core.datatree_mapping import ( + TreeIsomorphismError, + check_isomorphic, + map_over_subtree, +) +from xarray.testing import assert_equal + +empty = xr.Dataset() + + +class TestCheckTreesIsomorphic: + def test_not_a_tree(self): + with pytest.raises(TypeError, match="not a tree"): + check_isomorphic("s", 1) # type: ignore[arg-type] + + def test_different_widths(self): + dt1 = DataTree.from_dict(d={"a": empty}) + dt2 = DataTree.from_dict(d={"b": empty, "c": empty}) + expected_err_str = ( + "Number of children on node '/' of the left object: 1\n" + "Number of children on node '/' of the right object: 2" + ) + with pytest.raises(TreeIsomorphismError, match=expected_err_str): + check_isomorphic(dt1, dt2) + + def test_different_heights(self): + dt1 = DataTree.from_dict({"a": empty}) + dt2 = DataTree.from_dict({"b": empty, "b/c": empty}) + expected_err_str = ( + "Number of children on node '/a' of the left object: 0\n" + "Number of children on node '/b' of the right object: 1" + ) + with pytest.raises(TreeIsomorphismError, match=expected_err_str): + check_isomorphic(dt1, dt2) + + def test_names_different(self): + dt1 = DataTree.from_dict({"a": xr.Dataset()}) + dt2 = DataTree.from_dict({"b": empty}) + expected_err_str = ( + "Node '/a' in the left object has name 'a'\n" + "Node '/b' in the right object has name 'b'" + ) + with pytest.raises(TreeIsomorphismError, match=expected_err_str): + check_isomorphic(dt1, dt2, require_names_equal=True) + + def test_isomorphic_names_equal(self): + dt1 = DataTree.from_dict({"a": empty, "b": empty, "b/c": empty, "b/d": empty}) + dt2 = DataTree.from_dict({"a": empty, "b": empty, "b/c": empty, "b/d": empty}) + check_isomorphic(dt1, dt2, require_names_equal=True) + + def test_isomorphic_ordering(self): + dt1 = DataTree.from_dict({"a": empty, "b": empty, "b/d": empty, "b/c": empty}) + dt2 = DataTree.from_dict({"a": empty, "b": empty, "b/c": empty, "b/d": empty}) + check_isomorphic(dt1, dt2, require_names_equal=False) + + def test_isomorphic_names_not_equal(self): + dt1 = DataTree.from_dict({"a": empty, "b": empty, "b/c": empty, "b/d": empty}) + dt2 = DataTree.from_dict({"A": empty, "B": empty, "B/C": empty, "B/D": empty}) + check_isomorphic(dt1, dt2) + + def test_not_isomorphic_complex_tree(self, create_test_datatree): + dt1 = create_test_datatree() + dt2 = create_test_datatree() + dt2["set1/set2/extra"] = DataTree(name="extra") + with pytest.raises(TreeIsomorphismError, match="/set1/set2"): + check_isomorphic(dt1, dt2) + + def test_checking_from_root(self, create_test_datatree): + dt1 = create_test_datatree() + dt2 = create_test_datatree() + real_root: DataTree = DataTree(name="real root") + dt2.name = "not_real_root" + dt2.parent = real_root + with pytest.raises(TreeIsomorphismError): + check_isomorphic(dt1, dt2, check_from_root=True) + + +class TestMapOverSubTree: + def test_no_trees_passed(self): + @map_over_subtree + def times_ten(ds): + return 10.0 * ds + + with pytest.raises(TypeError, match="Must pass at least one tree"): + times_ten("dt") + + def test_not_isomorphic(self, create_test_datatree): + dt1 = create_test_datatree() + dt2 = create_test_datatree() + dt2["set1/set2/extra"] = DataTree(name="extra") + + @map_over_subtree + def times_ten(ds1, ds2): + return ds1 * ds2 + + with pytest.raises(TreeIsomorphismError): + times_ten(dt1, dt2) + + def test_no_trees_returned(self, create_test_datatree): + dt1 = create_test_datatree() + dt2 = create_test_datatree() + + @map_over_subtree + def bad_func(ds1, ds2): + return None + + with pytest.raises(TypeError, match="return value of None"): + bad_func(dt1, dt2) + + def test_single_dt_arg(self, create_test_datatree): + dt = create_test_datatree() + + @map_over_subtree + def times_ten(ds): + return 10.0 * ds + + expected = create_test_datatree(modify=lambda ds: 10.0 * ds) + result_tree = times_ten(dt) + assert_equal(result_tree, expected) + + def test_single_dt_arg_plus_args_and_kwargs(self, create_test_datatree): + dt = create_test_datatree() + + @map_over_subtree + def multiply_then_add(ds, times, add=0.0): + return (times * ds) + add + + expected = create_test_datatree(modify=lambda ds: (10.0 * ds) + 2.0) + result_tree = multiply_then_add(dt, 10.0, add=2.0) + assert_equal(result_tree, expected) + + def test_multiple_dt_args(self, create_test_datatree): + dt1 = create_test_datatree() + dt2 = create_test_datatree() + + @map_over_subtree + def add(ds1, ds2): + return ds1 + ds2 + + expected = create_test_datatree(modify=lambda ds: 2.0 * ds) + result = add(dt1, dt2) + assert_equal(result, expected) + + def test_dt_as_kwarg(self, create_test_datatree): + dt1 = create_test_datatree() + dt2 = create_test_datatree() + + @map_over_subtree + def add(ds1, value=0.0): + return ds1 + value + + expected = create_test_datatree(modify=lambda ds: 2.0 * ds) + result = add(dt1, value=dt2) + assert_equal(result, expected) + + def test_return_multiple_dts(self, create_test_datatree): + dt = create_test_datatree() + + @map_over_subtree + def minmax(ds): + return ds.min(), ds.max() + + dt_min, dt_max = minmax(dt) + expected_min = create_test_datatree(modify=lambda ds: ds.min()) + assert_equal(dt_min, expected_min) + expected_max = create_test_datatree(modify=lambda ds: ds.max()) + assert_equal(dt_max, expected_max) + + def test_return_wrong_type(self, simple_datatree): + dt1 = simple_datatree + + @map_over_subtree + def bad_func(ds1): + return "string" + + with pytest.raises(TypeError, match="not Dataset or DataArray"): + bad_func(dt1) + + def test_return_tuple_of_wrong_types(self, simple_datatree): + dt1 = simple_datatree + + @map_over_subtree + def bad_func(ds1): + return xr.Dataset(), "string" + + with pytest.raises(TypeError, match="not Dataset or DataArray"): + bad_func(dt1) + + @pytest.mark.xfail + def test_return_inconsistent_number_of_results(self, simple_datatree): + dt1 = simple_datatree + + @map_over_subtree + def bad_func(ds): + # Datasets in simple_datatree have different numbers of dims + # TODO need to instead return different numbers of Dataset objects for this test to catch the intended error + return tuple(ds.dims) + + with pytest.raises(TypeError, match="instead returns"): + bad_func(dt1) + + def test_wrong_number_of_arguments_for_func(self, simple_datatree): + dt = simple_datatree + + @map_over_subtree + def times_ten(ds): + return 10.0 * ds + + with pytest.raises( + TypeError, match="takes 1 positional argument but 2 were given" + ): + times_ten(dt, dt) + + def test_map_single_dataset_against_whole_tree(self, create_test_datatree): + dt = create_test_datatree() + + @map_over_subtree + def nodewise_merge(node_ds, fixed_ds): + return xr.merge([node_ds, fixed_ds]) + + other_ds = xr.Dataset({"z": ("z", [0])}) + expected = create_test_datatree(modify=lambda ds: xr.merge([ds, other_ds])) + result_tree = nodewise_merge(dt, other_ds) + assert_equal(result_tree, expected) + + @pytest.mark.xfail + def test_trees_with_different_node_names(self): + # TODO test this after I've got good tests for renaming nodes + raise NotImplementedError + + def test_dt_method(self, create_test_datatree): + dt = create_test_datatree() + + def multiply_then_add(ds, times, add=0.0): + return times * ds + add + + expected = create_test_datatree(modify=lambda ds: (10.0 * ds) + 2.0) + result_tree = dt.map_over_subtree(multiply_then_add, 10.0, add=2.0) + assert_equal(result_tree, expected) + + def test_discard_ancestry(self, create_test_datatree): + # Check for datatree GH issue https://github.com/xarray-contrib/datatree/issues/48 + dt = create_test_datatree() + subtree = dt["set1"] + + @map_over_subtree + def times_ten(ds): + return 10.0 * ds + + expected = create_test_datatree(modify=lambda ds: 10.0 * ds)["set1"] + result_tree = times_ten(subtree) + assert_equal(result_tree, expected, from_root=False) + + def test_skip_empty_nodes_with_attrs(self, create_test_datatree): + # inspired by xarray-datatree GH262 + dt = create_test_datatree() + dt["set1/set2"].attrs["foo"] = "bar" + + def check_for_data(ds): + # fails if run on a node that has no data + assert len(ds.variables) != 0 + return ds + + dt.map_over_subtree(check_for_data) + + def test_keep_attrs_on_empty_nodes(self, create_test_datatree): + # GH278 + dt = create_test_datatree() + dt["set1/set2"].attrs["foo"] = "bar" + + def empty_func(ds): + return ds + + result = dt.map_over_subtree(empty_func) + assert result["set1/set2"].attrs == dt["set1/set2"].attrs + + @pytest.mark.xfail( + reason="probably some bug in pytests handling of exception notes" + ) + def test_error_contains_path_of_offending_node(self, create_test_datatree): + dt = create_test_datatree() + dt["set1"]["bad_var"] = 0 + print(dt) + + def fail_on_specific_node(ds): + if "bad_var" in ds: + raise ValueError("Failed because 'bar_var' present in dataset") + + with pytest.raises( + ValueError, match="Raised whilst mapping function over node /set1" + ): + dt.map_over_subtree(fail_on_specific_node) + + +class TestMutableOperations: + def test_construct_using_type(self): + # from datatree GH issue https://github.com/xarray-contrib/datatree/issues/188 + # xarray's .weighted is unusual because it uses type() to create a Dataset/DataArray + + a = xr.DataArray( + np.random.rand(3, 4, 10), + dims=["x", "y", "time"], + coords={"area": (["x", "y"], np.random.rand(3, 4))}, + ).to_dataset(name="data") + b = xr.DataArray( + np.random.rand(2, 6, 14), + dims=["x", "y", "time"], + coords={"area": (["x", "y"], np.random.rand(2, 6))}, + ).to_dataset(name="data") + dt = DataTree.from_dict({"a": a, "b": b}) + + def weighted_mean(ds): + return ds.weighted(ds.area).mean(["x", "y"]) + + dt.map_over_subtree(weighted_mean) + + def test_alter_inplace_forbidden(self): + simpsons = DataTree.from_dict( + d={ + "/": xr.Dataset({"age": 83}), + "/Herbert": xr.Dataset({"age": 40}), + "/Homer": xr.Dataset({"age": 39}), + "/Homer/Bart": xr.Dataset({"age": 10}), + "/Homer/Lisa": xr.Dataset({"age": 8}), + "/Homer/Maggie": xr.Dataset({"age": 1}), + }, + name="Abe", + ) + + def fast_forward(ds: xr.Dataset, years: float) -> xr.Dataset: + """Add some years to the age, but by altering the given dataset""" + ds["age"] = ds["age"] + years + return ds + + with pytest.raises(AttributeError): + simpsons.map_over_subtree(fast_forward, years=10) + + +@pytest.mark.xfail +class TestMapOverSubTreeInplace: + def test_map_over_subtree_inplace(self): + raise NotImplementedError diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_deprecation_helpers.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_deprecation_helpers.py new file mode 100644 index 0000000..f21c809 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_deprecation_helpers.py @@ -0,0 +1,140 @@ +import pytest + +from xarray.util.deprecation_helpers import _deprecate_positional_args + + +def test_deprecate_positional_args_warns_for_function(): + @_deprecate_positional_args("v0.1") + def f1(a, b, *, c="c", d="d"): + return a, b, c, d + + result = f1(1, 2) + assert result == (1, 2, "c", "d") + + result = f1(1, 2, c=3, d=4) + assert result == (1, 2, 3, 4) + + with pytest.warns(FutureWarning, match=r".*v0.1"): + result = f1(1, 2, 3) # type: ignore[misc] + assert result == (1, 2, 3, "d") + + with pytest.warns(FutureWarning, match=r"Passing 'c' as positional"): + result = f1(1, 2, 3) # type: ignore[misc] + assert result == (1, 2, 3, "d") + + with pytest.warns(FutureWarning, match=r"Passing 'c, d' as positional"): + result = f1(1, 2, 3, 4) # type: ignore[misc] + assert result == (1, 2, 3, 4) + + @_deprecate_positional_args("v0.1") + def f2(a="a", *, b="b", c="c", d="d"): + return a, b, c, d + + with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"): + result = f2(1, 2) # type: ignore[misc] + assert result == (1, 2, "c", "d") + + @_deprecate_positional_args("v0.1") + def f3(a, *, b="b", **kwargs): + return a, b, kwargs + + with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"): + result = f3(1, 2) # type: ignore[misc] + assert result == (1, 2, {}) + + with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"): + result = f3(1, 2, f="f") # type: ignore[misc] + assert result == (1, 2, {"f": "f"}) + + @_deprecate_positional_args("v0.1") + def f4(a, /, *, b="b", **kwargs): + return a, b, kwargs + + result = f4(1) + assert result == (1, "b", {}) + + result = f4(1, b=2, f="f") + assert result == (1, 2, {"f": "f"}) + + with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"): + result = f4(1, 2, f="f") # type: ignore[misc] + assert result == (1, 2, {"f": "f"}) + + with pytest.raises(TypeError, match=r"Keyword-only param without default"): + + @_deprecate_positional_args("v0.1") + def f5(a, *, b, c=3, **kwargs): + pass + + +def test_deprecate_positional_args_warns_for_class(): + class A1: + @_deprecate_positional_args("v0.1") + def method(self, a, b, *, c="c", d="d"): + return a, b, c, d + + result = A1().method(1, 2) + assert result == (1, 2, "c", "d") + + result = A1().method(1, 2, c=3, d=4) + assert result == (1, 2, 3, 4) + + with pytest.warns(FutureWarning, match=r".*v0.1"): + result = A1().method(1, 2, 3) # type: ignore[misc] + assert result == (1, 2, 3, "d") + + with pytest.warns(FutureWarning, match=r"Passing 'c' as positional"): + result = A1().method(1, 2, 3) # type: ignore[misc] + assert result == (1, 2, 3, "d") + + with pytest.warns(FutureWarning, match=r"Passing 'c, d' as positional"): + result = A1().method(1, 2, 3, 4) # type: ignore[misc] + assert result == (1, 2, 3, 4) + + class A2: + @_deprecate_positional_args("v0.1") + def method(self, a=1, b=1, *, c="c", d="d"): + return a, b, c, d + + with pytest.warns(FutureWarning, match=r"Passing 'c' as positional"): + result = A2().method(1, 2, 3) # type: ignore[misc] + assert result == (1, 2, 3, "d") + + with pytest.warns(FutureWarning, match=r"Passing 'c, d' as positional"): + result = A2().method(1, 2, 3, 4) # type: ignore[misc] + assert result == (1, 2, 3, 4) + + class A3: + @_deprecate_positional_args("v0.1") + def method(self, a, *, b="b", **kwargs): + return a, b, kwargs + + with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"): + result = A3().method(1, 2) # type: ignore[misc] + assert result == (1, 2, {}) + + with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"): + result = A3().method(1, 2, f="f") # type: ignore[misc] + assert result == (1, 2, {"f": "f"}) + + class A4: + @_deprecate_positional_args("v0.1") + def method(self, a, /, *, b="b", **kwargs): + return a, b, kwargs + + result = A4().method(1) + assert result == (1, "b", {}) + + result = A4().method(1, b=2, f="f") + assert result == (1, 2, {"f": "f"}) + + with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"): + result = A4().method(1, 2, f="f") # type: ignore[misc] + assert result == (1, 2, {"f": "f"}) + + with pytest.raises(TypeError, match=r"Keyword-only param without default"): + + class A5: + @_deprecate_positional_args("v0.1") + def __init__(self, a, *, b, c=3, **kwargs): + pass diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_distributed.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_distributed.py new file mode 100644 index 0000000..d223bce --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_distributed.py @@ -0,0 +1,298 @@ +""" isort:skip_file """ + +from __future__ import annotations + +import pickle +from typing import TYPE_CHECKING, Any + +import numpy as np +import pytest + +if TYPE_CHECKING: + import dask + import dask.array as da + import distributed +else: + dask = pytest.importorskip("dask") + da = pytest.importorskip("dask.array") + distributed = pytest.importorskip("distributed") + +from dask.distributed import Client, Lock +from distributed.client import futures_of +from distributed.utils_test import ( # noqa: F401 + cleanup, + cluster, + gen_cluster, + loop, + loop_in_thread, +) + +import xarray as xr +from xarray.backends.locks import HDF5_LOCK, CombinedLock, SerializableLock +from xarray.tests import ( + assert_allclose, + assert_identical, + has_h5netcdf, + has_netCDF4, + has_scipy, + requires_cftime, + requires_netCDF4, + requires_zarr, +) +from xarray.tests.test_backends import ( + ON_WINDOWS, + create_tmp_file, +) +from xarray.tests.test_dataset import create_test_data + +loop = loop # loop is an imported fixture, which flake8 has issues ack-ing + + +@pytest.fixture +def tmp_netcdf_filename(tmpdir): + return str(tmpdir.join("testfile.nc")) + + +ENGINES = [] +if has_scipy: + ENGINES.append("scipy") +if has_netCDF4: + ENGINES.append("netcdf4") +if has_h5netcdf: + ENGINES.append("h5netcdf") + +NC_FORMATS = { + "netcdf4": [ + "NETCDF3_CLASSIC", + "NETCDF3_64BIT_OFFSET", + "NETCDF3_64BIT_DATA", + "NETCDF4_CLASSIC", + "NETCDF4", + ], + "scipy": ["NETCDF3_CLASSIC", "NETCDF3_64BIT"], + "h5netcdf": ["NETCDF4"], +} + +ENGINES_AND_FORMATS = [ + ("netcdf4", "NETCDF3_CLASSIC"), + ("netcdf4", "NETCDF4_CLASSIC"), + ("netcdf4", "NETCDF4"), + ("h5netcdf", "NETCDF4"), + ("scipy", "NETCDF3_64BIT"), +] + + +@pytest.mark.parametrize("engine,nc_format", ENGINES_AND_FORMATS) +def test_dask_distributed_netcdf_roundtrip( + loop, tmp_netcdf_filename, engine, nc_format +): + if engine not in ENGINES: + pytest.skip("engine not available") + + chunks = {"dim1": 4, "dim2": 3, "dim3": 6} + + with cluster() as (s, [a, b]): + with Client(s["address"], loop=loop): + original = create_test_data().chunk(chunks) + + if engine == "scipy": + with pytest.raises(NotImplementedError): + original.to_netcdf( + tmp_netcdf_filename, engine=engine, format=nc_format + ) + return + + original.to_netcdf(tmp_netcdf_filename, engine=engine, format=nc_format) + + with xr.open_dataset( + tmp_netcdf_filename, chunks=chunks, engine=engine + ) as restored: + assert isinstance(restored.var1.data, da.Array) + computed = restored.compute() + assert_allclose(original, computed) + + +@requires_netCDF4 +def test_dask_distributed_write_netcdf_with_dimensionless_variables( + loop, tmp_netcdf_filename +): + with cluster() as (s, [a, b]): + with Client(s["address"], loop=loop): + original = xr.Dataset({"x": da.zeros(())}) + original.to_netcdf(tmp_netcdf_filename) + + with xr.open_dataset(tmp_netcdf_filename) as actual: + assert actual.x.shape == () + + +@requires_cftime +@requires_netCDF4 +@pytest.mark.parametrize("parallel", (True, False)) +def test_open_mfdataset_can_open_files_with_cftime_index(parallel, tmp_path): + T = xr.cftime_range("20010101", "20010501", calendar="360_day") + Lon = np.arange(100) + data = np.random.random((T.size, Lon.size)) + da = xr.DataArray(data, coords={"time": T, "Lon": Lon}, name="test") + file_path = tmp_path / "test.nc" + da.to_netcdf(file_path) + with cluster() as (s, [a, b]): + with Client(s["address"]): + with xr.open_mfdataset(file_path, parallel=parallel) as tf: + assert_identical(tf["test"], da) + + +@requires_cftime +@requires_netCDF4 +@pytest.mark.parametrize("parallel", (True, False)) +def test_open_mfdataset_multiple_files_parallel_distributed(parallel, tmp_path): + lon = np.arange(100) + time = xr.cftime_range("20010101", periods=100, calendar="360_day") + data = np.random.random((time.size, lon.size)) + da = xr.DataArray(data, coords={"time": time, "lon": lon}, name="test") + + fnames = [] + for i in range(0, 100, 10): + fname = tmp_path / f"test_{i}.nc" + da.isel(time=slice(i, i + 10)).to_netcdf(fname) + fnames.append(fname) + + with cluster() as (s, [a, b]): + with Client(s["address"]): + with xr.open_mfdataset( + fnames, parallel=parallel, concat_dim="time", combine="nested" + ) as tf: + assert_identical(tf["test"], da) + + +# TODO: move this to test_backends.py +@requires_cftime +@requires_netCDF4 +@pytest.mark.parametrize("parallel", (True, False)) +def test_open_mfdataset_multiple_files_parallel(parallel, tmp_path): + if parallel: + pytest.skip( + "Flaky in CI. Would be a welcome contribution to make a similar test reliable." + ) + lon = np.arange(100) + time = xr.cftime_range("20010101", periods=100, calendar="360_day") + data = np.random.random((time.size, lon.size)) + da = xr.DataArray(data, coords={"time": time, "lon": lon}, name="test") + + fnames = [] + for i in range(0, 100, 10): + fname = tmp_path / f"test_{i}.nc" + da.isel(time=slice(i, i + 10)).to_netcdf(fname) + fnames.append(fname) + + for get in [dask.threaded.get, dask.multiprocessing.get, dask.local.get_sync, None]: + with dask.config.set(scheduler=get): + with xr.open_mfdataset( + fnames, parallel=parallel, concat_dim="time", combine="nested" + ) as tf: + assert_identical(tf["test"], da) + + +@pytest.mark.parametrize("engine,nc_format", ENGINES_AND_FORMATS) +def test_dask_distributed_read_netcdf_integration_test( + loop, tmp_netcdf_filename, engine, nc_format +): + if engine not in ENGINES: + pytest.skip("engine not available") + + chunks = {"dim1": 4, "dim2": 3, "dim3": 6} + + with cluster() as (s, [a, b]): + with Client(s["address"], loop=loop): + original = create_test_data() + original.to_netcdf(tmp_netcdf_filename, engine=engine, format=nc_format) + + with xr.open_dataset( + tmp_netcdf_filename, chunks=chunks, engine=engine + ) as restored: + assert isinstance(restored.var1.data, da.Array) + computed = restored.compute() + assert_allclose(original, computed) + + +@requires_zarr +@pytest.mark.parametrize("consolidated", [True, False]) +@pytest.mark.parametrize("compute", [True, False]) +def test_dask_distributed_zarr_integration_test( + loop, consolidated: bool, compute: bool +) -> None: + if consolidated: + write_kwargs: dict[str, Any] = {"consolidated": True} + read_kwargs: dict[str, Any] = {"backend_kwargs": {"consolidated": True}} + else: + write_kwargs = read_kwargs = {} + chunks = {"dim1": 4, "dim2": 3, "dim3": 5} + with cluster() as (s, [a, b]): + with Client(s["address"], loop=loop): + original = create_test_data().chunk(chunks) + with create_tmp_file( + allow_cleanup_failure=ON_WINDOWS, suffix=".zarrc" + ) as filename: + maybe_futures = original.to_zarr( # type: ignore[call-overload] #mypy bug? + filename, compute=compute, **write_kwargs + ) + if not compute: + maybe_futures.compute() + with xr.open_dataset( + filename, chunks="auto", engine="zarr", **read_kwargs + ) as restored: + assert isinstance(restored.var1.data, da.Array) + computed = restored.compute() + assert_allclose(original, computed) + + +@gen_cluster(client=True) +async def test_async(c, s, a, b) -> None: + x = create_test_data() + assert not dask.is_dask_collection(x) + y = x.chunk({"dim2": 4}) + 10 + assert dask.is_dask_collection(y) + assert dask.is_dask_collection(y.var1) + assert dask.is_dask_collection(y.var2) + + z = y.persist() + assert str(z) + + assert dask.is_dask_collection(z) + assert dask.is_dask_collection(z.var1) + assert dask.is_dask_collection(z.var2) + assert len(y.__dask_graph__()) > len(z.__dask_graph__()) + + assert not futures_of(y) + assert futures_of(z) + + future = c.compute(z) + w = await future + assert not dask.is_dask_collection(w) + assert_allclose(x + 10, w) + + assert s.tasks + + +def test_hdf5_lock() -> None: + assert isinstance(HDF5_LOCK, SerializableLock) + + +@gen_cluster(client=True) +async def test_serializable_locks(c, s, a, b) -> None: + def f(x, lock=None): + with lock: + return x + 1 + + # note, the creation of Lock needs to be done inside a cluster + for lock in [ + HDF5_LOCK, + Lock(), + Lock("filename.nc"), + CombinedLock([HDF5_LOCK]), + CombinedLock([HDF5_LOCK, Lock("filename.nc")]), + ]: + futures = c.map(f, list(range(10)), lock=lock) + await c.gather(futures) + + lock2 = pickle.loads(pickle.dumps(lock)) + assert type(lock) == type(lock2) diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_dtypes.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_dtypes.py new file mode 100644 index 0000000..e817bfd --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_dtypes.py @@ -0,0 +1,176 @@ +from __future__ import annotations + +import numpy as np +import pytest + +from xarray.core import dtypes +from xarray.tests import requires_array_api_strict + +try: + import array_api_strict +except ImportError: + + class DummyArrayAPINamespace: + bool = None + int32 = None + float64 = None + + array_api_strict = DummyArrayAPINamespace + + +@pytest.mark.parametrize( + "args, expected", + [ + ([bool], bool), + ([bool, np.bytes_], np.object_), + ([np.float32, np.float64], np.float64), + ([np.float32, np.bytes_], np.object_), + ([np.str_, np.int64], np.object_), + ([np.str_, np.str_], np.str_), + ([np.bytes_, np.str_], np.object_), + ], +) +def test_result_type(args, expected) -> None: + actual = dtypes.result_type(*args) + assert actual == expected + + +@pytest.mark.parametrize( + ["values", "expected"], + ( + ([np.arange(3, dtype="float32"), np.nan], np.float32), + ([np.arange(3, dtype="int8"), 1], np.int8), + ([np.array(["a", "b"], dtype=str), np.nan], object), + ([np.array([b"a", b"b"], dtype=bytes), True], object), + ([np.array([b"a", b"b"], dtype=bytes), "c"], object), + ([np.array(["a", "b"], dtype=str), "c"], np.dtype(str)), + ([np.array(["a", "b"], dtype=str), None], object), + ([0, 1], np.dtype("int")), + ), +) +def test_result_type_scalars(values, expected) -> None: + actual = dtypes.result_type(*values) + + assert np.issubdtype(actual, expected) + + +def test_result_type_dask_array() -> None: + # verify it works without evaluating dask arrays + da = pytest.importorskip("dask.array") + dask = pytest.importorskip("dask") + + def error(): + raise RuntimeError + + array = da.from_delayed(dask.delayed(error)(), (), np.float64) + with pytest.raises(RuntimeError): + array.compute() + + actual = dtypes.result_type(array) + assert actual == np.float64 + + # note that this differs from the behavior for scalar numpy arrays, which + # would get promoted to float32 + actual = dtypes.result_type(array, np.array([0.5, 1.0], dtype=np.float32)) + assert actual == np.float64 + + +@pytest.mark.parametrize("obj", [1.0, np.inf, "ab", 1.0 + 1.0j, True]) +def test_inf(obj) -> None: + assert dtypes.INF > obj + assert dtypes.NINF < obj + + +@pytest.mark.parametrize( + "kind, expected", + [ + ("b", (np.float32, "nan")), # dtype('int8') + ("B", (np.float32, "nan")), # dtype('uint8') + ("c", (np.dtype("O"), "nan")), # dtype('S1') + ("D", (np.complex128, "(nan+nanj)")), # dtype('complex128') + ("d", (np.float64, "nan")), # dtype('float64') + ("e", (np.float16, "nan")), # dtype('float16') + ("F", (np.complex64, "(nan+nanj)")), # dtype('complex64') + ("f", (np.float32, "nan")), # dtype('float32') + ("h", (np.float32, "nan")), # dtype('int16') + ("H", (np.float32, "nan")), # dtype('uint16') + ("i", (np.float64, "nan")), # dtype('int32') + ("I", (np.float64, "nan")), # dtype('uint32') + ("l", (np.float64, "nan")), # dtype('int64') + ("L", (np.float64, "nan")), # dtype('uint64') + ("m", (np.timedelta64, "NaT")), # dtype(' None: + # 'g': np.float128 is not tested : not available on all platforms + # 'G': np.complex256 is not tested : not available on all platforms + + actual = dtypes.maybe_promote(np.dtype(kind)) + assert actual[0] == expected[0] + assert str(actual[1]) == expected[1] + + +def test_nat_types_membership() -> None: + assert np.datetime64("NaT").dtype in dtypes.NAT_TYPES + assert np.timedelta64("NaT").dtype in dtypes.NAT_TYPES + assert np.float64 not in dtypes.NAT_TYPES + + +@pytest.mark.parametrize( + ["dtype", "kinds", "xp", "expected"], + ( + (np.dtype("int32"), "integral", np, True), + (np.dtype("float16"), "real floating", np, True), + (np.dtype("complex128"), "complex floating", np, True), + (np.dtype("U"), "numeric", np, False), + pytest.param( + array_api_strict.int32, + "integral", + array_api_strict, + True, + marks=requires_array_api_strict, + id="array_api-int", + ), + pytest.param( + array_api_strict.float64, + "real floating", + array_api_strict, + True, + marks=requires_array_api_strict, + id="array_api-float", + ), + pytest.param( + array_api_strict.bool, + "numeric", + array_api_strict, + False, + marks=requires_array_api_strict, + id="array_api-bool", + ), + ), +) +def test_isdtype(dtype, kinds, xp, expected) -> None: + actual = dtypes.isdtype(dtype, kinds, xp=xp) + assert actual == expected + + +@pytest.mark.parametrize( + ["dtype", "kinds", "xp", "error", "pattern"], + ( + (np.dtype("int32"), "foo", np, (TypeError, ValueError), "kind"), + (np.dtype("int32"), np.signedinteger, np, TypeError, "kind"), + (np.dtype("float16"), 1, np, TypeError, "kind"), + ), +) +def test_isdtype_error(dtype, kinds, xp, error, pattern): + with pytest.raises(error, match=pattern): + dtypes.isdtype(dtype, kinds, xp=xp) diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_duck_array_ops.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_duck_array_ops.py new file mode 100644 index 0000000..afcf10e --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_duck_array_ops.py @@ -0,0 +1,1047 @@ +from __future__ import annotations + +import datetime as dt +import warnings + +import numpy as np +import pandas as pd +import pytest +from numpy import array, nan + +from xarray import DataArray, Dataset, cftime_range, concat +from xarray.core import dtypes, duck_array_ops +from xarray.core.duck_array_ops import ( + array_notnull_equiv, + concatenate, + count, + first, + gradient, + last, + least_squares, + mean, + np_timedelta64_to_float, + pd_timedelta_to_float, + push, + py_timedelta_to_float, + stack, + timedelta_to_numeric, + where, +) +from xarray.core.extension_array import PandasExtensionArray +from xarray.namedarray.pycompat import array_type +from xarray.testing import assert_allclose, assert_equal, assert_identical +from xarray.tests import ( + arm_xfail, + assert_array_equal, + has_dask, + has_scipy, + raise_if_dask_computes, + requires_bottleneck, + requires_cftime, + requires_dask, + requires_pyarrow, +) + +dask_array_type = array_type("dask") + + +@pytest.fixture +def categorical1(): + return pd.Categorical(["cat1", "cat2", "cat2", "cat1", "cat2"]) + + +@pytest.fixture +def categorical2(): + return pd.Categorical(["cat2", "cat1", "cat2", "cat3", "cat1"]) + + +try: + import pyarrow as pa + + @pytest.fixture + def arrow1(): + return pd.arrays.ArrowExtensionArray( + pa.array([{"x": 1, "y": True}, {"x": 2, "y": False}]) + ) + + @pytest.fixture + def arrow2(): + return pd.arrays.ArrowExtensionArray( + pa.array([{"x": 3, "y": False}, {"x": 4, "y": True}]) + ) + +except ImportError: + pass + + +@pytest.fixture +def int1(): + return pd.arrays.IntegerArray( + np.array([1, 2, 3, 4, 5]), np.array([True, False, False, True, True]) + ) + + +@pytest.fixture +def int2(): + return pd.arrays.IntegerArray( + np.array([6, 7, 8, 9, 10]), np.array([True, True, False, True, False]) + ) + + +class TestOps: + @pytest.fixture(autouse=True) + def setUp(self): + self.x = array( + [ + [ + [nan, nan, 2.0, nan], + [nan, 5.0, 6.0, nan], + [8.0, 9.0, 10.0, nan], + ], + [ + [nan, 13.0, 14.0, 15.0], + [nan, 17.0, 18.0, nan], + [nan, 21.0, nan, nan], + ], + ] + ) + + def test_first(self): + expected_results = [ + array([[nan, 13, 2, 15], [nan, 5, 6, nan], [8, 9, 10, nan]]), + array([[8, 5, 2, nan], [nan, 13, 14, 15]]), + array([[2, 5, 8], [13, 17, 21]]), + ] + for axis, expected in zip([0, 1, 2, -3, -2, -1], 2 * expected_results): + actual = first(self.x, axis) + assert_array_equal(expected, actual) + + expected = self.x[0] + actual = first(self.x, axis=0, skipna=False) + assert_array_equal(expected, actual) + + expected = self.x[..., 0] + actual = first(self.x, axis=-1, skipna=False) + assert_array_equal(expected, actual) + + with pytest.raises(IndexError, match=r"out of bounds"): + first(self.x, 3) + + def test_last(self): + expected_results = [ + array([[nan, 13, 14, 15], [nan, 17, 18, nan], [8, 21, 10, nan]]), + array([[8, 9, 10, nan], [nan, 21, 18, 15]]), + array([[2, 6, 10], [15, 18, 21]]), + ] + for axis, expected in zip([0, 1, 2, -3, -2, -1], 2 * expected_results): + actual = last(self.x, axis) + assert_array_equal(expected, actual) + + expected = self.x[-1] + actual = last(self.x, axis=0, skipna=False) + assert_array_equal(expected, actual) + + expected = self.x[..., -1] + actual = last(self.x, axis=-1, skipna=False) + assert_array_equal(expected, actual) + + with pytest.raises(IndexError, match=r"out of bounds"): + last(self.x, 3) + + def test_count(self): + assert 12 == count(self.x) + + expected = array([[1, 2, 3], [3, 2, 1]]) + assert_array_equal(expected, count(self.x, axis=-1)) + + assert 1 == count(np.datetime64("2000-01-01")) + + def test_where_type_promotion(self): + result = where(np.array([True, False]), np.array([1, 2]), np.array(["a", "b"])) + assert_array_equal(result, np.array([1, "b"], dtype=object)) + + result = where([True, False], np.array([1, 2], np.float32), np.nan) + assert result.dtype == np.float32 + assert_array_equal(result, np.array([1, np.nan], dtype=np.float32)) + + def test_where_extension_duck_array(self, categorical1, categorical2): + where_res = where( + np.array([True, False, True, False, False]), + PandasExtensionArray(categorical1), + PandasExtensionArray(categorical2), + ) + assert isinstance(where_res, PandasExtensionArray) + assert ( + where_res == pd.Categorical(["cat1", "cat1", "cat2", "cat3", "cat1"]) + ).all() + + def test_concatenate_extension_duck_array(self, categorical1, categorical2): + concate_res = concatenate( + [PandasExtensionArray(categorical1), PandasExtensionArray(categorical2)] + ) + assert isinstance(concate_res, PandasExtensionArray) + assert ( + concate_res + == type(categorical1)._concat_same_type((categorical1, categorical2)) + ).all() + + @requires_pyarrow + def test_extension_array_pyarrow_concatenate(self, arrow1, arrow2): + concatenated = concatenate( + (PandasExtensionArray(arrow1), PandasExtensionArray(arrow2)) + ) + assert concatenated[2]["x"] == 3 + assert concatenated[3]["y"] + + def test___getitem__extension_duck_array(self, categorical1): + extension_duck_array = PandasExtensionArray(categorical1) + assert (extension_duck_array[0:2] == categorical1[0:2]).all() + assert isinstance(extension_duck_array[0:2], PandasExtensionArray) + assert extension_duck_array[0] == categorical1[0] + assert isinstance(extension_duck_array[0], PandasExtensionArray) + mask = [True, False, True, False, True] + assert (extension_duck_array[mask] == categorical1[mask]).all() + + def test__setitem__extension_duck_array(self, categorical1): + extension_duck_array = PandasExtensionArray(categorical1) + extension_duck_array[2] = "cat1" # already existing category + assert extension_duck_array[2] == "cat1" + with pytest.raises(TypeError, match="Cannot setitem on a Categorical"): + extension_duck_array[2] = "cat4" # new category + + def test_stack_type_promotion(self): + result = stack([1, "b"]) + assert_array_equal(result, np.array([1, "b"], dtype=object)) + + def test_concatenate_type_promotion(self): + result = concatenate([np.array([1]), np.array(["b"])]) + assert_array_equal(result, np.array([1, "b"], dtype=object)) + + @pytest.mark.filterwarnings("error") + def test_all_nan_arrays(self): + assert np.isnan(mean([np.nan, np.nan])) + + +@requires_dask +class TestDaskOps(TestOps): + @pytest.fixture(autouse=True) + def setUp(self): + import dask.array + + self.x = dask.array.from_array( + [ + [ + [nan, nan, 2.0, nan], + [nan, 5.0, 6.0, nan], + [8.0, 9.0, 10.0, nan], + ], + [ + [nan, 13.0, 14.0, 15.0], + [nan, 17.0, 18.0, nan], + [nan, 21.0, nan, nan], + ], + ], + chunks=(2, 1, 2), + ) + + +def test_cumsum_1d(): + inputs = np.array([0, 1, 2, 3]) + expected = np.array([0, 1, 3, 6]) + actual = duck_array_ops.cumsum(inputs) + assert_array_equal(expected, actual) + + actual = duck_array_ops.cumsum(inputs, axis=0) + assert_array_equal(expected, actual) + + actual = duck_array_ops.cumsum(inputs, axis=-1) + assert_array_equal(expected, actual) + + actual = duck_array_ops.cumsum(inputs, axis=(0,)) + assert_array_equal(expected, actual) + + actual = duck_array_ops.cumsum(inputs, axis=()) + assert_array_equal(inputs, actual) + + +def test_cumsum_2d(): + inputs = np.array([[1, 2], [3, 4]]) + + expected = np.array([[1, 3], [4, 10]]) + actual = duck_array_ops.cumsum(inputs) + assert_array_equal(expected, actual) + + actual = duck_array_ops.cumsum(inputs, axis=(0, 1)) + assert_array_equal(expected, actual) + + actual = duck_array_ops.cumsum(inputs, axis=()) + assert_array_equal(inputs, actual) + + +def test_cumprod_2d(): + inputs = np.array([[1, 2], [3, 4]]) + + expected = np.array([[1, 2], [3, 2 * 3 * 4]]) + actual = duck_array_ops.cumprod(inputs) + assert_array_equal(expected, actual) + + actual = duck_array_ops.cumprod(inputs, axis=(0, 1)) + assert_array_equal(expected, actual) + + actual = duck_array_ops.cumprod(inputs, axis=()) + assert_array_equal(inputs, actual) + + +class TestArrayNotNullEquiv: + @pytest.mark.parametrize( + "arr1, arr2", + [ + (np.array([1, 2, 3]), np.array([1, 2, 3])), + (np.array([1, 2, np.nan]), np.array([1, np.nan, 3])), + (np.array([np.nan, 2, np.nan]), np.array([1, np.nan, np.nan])), + ], + ) + def test_equal(self, arr1, arr2): + assert array_notnull_equiv(arr1, arr2) + + def test_some_not_equal(self): + a = np.array([1, 2, 4]) + b = np.array([1, np.nan, 3]) + assert not array_notnull_equiv(a, b) + + def test_wrong_shape(self): + a = np.array([[1, np.nan, np.nan, 4]]) + b = np.array([[1, 2], [np.nan, 4]]) + assert not array_notnull_equiv(a, b) + + @pytest.mark.parametrize( + "val1, val2, val3, null", + [ + ( + np.datetime64("2000"), + np.datetime64("2001"), + np.datetime64("2002"), + np.datetime64("NaT"), + ), + (1.0, 2.0, 3.0, np.nan), + ("foo", "bar", "baz", None), + ("foo", "bar", "baz", np.nan), + ], + ) + def test_types(self, val1, val2, val3, null): + dtype = object if isinstance(val1, str) else None + arr1 = np.array([val1, null, val3, null], dtype=dtype) + arr2 = np.array([val1, val2, null, null], dtype=dtype) + assert array_notnull_equiv(arr1, arr2) + + +def construct_dataarray(dim_num, dtype, contains_nan, dask): + # dimnum <= 3 + rng = np.random.RandomState(0) + shapes = [16, 8, 4][:dim_num] + dims = ("x", "y", "z")[:dim_num] + + if np.issubdtype(dtype, np.floating): + array = rng.randn(*shapes).astype(dtype) + elif np.issubdtype(dtype, np.integer): + array = rng.randint(0, 10, size=shapes).astype(dtype) + elif np.issubdtype(dtype, np.bool_): + array = rng.randint(0, 1, size=shapes).astype(dtype) + elif dtype == str: + array = rng.choice(["a", "b", "c", "d"], size=shapes) + else: + raise ValueError + + if contains_nan: + inds = rng.choice(range(array.size), int(array.size * 0.2)) + dtype, fill_value = dtypes.maybe_promote(array.dtype) + array = array.astype(dtype) + array.flat[inds] = fill_value + + da = DataArray(array, dims=dims, coords={"x": np.arange(16)}, name="da") + + if dask and has_dask: + chunks = {d: 4 for d in dims} + da = da.chunk(chunks) + + return da + + +def from_series_or_scalar(se): + if isinstance(se, pd.Series): + return DataArray.from_series(se) + else: # scalar case + return DataArray(se) + + +def series_reduce(da, func, dim, **kwargs): + """convert DataArray to pd.Series, apply pd.func, then convert back to + a DataArray. Multiple dims cannot be specified.""" + + # pd no longer accepts skipna=None https://github.com/pandas-dev/pandas/issues/44178 + if kwargs.get("skipna", True) is None: + kwargs["skipna"] = True + + if dim is None or da.ndim == 1: + se = da.to_series() + return from_series_or_scalar(getattr(se, func)(**kwargs)) + else: + da1 = [] + dims = list(da.dims) + dims.remove(dim) + d = dims[0] + for i in range(len(da[d])): + da1.append(series_reduce(da.isel(**{d: i}), func, dim, **kwargs)) + + if d in da.coords: + return concat(da1, dim=da[d]) + return concat(da1, dim=d) + + +def assert_dask_array(da, dask): + if dask and da.ndim > 0: + assert isinstance(da.data, dask_array_type) + + +@arm_xfail +@pytest.mark.filterwarnings("ignore:All-NaN .* encountered:RuntimeWarning") +@pytest.mark.parametrize("dask", [False, True] if has_dask else [False]) +def test_datetime_mean(dask: bool) -> None: + # Note: only testing numpy, as dask is broken upstream + da = DataArray( + np.array(["2010-01-01", "NaT", "2010-01-03", "NaT", "NaT"], dtype="M8[ns]"), + dims=["time"], + ) + if dask: + # Trigger use case where a chunk is full of NaT + da = da.chunk({"time": 3}) + + expect = DataArray(np.array("2010-01-02", dtype="M8[ns]")) + expect_nat = DataArray(np.array("NaT", dtype="M8[ns]")) + + actual = da.mean() + if dask: + assert actual.chunks is not None + assert_equal(actual, expect) + + actual = da.mean(skipna=False) + if dask: + assert actual.chunks is not None + assert_equal(actual, expect_nat) + + # tests for 1d array full of NaT + assert_equal(da[[1]].mean(), expect_nat) + assert_equal(da[[1]].mean(skipna=False), expect_nat) + + # tests for a 0d array + assert_equal(da[0].mean(), da[0]) + assert_equal(da[0].mean(skipna=False), da[0]) + assert_equal(da[1].mean(), expect_nat) + assert_equal(da[1].mean(skipna=False), expect_nat) + + +@requires_cftime +@pytest.mark.parametrize("dask", [False, True]) +def test_cftime_datetime_mean(dask): + if dask and not has_dask: + pytest.skip("requires dask") + + times = cftime_range("2000", periods=4) + da = DataArray(times, dims=["time"]) + da_2d = DataArray(times.values.reshape(2, 2)) + + if dask: + da = da.chunk({"time": 2}) + da_2d = da_2d.chunk({"dim_0": 2}) + + expected = da.isel(time=0) + # one compute needed to check the array contains cftime datetimes + with raise_if_dask_computes(max_computes=1): + result = da.isel(time=0).mean() + assert_dask_array(result, dask) + assert_equal(result, expected) + + expected = DataArray(times.date_type(2000, 1, 2, 12)) + with raise_if_dask_computes(max_computes=1): + result = da.mean() + assert_dask_array(result, dask) + assert_equal(result, expected) + + with raise_if_dask_computes(max_computes=1): + result = da_2d.mean() + assert_dask_array(result, dask) + assert_equal(result, expected) + + +@requires_cftime +@requires_dask +def test_mean_over_non_time_dim_of_dataset_with_dask_backed_cftime_data(): + # Regression test for part two of GH issue 5897: averaging over a non-time + # dimension still fails if the time variable is dask-backed. + ds = Dataset( + { + "var1": (("time",), cftime_range("2021-10-31", periods=10, freq="D")), + "var2": (("x",), list(range(10))), + } + ) + expected = ds.mean("x") + result = ds.chunk({}).mean("x") + assert_equal(result, expected) + + +@requires_cftime +def test_cftime_datetime_mean_long_time_period(): + import cftime + + times = np.array( + [ + [ + cftime.DatetimeNoLeap(400, 12, 31, 0, 0, 0, 0), + cftime.DatetimeNoLeap(520, 12, 31, 0, 0, 0, 0), + ], + [ + cftime.DatetimeNoLeap(520, 12, 31, 0, 0, 0, 0), + cftime.DatetimeNoLeap(640, 12, 31, 0, 0, 0, 0), + ], + [ + cftime.DatetimeNoLeap(640, 12, 31, 0, 0, 0, 0), + cftime.DatetimeNoLeap(760, 12, 31, 0, 0, 0, 0), + ], + ] + ) + + da = DataArray(times, dims=["time", "d2"]) + result = da.mean("d2") + expected = DataArray( + [ + cftime.DatetimeNoLeap(460, 12, 31, 0, 0, 0, 0), + cftime.DatetimeNoLeap(580, 12, 31, 0, 0, 0, 0), + cftime.DatetimeNoLeap(700, 12, 31, 0, 0, 0, 0), + ], + dims=["time"], + ) + assert_equal(result, expected) + + +def test_empty_axis_dtype(): + ds = Dataset() + ds["pos"] = [1, 2, 3] + ds["data"] = ("pos", "time"), [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]] + ds["var"] = "pos", [2, 3, 4] + assert_identical(ds.mean(dim="time")["var"], ds["var"]) + assert_identical(ds.max(dim="time")["var"], ds["var"]) + assert_identical(ds.min(dim="time")["var"], ds["var"]) + assert_identical(ds.sum(dim="time")["var"], ds["var"]) + + +@pytest.mark.parametrize("dim_num", [1, 2]) +@pytest.mark.parametrize("dtype", [float, int, np.float32, np.bool_]) +@pytest.mark.parametrize("dask", [False, True]) +@pytest.mark.parametrize("func", ["sum", "min", "max", "mean", "var"]) +# TODO test cumsum, cumprod +@pytest.mark.parametrize("skipna", [False, True]) +@pytest.mark.parametrize("aggdim", [None, "x"]) +def test_reduce(dim_num, dtype, dask, func, skipna, aggdim): + if aggdim == "y" and dim_num < 2: + pytest.skip("dim not in this test") + + if dtype == np.bool_ and func == "mean": + pytest.skip("numpy does not support this") + + if dask and not has_dask: + pytest.skip("requires dask") + + if dask and skipna is False and dtype in [np.bool_]: + pytest.skip("dask does not compute object-typed array") + + rtol = 1e-04 if dtype == np.float32 else 1e-05 + + da = construct_dataarray(dim_num, dtype, contains_nan=True, dask=dask) + axis = None if aggdim is None else da.get_axis_num(aggdim) + + # TODO: remove these after resolving + # https://github.com/dask/dask/issues/3245 + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "Mean of empty slice") + warnings.filterwarnings("ignore", "All-NaN slice") + warnings.filterwarnings("ignore", "invalid value encountered in") + + if da.dtype.kind == "O" and skipna: + # Numpy < 1.13 does not handle object-type array. + try: + if skipna: + expected = getattr(np, f"nan{func}")(da.values, axis=axis) + else: + expected = getattr(np, func)(da.values, axis=axis) + + actual = getattr(da, func)(skipna=skipna, dim=aggdim) + assert_dask_array(actual, dask) + np.testing.assert_allclose( + actual.values, np.array(expected), rtol=1.0e-4, equal_nan=True + ) + except (TypeError, AttributeError, ZeroDivisionError): + # TODO currently, numpy does not support some methods such as + # nanmean for object dtype + pass + + actual = getattr(da, func)(skipna=skipna, dim=aggdim) + + # for dask case, make sure the result is the same for numpy backend + expected = getattr(da.compute(), func)(skipna=skipna, dim=aggdim) + assert_allclose(actual, expected, rtol=rtol) + + # make sure the compatibility with pandas' results. + if func in ["var", "std"]: + expected = series_reduce(da, func, skipna=skipna, dim=aggdim, ddof=0) + assert_allclose(actual, expected, rtol=rtol) + # also check ddof!=0 case + actual = getattr(da, func)(skipna=skipna, dim=aggdim, ddof=5) + if dask: + assert isinstance(da.data, dask_array_type) + expected = series_reduce(da, func, skipna=skipna, dim=aggdim, ddof=5) + assert_allclose(actual, expected, rtol=rtol) + else: + expected = series_reduce(da, func, skipna=skipna, dim=aggdim) + assert_allclose(actual, expected, rtol=rtol) + + # make sure the dtype argument + if func not in ["max", "min"]: + actual = getattr(da, func)(skipna=skipna, dim=aggdim, dtype=float) + assert_dask_array(actual, dask) + assert actual.dtype == float + + # without nan + da = construct_dataarray(dim_num, dtype, contains_nan=False, dask=dask) + actual = getattr(da, func)(skipna=skipna) + if dask: + assert isinstance(da.data, dask_array_type) + expected = getattr(np, f"nan{func}")(da.values) + if actual.dtype == object: + assert actual.values == np.array(expected) + else: + assert np.allclose(actual.values, np.array(expected), rtol=rtol) + + +@pytest.mark.parametrize("dim_num", [1, 2]) +@pytest.mark.parametrize("dtype", [float, int, np.float32, np.bool_, str]) +@pytest.mark.parametrize("contains_nan", [True, False]) +@pytest.mark.parametrize("dask", [False, True]) +@pytest.mark.parametrize("func", ["min", "max"]) +@pytest.mark.parametrize("skipna", [False, True]) +@pytest.mark.parametrize("aggdim", ["x", "y"]) +def test_argmin_max(dim_num, dtype, contains_nan, dask, func, skipna, aggdim): + # pandas-dev/pandas#16830, we do not check consistency with pandas but + # just make sure da[da.argmin()] == da.min() + + if aggdim == "y" and dim_num < 2: + pytest.skip("dim not in this test") + + if dask and not has_dask: + pytest.skip("requires dask") + + if contains_nan: + if not skipna: + pytest.skip("numpy's argmin (not nanargmin) does not handle object-dtype") + if skipna and np.dtype(dtype).kind in "iufc": + pytest.skip("numpy's nanargmin raises ValueError for all nan axis") + da = construct_dataarray(dim_num, dtype, contains_nan=contains_nan, dask=dask) + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "All-NaN slice") + + actual = da.isel( + **{aggdim: getattr(da, "arg" + func)(dim=aggdim, skipna=skipna).compute()} + ) + expected = getattr(da, func)(dim=aggdim, skipna=skipna) + assert_allclose( + actual.drop_vars(list(actual.coords)), + expected.drop_vars(list(expected.coords)), + ) + + +def test_argmin_max_error(): + da = construct_dataarray(2, np.bool_, contains_nan=True, dask=False) + da[0] = np.nan + with pytest.raises(ValueError): + da.argmin(dim="y") + + +@pytest.mark.parametrize( + ["array", "expected"], + [ + ( + np.array([np.datetime64("2000-01-01"), np.datetime64("NaT")]), + np.array([False, True]), + ), + ( + np.array([np.timedelta64(1, "h"), np.timedelta64("NaT")]), + np.array([False, True]), + ), + ( + np.array([0.0, np.nan]), + np.array([False, True]), + ), + ( + np.array([1j, np.nan]), + np.array([False, True]), + ), + ( + np.array(["foo", np.nan], dtype=object), + np.array([False, True]), + ), + ( + np.array([1, 2], dtype=int), + np.array([False, False]), + ), + ( + np.array([True, False], dtype=bool), + np.array([False, False]), + ), + ], +) +def test_isnull(array, expected): + actual = duck_array_ops.isnull(array) + np.testing.assert_equal(expected, actual) + + +@requires_dask +def test_isnull_with_dask(): + da = construct_dataarray(2, np.float32, contains_nan=True, dask=True) + assert isinstance(da.isnull().data, dask_array_type) + assert_equal(da.isnull().load(), da.load().isnull()) + + +@pytest.mark.skipif(not has_dask, reason="This is for dask.") +@pytest.mark.parametrize("axis", [0, -1, 1]) +@pytest.mark.parametrize("edge_order", [1, 2]) +def test_dask_gradient(axis, edge_order): + import dask.array as da + + array = np.array(np.random.randn(100, 5, 40)) + x = np.exp(np.linspace(0, 1, array.shape[axis])) + + darray = da.from_array(array, chunks=[(6, 30, 30, 20, 14), 5, 8]) + expected = gradient(array, x, axis=axis, edge_order=edge_order) + actual = gradient(darray, x, axis=axis, edge_order=edge_order) + + assert isinstance(actual, da.Array) + assert_array_equal(actual, expected) + + +@pytest.mark.parametrize("dim_num", [1, 2]) +@pytest.mark.parametrize("dtype", [float, int, np.float32, np.bool_]) +@pytest.mark.parametrize("dask", [False, True]) +@pytest.mark.parametrize("func", ["sum", "prod"]) +@pytest.mark.parametrize("aggdim", [None, "x"]) +@pytest.mark.parametrize("contains_nan", [True, False]) +@pytest.mark.parametrize("skipna", [True, False, None]) +def test_min_count(dim_num, dtype, dask, func, aggdim, contains_nan, skipna): + if dask and not has_dask: + pytest.skip("requires dask") + + da = construct_dataarray(dim_num, dtype, contains_nan=contains_nan, dask=dask) + min_count = 3 + + # If using Dask, the function call should be lazy. + with raise_if_dask_computes(): + actual = getattr(da, func)(dim=aggdim, skipna=skipna, min_count=min_count) + + expected = series_reduce(da, func, skipna=skipna, dim=aggdim, min_count=min_count) + assert_allclose(actual, expected) + assert_dask_array(actual, dask) + + +@pytest.mark.parametrize("dtype", [float, int, np.float32, np.bool_]) +@pytest.mark.parametrize("dask", [False, True]) +@pytest.mark.parametrize("func", ["sum", "prod"]) +def test_min_count_nd(dtype, dask, func): + if dask and not has_dask: + pytest.skip("requires dask") + + min_count = 3 + dim_num = 3 + da = construct_dataarray(dim_num, dtype, contains_nan=True, dask=dask) + + # If using Dask, the function call should be lazy. + with raise_if_dask_computes(): + actual = getattr(da, func)( + dim=["x", "y", "z"], skipna=True, min_count=min_count + ) + + # Supplying all dims is equivalent to supplying `...` or `None` + expected = getattr(da, func)(dim=..., skipna=True, min_count=min_count) + + assert_allclose(actual, expected) + assert_dask_array(actual, dask) + + +@pytest.mark.parametrize("dask", [False, True]) +@pytest.mark.parametrize("func", ["sum", "prod"]) +@pytest.mark.parametrize("dim", [None, "a", "b"]) +def test_min_count_specific(dask, func, dim): + if dask and not has_dask: + pytest.skip("requires dask") + + # Simple array with four non-NaN values. + da = DataArray(np.ones((6, 6), dtype=np.float64) * np.nan, dims=("a", "b")) + da[0][0] = 2 + da[0][3] = 2 + da[3][0] = 2 + da[3][3] = 2 + if dask: + da = da.chunk({"a": 3, "b": 3}) + + # Expected result if we set min_count to the number of non-NaNs in a + # row/column/the entire array. + if dim: + min_count = 2 + expected = DataArray( + [4.0, np.nan, np.nan] * 2, dims=("a" if dim == "b" else "b",) + ) + else: + min_count = 4 + expected = DataArray(8.0 if func == "sum" else 16.0) + + # Check for that min_count. + with raise_if_dask_computes(): + actual = getattr(da, func)(dim, skipna=True, min_count=min_count) + assert_dask_array(actual, dask) + assert_allclose(actual, expected) + + # With min_count being one higher, should get all NaN. + min_count += 1 + expected *= np.nan + with raise_if_dask_computes(): + actual = getattr(da, func)(dim, skipna=True, min_count=min_count) + assert_dask_array(actual, dask) + assert_allclose(actual, expected) + + +@pytest.mark.parametrize("func", ["sum", "prod"]) +def test_min_count_dataset(func): + da = construct_dataarray(2, dtype=float, contains_nan=True, dask=False) + ds = Dataset({"var1": da}, coords={"scalar": 0}) + actual = getattr(ds, func)(dim="x", skipna=True, min_count=3)["var1"] + expected = getattr(ds["var1"], func)(dim="x", skipna=True, min_count=3) + assert_allclose(actual, expected) + + +@pytest.mark.parametrize("dtype", [float, int, np.float32, np.bool_]) +@pytest.mark.parametrize("dask", [False, True]) +@pytest.mark.parametrize("skipna", [False, True]) +@pytest.mark.parametrize("func", ["sum", "prod"]) +def test_multiple_dims(dtype, dask, skipna, func): + if dask and not has_dask: + pytest.skip("requires dask") + da = construct_dataarray(3, dtype, contains_nan=True, dask=dask) + + actual = getattr(da, func)(("x", "y"), skipna=skipna) + expected = getattr(getattr(da, func)("x", skipna=skipna), func)("y", skipna=skipna) + assert_allclose(actual, expected) + + +@pytest.mark.parametrize("dask", [True, False]) +def test_datetime_to_numeric_datetime64(dask): + if dask and not has_dask: + pytest.skip("requires dask") + + times = pd.date_range("2000", periods=5, freq="7D").values + if dask: + import dask.array + + times = dask.array.from_array(times, chunks=-1) + + with raise_if_dask_computes(): + result = duck_array_ops.datetime_to_numeric(times, datetime_unit="h") + expected = 24 * np.arange(0, 35, 7) + np.testing.assert_array_equal(result, expected) + + offset = times[1] + with raise_if_dask_computes(): + result = duck_array_ops.datetime_to_numeric( + times, offset=offset, datetime_unit="h" + ) + expected = 24 * np.arange(-7, 28, 7) + np.testing.assert_array_equal(result, expected) + + dtype = np.float32 + with raise_if_dask_computes(): + result = duck_array_ops.datetime_to_numeric( + times, datetime_unit="h", dtype=dtype + ) + expected = 24 * np.arange(0, 35, 7).astype(dtype) + np.testing.assert_array_equal(result, expected) + + +@requires_cftime +@pytest.mark.parametrize("dask", [True, False]) +def test_datetime_to_numeric_cftime(dask): + if dask and not has_dask: + pytest.skip("requires dask") + + times = cftime_range("2000", periods=5, freq="7D", calendar="standard").values + if dask: + import dask.array + + times = dask.array.from_array(times, chunks=-1) + with raise_if_dask_computes(): + result = duck_array_ops.datetime_to_numeric(times, datetime_unit="h", dtype=int) + expected = 24 * np.arange(0, 35, 7) + np.testing.assert_array_equal(result, expected) + + offset = times[1] + with raise_if_dask_computes(): + result = duck_array_ops.datetime_to_numeric( + times, offset=offset, datetime_unit="h", dtype=int + ) + expected = 24 * np.arange(-7, 28, 7) + np.testing.assert_array_equal(result, expected) + + dtype = np.float32 + with raise_if_dask_computes(): + result = duck_array_ops.datetime_to_numeric( + times, datetime_unit="h", dtype=dtype + ) + expected = 24 * np.arange(0, 35, 7).astype(dtype) + np.testing.assert_array_equal(result, expected) + + with raise_if_dask_computes(): + if dask: + time = dask.array.asarray(times[1]) + else: + time = np.asarray(times[1]) + result = duck_array_ops.datetime_to_numeric( + time, offset=times[0], datetime_unit="h", dtype=int + ) + expected = np.array(24 * 7).astype(int) + np.testing.assert_array_equal(result, expected) + + +@requires_cftime +def test_datetime_to_numeric_potential_overflow(): + import cftime + + times = pd.date_range("2000", periods=5, freq="7D").values.astype("datetime64[us]") + cftimes = cftime_range( + "2000", periods=5, freq="7D", calendar="proleptic_gregorian" + ).values + + offset = np.datetime64("0001-01-01") + cfoffset = cftime.DatetimeProlepticGregorian(1, 1, 1) + + result = duck_array_ops.datetime_to_numeric( + times, offset=offset, datetime_unit="D", dtype=int + ) + cfresult = duck_array_ops.datetime_to_numeric( + cftimes, offset=cfoffset, datetime_unit="D", dtype=int + ) + + expected = 730119 + np.arange(0, 35, 7) + + np.testing.assert_array_equal(result, expected) + np.testing.assert_array_equal(cfresult, expected) + + +def test_py_timedelta_to_float(): + assert py_timedelta_to_float(dt.timedelta(days=1), "ns") == 86400 * 1e9 + assert py_timedelta_to_float(dt.timedelta(days=1e6), "ps") == 86400 * 1e18 + assert py_timedelta_to_float(dt.timedelta(days=1e6), "ns") == 86400 * 1e15 + assert py_timedelta_to_float(dt.timedelta(days=1e6), "us") == 86400 * 1e12 + assert py_timedelta_to_float(dt.timedelta(days=1e6), "ms") == 86400 * 1e9 + assert py_timedelta_to_float(dt.timedelta(days=1e6), "s") == 86400 * 1e6 + assert py_timedelta_to_float(dt.timedelta(days=1e6), "D") == 1e6 + + +@pytest.mark.parametrize( + "td, expected", + ([np.timedelta64(1, "D"), 86400 * 1e9], [np.timedelta64(1, "ns"), 1.0]), +) +def test_np_timedelta64_to_float(td, expected): + out = np_timedelta64_to_float(td, datetime_unit="ns") + np.testing.assert_allclose(out, expected) + assert isinstance(out, float) + + out = np_timedelta64_to_float(np.atleast_1d(td), datetime_unit="ns") + np.testing.assert_allclose(out, expected) + + +@pytest.mark.parametrize( + "td, expected", ([pd.Timedelta(1, "D"), 86400 * 1e9], [pd.Timedelta(1, "ns"), 1.0]) +) +def test_pd_timedelta_to_float(td, expected): + out = pd_timedelta_to_float(td, datetime_unit="ns") + np.testing.assert_allclose(out, expected) + assert isinstance(out, float) + + +@pytest.mark.parametrize( + "td", [dt.timedelta(days=1), np.timedelta64(1, "D"), pd.Timedelta(1, "D"), "1 day"] +) +def test_timedelta_to_numeric(td): + # Scalar input + out = timedelta_to_numeric(td, "ns") + np.testing.assert_allclose(out, 86400 * 1e9) + assert isinstance(out, float) + + +@pytest.mark.parametrize("use_dask", [True, False]) +@pytest.mark.parametrize("skipna", [True, False]) +def test_least_squares(use_dask, skipna): + if use_dask and (not has_dask or not has_scipy): + pytest.skip("requires dask and scipy") + lhs = np.array([[1, 2], [1, 2], [3, 2]]) + rhs = DataArray(np.array([3, 5, 7]), dims=("y",)) + + if use_dask: + rhs = rhs.chunk({"y": 1}) + + coeffs, residuals = least_squares(lhs, rhs.data, skipna=skipna) + + np.testing.assert_allclose(coeffs, [1.5, 1.25]) + np.testing.assert_allclose(residuals, [2.0]) + + +@requires_dask +@requires_bottleneck +def test_push_dask(): + import bottleneck + import dask.array + + array = np.array([np.nan, 1, 2, 3, np.nan, np.nan, np.nan, np.nan, 4, 5, np.nan, 6]) + + for n in [None, 1, 2, 3, 4, 5, 11]: + expected = bottleneck.push(array, axis=0, n=n) + for c in range(1, 11): + with raise_if_dask_computes(): + actual = push(dask.array.from_array(array, chunks=c), axis=0, n=n) + np.testing.assert_equal(actual, expected) + + # some chunks of size-1 with NaN + with raise_if_dask_computes(): + actual = push( + dask.array.from_array(array, chunks=(1, 2, 3, 2, 2, 1, 1)), axis=0, n=n + ) + np.testing.assert_equal(actual, expected) + + +def test_extension_array_equality(categorical1, int1): + int_duck_array = PandasExtensionArray(int1) + categorical_duck_array = PandasExtensionArray(categorical1) + assert (int_duck_array != categorical_duck_array).all() + assert (categorical_duck_array == categorical1).all() + assert (int_duck_array[0:2] == int1[0:2]).all() + + +def test_extension_array_singleton_equality(categorical1): + categorical_duck_array = PandasExtensionArray(categorical1) + assert (categorical_duck_array != "cat3").all() + + +def test_extension_array_repr(int1): + int_duck_array = PandasExtensionArray(int1) + assert repr(int1) in repr(int_duck_array) + + +def test_extension_array_attr(int1): + int_duck_array = PandasExtensionArray(int1) + assert (~int_duck_array.fillna(10)).all() diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_error_messages.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_error_messages.py new file mode 100644 index 0000000..b5840aa --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_error_messages.py @@ -0,0 +1,17 @@ +""" +This new file is intended to test the quality & friendliness of error messages that are +raised by xarray. It's currently separate from the standard tests, which are more +focused on the functions working (though we could consider integrating them.). +""" + +import pytest + + +def test_no_var_in_dataset(ds): + with pytest.raises( + KeyError, + match=( + r"No variable named 'foo'. Variables on the dataset include \['z1', 'z2', 'x', 'time', 'c', 'y'\]" + ), + ): + ds["foo"] diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_extensions.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_extensions.py new file mode 100644 index 0000000..7cfffd6 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_extensions.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +import pickle + +import pytest + +import xarray as xr + +# TODO: Remove imports in favour of xr.DataTree etc, once part of public API +from xarray.core.datatree import DataTree +from xarray.core.extensions import register_datatree_accessor +from xarray.tests import assert_identical + + +@register_datatree_accessor("example_accessor") +@xr.register_dataset_accessor("example_accessor") +@xr.register_dataarray_accessor("example_accessor") +class ExampleAccessor: + """For the pickling tests below.""" + + def __init__(self, xarray_obj): + self.obj = xarray_obj + + +class TestAccessor: + def test_register(self) -> None: + @register_datatree_accessor("demo") + @xr.register_dataset_accessor("demo") + @xr.register_dataarray_accessor("demo") + class DemoAccessor: + """Demo accessor.""" + + def __init__(self, xarray_obj): + self._obj = xarray_obj + + @property + def foo(self): + return "bar" + + dt: DataTree = DataTree() + assert dt.demo.foo == "bar" + + ds = xr.Dataset() + assert ds.demo.foo == "bar" + + da = xr.DataArray(0) + assert da.demo.foo == "bar" + # accessor is cached + assert ds.demo is ds.demo + + # check descriptor + assert ds.demo.__doc__ == "Demo accessor." + # TODO: typing doesn't seem to work with accessors + assert xr.Dataset.demo.__doc__ == "Demo accessor." # type: ignore + assert isinstance(ds.demo, DemoAccessor) + assert xr.Dataset.demo is DemoAccessor # type: ignore + + # ensure we can remove it + del xr.Dataset.demo # type: ignore + assert not hasattr(xr.Dataset, "demo") + + with pytest.warns(Warning, match="overriding a preexisting attribute"): + + @xr.register_dataarray_accessor("demo") + class Foo: + pass + + # it didn't get registered again + assert not hasattr(xr.Dataset, "demo") + + def test_pickle_dataset(self) -> None: + ds = xr.Dataset() + ds_restored = pickle.loads(pickle.dumps(ds)) + assert_identical(ds, ds_restored) + + # state save on the accessor is restored + assert ds.example_accessor is ds.example_accessor + ds.example_accessor.value = "foo" + ds_restored = pickle.loads(pickle.dumps(ds)) + assert_identical(ds, ds_restored) + assert ds_restored.example_accessor.value == "foo" + + def test_pickle_dataarray(self) -> None: + array = xr.Dataset() + assert array.example_accessor is array.example_accessor + array_restored = pickle.loads(pickle.dumps(array)) + assert_identical(array, array_restored) + + def test_broken_accessor(self) -> None: + # regression test for GH933 + + @xr.register_dataset_accessor("stupid_accessor") + class BrokenAccessor: + def __init__(self, xarray_obj): + raise AttributeError("broken") + + with pytest.raises(RuntimeError, match=r"error initializing"): + xr.Dataset().stupid_accessor diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_formatting.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_formatting.py new file mode 100644 index 0000000..2c40ac8 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_formatting.py @@ -0,0 +1,1144 @@ +from __future__ import annotations + +import sys +from textwrap import dedent + +import numpy as np +import pandas as pd +import pytest + +import xarray as xr +from xarray.core import formatting +from xarray.core.datatree import DataTree # TODO: Remove when can do xr.DataTree +from xarray.tests import requires_cftime, requires_dask, requires_netCDF4 + +ON_WINDOWS = sys.platform == "win32" + + +class TestFormatting: + def test_get_indexer_at_least_n_items(self) -> None: + cases = [ + ((20,), (slice(10),), (slice(-10, None),)), + ((3, 20), (0, slice(10)), (-1, slice(-10, None))), + ((2, 10), (0, slice(10)), (-1, slice(-10, None))), + ((2, 5), (slice(2), slice(None)), (slice(-2, None), slice(None))), + ((1, 2, 5), (0, slice(2), slice(None)), (-1, slice(-2, None), slice(None))), + ((2, 3, 5), (0, slice(2), slice(None)), (-1, slice(-2, None), slice(None))), + ( + (1, 10, 1), + (0, slice(10), slice(None)), + (-1, slice(-10, None), slice(None)), + ), + ( + (2, 5, 1), + (slice(2), slice(None), slice(None)), + (slice(-2, None), slice(None), slice(None)), + ), + ((2, 5, 3), (0, slice(4), slice(None)), (-1, slice(-4, None), slice(None))), + ( + (2, 3, 3), + (slice(2), slice(None), slice(None)), + (slice(-2, None), slice(None), slice(None)), + ), + ] + for shape, start_expected, end_expected in cases: + actual = formatting._get_indexer_at_least_n_items(shape, 10, from_end=False) + assert start_expected == actual + actual = formatting._get_indexer_at_least_n_items(shape, 10, from_end=True) + assert end_expected == actual + + def test_first_n_items(self) -> None: + array = np.arange(100).reshape(10, 5, 2) + for n in [3, 10, 13, 100, 200]: + actual = formatting.first_n_items(array, n) + expected = array.flat[:n] + assert (expected == actual).all() + + with pytest.raises(ValueError, match=r"at least one item"): + formatting.first_n_items(array, 0) + + def test_last_n_items(self) -> None: + array = np.arange(100).reshape(10, 5, 2) + for n in [3, 10, 13, 100, 200]: + actual = formatting.last_n_items(array, n) + expected = array.flat[-n:] + assert (expected == actual).all() + + with pytest.raises(ValueError, match=r"at least one item"): + formatting.first_n_items(array, 0) + + def test_last_item(self) -> None: + array = np.arange(100) + + reshape = ((10, 10), (1, 100), (2, 2, 5, 5)) + expected = np.array([99]) + + for r in reshape: + result = formatting.last_item(array.reshape(r)) + assert result == expected + + def test_format_item(self) -> None: + cases = [ + (pd.Timestamp("2000-01-01T12"), "2000-01-01T12:00:00"), + (pd.Timestamp("2000-01-01"), "2000-01-01"), + (pd.Timestamp("NaT"), "NaT"), + (pd.Timedelta("10 days 1 hour"), "10 days 01:00:00"), + (pd.Timedelta("-3 days"), "-3 days +00:00:00"), + (pd.Timedelta("3 hours"), "0 days 03:00:00"), + (pd.Timedelta("NaT"), "NaT"), + ("foo", "'foo'"), + (b"foo", "b'foo'"), + (1, "1"), + (1.0, "1.0"), + (np.float16(1.1234), "1.123"), + (np.float32(1.0111111), "1.011"), + (np.float64(22.222222), "22.22"), + ] + for item, expected in cases: + actual = formatting.format_item(item) + assert expected == actual + + def test_format_items(self) -> None: + cases = [ + (np.arange(4) * np.timedelta64(1, "D"), "0 days 1 days 2 days 3 days"), + ( + np.arange(4) * np.timedelta64(3, "h"), + "00:00:00 03:00:00 06:00:00 09:00:00", + ), + ( + np.arange(4) * np.timedelta64(500, "ms"), + "00:00:00 00:00:00.500000 00:00:01 00:00:01.500000", + ), + (pd.to_timedelta(["NaT", "0s", "1s", "NaT"]), "NaT 00:00:00 00:00:01 NaT"), + ( + pd.to_timedelta(["1 day 1 hour", "1 day", "0 hours"]), + "1 days 01:00:00 1 days 00:00:00 0 days 00:00:00", + ), + ([1, 2, 3], "1 2 3"), + ] + for item, expected in cases: + actual = " ".join(formatting.format_items(item)) + assert expected == actual + + def test_format_array_flat(self) -> None: + actual = formatting.format_array_flat(np.arange(100), 2) + expected = "..." + assert expected == actual + + actual = formatting.format_array_flat(np.arange(100), 9) + expected = "0 ... 99" + assert expected == actual + + actual = formatting.format_array_flat(np.arange(100), 10) + expected = "0 1 ... 99" + assert expected == actual + + actual = formatting.format_array_flat(np.arange(100), 13) + expected = "0 1 ... 98 99" + assert expected == actual + + actual = formatting.format_array_flat(np.arange(100), 15) + expected = "0 1 2 ... 98 99" + assert expected == actual + + # NB: Probably not ideal; an alternative would be cutting after the + # first ellipsis + actual = formatting.format_array_flat(np.arange(100.0), 11) + expected = "0.0 ... ..." + assert expected == actual + + actual = formatting.format_array_flat(np.arange(100.0), 12) + expected = "0.0 ... 99.0" + assert expected == actual + + actual = formatting.format_array_flat(np.arange(3), 5) + expected = "0 1 2" + assert expected == actual + + actual = formatting.format_array_flat(np.arange(4.0), 11) + expected = "0.0 ... 3.0" + assert expected == actual + + actual = formatting.format_array_flat(np.arange(0), 0) + expected = "" + assert expected == actual + + actual = formatting.format_array_flat(np.arange(1), 1) + expected = "0" + assert expected == actual + + actual = formatting.format_array_flat(np.arange(2), 3) + expected = "0 1" + assert expected == actual + + actual = formatting.format_array_flat(np.arange(4), 7) + expected = "0 1 2 3" + assert expected == actual + + actual = formatting.format_array_flat(np.arange(5), 7) + expected = "0 ... 4" + assert expected == actual + + long_str = [" ".join(["hello world" for _ in range(100)])] + actual = formatting.format_array_flat(np.asarray([long_str]), 21) + expected = "'hello world hello..." + assert expected == actual + + def test_pretty_print(self) -> None: + assert formatting.pretty_print("abcdefghij", 8) == "abcde..." + assert formatting.pretty_print("ß", 1) == "ß" + + def test_maybe_truncate(self) -> None: + assert formatting.maybe_truncate("ß", 10) == "ß" + + def test_format_timestamp_invalid_pandas_format(self) -> None: + expected = "2021-12-06 17:00:00 00" + with pytest.raises(ValueError): + formatting.format_timestamp(expected) + + def test_format_timestamp_out_of_bounds(self) -> None: + from datetime import datetime + + date = datetime(1300, 12, 1) + expected = "1300-12-01" + result = formatting.format_timestamp(date) + assert result == expected + + date = datetime(2300, 12, 1) + expected = "2300-12-01" + result = formatting.format_timestamp(date) + assert result == expected + + def test_attribute_repr(self) -> None: + short = formatting.summarize_attr("key", "Short string") + long = formatting.summarize_attr("key", 100 * "Very long string ") + newlines = formatting.summarize_attr("key", "\n\n\n") + tabs = formatting.summarize_attr("key", "\t\t\t") + assert short == " key: Short string" + assert len(long) <= 80 + assert long.endswith("...") + assert "\n" not in newlines + assert "\t" not in tabs + + def test_index_repr(self) -> None: + from xarray.core.indexes import Index + + class CustomIndex(Index): + names: tuple[str, ...] + + def __init__(self, names: tuple[str, ...]): + self.names = names + + def __repr__(self): + return f"CustomIndex(coords={self.names})" + + coord_names = ("x", "y") + index = CustomIndex(coord_names) + names = ("x",) + + normal = formatting.summarize_index(names, index, col_width=20) + assert names[0] in normal + assert len(normal.splitlines()) == len(names) + assert "CustomIndex" in normal + + class IndexWithInlineRepr(CustomIndex): + def _repr_inline_(self, max_width: int): + return f"CustomIndex[{', '.join(self.names)}]" + + index = IndexWithInlineRepr(coord_names) + inline = formatting.summarize_index(names, index, col_width=20) + assert names[0] in inline + assert index._repr_inline_(max_width=40) in inline + + @pytest.mark.parametrize( + "names", + ( + ("x",), + ("x", "y"), + ("x", "y", "z"), + ("x", "y", "z", "a"), + ), + ) + def test_index_repr_grouping(self, names) -> None: + from xarray.core.indexes import Index + + class CustomIndex(Index): + def __init__(self, names): + self.names = names + + def __repr__(self): + return f"CustomIndex(coords={self.names})" + + index = CustomIndex(names) + + normal = formatting.summarize_index(names, index, col_width=20) + assert all(name in normal for name in names) + assert len(normal.splitlines()) == len(names) + assert "CustomIndex" in normal + + hint_chars = [line[2] for line in normal.splitlines()] + + if len(names) <= 1: + assert hint_chars == [" "] + else: + assert hint_chars[0] == "┌" and hint_chars[-1] == "└" + assert len(names) == 2 or hint_chars[1:-1] == ["│"] * (len(names) - 2) + + def test_diff_array_repr(self) -> None: + da_a = xr.DataArray( + np.array([[1, 2, 3], [4, 5, 6]], dtype="int64"), + dims=("x", "y"), + coords={ + "x": np.array(["a", "b"], dtype="U1"), + "y": np.array([1, 2, 3], dtype="int64"), + }, + attrs={"units": "m", "description": "desc"}, + ) + + da_b = xr.DataArray( + np.array([1, 2], dtype="int64"), + dims="x", + coords={ + "x": np.array(["a", "c"], dtype="U1"), + "label": ("x", np.array([1, 2], dtype="int64")), + }, + attrs={"units": "kg"}, + ) + + byteorder = "<" if sys.byteorder == "little" else ">" + expected = dedent( + """\ + Left and right DataArray objects are not identical + Differing dimensions: + (x: 2, y: 3) != (x: 2) + Differing values: + L + array([[1, 2, 3], + [4, 5, 6]], dtype=int64) + R + array([1, 2], dtype=int64) + Differing coordinates: + L * x (x) %cU1 8B 'a' 'b' + R * x (x) %cU1 8B 'a' 'c' + Coordinates only on the left object: + * y (y) int64 24B 1 2 3 + Coordinates only on the right object: + label (x) int64 16B 1 2 + Differing attributes: + L units: m + R units: kg + Attributes only on the left object: + description: desc""" + % (byteorder, byteorder) + ) + + actual = formatting.diff_array_repr(da_a, da_b, "identical") + try: + assert actual == expected + except AssertionError: + # depending on platform, dtype may not be shown in numpy array repr + assert actual == expected.replace(", dtype=int64", "") + + va = xr.Variable( + "x", np.array([1, 2, 3], dtype="int64"), {"title": "test Variable"} + ) + vb = xr.Variable(("x", "y"), np.array([[1, 2, 3], [4, 5, 6]], dtype="int64")) + + expected = dedent( + """\ + Left and right Variable objects are not equal + Differing dimensions: + (x: 3) != (x: 2, y: 3) + Differing values: + L + array([1, 2, 3], dtype=int64) + R + array([[1, 2, 3], + [4, 5, 6]], dtype=int64)""" + ) + + actual = formatting.diff_array_repr(va, vb, "equals") + try: + assert actual == expected + except AssertionError: + assert actual == expected.replace(", dtype=int64", "") + + @pytest.mark.filterwarnings("error") + def test_diff_attrs_repr_with_array(self) -> None: + attrs_a = {"attr": np.array([0, 1])} + + attrs_b = {"attr": 1} + expected = dedent( + """\ + Differing attributes: + L attr: [0 1] + R attr: 1 + """ + ).strip() + actual = formatting.diff_attrs_repr(attrs_a, attrs_b, "equals") + assert expected == actual + + attrs_c = {"attr": np.array([-3, 5])} + expected = dedent( + """\ + Differing attributes: + L attr: [0 1] + R attr: [-3 5] + """ + ).strip() + actual = formatting.diff_attrs_repr(attrs_a, attrs_c, "equals") + assert expected == actual + + # should not raise a warning + attrs_c = {"attr": np.array([0, 1, 2])} + expected = dedent( + """\ + Differing attributes: + L attr: [0 1] + R attr: [0 1 2] + """ + ).strip() + actual = formatting.diff_attrs_repr(attrs_a, attrs_c, "equals") + assert expected == actual + + def test_diff_dataset_repr(self) -> None: + ds_a = xr.Dataset( + data_vars={ + "var1": (("x", "y"), np.array([[1, 2, 3], [4, 5, 6]], dtype="int64")), + "var2": ("x", np.array([3, 4], dtype="int64")), + }, + coords={ + "x": ( + "x", + np.array(["a", "b"], dtype="U1"), + {"foo": "bar", "same": "same"}, + ), + "y": np.array([1, 2, 3], dtype="int64"), + }, + attrs={"title": "mytitle", "description": "desc"}, + ) + + ds_b = xr.Dataset( + data_vars={"var1": ("x", np.array([1, 2], dtype="int64"))}, + coords={ + "x": ( + "x", + np.array(["a", "c"], dtype="U1"), + {"source": 0, "foo": "baz", "same": "same"}, + ), + "label": ("x", np.array([1, 2], dtype="int64")), + }, + attrs={"title": "newtitle"}, + ) + + byteorder = "<" if sys.byteorder == "little" else ">" + expected = dedent( + """\ + Left and right Dataset objects are not identical + Differing dimensions: + (x: 2, y: 3) != (x: 2) + Differing coordinates: + L * x (x) %cU1 8B 'a' 'b' + Differing variable attributes: + foo: bar + R * x (x) %cU1 8B 'a' 'c' + Differing variable attributes: + source: 0 + foo: baz + Coordinates only on the left object: + * y (y) int64 24B 1 2 3 + Coordinates only on the right object: + label (x) int64 16B 1 2 + Differing data variables: + L var1 (x, y) int64 48B 1 2 3 4 5 6 + R var1 (x) int64 16B 1 2 + Data variables only on the left object: + var2 (x) int64 16B 3 4 + Differing attributes: + L title: mytitle + R title: newtitle + Attributes only on the left object: + description: desc""" + % (byteorder, byteorder) + ) + + actual = formatting.diff_dataset_repr(ds_a, ds_b, "identical") + assert actual == expected + + def test_array_repr(self) -> None: + ds = xr.Dataset( + coords={ + "foo": np.array([1, 2, 3], dtype=np.uint64), + "bar": np.array([1, 2, 3], dtype=np.uint64), + } + ) + ds[(1, 2)] = xr.DataArray(np.array([0], dtype=np.uint64), dims="test") + ds_12 = ds[(1, 2)] + + # Test repr function behaves correctly: + actual = formatting.array_repr(ds_12) + + expected = dedent( + """\ + Size: 8B + array([0], dtype=uint64) + Dimensions without coordinates: test""" + ) + + assert actual == expected + + # Test repr, str prints returns correctly as well: + assert repr(ds_12) == expected + assert str(ds_12) == expected + + # f-strings (aka format(...)) by default should use the repr: + actual = f"{ds_12}" + assert actual == expected + + with xr.set_options(display_expand_data=False): + actual = formatting.array_repr(ds[(1, 2)]) + expected = dedent( + """\ + Size: 8B + 0 + Dimensions without coordinates: test""" + ) + + assert actual == expected + + def test_array_repr_variable(self) -> None: + var = xr.Variable("x", [0, 1]) + + formatting.array_repr(var) + + with xr.set_options(display_expand_data=False): + formatting.array_repr(var) + + def test_array_repr_recursive(self) -> None: + # GH:issue:7111 + + # direct recursion + var = xr.Variable("x", [0, 1]) + var.attrs["x"] = var + formatting.array_repr(var) + + da = xr.DataArray([0, 1], dims=["x"]) + da.attrs["x"] = da + formatting.array_repr(da) + + # indirect recursion + var.attrs["x"] = da + da.attrs["x"] = var + formatting.array_repr(var) + formatting.array_repr(da) + + @requires_dask + def test_array_scalar_format(self) -> None: + # Test numpy scalars: + var = xr.DataArray(np.array(0)) + assert format(var, "") == repr(var) + assert format(var, "d") == "0" + assert format(var, ".2f") == "0.00" + + # Test dask scalars, not supported however: + import dask.array as da + + var = xr.DataArray(da.array(0)) + assert format(var, "") == repr(var) + with pytest.raises(TypeError) as excinfo: + format(var, ".2f") + assert "unsupported format string passed to" in str(excinfo.value) + + # Test numpy arrays raises: + var = xr.DataArray([0.1, 0.2]) + with pytest.raises(NotImplementedError) as excinfo: # type: ignore + format(var, ".2f") + assert "Using format_spec is only supported" in str(excinfo.value) + + def test_datatree_print_empty_node(self): + dt: DataTree = DataTree(name="root") + printout = dt.__str__() + assert printout == "DataTree('root', parent=None)" + + def test_datatree_print_empty_node_with_attrs(self): + dat = xr.Dataset(attrs={"note": "has attrs"}) + dt: DataTree = DataTree(name="root", data=dat) + printout = dt.__str__() + assert printout == dedent( + """\ + DataTree('root', parent=None) + Dimensions: () + Data variables: + *empty* + Attributes: + note: has attrs""" + ) + + def test_datatree_print_node_with_data(self): + dat = xr.Dataset({"a": [0, 2]}) + dt: DataTree = DataTree(name="root", data=dat) + printout = dt.__str__() + expected = [ + "DataTree('root', parent=None)", + "Dimensions", + "Coordinates", + "a", + "Data variables", + "*empty*", + ] + for expected_line, printed_line in zip(expected, printout.splitlines()): + assert expected_line in printed_line + + def test_datatree_printout_nested_node(self): + dat = xr.Dataset({"a": [0, 2]}) + root: DataTree = DataTree(name="root") + DataTree(name="results", data=dat, parent=root) + printout = root.__str__() + assert printout.splitlines()[2].startswith(" ") + + def test_datatree_repr_of_node_with_data(self): + dat = xr.Dataset({"a": [0, 2]}) + dt: DataTree = DataTree(name="root", data=dat) + assert "Coordinates" in repr(dt) + + def test_diff_datatree_repr_structure(self): + dt_1: DataTree = DataTree.from_dict({"a": None, "a/b": None, "a/c": None}) + dt_2: DataTree = DataTree.from_dict({"d": None, "d/e": None}) + + expected = dedent( + """\ + Left and right DataTree objects are not isomorphic + + Number of children on node '/a' of the left object: 2 + Number of children on node '/d' of the right object: 1""" + ) + actual = formatting.diff_datatree_repr(dt_1, dt_2, "isomorphic") + assert actual == expected + + def test_diff_datatree_repr_node_names(self): + dt_1: DataTree = DataTree.from_dict({"a": None}) + dt_2: DataTree = DataTree.from_dict({"b": None}) + + expected = dedent( + """\ + Left and right DataTree objects are not identical + + Node '/a' in the left object has name 'a' + Node '/b' in the right object has name 'b'""" + ) + actual = formatting.diff_datatree_repr(dt_1, dt_2, "identical") + assert actual == expected + + def test_diff_datatree_repr_node_data(self): + # casting to int64 explicitly ensures that int64s are created on all architectures + ds1 = xr.Dataset({"u": np.int64(0), "v": np.int64(1)}) + ds3 = xr.Dataset({"w": np.int64(5)}) + dt_1: DataTree = DataTree.from_dict({"a": ds1, "a/b": ds3}) + ds2 = xr.Dataset({"u": np.int64(0)}) + ds4 = xr.Dataset({"w": np.int64(6)}) + dt_2: DataTree = DataTree.from_dict({"a": ds2, "a/b": ds4}) + + expected = dedent( + """\ + Left and right DataTree objects are not equal + + + Data in nodes at position '/a' do not match: + + Data variables only on the left object: + v int64 8B 1 + + Data in nodes at position '/a/b' do not match: + + Differing data variables: + L w int64 8B 5 + R w int64 8B 6""" + ) + actual = formatting.diff_datatree_repr(dt_1, dt_2, "equals") + assert actual == expected + + +def test_inline_variable_array_repr_custom_repr() -> None: + class CustomArray: + def __init__(self, value, attr): + self.value = value + self.attr = attr + + def _repr_inline_(self, width): + formatted = f"({self.attr}) {self.value}" + if len(formatted) > width: + formatted = f"({self.attr}) ..." + + return formatted + + def __array_namespace__(self, *args, **kwargs): + return NotImplemented + + @property + def shape(self) -> tuple[int, ...]: + return self.value.shape + + @property + def dtype(self): + return self.value.dtype + + @property + def ndim(self): + return self.value.ndim + + value = CustomArray(np.array([20, 40]), "m") + variable = xr.Variable("x", value) + + max_width = 10 + actual = formatting.inline_variable_array_repr(variable, max_width=10) + + assert actual == value._repr_inline_(max_width) + + +def test_set_numpy_options() -> None: + original_options = np.get_printoptions() + with formatting.set_numpy_options(threshold=10): + assert len(repr(np.arange(500))) < 200 + # original options are restored + assert np.get_printoptions() == original_options + + +def test_short_array_repr() -> None: + cases = [ + np.random.randn(500), + np.random.randn(20, 20), + np.random.randn(5, 10, 15), + np.random.randn(5, 10, 15, 3), + np.random.randn(100, 5, 1), + ] + # number of lines: + # for default numpy repr: 167, 140, 254, 248, 599 + # for short_array_repr: 1, 7, 24, 19, 25 + for array in cases: + num_lines = formatting.short_array_repr(array).count("\n") + 1 + assert num_lines < 30 + + # threshold option (default: 200) + array2 = np.arange(100) + assert "..." not in formatting.short_array_repr(array2) + with xr.set_options(display_values_threshold=10): + assert "..." in formatting.short_array_repr(array2) + + +def test_large_array_repr_length() -> None: + da = xr.DataArray(np.random.randn(100, 5, 1)) + + result = repr(da).splitlines() + assert len(result) < 50 + + +@requires_netCDF4 +def test_repr_file_collapsed(tmp_path) -> None: + arr_to_store = xr.DataArray(np.arange(300, dtype=np.int64), dims="test") + arr_to_store.to_netcdf(tmp_path / "test.nc", engine="netcdf4") + + with ( + xr.open_dataarray(tmp_path / "test.nc") as arr, + xr.set_options(display_expand_data=False), + ): + actual = repr(arr) + expected = dedent( + """\ + Size: 2kB + [300 values with dtype=int64] + Dimensions without coordinates: test""" + ) + + assert actual == expected + + arr_loaded = arr.compute() + actual = arr_loaded.__repr__() + expected = dedent( + """\ + Size: 2kB + 0 1 2 3 4 5 6 7 8 9 10 11 12 ... 288 289 290 291 292 293 294 295 296 297 298 299 + Dimensions without coordinates: test""" + ) + + assert actual == expected + + +@pytest.mark.parametrize( + "display_max_rows, n_vars, n_attr", + [(50, 40, 30), (35, 40, 30), (11, 40, 30), (1, 40, 30)], +) +def test__mapping_repr(display_max_rows, n_vars, n_attr) -> None: + long_name = "long_name" + a = np.char.add(long_name, np.arange(0, n_vars).astype(str)) + b = np.char.add("attr_", np.arange(0, n_attr).astype(str)) + c = np.char.add("coord", np.arange(0, n_vars).astype(str)) + attrs = {k: 2 for k in b} + coords = {_c: np.array([0, 1], dtype=np.uint64) for _c in c} + data_vars = dict() + for v, _c in zip(a, coords.items()): + data_vars[v] = xr.DataArray( + name=v, + data=np.array([3, 4], dtype=np.uint64), + dims=[_c[0]], + coords=dict([_c]), + ) + + ds = xr.Dataset(data_vars) + ds.attrs = attrs + + with xr.set_options(display_max_rows=display_max_rows): + # Parse the data_vars print and show only data_vars rows: + summary = formatting.dataset_repr(ds).split("\n") + summary = [v for v in summary if long_name in v] + # The length should be less than or equal to display_max_rows: + len_summary = len(summary) + data_vars_print_size = min(display_max_rows, len_summary) + assert len_summary == data_vars_print_size + + summary = formatting.data_vars_repr(ds.data_vars).split("\n") + summary = [v for v in summary if long_name in v] + # The length should be equal to the number of data variables + len_summary = len(summary) + assert len_summary == n_vars + + summary = formatting.coords_repr(ds.coords).split("\n") + summary = [v for v in summary if "coord" in v] + # The length should be equal to the number of data variables + len_summary = len(summary) + assert len_summary == n_vars + + with xr.set_options( + display_max_rows=display_max_rows, + display_expand_coords=False, + display_expand_data_vars=False, + display_expand_attrs=False, + ): + actual = formatting.dataset_repr(ds) + col_width = formatting._calculate_col_width(ds.variables) + dims_start = formatting.pretty_print("Dimensions:", col_width) + dims_values = formatting.dim_summary_limited( + ds, col_width=col_width + 1, max_rows=display_max_rows + ) + expected_size = "1kB" + expected = f"""\ + Size: {expected_size} +{dims_start}({dims_values}) +Coordinates: ({n_vars}) +Data variables: ({n_vars}) +Attributes: ({n_attr})""" + expected = dedent(expected) + assert actual == expected + + +def test__mapping_repr_recursive() -> None: + # GH:issue:7111 + + # direct recursion + ds = xr.Dataset({"a": ("x", [1, 2, 3])}) + ds.attrs["ds"] = ds + formatting.dataset_repr(ds) + + # indirect recursion + ds2 = xr.Dataset({"b": ("y", [1, 2, 3])}) + ds.attrs["ds"] = ds2 + ds2.attrs["ds"] = ds + formatting.dataset_repr(ds2) + + +def test__element_formatter(n_elements: int = 100) -> None: + expected = """\ + Dimensions without coordinates: dim_0: 3, dim_1: 3, dim_2: 3, dim_3: 3, + dim_4: 3, dim_5: 3, dim_6: 3, dim_7: 3, + dim_8: 3, dim_9: 3, dim_10: 3, dim_11: 3, + dim_12: 3, dim_13: 3, dim_14: 3, dim_15: 3, + dim_16: 3, dim_17: 3, dim_18: 3, dim_19: 3, + dim_20: 3, dim_21: 3, dim_22: 3, dim_23: 3, + ... + dim_76: 3, dim_77: 3, dim_78: 3, dim_79: 3, + dim_80: 3, dim_81: 3, dim_82: 3, dim_83: 3, + dim_84: 3, dim_85: 3, dim_86: 3, dim_87: 3, + dim_88: 3, dim_89: 3, dim_90: 3, dim_91: 3, + dim_92: 3, dim_93: 3, dim_94: 3, dim_95: 3, + dim_96: 3, dim_97: 3, dim_98: 3, dim_99: 3""" + expected = dedent(expected) + + intro = "Dimensions without coordinates: " + elements = [ + f"{k}: {v}" for k, v in {f"dim_{k}": 3 for k in np.arange(n_elements)}.items() + ] + values = xr.core.formatting._element_formatter( + elements, col_width=len(intro), max_rows=12 + ) + actual = intro + values + assert expected == actual + + +def test_lazy_array_wont_compute() -> None: + from xarray.core.indexing import LazilyIndexedArray + + class LazilyIndexedArrayNotComputable(LazilyIndexedArray): + def __array__(self, dtype=None, copy=None): + raise NotImplementedError("Computing this array is not possible.") + + arr = LazilyIndexedArrayNotComputable(np.array([1, 2])) + var = xr.DataArray(arr) + + # These will crash if var.data are converted to numpy arrays: + var.__repr__() + var._repr_html_() + + +@pytest.mark.parametrize("as_dataset", (False, True)) +def test_format_xindexes_none(as_dataset: bool) -> None: + # ensure repr for empty xindexes can be displayed #8367 + + expected = """\ + Indexes: + *empty*""" + expected = dedent(expected) + + obj: xr.DataArray | xr.Dataset = xr.DataArray() + obj = obj._to_temp_dataset() if as_dataset else obj + + actual = repr(obj.xindexes) + assert actual == expected + + +@pytest.mark.parametrize("as_dataset", (False, True)) +def test_format_xindexes(as_dataset: bool) -> None: + expected = """\ + Indexes: + x PandasIndex""" + expected = dedent(expected) + + obj: xr.DataArray | xr.Dataset = xr.DataArray([1], coords={"x": [1]}) + obj = obj._to_temp_dataset() if as_dataset else obj + + actual = repr(obj.xindexes) + assert actual == expected + + +@requires_cftime +def test_empty_cftimeindex_repr() -> None: + index = xr.coding.cftimeindex.CFTimeIndex([]) + + expected = """\ + Indexes: + time CFTimeIndex([], dtype='object', length=0, calendar=None, freq=None)""" + expected = dedent(expected) + + da = xr.DataArray([], coords={"time": index}) + + actual = repr(da.indexes) + assert actual == expected + + +def test_display_nbytes() -> None: + xds = xr.Dataset( + { + "foo": np.arange(1200, dtype=np.int16), + "bar": np.arange(111, dtype=np.int16), + } + ) + + # Note: int16 is used to ensure that dtype is shown in the + # numpy array representation for all OSes included Windows + + actual = repr(xds) + expected = """ + Size: 3kB +Dimensions: (foo: 1200, bar: 111) +Coordinates: + * foo (foo) int16 2kB 0 1 2 3 4 5 6 ... 1194 1195 1196 1197 1198 1199 + * bar (bar) int16 222B 0 1 2 3 4 5 6 7 ... 104 105 106 107 108 109 110 +Data variables: + *empty* + """.strip() + assert actual == expected + + actual = repr(xds["foo"]) + expected = """ + Size: 2kB +array([ 0, 1, 2, ..., 1197, 1198, 1199], dtype=int16) +Coordinates: + * foo (foo) int16 2kB 0 1 2 3 4 5 6 ... 1194 1195 1196 1197 1198 1199 +""".strip() + assert actual == expected + + +def test_array_repr_dtypes(): + + # These dtypes are expected to be represented similarly + # on Ubuntu, macOS and Windows environments of the CI. + # Unsigned integer could be used as easy replacements + # for tests where the data-type does not matter, + # but the repr does, including the size + # (size of a int == size of an uint) + + # Signed integer dtypes + + ds = xr.DataArray(np.array([0], dtype="int8"), dims="x") + actual = repr(ds) + expected = """ + Size: 1B +array([0], dtype=int8) +Dimensions without coordinates: x + """.strip() + assert actual == expected + + ds = xr.DataArray(np.array([0], dtype="int16"), dims="x") + actual = repr(ds) + expected = """ + Size: 2B +array([0], dtype=int16) +Dimensions without coordinates: x + """.strip() + assert actual == expected + + # Unsigned integer dtypes + + ds = xr.DataArray(np.array([0], dtype="uint8"), dims="x") + actual = repr(ds) + expected = """ + Size: 1B +array([0], dtype=uint8) +Dimensions without coordinates: x + """.strip() + assert actual == expected + + ds = xr.DataArray(np.array([0], dtype="uint16"), dims="x") + actual = repr(ds) + expected = """ + Size: 2B +array([0], dtype=uint16) +Dimensions without coordinates: x + """.strip() + assert actual == expected + + ds = xr.DataArray(np.array([0], dtype="uint32"), dims="x") + actual = repr(ds) + expected = """ + Size: 4B +array([0], dtype=uint32) +Dimensions without coordinates: x + """.strip() + assert actual == expected + + ds = xr.DataArray(np.array([0], dtype="uint64"), dims="x") + actual = repr(ds) + expected = """ + Size: 8B +array([0], dtype=uint64) +Dimensions without coordinates: x + """.strip() + assert actual == expected + + # Float dtypes + + ds = xr.DataArray(np.array([0.0]), dims="x") + actual = repr(ds) + expected = """ + Size: 8B +array([0.]) +Dimensions without coordinates: x + """.strip() + assert actual == expected + + ds = xr.DataArray(np.array([0], dtype="float16"), dims="x") + actual = repr(ds) + expected = """ + Size: 2B +array([0.], dtype=float16) +Dimensions without coordinates: x + """.strip() + assert actual == expected + + ds = xr.DataArray(np.array([0], dtype="float32"), dims="x") + actual = repr(ds) + expected = """ + Size: 4B +array([0.], dtype=float32) +Dimensions without coordinates: x + """.strip() + assert actual == expected + + ds = xr.DataArray(np.array([0], dtype="float64"), dims="x") + actual = repr(ds) + expected = """ + Size: 8B +array([0.]) +Dimensions without coordinates: x + """.strip() + assert actual == expected + + +@pytest.mark.skipif( + ON_WINDOWS, + reason="Default numpy's dtypes vary according to OS", +) +def test_array_repr_dtypes_unix() -> None: + + # Signed integer dtypes + + ds = xr.DataArray(np.array([0]), dims="x") + actual = repr(ds) + expected = """ + Size: 8B +array([0]) +Dimensions without coordinates: x + """.strip() + assert actual == expected + + ds = xr.DataArray(np.array([0], dtype="int32"), dims="x") + actual = repr(ds) + expected = """ + Size: 4B +array([0], dtype=int32) +Dimensions without coordinates: x + """.strip() + assert actual == expected + + ds = xr.DataArray(np.array([0], dtype="int64"), dims="x") + actual = repr(ds) + expected = """ + Size: 8B +array([0]) +Dimensions without coordinates: x + """.strip() + assert actual == expected + + +@pytest.mark.skipif( + not ON_WINDOWS, + reason="Default numpy's dtypes vary according to OS", +) +def test_array_repr_dtypes_on_windows() -> None: + + # Integer dtypes + + ds = xr.DataArray(np.array([0]), dims="x") + actual = repr(ds) + expected = """ + Size: 4B +array([0]) +Dimensions without coordinates: x + """.strip() + assert actual == expected + + ds = xr.DataArray(np.array([0], dtype="int32"), dims="x") + actual = repr(ds) + expected = """ + Size: 4B +array([0]) +Dimensions without coordinates: x + """.strip() + assert actual == expected + + ds = xr.DataArray(np.array([0], dtype="int64"), dims="x") + actual = repr(ds) + expected = """ + Size: 8B +array([0], dtype=int64) +Dimensions without coordinates: x + """.strip() + assert actual == expected diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_formatting_html.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_formatting_html.py new file mode 100644 index 0000000..ada7f75 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_formatting_html.py @@ -0,0 +1,393 @@ +from __future__ import annotations + +import numpy as np +import pandas as pd +import pytest + +import xarray as xr +from xarray.core import formatting_html as fh +from xarray.core.coordinates import Coordinates +from xarray.core.datatree import DataTree + + +@pytest.fixture +def dataarray() -> xr.DataArray: + return xr.DataArray(np.random.RandomState(0).randn(4, 6)) + + +@pytest.fixture +def dask_dataarray(dataarray: xr.DataArray) -> xr.DataArray: + pytest.importorskip("dask") + return dataarray.chunk() + + +@pytest.fixture +def multiindex() -> xr.Dataset: + midx = pd.MultiIndex.from_product( + [["a", "b"], [1, 2]], names=("level_1", "level_2") + ) + midx_coords = Coordinates.from_pandas_multiindex(midx, "x") + return xr.Dataset({}, midx_coords) + + +@pytest.fixture +def dataset() -> xr.Dataset: + times = pd.date_range("2000-01-01", "2001-12-31", name="time") + annual_cycle = np.sin(2 * np.pi * (times.dayofyear.values / 365.25 - 0.28)) + + base = 10 + 15 * annual_cycle.reshape(-1, 1) + tmin_values = base + 3 * np.random.randn(annual_cycle.size, 3) + tmax_values = base + 10 + 3 * np.random.randn(annual_cycle.size, 3) + + return xr.Dataset( + { + "tmin": (("time", "location"), tmin_values), + "tmax": (("time", "location"), tmax_values), + }, + {"time": times, "location": ["", "IN", "IL"]}, + attrs={"description": "Test data."}, + ) + + +def test_short_data_repr_html(dataarray: xr.DataArray) -> None: + data_repr = fh.short_data_repr_html(dataarray) + assert data_repr.startswith("
    array")
    +
    +
    +def test_short_data_repr_html_non_str_keys(dataset: xr.Dataset) -> None:
    +    ds = dataset.assign({2: lambda x: x["tmin"]})
    +    fh.dataset_repr(ds)
    +
    +
    +def test_short_data_repr_html_dask(dask_dataarray: xr.DataArray) -> None:
    +    assert hasattr(dask_dataarray.data, "_repr_html_")
    +    data_repr = fh.short_data_repr_html(dask_dataarray)
    +    assert data_repr == dask_dataarray.data._repr_html_()
    +
    +
    +def test_format_dims_no_dims() -> None:
    +    dims: dict = {}
    +    dims_with_index: list = []
    +    formatted = fh.format_dims(dims, dims_with_index)
    +    assert formatted == ""
    +
    +
    +def test_format_dims_unsafe_dim_name() -> None:
    +    dims = {"": 3, "y": 2}
    +    dims_with_index: list = []
    +    formatted = fh.format_dims(dims, dims_with_index)
    +    assert "<x>" in formatted
    +
    +
    +def test_format_dims_non_index() -> None:
    +    dims, dims_with_index = {"x": 3, "y": 2}, ["time"]
    +    formatted = fh.format_dims(dims, dims_with_index)
    +    assert "class='xr-has-index'" not in formatted
    +
    +
    +def test_format_dims_index() -> None:
    +    dims, dims_with_index = {"x": 3, "y": 2}, ["x"]
    +    formatted = fh.format_dims(dims, dims_with_index)
    +    assert "class='xr-has-index'" in formatted
    +
    +
    +def test_summarize_attrs_with_unsafe_attr_name_and_value() -> None:
    +    attrs = {"": 3, "y": ""}
    +    formatted = fh.summarize_attrs(attrs)
    +    assert "
    <x> :
    " in formatted + assert "
    y :
    " in formatted + assert "
    3
    " in formatted + assert "
    <pd.DataFrame>
    " in formatted + + +def test_repr_of_dataarray(dataarray: xr.DataArray) -> None: + formatted = fh.array_repr(dataarray) + assert "dim_0" in formatted + # has an expanded data section + assert formatted.count("class='xr-array-in' type='checkbox' checked>") == 1 + # coords, indexes and attrs don't have an items so they'll be be disabled and collapsed + assert ( + formatted.count("class='xr-section-summary-in' type='checkbox' disabled >") == 3 + ) + + with xr.set_options(display_expand_data=False): + formatted = fh.array_repr(dataarray) + assert "dim_0" in formatted + # has a collapsed data section + assert formatted.count("class='xr-array-in' type='checkbox' checked>") == 0 + # coords, indexes and attrs don't have an items so they'll be be disabled and collapsed + assert ( + formatted.count("class='xr-section-summary-in' type='checkbox' disabled >") + == 3 + ) + + +def test_repr_of_multiindex(multiindex: xr.Dataset) -> None: + formatted = fh.dataset_repr(multiindex) + assert "(x)" in formatted + + +def test_repr_of_dataset(dataset: xr.Dataset) -> None: + formatted = fh.dataset_repr(dataset) + # coords, attrs, and data_vars are expanded + assert ( + formatted.count("class='xr-section-summary-in' type='checkbox' checked>") == 3 + ) + # indexes is collapsed + assert formatted.count("class='xr-section-summary-in' type='checkbox' >") == 1 + assert "<U4" in formatted or ">U4" in formatted + assert "<IA>" in formatted + + with xr.set_options( + display_expand_coords=False, + display_expand_data_vars=False, + display_expand_attrs=False, + display_expand_indexes=True, + ): + formatted = fh.dataset_repr(dataset) + # coords, attrs, and data_vars are collapsed, indexes is expanded + assert ( + formatted.count("class='xr-section-summary-in' type='checkbox' checked>") + == 1 + ) + assert "<U4" in formatted or ">U4" in formatted + assert "<IA>" in formatted + + +def test_repr_text_fallback(dataset: xr.Dataset) -> None: + formatted = fh.dataset_repr(dataset) + + # Just test that the "pre" block used for fallback to plain text is present. + assert "
    " in formatted
    +
    +
    +def test_variable_repr_html() -> None:
    +    v = xr.Variable(["time", "x"], [[1, 2, 3], [4, 5, 6]], {"foo": "bar"})
    +    assert hasattr(v, "_repr_html_")
    +    with xr.set_options(display_style="html"):
    +        html = v._repr_html_().strip()
    +    # We don't do a complete string identity since
    +    # html output is probably subject to change, is long and... reasons.
    +    # Just test that something reasonable was produced.
    +    assert html.startswith("")
    +    assert "xarray.Variable" in html
    +
    +
    +def test_repr_of_nonstr_dataset(dataset: xr.Dataset) -> None:
    +    ds = dataset.copy()
    +    ds.attrs[1] = "Test value"
    +    ds[2] = ds["tmin"]
    +    formatted = fh.dataset_repr(ds)
    +    assert "
    1 :
    Test value
    " in formatted + assert "
    2" in formatted + + +def test_repr_of_nonstr_dataarray(dataarray: xr.DataArray) -> None: + da = dataarray.rename(dim_0=15) + da.attrs[1] = "value" + formatted = fh.array_repr(da) + assert "
    1 :
    value
    " in formatted + assert "
  • 15: 4
  • " in formatted + + +def test_nonstr_variable_repr_html() -> None: + v = xr.Variable(["time", 10], [[1, 2, 3], [4, 5, 6]], {22: "bar"}) + assert hasattr(v, "_repr_html_") + with xr.set_options(display_style="html"): + html = v._repr_html_().strip() + assert "
    22 :
    bar
    " in html + assert "
  • 10: 3
  • " in html + + +@pytest.fixture(scope="module", params=["some html", "some other html"]) +def repr(request): + return request.param + + +class Test_summarize_datatree_children: + """ + Unit tests for summarize_datatree_children. + """ + + func = staticmethod(fh.summarize_datatree_children) + + @pytest.fixture(scope="class") + def childfree_tree_factory(self): + """ + Fixture for a child-free DataTree factory. + """ + from random import randint + + def _childfree_tree_factory(): + return DataTree( + data=xr.Dataset({"z": ("y", [randint(1, 100) for _ in range(3)])}) + ) + + return _childfree_tree_factory + + @pytest.fixture(scope="class") + def childfree_tree(self, childfree_tree_factory): + """ + Fixture for a child-free DataTree. + """ + return childfree_tree_factory() + + @pytest.fixture(scope="function") + def mock_datatree_node_repr(self, monkeypatch): + """ + Apply mocking for datatree_node_repr. + """ + + def mock(group_title, dt): + """ + Mock with a simple result + """ + return group_title + " " + str(id(dt)) + + monkeypatch.setattr(fh, "datatree_node_repr", mock) + + @pytest.fixture(scope="function") + def mock_wrap_datatree_repr(self, monkeypatch): + """ + Apply mocking for _wrap_datatree_repr. + """ + + def mock(r, *, end, **kwargs): + """ + Mock by appending "end" or "not end". + """ + return r + " " + ("end" if end else "not end") + "//" + + monkeypatch.setattr(fh, "_wrap_datatree_repr", mock) + + def test_empty_mapping(self): + """ + Test with an empty mapping of children. + """ + children: dict[str, DataTree] = {} + assert self.func(children) == ( + "
    " + "
    " + ) + + def test_one_child( + self, childfree_tree, mock_wrap_datatree_repr, mock_datatree_node_repr + ): + """ + Test with one child. + + Uses a mock of _wrap_datatree_repr and _datatree_node_repr to essentially mock + the inline lambda function "lines_callback". + """ + # Create mapping of children + children = {"a": childfree_tree} + + # Expect first line to be produced from the first child, and + # wrapped as the last child + first_line = f"a {id(children['a'])} end//" + + assert self.func(children) == ( + "
    " + f"{first_line}" + "
    " + ) + + def test_two_children( + self, childfree_tree_factory, mock_wrap_datatree_repr, mock_datatree_node_repr + ): + """ + Test with two level deep children. + + Uses a mock of _wrap_datatree_repr and datatree_node_repr to essentially mock + the inline lambda function "lines_callback". + """ + + # Create mapping of children + children = {"a": childfree_tree_factory(), "b": childfree_tree_factory()} + + # Expect first line to be produced from the first child, and + # wrapped as _not_ the last child + first_line = f"a {id(children['a'])} not end//" + + # Expect second line to be produced from the second child, and + # wrapped as the last child + second_line = f"b {id(children['b'])} end//" + + assert self.func(children) == ( + "
    " + f"{first_line}" + f"{second_line}" + "
    " + ) + + +class Test__wrap_datatree_repr: + """ + Unit tests for _wrap_datatree_repr. + """ + + func = staticmethod(fh._wrap_datatree_repr) + + def test_end(self, repr): + """ + Test with end=True. + """ + r = self.func(repr, end=True) + assert r == ( + "
    " + "
    " + "
    " + "
    " + "
    " + "
    " + f"{repr}" + "
    " + "
    " + ) + + def test_not_end(self, repr): + """ + Test with end=False. + """ + r = self.func(repr, end=False) + assert r == ( + "
    " + "
    " + "
    " + "
    " + "
    " + "
    " + f"{repr}" + "
    " + "
    " + ) diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_groupby.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_groupby.py new file mode 100644 index 0000000..47cda06 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_groupby.py @@ -0,0 +1,2608 @@ +from __future__ import annotations + +import datetime +import operator +import warnings +from unittest import mock + +import numpy as np +import pandas as pd +import pytest +from packaging.version import Version + +import xarray as xr +from xarray import DataArray, Dataset, Variable +from xarray.core.groupby import _consolidate_slices +from xarray.core.types import InterpOptions +from xarray.tests import ( + InaccessibleArray, + assert_allclose, + assert_equal, + assert_identical, + create_test_data, + has_cftime, + has_flox, + requires_dask, + requires_flox, + requires_scipy, +) + + +@pytest.fixture +def dataset() -> xr.Dataset: + ds = xr.Dataset( + { + "foo": (("x", "y", "z"), np.random.randn(3, 4, 2)), + "baz": ("x", ["e", "f", "g"]), + "cat": ("y", pd.Categorical(["cat1", "cat2", "cat2", "cat1"])), + }, + {"x": ("x", ["a", "b", "c"], {"name": "x"}), "y": [1, 2, 3, 4], "z": [1, 2]}, + ) + ds["boo"] = (("z", "y"), [["f", "g", "h", "j"]] * 2) + + return ds + + +@pytest.fixture +def array(dataset) -> xr.DataArray: + return dataset["foo"] + + +def test_consolidate_slices() -> None: + assert _consolidate_slices([slice(3), slice(3, 5)]) == [slice(5)] + assert _consolidate_slices([slice(2, 3), slice(3, 6)]) == [slice(2, 6)] + assert _consolidate_slices([slice(2, 3, 1), slice(3, 6, 1)]) == [slice(2, 6, 1)] + + slices = [slice(2, 3), slice(5, 6)] + assert _consolidate_slices(slices) == slices + + # ignore type because we're checking for an error anyway + with pytest.raises(ValueError): + _consolidate_slices([slice(3), 4]) # type: ignore[list-item] + + +@pytest.mark.filterwarnings("ignore:return type") +def test_groupby_dims_property(dataset, recwarn) -> None: + # dims is sensitive to squeeze, always warn + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + assert dataset.groupby("x").dims == dataset.isel(x=1).dims + assert dataset.groupby("y").dims == dataset.isel(y=1).dims + # in pytest-8, pytest.warns() no longer clears all warnings + recwarn.clear() + + # when squeeze=False, no warning should be raised + assert tuple(dataset.groupby("x", squeeze=False).dims) == tuple( + dataset.isel(x=slice(1, 2)).dims + ) + assert tuple(dataset.groupby("y", squeeze=False).dims) == tuple( + dataset.isel(y=slice(1, 2)).dims + ) + assert len(recwarn) == 0 + + dataset = dataset.drop_vars(["cat"]) + stacked = dataset.stack({"xy": ("x", "y")}) + assert tuple(stacked.groupby("xy", squeeze=False).dims) == tuple( + stacked.isel(xy=[0]).dims + ) + assert len(recwarn) == 0 + + +def test_groupby_sizes_property(dataset) -> None: + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + assert dataset.groupby("x").sizes == dataset.isel(x=1).sizes + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + assert dataset.groupby("y").sizes == dataset.isel(y=1).sizes + dataset = dataset.drop_vars("cat") + stacked = dataset.stack({"xy": ("x", "y")}) + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + assert stacked.groupby("xy").sizes == stacked.isel(xy=0).sizes + + +def test_multi_index_groupby_map(dataset) -> None: + # regression test for GH873 + ds = dataset.isel(z=1, drop=True)[["foo"]] + expected = 2 * ds + # The function in `map` may be sensitive to squeeze, always warn + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + actual = ( + ds.stack(space=["x", "y"]) + .groupby("space") + .map(lambda x: 2 * x) + .unstack("space") + ) + assert_equal(expected, actual) + + +def test_reduce_numeric_only(dataset) -> None: + gb = dataset.groupby("x", squeeze=False) + with xr.set_options(use_flox=False): + expected = gb.sum() + with xr.set_options(use_flox=True): + actual = gb.sum() + assert_identical(expected, actual) + + +def test_multi_index_groupby_sum() -> None: + # regression test for GH873 + ds = xr.Dataset( + {"foo": (("x", "y", "z"), np.ones((3, 4, 2)))}, + {"x": ["a", "b", "c"], "y": [1, 2, 3, 4]}, + ) + expected = ds.sum("z") + actual = ds.stack(space=["x", "y"]).groupby("space").sum("z").unstack("space") + assert_equal(expected, actual) + + +def test_groupby_da_datetime() -> None: + # test groupby with a DataArray of dtype datetime for GH1132 + # create test data + times = pd.date_range("2000-01-01", periods=4) + foo = xr.DataArray([1, 2, 3, 4], coords=dict(time=times), dims="time") + # create test index + reference_dates = [times[0], times[2]] + labels = reference_dates[0:1] * 2 + reference_dates[1:2] * 2 + ind = xr.DataArray( + labels, coords=dict(time=times), dims="time", name="reference_date" + ) + g = foo.groupby(ind) + actual = g.sum(dim="time") + expected = xr.DataArray( + [3, 7], coords=dict(reference_date=reference_dates), dims="reference_date" + ) + assert_equal(expected, actual) + + +def test_groupby_duplicate_coordinate_labels() -> None: + # fix for http://stackoverflow.com/questions/38065129 + array = xr.DataArray([1, 2, 3], [("x", [1, 1, 2])]) + expected = xr.DataArray([3, 3], [("x", [1, 2])]) + actual = array.groupby("x").sum() + assert_equal(expected, actual) + + +def test_groupby_input_mutation() -> None: + # regression test for GH2153 + array = xr.DataArray([1, 2, 3], [("x", [2, 2, 1])]) + array_copy = array.copy() + expected = xr.DataArray([3, 3], [("x", [1, 2])]) + actual = array.groupby("x").sum() + assert_identical(expected, actual) + assert_identical(array, array_copy) # should not modify inputs + + +@pytest.mark.parametrize("use_flox", [True, False]) +def test_groupby_indexvariable(use_flox: bool) -> None: + # regression test for GH7919 + array = xr.DataArray([1, 2, 3], [("x", [2, 2, 1])]) + iv = xr.IndexVariable(dims="x", data=pd.Index(array.x.values)) + with xr.set_options(use_flox=use_flox): + actual = array.groupby(iv).sum() + actual = array.groupby(iv).sum() + expected = xr.DataArray([3, 3], [("x", [1, 2])]) + assert_identical(expected, actual) + + +@pytest.mark.parametrize( + "obj", + [ + xr.DataArray([1, 2, 3, 4, 5, 6], [("x", [1, 1, 1, 2, 2, 2])]), + xr.Dataset({"foo": ("x", [1, 2, 3, 4, 5, 6])}, {"x": [1, 1, 1, 2, 2, 2]}), + ], +) +def test_groupby_map_shrink_groups(obj) -> None: + expected = obj.isel(x=[0, 1, 3, 4]) + actual = obj.groupby("x").map(lambda f: f.isel(x=[0, 1])) + assert_identical(expected, actual) + + +@pytest.mark.parametrize( + "obj", + [ + xr.DataArray([1, 2, 3], [("x", [1, 2, 2])]), + xr.Dataset({"foo": ("x", [1, 2, 3])}, {"x": [1, 2, 2]}), + ], +) +def test_groupby_map_change_group_size(obj) -> None: + def func(group): + if group.sizes["x"] == 1: + result = group.isel(x=[0, 0]) + else: + result = group.isel(x=[0]) + return result + + expected = obj.isel(x=[0, 0, 1]) + actual = obj.groupby("x").map(func) + assert_identical(expected, actual) + + +def test_da_groupby_map_func_args() -> None: + def func(arg1, arg2, arg3=0): + return arg1 + arg2 + arg3 + + array = xr.DataArray([1, 1, 1], [("x", [1, 2, 3])]) + expected = xr.DataArray([3, 3, 3], [("x", [1, 2, 3])]) + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + actual = array.groupby("x").map(func, args=(1,), arg3=1) + assert_identical(expected, actual) + + +def test_ds_groupby_map_func_args() -> None: + def func(arg1, arg2, arg3=0): + return arg1 + arg2 + arg3 + + dataset = xr.Dataset({"foo": ("x", [1, 1, 1])}, {"x": [1, 2, 3]}) + expected = xr.Dataset({"foo": ("x", [3, 3, 3])}, {"x": [1, 2, 3]}) + # The function in `map` may be sensitive to squeeze, always warn + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + actual = dataset.groupby("x").map(func, args=(1,), arg3=1) + assert_identical(expected, actual) + + +def test_da_groupby_empty() -> None: + empty_array = xr.DataArray([], dims="dim") + + with pytest.raises(ValueError): + empty_array.groupby("dim") + + +@requires_dask +def test_dask_da_groupby_quantile() -> None: + # Only works when the grouped reduction can run blockwise + # Scalar quantile + expected = xr.DataArray( + data=[2, 5], coords={"x": [1, 2], "quantile": 0.5}, dims="x" + ) + array = xr.DataArray( + data=[1, 2, 3, 4, 5, 6], coords={"x": [1, 1, 1, 2, 2, 2]}, dims="x" + ) + with pytest.raises(ValueError): + array.chunk(x=1).groupby("x").quantile(0.5) + + # will work blockwise with flox + actual = array.chunk(x=3).groupby("x").quantile(0.5) + assert_identical(expected, actual) + + # will work blockwise with flox + actual = array.chunk(x=-1).groupby("x").quantile(0.5) + assert_identical(expected, actual) + + +@requires_dask +def test_dask_da_groupby_median() -> None: + expected = xr.DataArray(data=[2, 5], coords={"x": [1, 2]}, dims="x") + array = xr.DataArray( + data=[1, 2, 3, 4, 5, 6], coords={"x": [1, 1, 1, 2, 2, 2]}, dims="x" + ) + with xr.set_options(use_flox=False): + actual = array.chunk(x=1).groupby("x").median() + assert_identical(expected, actual) + + with xr.set_options(use_flox=True): + actual = array.chunk(x=1).groupby("x").median() + assert_identical(expected, actual) + + # will work blockwise with flox + actual = array.chunk(x=3).groupby("x").median() + assert_identical(expected, actual) + + # will work blockwise with flox + actual = array.chunk(x=-1).groupby("x").median() + assert_identical(expected, actual) + + +def test_da_groupby_quantile() -> None: + array = xr.DataArray( + data=[1, 2, 3, 4, 5, 6], coords={"x": [1, 1, 1, 2, 2, 2]}, dims="x" + ) + + # Scalar quantile + expected = xr.DataArray( + data=[2, 5], coords={"x": [1, 2], "quantile": 0.5}, dims="x" + ) + actual = array.groupby("x").quantile(0.5) + assert_identical(expected, actual) + + # Vector quantile + expected = xr.DataArray( + data=[[1, 3], [4, 6]], + coords={"x": [1, 2], "quantile": [0, 1]}, + dims=("x", "quantile"), + ) + actual = array.groupby("x").quantile([0, 1]) + assert_identical(expected, actual) + + array = xr.DataArray( + data=[np.nan, 2, 3, 4, 5, 6], coords={"x": [1, 1, 1, 2, 2, 2]}, dims="x" + ) + + for skipna in (True, False, None): + e = [np.nan, 5] if skipna is False else [2.5, 5] + + expected = xr.DataArray(data=e, coords={"x": [1, 2], "quantile": 0.5}, dims="x") + actual = array.groupby("x").quantile(0.5, skipna=skipna) + assert_identical(expected, actual) + + # Multiple dimensions + array = xr.DataArray( + data=[[1, 11, 26], [2, 12, 22], [3, 13, 23], [4, 16, 24], [5, 15, 25]], + coords={"x": [1, 1, 1, 2, 2], "y": [0, 0, 1]}, + dims=("x", "y"), + ) + + actual_x = array.groupby("x").quantile(0, dim=...) + expected_x = xr.DataArray( + data=[1, 4], coords={"x": [1, 2], "quantile": 0}, dims="x" + ) + assert_identical(expected_x, actual_x) + + actual_y = array.groupby("y").quantile(0, dim=...) + expected_y = xr.DataArray( + data=[1, 22], coords={"y": [0, 1], "quantile": 0}, dims="y" + ) + assert_identical(expected_y, actual_y) + + actual_xx = array.groupby("x").quantile(0) + expected_xx = xr.DataArray( + data=[[1, 11, 22], [4, 15, 24]], + coords={"x": [1, 2], "y": [0, 0, 1], "quantile": 0}, + dims=("x", "y"), + ) + assert_identical(expected_xx, actual_xx) + + actual_yy = array.groupby("y").quantile(0) + expected_yy = xr.DataArray( + data=[[1, 26], [2, 22], [3, 23], [4, 24], [5, 25]], + coords={"x": [1, 1, 1, 2, 2], "y": [0, 1], "quantile": 0}, + dims=("x", "y"), + ) + assert_identical(expected_yy, actual_yy) + + times = pd.date_range("2000-01-01", periods=365) + x = [0, 1] + foo = xr.DataArray( + np.reshape(np.arange(365 * 2), (365, 2)), + coords={"time": times, "x": x}, + dims=("time", "x"), + ) + g = foo.groupby(foo.time.dt.month) + + actual = g.quantile(0, dim=...) + expected = xr.DataArray( + data=[ + 0.0, + 62.0, + 120.0, + 182.0, + 242.0, + 304.0, + 364.0, + 426.0, + 488.0, + 548.0, + 610.0, + 670.0, + ], + coords={"month": np.arange(1, 13), "quantile": 0}, + dims="month", + ) + assert_identical(expected, actual) + + actual = g.quantile(0, dim="time")[:2] + expected = xr.DataArray( + data=[[0.0, 1], [62.0, 63]], + coords={"month": [1, 2], "x": [0, 1], "quantile": 0}, + dims=("month", "x"), + ) + assert_identical(expected, actual) + + # method keyword + array = xr.DataArray(data=[1, 2, 3, 4], coords={"x": [1, 1, 2, 2]}, dims="x") + + expected = xr.DataArray( + data=[1, 3], coords={"x": [1, 2], "quantile": 0.5}, dims="x" + ) + actual = array.groupby("x").quantile(0.5, method="lower") + assert_identical(expected, actual) + + +def test_ds_groupby_quantile() -> None: + ds = xr.Dataset( + data_vars={"a": ("x", [1, 2, 3, 4, 5, 6])}, coords={"x": [1, 1, 1, 2, 2, 2]} + ) + + # Scalar quantile + expected = xr.Dataset( + data_vars={"a": ("x", [2, 5])}, coords={"quantile": 0.5, "x": [1, 2]} + ) + actual = ds.groupby("x").quantile(0.5) + assert_identical(expected, actual) + + # Vector quantile + expected = xr.Dataset( + data_vars={"a": (("x", "quantile"), [[1, 3], [4, 6]])}, + coords={"x": [1, 2], "quantile": [0, 1]}, + ) + actual = ds.groupby("x").quantile([0, 1]) + assert_identical(expected, actual) + + ds = xr.Dataset( + data_vars={"a": ("x", [np.nan, 2, 3, 4, 5, 6])}, + coords={"x": [1, 1, 1, 2, 2, 2]}, + ) + + for skipna in (True, False, None): + e = [np.nan, 5] if skipna is False else [2.5, 5] + + expected = xr.Dataset( + data_vars={"a": ("x", e)}, coords={"quantile": 0.5, "x": [1, 2]} + ) + actual = ds.groupby("x").quantile(0.5, skipna=skipna) + assert_identical(expected, actual) + + # Multiple dimensions + ds = xr.Dataset( + data_vars={ + "a": ( + ("x", "y"), + [[1, 11, 26], [2, 12, 22], [3, 13, 23], [4, 16, 24], [5, 15, 25]], + ) + }, + coords={"x": [1, 1, 1, 2, 2], "y": [0, 0, 1]}, + ) + + actual_x = ds.groupby("x").quantile(0, dim=...) + expected_x = xr.Dataset({"a": ("x", [1, 4])}, coords={"x": [1, 2], "quantile": 0}) + assert_identical(expected_x, actual_x) + + actual_y = ds.groupby("y").quantile(0, dim=...) + expected_y = xr.Dataset({"a": ("y", [1, 22])}, coords={"y": [0, 1], "quantile": 0}) + assert_identical(expected_y, actual_y) + + actual_xx = ds.groupby("x").quantile(0) + expected_xx = xr.Dataset( + {"a": (("x", "y"), [[1, 11, 22], [4, 15, 24]])}, + coords={"x": [1, 2], "y": [0, 0, 1], "quantile": 0}, + ) + assert_identical(expected_xx, actual_xx) + + actual_yy = ds.groupby("y").quantile(0) + expected_yy = xr.Dataset( + {"a": (("x", "y"), [[1, 26], [2, 22], [3, 23], [4, 24], [5, 25]])}, + coords={"x": [1, 1, 1, 2, 2], "y": [0, 1], "quantile": 0}, + ).transpose() + assert_identical(expected_yy, actual_yy) + + times = pd.date_range("2000-01-01", periods=365) + x = [0, 1] + foo = xr.Dataset( + {"a": (("time", "x"), np.reshape(np.arange(365 * 2), (365, 2)))}, + coords=dict(time=times, x=x), + ) + g = foo.groupby(foo.time.dt.month) + + actual = g.quantile(0, dim=...) + expected = xr.Dataset( + { + "a": ( + "month", + [ + 0.0, + 62.0, + 120.0, + 182.0, + 242.0, + 304.0, + 364.0, + 426.0, + 488.0, + 548.0, + 610.0, + 670.0, + ], + ) + }, + coords={"month": np.arange(1, 13), "quantile": 0}, + ) + assert_identical(expected, actual) + + actual = g.quantile(0, dim="time").isel(month=slice(None, 2)) + expected = xr.Dataset( + data_vars={"a": (("month", "x"), [[0.0, 1], [62.0, 63]])}, + coords={"month": [1, 2], "x": [0, 1], "quantile": 0}, + ) + assert_identical(expected, actual) + + ds = xr.Dataset(data_vars={"a": ("x", [1, 2, 3, 4])}, coords={"x": [1, 1, 2, 2]}) + + # method keyword + expected = xr.Dataset( + data_vars={"a": ("x", [1, 3])}, coords={"quantile": 0.5, "x": [1, 2]} + ) + actual = ds.groupby("x").quantile(0.5, method="lower") + assert_identical(expected, actual) + + +@pytest.mark.parametrize("as_dataset", [False, True]) +def test_groupby_quantile_interpolation_deprecated(as_dataset: bool) -> None: + array = xr.DataArray(data=[1, 2, 3, 4], coords={"x": [1, 1, 2, 2]}, dims="x") + + arr: xr.DataArray | xr.Dataset + arr = array.to_dataset(name="name") if as_dataset else array + + with pytest.warns( + FutureWarning, + match="`interpolation` argument to quantile was renamed to `method`", + ): + actual = arr.quantile(0.5, interpolation="lower") + + expected = arr.quantile(0.5, method="lower") + + assert_identical(actual, expected) + + with warnings.catch_warnings(record=True): + with pytest.raises(TypeError, match="interpolation and method keywords"): + arr.quantile(0.5, method="lower", interpolation="lower") + + +def test_da_groupby_assign_coords() -> None: + actual = xr.DataArray( + [[3, 4, 5], [6, 7, 8]], dims=["y", "x"], coords={"y": range(2), "x": range(3)} + ) + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + actual1 = actual.groupby("x").assign_coords({"y": [-1, -2]}) + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + actual2 = actual.groupby("x").assign_coords(y=[-1, -2]) + expected = xr.DataArray( + [[3, 4, 5], [6, 7, 8]], dims=["y", "x"], coords={"y": [-1, -2], "x": range(3)} + ) + assert_identical(expected, actual1) + assert_identical(expected, actual2) + + +repr_da = xr.DataArray( + np.random.randn(10, 20, 6, 24), + dims=["x", "y", "z", "t"], + coords={ + "z": ["a", "b", "c", "a", "b", "c"], + "x": [1, 1, 1, 2, 2, 3, 4, 5, 3, 4], + "t": xr.date_range("2001-01-01", freq="ME", periods=24, use_cftime=False), + "month": ("t", list(range(1, 13)) * 2), + }, +) + + +@pytest.mark.parametrize("dim", ["x", "y", "z", "month"]) +@pytest.mark.parametrize("obj", [repr_da, repr_da.to_dataset(name="a")]) +def test_groupby_repr(obj, dim) -> None: + actual = repr(obj.groupby(dim)) + expected = f"{obj.__class__.__name__}GroupBy" + expected += f", grouped over {dim!r}" + expected += f"\n{len(np.unique(obj[dim]))!r} groups with labels " + if dim == "x": + expected += "1, 2, 3, 4, 5." + elif dim == "y": + expected += "0, 1, 2, 3, 4, 5, ..., 15, 16, 17, 18, 19." + elif dim == "z": + expected += "'a', 'b', 'c'." + elif dim == "month": + expected += "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12." + assert actual == expected + + +@pytest.mark.parametrize("obj", [repr_da, repr_da.to_dataset(name="a")]) +def test_groupby_repr_datetime(obj) -> None: + actual = repr(obj.groupby("t.month")) + expected = f"{obj.__class__.__name__}GroupBy" + expected += ", grouped over 'month'" + expected += f"\n{len(np.unique(obj.t.dt.month))!r} groups with labels " + expected += "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12." + assert actual == expected + + +@pytest.mark.filterwarnings("ignore:Converting non-nanosecond") +@pytest.mark.filterwarnings("ignore:invalid value encountered in divide:RuntimeWarning") +def test_groupby_drops_nans() -> None: + # GH2383 + # nan in 2D data variable (requires stacking) + ds = xr.Dataset( + { + "variable": (("lat", "lon", "time"), np.arange(60.0).reshape((4, 3, 5))), + "id": (("lat", "lon"), np.arange(12.0).reshape((4, 3))), + }, + coords={"lat": np.arange(4), "lon": np.arange(3), "time": np.arange(5)}, + ) + + ds["id"].values[0, 0] = np.nan + ds["id"].values[3, 0] = np.nan + ds["id"].values[-1, -1] = np.nan + + grouped = ds.groupby(ds.id) + + # non reduction operation + expected1 = ds.copy() + expected1.variable.values[0, 0, :] = np.nan + expected1.variable.values[-1, -1, :] = np.nan + expected1.variable.values[3, 0, :] = np.nan + actual1 = grouped.map(lambda x: x).transpose(*ds.variable.dims) + assert_identical(actual1, expected1) + + # reduction along grouped dimension + actual2 = grouped.mean() + stacked = ds.stack({"xy": ["lat", "lon"]}) + expected2 = ( + stacked.variable.where(stacked.id.notnull()) + .rename({"xy": "id"}) + .to_dataset() + .reset_index("id", drop=True) + .assign(id=stacked.id.values) + .dropna("id") + .transpose(*actual2.variable.dims) + ) + assert_identical(actual2, expected2) + + # reduction operation along a different dimension + actual3 = grouped.mean("time") + expected3 = ds.mean("time").where(ds.id.notnull()) + assert_identical(actual3, expected3) + + # NaN in non-dimensional coordinate + array = xr.DataArray([1, 2, 3], [("x", [1, 2, 3])]) + array["x1"] = ("x", [1, 1, np.nan]) + expected4 = xr.DataArray(3, [("x1", [1])]) + actual4 = array.groupby("x1").sum() + assert_equal(expected4, actual4) + + # NaT in non-dimensional coordinate + array["t"] = ( + "x", + [ + np.datetime64("2001-01-01"), + np.datetime64("2001-01-01"), + np.datetime64("NaT"), + ], + ) + expected5 = xr.DataArray(3, [("t", [np.datetime64("2001-01-01")])]) + actual5 = array.groupby("t").sum() + assert_equal(expected5, actual5) + + # test for repeated coordinate labels + array = xr.DataArray([0, 1, 2, 4, 3, 4], [("x", [np.nan, 1, 1, np.nan, 2, np.nan])]) + expected6 = xr.DataArray([3, 3], [("x", [1, 2])]) + actual6 = array.groupby("x").sum() + assert_equal(expected6, actual6) + + +def test_groupby_grouping_errors() -> None: + dataset = xr.Dataset({"foo": ("x", [1, 1, 1])}, {"x": [1, 2, 3]}) + with pytest.raises( + ValueError, match=r"None of the data falls within bins with edges" + ): + dataset.groupby_bins("x", bins=[0.1, 0.2, 0.3]) + + with pytest.raises( + ValueError, match=r"None of the data falls within bins with edges" + ): + dataset.to_dataarray().groupby_bins("x", bins=[0.1, 0.2, 0.3]) + + with pytest.raises(ValueError, match=r"All bin edges are NaN."): + dataset.groupby_bins("x", bins=[np.nan, np.nan, np.nan]) + + with pytest.raises(ValueError, match=r"All bin edges are NaN."): + dataset.to_dataarray().groupby_bins("x", bins=[np.nan, np.nan, np.nan]) + + with pytest.raises(ValueError, match=r"Failed to group data."): + dataset.groupby(dataset.foo * np.nan) + + with pytest.raises(ValueError, match=r"Failed to group data."): + dataset.to_dataarray().groupby(dataset.foo * np.nan) + + +def test_groupby_reduce_dimension_error(array) -> None: + grouped = array.groupby("y") + # assert_identical(array, grouped.mean()) + + with pytest.raises(ValueError, match=r"cannot reduce over dimensions"): + grouped.mean("huh") + + with pytest.raises(ValueError, match=r"cannot reduce over dimensions"): + grouped.mean(("x", "y", "asd")) + + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + assert_identical(array.mean("x"), grouped.reduce(np.mean, "x")) + assert_allclose(array.mean(["x", "z"]), grouped.reduce(np.mean, ["x", "z"])) + + grouped = array.groupby("y", squeeze=False) + assert_identical(array, grouped.mean()) + + assert_identical(array.mean("x"), grouped.reduce(np.mean, "x")) + assert_allclose(array.mean(["x", "z"]), grouped.reduce(np.mean, ["x", "z"])) + + +def test_groupby_multiple_string_args(array) -> None: + with pytest.raises(TypeError): + array.groupby("x", "y") + + +def test_groupby_bins_timeseries() -> None: + ds = xr.Dataset() + ds["time"] = xr.DataArray( + pd.date_range("2010-08-01", "2010-08-15", freq="15min"), dims="time" + ) + ds["val"] = xr.DataArray(np.ones(ds["time"].shape), dims="time") + time_bins = pd.date_range(start="2010-08-01", end="2010-08-15", freq="24h") + actual = ds.groupby_bins("time", time_bins).sum() + expected = xr.DataArray( + 96 * np.ones((14,)), + dims=["time_bins"], + coords={"time_bins": pd.cut(time_bins, time_bins).categories}, + ).to_dataset(name="val") + assert_identical(actual, expected) + + +def test_groupby_none_group_name() -> None: + # GH158 + # xarray should not fail if a DataArray's name attribute is None + + data = np.arange(10) + 10 + da = xr.DataArray(data) # da.name = None + key = xr.DataArray(np.floor_divide(data, 2)) + + mean = da.groupby(key).mean() + assert "group" in mean.dims + + +def test_groupby_getitem(dataset) -> None: + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + assert_identical(dataset.sel(x="a"), dataset.groupby("x")["a"]) + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + assert_identical(dataset.sel(z=1), dataset.groupby("z")[1]) + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + assert_identical(dataset.foo.sel(x="a"), dataset.foo.groupby("x")["a"]) + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + assert_identical(dataset.foo.sel(z=1), dataset.foo.groupby("z")[1]) + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + assert_identical(dataset.cat.sel(y=1), dataset.cat.groupby("y")[1]) + + assert_identical(dataset.sel(x=["a"]), dataset.groupby("x", squeeze=False)["a"]) + assert_identical(dataset.sel(z=[1]), dataset.groupby("z", squeeze=False)[1]) + + assert_identical( + dataset.foo.sel(x=["a"]), dataset.foo.groupby("x", squeeze=False)["a"] + ) + assert_identical(dataset.foo.sel(z=[1]), dataset.foo.groupby("z", squeeze=False)[1]) + + assert_identical(dataset.cat.sel(y=[1]), dataset.cat.groupby("y", squeeze=False)[1]) + with pytest.raises( + NotImplementedError, match="Cannot broadcast 1d-only pandas categorical array." + ): + dataset.groupby("boo", squeeze=False) + dataset = dataset.drop_vars(["cat"]) + actual = ( + dataset.groupby("boo", squeeze=False)["f"].unstack().transpose("x", "y", "z") + ) + expected = dataset.sel(y=[1], z=[1, 2]).transpose("x", "y", "z") + assert_identical(expected, actual) + + +def test_groupby_dataset() -> None: + data = Dataset( + {"z": (["x", "y"], np.random.randn(3, 5))}, + {"x": ("x", list("abc")), "c": ("x", [0, 1, 0]), "y": range(5)}, + ) + groupby = data.groupby("x", squeeze=False) + assert len(groupby) == 3 + expected_groups = {"a": slice(0, 1), "b": slice(1, 2), "c": slice(2, 3)} + assert groupby.groups == expected_groups + expected_items = [ + ("a", data.isel(x=[0])), + ("b", data.isel(x=[1])), + ("c", data.isel(x=[2])), + ] + for actual1, expected1 in zip(groupby, expected_items): + assert actual1[0] == expected1[0] + assert_equal(actual1[1], expected1[1]) + + def identity(x): + return x + + for k in ["x", "c", "y"]: + actual2 = data.groupby(k, squeeze=False).map(identity) + assert_equal(data, actual2) + + +def test_groupby_dataset_squeeze_None() -> None: + """Delete when removing squeeze.""" + data = Dataset( + {"z": (["x", "y"], np.random.randn(3, 5))}, + {"x": ("x", list("abc")), "c": ("x", [0, 1, 0]), "y": range(5)}, + ) + groupby = data.groupby("x") + assert len(groupby) == 3 + expected_groups = {"a": 0, "b": 1, "c": 2} + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + assert groupby.groups == expected_groups + expected_items = [ + ("a", data.isel(x=0)), + ("b", data.isel(x=1)), + ("c", data.isel(x=2)), + ] + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + for actual1, expected1 in zip(groupby, expected_items): + assert actual1[0] == expected1[0] + assert_equal(actual1[1], expected1[1]) + + def identity(x): + return x + + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + for k in ["x", "c"]: + actual2 = data.groupby(k).map(identity) + assert_equal(data, actual2) + + +def test_groupby_dataset_returns_new_type() -> None: + data = Dataset({"z": (["x", "y"], np.random.randn(3, 5))}) + + actual1 = data.groupby("x", squeeze=False).map(lambda ds: ds["z"]) + expected1 = data["z"] + assert_identical(expected1, actual1) + + actual2 = data["z"].groupby("x", squeeze=False).map(lambda x: x.to_dataset()) + expected2 = data + assert_identical(expected2, actual2) + + +def test_groupby_dataset_iter() -> None: + data = create_test_data() + for n, (t, sub) in enumerate(list(data.groupby("dim1", squeeze=False))[:3]): + assert data["dim1"][n] == t + assert_equal(data["var1"][[n]], sub["var1"]) + assert_equal(data["var2"][[n]], sub["var2"]) + assert_equal(data["var3"][:, [n]], sub["var3"]) + + +def test_groupby_dataset_errors() -> None: + data = create_test_data() + with pytest.raises(TypeError, match=r"`group` must be"): + data.groupby(np.arange(10)) # type: ignore[arg-type,unused-ignore] + with pytest.raises(ValueError, match=r"length does not match"): + data.groupby(data["dim1"][:3]) + with pytest.raises(TypeError, match=r"`group` must be"): + data.groupby(data.coords["dim1"].to_index()) + + +def test_groupby_dataset_reduce() -> None: + data = Dataset( + { + "xy": (["x", "y"], np.random.randn(3, 4)), + "xonly": ("x", np.random.randn(3)), + "yonly": ("y", np.random.randn(4)), + "letters": ("y", ["a", "a", "b", "b"]), + } + ) + + expected = data.mean("y") + expected["yonly"] = expected["yonly"].variable.set_dims({"x": 3}) + actual = data.groupby("x").mean(...) + assert_allclose(expected, actual) + + actual = data.groupby("x").mean("y") + assert_allclose(expected, actual) + + letters = data["letters"] + expected = Dataset( + { + "xy": data["xy"].groupby(letters).mean(...), + "xonly": (data["xonly"].mean().variable.set_dims({"letters": 2})), + "yonly": data["yonly"].groupby(letters).mean(), + } + ) + actual = data.groupby("letters").mean(...) + assert_allclose(expected, actual) + + +@pytest.mark.parametrize("squeeze", [True, False]) +def test_groupby_dataset_math(squeeze: bool) -> None: + def reorder_dims(x): + return x.transpose("dim1", "dim2", "dim3", "time") + + ds = create_test_data() + ds["dim1"] = ds["dim1"] + grouped = ds.groupby("dim1", squeeze=squeeze) + + expected = reorder_dims(ds + ds.coords["dim1"]) + actual = grouped + ds.coords["dim1"] + assert_identical(expected, reorder_dims(actual)) + + actual = ds.coords["dim1"] + grouped + assert_identical(expected, reorder_dims(actual)) + + ds2 = 2 * ds + expected = reorder_dims(ds + ds2) + actual = grouped + ds2 + assert_identical(expected, reorder_dims(actual)) + + actual = ds2 + grouped + assert_identical(expected, reorder_dims(actual)) + + +def test_groupby_math_more() -> None: + ds = create_test_data() + grouped = ds.groupby("numbers") + zeros = DataArray([0, 0, 0, 0], [("numbers", range(4))]) + expected = (ds + Variable("dim3", np.zeros(10))).transpose( + "dim3", "dim1", "dim2", "time" + ) + actual = grouped + zeros + assert_equal(expected, actual) + + actual = zeros + grouped + assert_equal(expected, actual) + + with pytest.raises(ValueError, match=r"incompat.* grouped binary"): + grouped + ds + with pytest.raises(ValueError, match=r"incompat.* grouped binary"): + ds + grouped + with pytest.raises(TypeError, match=r"only support binary ops"): + grouped + 1 # type: ignore[operator] + with pytest.raises(TypeError, match=r"only support binary ops"): + grouped + grouped # type: ignore[operator] + with pytest.raises(TypeError, match=r"in-place operations"): + ds += grouped # type: ignore[arg-type] + + ds = Dataset( + { + "x": ("time", np.arange(100)), + "time": pd.date_range("2000-01-01", periods=100), + } + ) + with pytest.raises(ValueError, match=r"incompat.* grouped binary"): + ds + ds.groupby("time.month") + + +def test_groupby_math_bitshift() -> None: + # create new dataset of int's only + ds = Dataset( + { + "x": ("index", np.ones(4, dtype=int)), + "y": ("index", np.ones(4, dtype=int) * -1), + "level": ("index", [1, 1, 2, 2]), + "index": [0, 1, 2, 3], + } + ) + shift = DataArray([1, 2, 1], [("level", [1, 2, 8])]) + + left_expected = Dataset( + { + "x": ("index", [2, 2, 4, 4]), + "y": ("index", [-2, -2, -4, -4]), + "level": ("index", [2, 2, 8, 8]), + "index": [0, 1, 2, 3], + } + ) + + left_manual = [] + for lev, group in ds.groupby("level"): + shifter = shift.sel(level=lev) + left_manual.append(group << shifter) + left_actual = xr.concat(left_manual, dim="index").reset_coords(names="level") + assert_equal(left_expected, left_actual) + + left_actual = (ds.groupby("level") << shift).reset_coords(names="level") + assert_equal(left_expected, left_actual) + + right_expected = Dataset( + { + "x": ("index", [0, 0, 2, 2]), + "y": ("index", [-1, -1, -2, -2]), + "level": ("index", [0, 0, 4, 4]), + "index": [0, 1, 2, 3], + } + ) + right_manual = [] + for lev, group in left_expected.groupby("level"): + shifter = shift.sel(level=lev) + right_manual.append(group >> shifter) + right_actual = xr.concat(right_manual, dim="index").reset_coords(names="level") + assert_equal(right_expected, right_actual) + + right_actual = (left_expected.groupby("level") >> shift).reset_coords(names="level") + assert_equal(right_expected, right_actual) + + +@pytest.mark.parametrize("use_flox", [True, False]) +def test_groupby_bins_cut_kwargs(use_flox: bool) -> None: + da = xr.DataArray(np.arange(12).reshape(6, 2), dims=("x", "y")) + x_bins = (0, 2, 4, 6) + + with xr.set_options(use_flox=use_flox): + actual = da.groupby_bins( + "x", bins=x_bins, include_lowest=True, right=False, squeeze=False + ).mean() + expected = xr.DataArray( + np.array([[1.0, 2.0], [5.0, 6.0], [9.0, 10.0]]), + dims=("x_bins", "y"), + coords={ + "x_bins": ("x_bins", pd.IntervalIndex.from_breaks(x_bins, closed="left")) + }, + ) + assert_identical(expected, actual) + + +@pytest.mark.parametrize("indexed_coord", [True, False]) +def test_groupby_bins_math(indexed_coord) -> None: + N = 7 + da = DataArray(np.random.random((N, N)), dims=("x", "y")) + if indexed_coord: + da["x"] = np.arange(N) + da["y"] = np.arange(N) + g = da.groupby_bins("x", np.arange(0, N + 1, 3)) + mean = g.mean() + expected = da.isel(x=slice(1, None)) - mean.isel(x_bins=("x", [0, 0, 0, 1, 1, 1])) + actual = g - mean + assert_identical(expected, actual) + + +def test_groupby_math_nD_group() -> None: + N = 40 + da = DataArray( + np.random.random((N, N)), + dims=("x", "y"), + coords={ + "labels": ( + "x", + np.repeat(["a", "b", "c", "d", "e", "f", "g", "h"], repeats=N // 8), + ), + }, + ) + da["labels2d"] = xr.broadcast(da.labels, da)[0] + + g = da.groupby("labels2d") + mean = g.mean() + expected = da - mean.sel(labels2d=da.labels2d) + expected["labels"] = expected.labels.broadcast_like(expected.labels2d) + actual = g - mean + assert_identical(expected, actual) + + da["num"] = ( + "x", + np.repeat([1, 2, 3, 4, 5, 6, 7, 8], repeats=N // 8), + ) + da["num2d"] = xr.broadcast(da.num, da)[0] + g = da.groupby_bins("num2d", bins=[0, 4, 6]) + mean = g.mean() + idxr = np.digitize(da.num2d, bins=(0, 4, 6), right=True)[:30, :] - 1 + expanded_mean = mean.drop_vars("num2d_bins").isel(num2d_bins=(("x", "y"), idxr)) + expected = da.isel(x=slice(30)) - expanded_mean + expected["labels"] = expected.labels.broadcast_like(expected.labels2d) + expected["num"] = expected.num.broadcast_like(expected.num2d) + expected["num2d_bins"] = (("x", "y"), mean.num2d_bins.data[idxr]) + actual = g - mean + assert_identical(expected, actual) + + +def test_groupby_dataset_math_virtual() -> None: + ds = Dataset({"x": ("t", [1, 2, 3])}, {"t": pd.date_range("20100101", periods=3)}) + grouped = ds.groupby("t.day") + actual = grouped - grouped.mean(...) + expected = Dataset({"x": ("t", [0, 0, 0])}, ds[["t", "t.day"]]) + assert_identical(actual, expected) + + +def test_groupby_math_dim_order() -> None: + da = DataArray( + np.ones((10, 10, 12)), + dims=("x", "y", "time"), + coords={"time": pd.date_range("2001-01-01", periods=12, freq="6h")}, + ) + grouped = da.groupby("time.day") + result = grouped - grouped.mean() + assert result.dims == da.dims + + +def test_groupby_dataset_nan() -> None: + # nan should be excluded from groupby + ds = Dataset({"foo": ("x", [1, 2, 3, 4])}, {"bar": ("x", [1, 1, 2, np.nan])}) + actual = ds.groupby("bar").mean(...) + expected = Dataset({"foo": ("bar", [1.5, 3]), "bar": [1, 2]}) + assert_identical(actual, expected) + + +def test_groupby_dataset_order() -> None: + # groupby should preserve variables order + ds = Dataset() + for vn in ["a", "b", "c"]: + ds[vn] = DataArray(np.arange(10), dims=["t"]) + data_vars_ref = list(ds.data_vars.keys()) + ds = ds.groupby("t").mean(...) + data_vars = list(ds.data_vars.keys()) + assert data_vars == data_vars_ref + # coords are now at the end of the list, so the test below fails + # all_vars = list(ds.variables.keys()) + # all_vars_ref = list(ds.variables.keys()) + # .assertEqual(all_vars, all_vars_ref) + + +def test_groupby_dataset_fillna() -> None: + ds = Dataset({"a": ("x", [np.nan, 1, np.nan, 3])}, {"x": [0, 1, 2, 3]}) + expected = Dataset({"a": ("x", range(4))}, {"x": [0, 1, 2, 3]}) + for target in [ds, expected]: + target.coords["b"] = ("x", [0, 0, 1, 1]) + actual = ds.groupby("b").fillna(DataArray([0, 2], dims="b")) + assert_identical(expected, actual) + + actual = ds.groupby("b").fillna(Dataset({"a": ("b", [0, 2])})) + assert_identical(expected, actual) + + # attrs with groupby + ds.attrs["attr"] = "ds" + ds.a.attrs["attr"] = "da" + actual = ds.groupby("b").fillna(Dataset({"a": ("b", [0, 2])})) + assert actual.attrs == ds.attrs + assert actual.a.name == "a" + assert actual.a.attrs == ds.a.attrs + + +def test_groupby_dataset_where() -> None: + # groupby + ds = Dataset({"a": ("x", range(5))}, {"c": ("x", [0, 0, 1, 1, 1])}) + cond = Dataset({"a": ("c", [True, False])}) + expected = ds.copy(deep=True) + expected["a"].values = np.array([0, 1] + [np.nan] * 3) + actual = ds.groupby("c").where(cond) + assert_identical(expected, actual) + + # attrs with groupby + ds.attrs["attr"] = "ds" + ds.a.attrs["attr"] = "da" + actual = ds.groupby("c").where(cond) + assert actual.attrs == ds.attrs + assert actual.a.name == "a" + assert actual.a.attrs == ds.a.attrs + + +def test_groupby_dataset_assign() -> None: + ds = Dataset({"a": ("x", range(3))}, {"b": ("x", ["A"] * 2 + ["B"])}) + actual = ds.groupby("b").assign(c=lambda ds: 2 * ds.a) + expected = ds.merge({"c": ("x", [0, 2, 4])}) + assert_identical(actual, expected) + + actual = ds.groupby("b").assign(c=lambda ds: ds.a.sum()) + expected = ds.merge({"c": ("x", [1, 1, 2])}) + assert_identical(actual, expected) + + actual = ds.groupby("b").assign_coords(c=lambda ds: ds.a.sum()) + expected = expected.set_coords("c") + assert_identical(actual, expected) + + +def test_groupby_dataset_map_dataarray_func() -> None: + # regression GH6379 + ds = Dataset({"foo": ("x", [1, 2, 3, 4])}, coords={"x": [0, 0, 1, 1]}) + actual = ds.groupby("x").map(lambda grp: grp.foo.mean()) + expected = DataArray([1.5, 3.5], coords={"x": [0, 1]}, dims="x", name="foo") + assert_identical(actual, expected) + + +def test_groupby_dataarray_map_dataset_func() -> None: + # regression GH6379 + da = DataArray([1, 2, 3, 4], coords={"x": [0, 0, 1, 1]}, dims="x", name="foo") + actual = da.groupby("x").map(lambda grp: grp.mean().to_dataset()) + expected = xr.Dataset({"foo": ("x", [1.5, 3.5])}, coords={"x": [0, 1]}) + assert_identical(actual, expected) + + +@requires_flox +@pytest.mark.parametrize("kwargs", [{"method": "map-reduce"}, {"engine": "numpy"}]) +def test_groupby_flox_kwargs(kwargs) -> None: + ds = Dataset({"a": ("x", range(5))}, {"c": ("x", [0, 0, 1, 1, 1])}) + with xr.set_options(use_flox=False): + expected = ds.groupby("c").mean() + with xr.set_options(use_flox=True): + actual = ds.groupby("c").mean(**kwargs) + assert_identical(expected, actual) + + +class TestDataArrayGroupBy: + @pytest.fixture(autouse=True) + def setup(self) -> None: + self.attrs = {"attr1": "value1", "attr2": 2929} + self.x = np.random.random((10, 20)) + self.v = Variable(["x", "y"], self.x) + self.va = Variable(["x", "y"], self.x, self.attrs) + self.ds = Dataset({"foo": self.v}) + self.dv = self.ds["foo"] + + self.mindex = pd.MultiIndex.from_product( + [["a", "b"], [1, 2]], names=("level_1", "level_2") + ) + self.mda = DataArray([0, 1, 2, 3], coords={"x": self.mindex}, dims="x") + + self.da = self.dv.copy() + self.da.coords["abc"] = ("y", np.array(["a"] * 9 + ["c"] + ["b"] * 10)) + self.da.coords["y"] = 20 + 100 * self.da["y"] + + def test_stack_groupby_unsorted_coord(self) -> None: + data = [[0, 1], [2, 3]] + data_flat = [0, 1, 2, 3] + dims = ["x", "y"] + y_vals = [2, 3] + + arr = xr.DataArray(data, dims=dims, coords={"y": y_vals}) + actual1 = arr.stack(z=dims).groupby("z").first() + midx1 = pd.MultiIndex.from_product([[0, 1], [2, 3]], names=dims) + expected1 = xr.DataArray(data_flat, dims=["z"], coords={"z": midx1}) + assert_equal(actual1, expected1) + + # GH: 3287. Note that y coord values are not in sorted order. + arr = xr.DataArray(data, dims=dims, coords={"y": y_vals[::-1]}) + actual2 = arr.stack(z=dims).groupby("z").first() + midx2 = pd.MultiIndex.from_product([[0, 1], [3, 2]], names=dims) + expected2 = xr.DataArray(data_flat, dims=["z"], coords={"z": midx2}) + assert_equal(actual2, expected2) + + def test_groupby_iter(self) -> None: + for (act_x, act_dv), (exp_x, exp_ds) in zip( + self.dv.groupby("y", squeeze=False), self.ds.groupby("y", squeeze=False) + ): + assert exp_x == act_x + assert_identical(exp_ds["foo"], act_dv) + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + for (_, exp_dv), (_, act_dv) in zip( + self.dv.groupby("x"), self.dv.groupby("x") + ): + assert_identical(exp_dv, act_dv) + + def test_groupby_properties(self) -> None: + grouped = self.da.groupby("abc") + expected_groups = {"a": range(0, 9), "c": [9], "b": range(10, 20)} + assert expected_groups.keys() == grouped.groups.keys() + for key in expected_groups: + expected_group = expected_groups[key] + actual_group = grouped.groups[key] + + # TODO: array_api doesn't allow slice: + assert not isinstance(expected_group, slice) + assert not isinstance(actual_group, slice) + + np.testing.assert_array_equal(expected_group, actual_group) + assert 3 == len(grouped) + + @pytest.mark.parametrize( + "by, use_da", [("x", False), ("y", False), ("y", True), ("abc", False)] + ) + @pytest.mark.parametrize("shortcut", [True, False]) + @pytest.mark.parametrize("squeeze", [None, True, False]) + def test_groupby_map_identity(self, by, use_da, shortcut, squeeze, recwarn) -> None: + expected = self.da + if use_da: + by = expected.coords[by] + + def identity(x): + return x + + grouped = expected.groupby(by, squeeze=squeeze) + actual = grouped.map(identity, shortcut=shortcut) + assert_identical(expected, actual) + + # abc is not a dim coordinate so no warnings expected! + if (by.name if use_da else by) != "abc": + assert len(recwarn) == (1 if squeeze in [None, True] else 0) + + def test_groupby_sum(self) -> None: + array = self.da + grouped = array.groupby("abc") + + expected_sum_all = Dataset( + { + "foo": Variable( + ["abc"], + np.array( + [ + self.x[:, :9].sum(), + self.x[:, 10:].sum(), + self.x[:, 9:10].sum(), + ] + ).T, + ), + "abc": Variable(["abc"], np.array(["a", "b", "c"])), + } + )["foo"] + assert_allclose(expected_sum_all, grouped.reduce(np.sum, dim=...)) + assert_allclose(expected_sum_all, grouped.sum(...)) + + expected = DataArray( + [ + array["y"].values[idx].sum() + for idx in [slice(9), slice(10, None), slice(9, 10)] + ], + [["a", "b", "c"]], + ["abc"], + ) + actual = array["y"].groupby("abc").map(np.sum) + assert_allclose(expected, actual) + actual = array["y"].groupby("abc").sum(...) + assert_allclose(expected, actual) + + expected_sum_axis1 = Dataset( + { + "foo": ( + ["x", "abc"], + np.array( + [ + self.x[:, :9].sum(1), + self.x[:, 10:].sum(1), + self.x[:, 9:10].sum(1), + ] + ).T, + ), + "abc": Variable(["abc"], np.array(["a", "b", "c"])), + } + )["foo"] + assert_allclose(expected_sum_axis1, grouped.reduce(np.sum, "y")) + assert_allclose(expected_sum_axis1, grouped.sum("y")) + + @pytest.mark.parametrize("method", ["sum", "mean", "median"]) + def test_groupby_reductions(self, method) -> None: + array = self.da + grouped = array.groupby("abc") + + reduction = getattr(np, method) + expected = Dataset( + { + "foo": Variable( + ["x", "abc"], + np.array( + [ + reduction(self.x[:, :9], axis=-1), + reduction(self.x[:, 10:], axis=-1), + reduction(self.x[:, 9:10], axis=-1), + ] + ).T, + ), + "abc": Variable(["abc"], np.array(["a", "b", "c"])), + } + )["foo"] + + with xr.set_options(use_flox=False): + actual_legacy = getattr(grouped, method)(dim="y") + + with xr.set_options(use_flox=True): + actual_npg = getattr(grouped, method)(dim="y") + + assert_allclose(expected, actual_legacy) + assert_allclose(expected, actual_npg) + + def test_groupby_count(self) -> None: + array = DataArray( + [0, 0, np.nan, np.nan, 0, 0], + coords={"cat": ("x", ["a", "b", "b", "c", "c", "c"])}, + dims="x", + ) + actual = array.groupby("cat").count() + expected = DataArray([1, 1, 2], coords=[("cat", ["a", "b", "c"])]) + assert_identical(actual, expected) + + @pytest.mark.parametrize("shortcut", [True, False]) + @pytest.mark.parametrize("keep_attrs", [None, True, False]) + def test_groupby_reduce_keep_attrs( + self, shortcut: bool, keep_attrs: bool | None + ) -> None: + array = self.da + array.attrs["foo"] = "bar" + + actual = array.groupby("abc").reduce( + np.mean, keep_attrs=keep_attrs, shortcut=shortcut + ) + with xr.set_options(use_flox=False): + expected = array.groupby("abc").mean(keep_attrs=keep_attrs) + assert_identical(expected, actual) + + @pytest.mark.parametrize("keep_attrs", [None, True, False]) + def test_groupby_keep_attrs(self, keep_attrs: bool | None) -> None: + array = self.da + array.attrs["foo"] = "bar" + + with xr.set_options(use_flox=False): + expected = array.groupby("abc").mean(keep_attrs=keep_attrs) + with xr.set_options(use_flox=True): + actual = array.groupby("abc").mean(keep_attrs=keep_attrs) + + # values are tested elsewhere, here we just check data + # TODO: add check_attrs kwarg to assert_allclose + actual.data = expected.data + assert_identical(expected, actual) + + def test_groupby_map_center(self) -> None: + def center(x): + return x - np.mean(x) + + array = self.da + grouped = array.groupby("abc") + + expected_ds = array.to_dataset() + exp_data = np.hstack( + [center(self.x[:, :9]), center(self.x[:, 9:10]), center(self.x[:, 10:])] + ) + expected_ds["foo"] = (["x", "y"], exp_data) + expected_centered = expected_ds["foo"] + assert_allclose(expected_centered, grouped.map(center)) + + def test_groupby_map_ndarray(self) -> None: + # regression test for #326 + array = self.da + grouped = array.groupby("abc") + actual = grouped.map(np.asarray) # type: ignore[arg-type] # TODO: Not sure using np.asarray like this makes sense with array api + assert_equal(array, actual) + + def test_groupby_map_changes_metadata(self) -> None: + def change_metadata(x): + x.coords["x"] = x.coords["x"] * 2 + x.attrs["fruit"] = "lemon" + return x + + array = self.da + grouped = array.groupby("abc") + actual = grouped.map(change_metadata) + expected = array.copy() + expected = change_metadata(expected) + assert_equal(expected, actual) + + @pytest.mark.parametrize("squeeze", [True, False]) + def test_groupby_math_squeeze(self, squeeze: bool) -> None: + array = self.da + grouped = array.groupby("x", squeeze=squeeze) + + expected = array + array.coords["x"] + actual = grouped + array.coords["x"] + assert_identical(expected, actual) + + actual = array.coords["x"] + grouped + assert_identical(expected, actual) + + ds = array.coords["x"].to_dataset(name="X") + expected = array + ds + actual = grouped + ds + assert_identical(expected, actual) + + actual = ds + grouped + assert_identical(expected, actual) + + def test_groupby_math(self) -> None: + array = self.da + grouped = array.groupby("abc") + expected_agg = (grouped.mean(...) - np.arange(3)).rename(None) + actual = grouped - DataArray(range(3), [("abc", ["a", "b", "c"])]) + actual_agg = actual.groupby("abc").mean(...) + assert_allclose(expected_agg, actual_agg) + + with pytest.raises(TypeError, match=r"only support binary ops"): + grouped + 1 # type: ignore[type-var] + with pytest.raises(TypeError, match=r"only support binary ops"): + grouped + grouped # type: ignore[type-var] + with pytest.raises(TypeError, match=r"in-place operations"): + array += grouped # type: ignore[arg-type] + + def test_groupby_math_not_aligned(self) -> None: + array = DataArray( + range(4), {"b": ("x", [0, 0, 1, 1]), "x": [0, 1, 2, 3]}, dims="x" + ) + other = DataArray([10], coords={"b": [0]}, dims="b") + actual = array.groupby("b") + other + expected = DataArray([10, 11, np.nan, np.nan], array.coords) + assert_identical(expected, actual) + + # regression test for #7797 + other = array.groupby("b").sum() + actual = array.sel(x=[0, 1]).groupby("b") - other + expected = DataArray([-1, 0], {"b": ("x", [0, 0]), "x": [0, 1]}, dims="x") + assert_identical(expected, actual) + + other = DataArray([10], coords={"c": 123, "b": [0]}, dims="b") + actual = array.groupby("b") + other + expected = DataArray([10, 11, np.nan, np.nan], array.coords) + expected.coords["c"] = (["x"], [123] * 2 + [np.nan] * 2) + assert_identical(expected, actual) + + other_ds = Dataset({"a": ("b", [10])}, {"b": [0]}) + actual_ds = array.groupby("b") + other_ds + expected_ds = Dataset({"a": ("x", [10, 11, np.nan, np.nan])}, array.coords) + assert_identical(expected_ds, actual_ds) + + def test_groupby_restore_dim_order(self) -> None: + array = DataArray( + np.random.randn(5, 3), + coords={"a": ("x", range(5)), "b": ("y", range(3))}, + dims=["x", "y"], + ) + for by, expected_dims in [ + ("x", ("x", "y")), + ("y", ("x", "y")), + ("a", ("a", "y")), + ("b", ("x", "b")), + ]: + result = array.groupby(by, squeeze=False).map(lambda x: x.squeeze()) + assert result.dims == expected_dims + + def test_groupby_restore_coord_dims(self) -> None: + array = DataArray( + np.random.randn(5, 3), + coords={ + "a": ("x", range(5)), + "b": ("y", range(3)), + "c": (("x", "y"), np.random.randn(5, 3)), + }, + dims=["x", "y"], + ) + + for by, expected_dims in [ + ("x", ("x", "y")), + ("y", ("x", "y")), + ("a", ("a", "y")), + ("b", ("x", "b")), + ]: + result = array.groupby(by, squeeze=False, restore_coord_dims=True).map( + lambda x: x.squeeze() + )["c"] + assert result.dims == expected_dims + + def test_groupby_first_and_last(self) -> None: + array = DataArray([1, 2, 3, 4, 5], dims="x") + by = DataArray(["a"] * 2 + ["b"] * 3, dims="x", name="ab") + + expected = DataArray([1, 3], [("ab", ["a", "b"])]) + actual = array.groupby(by).first() + assert_identical(expected, actual) + + expected = DataArray([2, 5], [("ab", ["a", "b"])]) + actual = array.groupby(by).last() + assert_identical(expected, actual) + + array = DataArray(np.random.randn(5, 3), dims=["x", "y"]) + expected = DataArray(array[[0, 2]], {"ab": ["a", "b"]}, ["ab", "y"]) + actual = array.groupby(by).first() + assert_identical(expected, actual) + + actual = array.groupby("x").first() + expected = array # should be a no-op + assert_identical(expected, actual) + + def make_groupby_multidim_example_array(self) -> DataArray: + return DataArray( + [[[0, 1], [2, 3]], [[5, 10], [15, 20]]], + coords={ + "lon": (["ny", "nx"], [[30, 40], [40, 50]]), + "lat": (["ny", "nx"], [[10, 10], [20, 20]]), + }, + dims=["time", "ny", "nx"], + ) + + def test_groupby_multidim(self) -> None: + array = self.make_groupby_multidim_example_array() + for dim, expected_sum in [ + ("lon", DataArray([5, 28, 23], coords=[("lon", [30.0, 40.0, 50.0])])), + ("lat", DataArray([16, 40], coords=[("lat", [10.0, 20.0])])), + ]: + actual_sum = array.groupby(dim).sum(...) + assert_identical(expected_sum, actual_sum) + + def test_groupby_multidim_map(self) -> None: + array = self.make_groupby_multidim_example_array() + actual = array.groupby("lon").map(lambda x: x - x.mean()) + expected = DataArray( + [[[-2.5, -6.0], [-5.0, -8.5]], [[2.5, 3.0], [8.0, 8.5]]], + coords=array.coords, + dims=array.dims, + ) + assert_identical(expected, actual) + + @pytest.mark.parametrize("use_flox", [True, False]) + @pytest.mark.parametrize("coords", [np.arange(4), np.arange(4)[::-1], [2, 0, 3, 1]]) + @pytest.mark.parametrize( + "cut_kwargs", + ( + {"labels": None, "include_lowest": True}, + {"labels": None, "include_lowest": False}, + {"labels": ["a", "b"]}, + {"labels": [1.2, 3.5]}, + {"labels": ["b", "a"]}, + ), + ) + def test_groupby_bins( + self, + coords: np.typing.ArrayLike, + use_flox: bool, + cut_kwargs: dict, + ) -> None: + array = DataArray( + np.arange(4), dims="dim_0", coords={"dim_0": coords}, name="a" + ) + # the first value should not be part of any group ("right" binning) + array[0] = 99 + # bins follow conventions for pandas.cut + # http://pandas.pydata.org/pandas-docs/stable/generated/pandas.cut.html + bins = [0, 1.5, 5] + + df = array.to_dataframe() + df["dim_0_bins"] = pd.cut(array["dim_0"], bins, **cut_kwargs) + + expected_df = df.groupby("dim_0_bins", observed=True).sum() + # TODO: can't convert df with IntervalIndex to Xarray + expected = ( + expected_df.reset_index(drop=True) + .to_xarray() + .assign_coords(index=np.array(expected_df.index)) + .rename({"index": "dim_0_bins"})["a"] + ) + + with xr.set_options(use_flox=use_flox): + actual = array.groupby_bins("dim_0", bins=bins, **cut_kwargs).sum() + assert_identical(expected, actual) + + actual = array.groupby_bins("dim_0", bins=bins, **cut_kwargs).map( + lambda x: x.sum() + ) + assert_identical(expected, actual) + + # make sure original array dims are unchanged + assert len(array.dim_0) == 4 + + def test_groupby_bins_ellipsis(self) -> None: + da = xr.DataArray(np.ones((2, 3, 4))) + bins = [-1, 0, 1, 2] + with xr.set_options(use_flox=False): + actual = da.groupby_bins("dim_0", bins).mean(...) + with xr.set_options(use_flox=True): + expected = da.groupby_bins("dim_0", bins).mean(...) + assert_allclose(actual, expected) + + @pytest.mark.parametrize("use_flox", [True, False]) + def test_groupby_bins_gives_correct_subset(self, use_flox: bool) -> None: + # GH7766 + rng = np.random.default_rng(42) + coords = rng.normal(5, 5, 1000) + bins = np.logspace(-4, 1, 10) + labels = [ + "one", + "two", + "three", + "four", + "five", + "six", + "seven", + "eight", + "nine", + ] + # xArray + # Make a mock dataarray + darr = xr.DataArray(coords, coords=[coords], dims=["coords"]) + expected = xr.DataArray( + [np.nan, np.nan, 1, 1, 1, 8, 31, 104, 542], + dims="coords_bins", + coords={"coords_bins": labels}, + ) + gb = darr.groupby_bins("coords", bins, labels=labels) + with xr.set_options(use_flox=use_flox): + actual = gb.count() + assert_identical(actual, expected) + + def test_groupby_bins_empty(self) -> None: + array = DataArray(np.arange(4), [("x", range(4))]) + # one of these bins will be empty + bins = [0, 4, 5] + bin_coords = pd.cut(array["x"], bins).categories + actual = array.groupby_bins("x", bins).sum() + expected = DataArray([6, np.nan], dims="x_bins", coords={"x_bins": bin_coords}) + assert_identical(expected, actual) + # make sure original array is unchanged + # (was a problem in earlier versions) + assert len(array.x) == 4 + + def test_groupby_bins_multidim(self) -> None: + array = self.make_groupby_multidim_example_array() + bins = [0, 15, 20] + bin_coords = pd.cut(array["lat"].values.flat, bins).categories + expected = DataArray([16, 40], dims="lat_bins", coords={"lat_bins": bin_coords}) + actual = array.groupby_bins("lat", bins).map(lambda x: x.sum()) + assert_identical(expected, actual) + # modify the array coordinates to be non-monotonic after unstacking + array["lat"].data = np.array([[10.0, 20.0], [20.0, 10.0]]) + expected = DataArray([28, 28], dims="lat_bins", coords={"lat_bins": bin_coords}) + actual = array.groupby_bins("lat", bins).map(lambda x: x.sum()) + assert_identical(expected, actual) + + bins = [-2, -1, 0, 1, 2] + field = DataArray(np.ones((5, 3)), dims=("x", "y")) + by = DataArray( + np.array([[-1.5, -1.5, 0.5, 1.5, 1.5] * 3]).reshape(5, 3), dims=("x", "y") + ) + actual = field.groupby_bins(by, bins=bins).count() + + bincoord = np.array( + [ + pd.Interval(left, right, closed="right") + for left, right in zip(bins[:-1], bins[1:]) + ], + dtype=object, + ) + expected = DataArray( + np.array([6, np.nan, 3, 6]), + dims="group_bins", + coords={"group_bins": bincoord}, + ) + assert_identical(actual, expected) + + def test_groupby_bins_sort(self) -> None: + data = xr.DataArray( + np.arange(100), dims="x", coords={"x": np.linspace(-100, 100, num=100)} + ) + binned_mean = data.groupby_bins("x", bins=11).mean() + assert binned_mean.to_index().is_monotonic_increasing + + with xr.set_options(use_flox=True): + actual = data.groupby_bins("x", bins=11).count() + with xr.set_options(use_flox=False): + expected = data.groupby_bins("x", bins=11).count() + assert_identical(actual, expected) + + def test_groupby_assign_coords(self) -> None: + array = DataArray([1, 2, 3, 4], {"c": ("x", [0, 0, 1, 1])}, dims="x") + actual = array.groupby("c").assign_coords(d=lambda a: a.mean()) + expected = array.copy() + expected.coords["d"] = ("x", [1.5, 1.5, 3.5, 3.5]) + assert_identical(actual, expected) + + def test_groupby_fillna(self) -> None: + a = DataArray([np.nan, 1, np.nan, 3], coords={"x": range(4)}, dims="x") + fill_value = DataArray([0, 1], dims="y") + actual = a.fillna(fill_value) + expected = DataArray( + [[0, 1], [1, 1], [0, 1], [3, 3]], coords={"x": range(4)}, dims=("x", "y") + ) + assert_identical(expected, actual) + + b = DataArray(range(4), coords={"x": range(4)}, dims="x") + expected = b.copy() + for target in [a, expected]: + target.coords["b"] = ("x", [0, 0, 1, 1]) + actual = a.groupby("b").fillna(DataArray([0, 2], dims="b")) + assert_identical(expected, actual) + + +class TestDataArrayResample: + @pytest.mark.parametrize("use_cftime", [True, False]) + def test_resample(self, use_cftime: bool) -> None: + if use_cftime and not has_cftime: + pytest.skip() + times = xr.date_range( + "2000-01-01", freq="6h", periods=10, use_cftime=use_cftime + ) + + def resample_as_pandas(array, *args, **kwargs): + array_ = array.copy(deep=True) + if use_cftime: + array_["time"] = times.to_datetimeindex() + result = DataArray.from_series( + array_.to_series().resample(*args, **kwargs).mean() + ) + if use_cftime: + result = result.convert_calendar( + calendar="standard", use_cftime=use_cftime + ) + return result + + array = DataArray(np.arange(10), [("time", times)]) + + actual = array.resample(time="24h").mean() + expected = resample_as_pandas(array, "24h") + assert_identical(expected, actual) + + actual = array.resample(time="24h").reduce(np.mean) + assert_identical(expected, actual) + + actual = array.resample(time="24h", closed="right").mean() + expected = resample_as_pandas(array, "24h", closed="right") + assert_identical(expected, actual) + + with pytest.raises(ValueError, match=r"index must be monotonic"): + array[[2, 0, 1]].resample(time="1D") + + @pytest.mark.parametrize("use_cftime", [True, False]) + def test_resample_doctest(self, use_cftime: bool) -> None: + # run the doctest example here so we are not surprised + if use_cftime and not has_cftime: + pytest.skip() + + da = xr.DataArray( + np.array([1, 2, 3, 1, 2, np.nan]), + dims="time", + coords=dict( + time=( + "time", + xr.date_range( + "2001-01-01", freq="ME", periods=6, use_cftime=use_cftime + ), + ), + labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ), + ) + actual = da.resample(time="3ME").count() + expected = DataArray( + [1, 3, 1], + dims="time", + coords={ + "time": xr.date_range( + "2001-01-01", freq="3ME", periods=3, use_cftime=use_cftime + ) + }, + ) + assert_identical(actual, expected) + + def test_da_resample_func_args(self) -> None: + def func(arg1, arg2, arg3=0.0): + return arg1.mean("time") + arg2 + arg3 + + times = pd.date_range("2000", periods=3, freq="D") + da = xr.DataArray([1.0, 1.0, 1.0], coords=[times], dims=["time"]) + expected = xr.DataArray([3.0, 3.0, 3.0], coords=[times], dims=["time"]) + actual = da.resample(time="D").map(func, args=(1.0,), arg3=1.0) + assert_identical(actual, expected) + + def test_resample_first(self) -> None: + times = pd.date_range("2000-01-01", freq="6h", periods=10) + array = DataArray(np.arange(10), [("time", times)]) + + # resample to same frequency + actual = array.resample(time="6h").first() + assert_identical(array, actual) + + actual = array.resample(time="1D").first() + expected = DataArray([0, 4, 8], [("time", times[::4])]) + assert_identical(expected, actual) + + # verify that labels don't use the first value + actual = array.resample(time="24h").first() + expected = DataArray(array.to_series().resample("24h").first()) + assert_identical(expected, actual) + + # missing values + array = array.astype(float) + array[:2] = np.nan + actual = array.resample(time="1D").first() + expected = DataArray([2, 4, 8], [("time", times[::4])]) + assert_identical(expected, actual) + + actual = array.resample(time="1D").first(skipna=False) + expected = DataArray([np.nan, 4, 8], [("time", times[::4])]) + assert_identical(expected, actual) + + # regression test for http://stackoverflow.com/questions/33158558/ + array = Dataset({"time": times})["time"] + actual = array.resample(time="1D").last() + expected_times = pd.to_datetime( + ["2000-01-01T18", "2000-01-02T18", "2000-01-03T06"], unit="ns" + ) + expected = DataArray(expected_times, [("time", times[::4])], name="time") + assert_identical(expected, actual) + + def test_resample_bad_resample_dim(self) -> None: + times = pd.date_range("2000-01-01", freq="6h", periods=10) + array = DataArray(np.arange(10), [("__resample_dim__", times)]) + with pytest.raises(ValueError, match=r"Proxy resampling dimension"): + array.resample(**{"__resample_dim__": "1D"}).first() # type: ignore[arg-type] + + @requires_scipy + def test_resample_drop_nondim_coords(self) -> None: + xs = np.arange(6) + ys = np.arange(3) + times = pd.date_range("2000-01-01", freq="6h", periods=5) + data = np.tile(np.arange(5), (6, 3, 1)) + xx, yy = np.meshgrid(xs * 5, ys * 2.5) + tt = np.arange(len(times), dtype=int) + array = DataArray(data, {"time": times, "x": xs, "y": ys}, ("x", "y", "time")) + xcoord = DataArray(xx.T, {"x": xs, "y": ys}, ("x", "y")) + ycoord = DataArray(yy.T, {"x": xs, "y": ys}, ("x", "y")) + tcoord = DataArray(tt, {"time": times}, ("time",)) + ds = Dataset({"data": array, "xc": xcoord, "yc": ycoord, "tc": tcoord}) + ds = ds.set_coords(["xc", "yc", "tc"]) + + # Select the data now, with the auxiliary coordinates in place + array = ds["data"] + + # Re-sample + actual = array.resample(time="12h", restore_coord_dims=True).mean("time") + assert "tc" not in actual.coords + + # Up-sample - filling + actual = array.resample(time="1h", restore_coord_dims=True).ffill() + assert "tc" not in actual.coords + + # Up-sample - interpolation + actual = array.resample(time="1h", restore_coord_dims=True).interpolate( + "linear" + ) + assert "tc" not in actual.coords + + def test_resample_keep_attrs(self) -> None: + times = pd.date_range("2000-01-01", freq="6h", periods=10) + array = DataArray(np.ones(10), [("time", times)]) + array.attrs["meta"] = "data" + + result = array.resample(time="1D").mean(keep_attrs=True) + expected = DataArray([1, 1, 1], [("time", times[::4])], attrs=array.attrs) + assert_identical(result, expected) + + def test_resample_skipna(self) -> None: + times = pd.date_range("2000-01-01", freq="6h", periods=10) + array = DataArray(np.ones(10), [("time", times)]) + array[1] = np.nan + + result = array.resample(time="1D").mean(skipna=False) + expected = DataArray([np.nan, 1, 1], [("time", times[::4])]) + assert_identical(result, expected) + + def test_upsample(self) -> None: + times = pd.date_range("2000-01-01", freq="6h", periods=5) + array = DataArray(np.arange(5), [("time", times)]) + + # Forward-fill + actual = array.resample(time="3h").ffill() + expected = DataArray(array.to_series().resample("3h").ffill()) + assert_identical(expected, actual) + + # Backward-fill + actual = array.resample(time="3h").bfill() + expected = DataArray(array.to_series().resample("3h").bfill()) + assert_identical(expected, actual) + + # As frequency + actual = array.resample(time="3h").asfreq() + expected = DataArray(array.to_series().resample("3h").asfreq()) + assert_identical(expected, actual) + + # Pad + actual = array.resample(time="3h").pad() + expected = DataArray(array.to_series().resample("3h").ffill()) + assert_identical(expected, actual) + + # Nearest + rs = array.resample(time="3h") + actual = rs.nearest() + new_times = rs.groupers[0].full_index + expected = DataArray(array.reindex(time=new_times, method="nearest")) + assert_identical(expected, actual) + + def test_upsample_nd(self) -> None: + # Same as before, but now we try on multi-dimensional DataArrays. + xs = np.arange(6) + ys = np.arange(3) + times = pd.date_range("2000-01-01", freq="6h", periods=5) + data = np.tile(np.arange(5), (6, 3, 1)) + array = DataArray(data, {"time": times, "x": xs, "y": ys}, ("x", "y", "time")) + + # Forward-fill + actual = array.resample(time="3h").ffill() + expected_data = np.repeat(data, 2, axis=-1) + expected_times = times.to_series().resample("3h").asfreq().index + expected_data = expected_data[..., : len(expected_times)] + expected = DataArray( + expected_data, + {"time": expected_times, "x": xs, "y": ys}, + ("x", "y", "time"), + ) + assert_identical(expected, actual) + + # Backward-fill + actual = array.resample(time="3h").ffill() + expected_data = np.repeat(np.flipud(data.T).T, 2, axis=-1) + expected_data = np.flipud(expected_data.T).T + expected_times = times.to_series().resample("3h").asfreq().index + expected_data = expected_data[..., : len(expected_times)] + expected = DataArray( + expected_data, + {"time": expected_times, "x": xs, "y": ys}, + ("x", "y", "time"), + ) + assert_identical(expected, actual) + + # As frequency + actual = array.resample(time="3h").asfreq() + expected_data = np.repeat(data, 2, axis=-1).astype(float)[..., :-1] + expected_data[..., 1::2] = np.nan + expected_times = times.to_series().resample("3h").asfreq().index + expected = DataArray( + expected_data, + {"time": expected_times, "x": xs, "y": ys}, + ("x", "y", "time"), + ) + assert_identical(expected, actual) + + # Pad + actual = array.resample(time="3h").pad() + expected_data = np.repeat(data, 2, axis=-1) + expected_data[..., 1::2] = expected_data[..., ::2] + expected_data = expected_data[..., :-1] + expected_times = times.to_series().resample("3h").asfreq().index + expected = DataArray( + expected_data, + {"time": expected_times, "x": xs, "y": ys}, + ("x", "y", "time"), + ) + assert_identical(expected, actual) + + def test_upsample_tolerance(self) -> None: + # Test tolerance keyword for upsample methods bfill, pad, nearest + times = pd.date_range("2000-01-01", freq="1D", periods=2) + times_upsampled = pd.date_range("2000-01-01", freq="6h", periods=5) + array = DataArray(np.arange(2), [("time", times)]) + + # Forward fill + actual = array.resample(time="6h").ffill(tolerance="12h") # type: ignore[arg-type] # TODO: tolerance also allows strings, same issue in .reindex. + expected = DataArray([0.0, 0.0, 0.0, np.nan, 1.0], [("time", times_upsampled)]) + assert_identical(expected, actual) + + # Backward fill + actual = array.resample(time="6h").bfill(tolerance="12h") # type: ignore[arg-type] # TODO: tolerance also allows strings, same issue in .reindex. + expected = DataArray([0.0, np.nan, 1.0, 1.0, 1.0], [("time", times_upsampled)]) + assert_identical(expected, actual) + + # Nearest + actual = array.resample(time="6h").nearest(tolerance="6h") # type: ignore[arg-type] # TODO: tolerance also allows strings, same issue in .reindex. + expected = DataArray([0, 0, np.nan, 1, 1], [("time", times_upsampled)]) + assert_identical(expected, actual) + + @requires_scipy + def test_upsample_interpolate(self) -> None: + from scipy.interpolate import interp1d + + xs = np.arange(6) + ys = np.arange(3) + times = pd.date_range("2000-01-01", freq="6h", periods=5) + + z = np.arange(5) ** 2 + data = np.tile(z, (6, 3, 1)) + array = DataArray(data, {"time": times, "x": xs, "y": ys}, ("x", "y", "time")) + + expected_times = times.to_series().resample("1h").asfreq().index + # Split the times into equal sub-intervals to simulate the 6 hour + # to 1 hour up-sampling + new_times_idx = np.linspace(0, len(times) - 1, len(times) * 5) + kinds: list[InterpOptions] = [ + "linear", + "nearest", + "zero", + "slinear", + "quadratic", + "cubic", + "polynomial", + ] + for kind in kinds: + kwargs = {} + if kind == "polynomial": + kwargs["order"] = 1 + actual = array.resample(time="1h").interpolate(kind, **kwargs) + # using interp1d, polynomial order is to set directly in kind using int + f = interp1d( + np.arange(len(times)), + data, + kind=kwargs["order"] if kind == "polynomial" else kind, + axis=-1, + bounds_error=True, + assume_sorted=True, + ) + expected_data = f(new_times_idx) + expected = DataArray( + expected_data, + {"time": expected_times, "x": xs, "y": ys}, + ("x", "y", "time"), + ) + # Use AllClose because there are some small differences in how + # we upsample timeseries versus the integer indexing as I've + # done here due to floating point arithmetic + assert_allclose(expected, actual, rtol=1e-16) + + @requires_scipy + @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") + def test_upsample_interpolate_bug_2197(self) -> None: + dates = pd.date_range("2007-02-01", "2007-03-01", freq="D") + da = xr.DataArray(np.arange(len(dates)), [("time", dates)]) + result = da.resample(time="ME").interpolate("linear") + expected_times = np.array( + [np.datetime64("2007-02-28"), np.datetime64("2007-03-31")] + ) + expected = xr.DataArray([27.0, np.nan], [("time", expected_times)]) + assert_equal(result, expected) + + @requires_scipy + def test_upsample_interpolate_regression_1605(self) -> None: + dates = pd.date_range("2016-01-01", "2016-03-31", freq="1D") + expected = xr.DataArray( + np.random.random((len(dates), 2, 3)), + dims=("time", "x", "y"), + coords={"time": dates}, + ) + actual = expected.resample(time="1D").interpolate("linear") + assert_allclose(actual, expected, rtol=1e-16) + + @requires_dask + @requires_scipy + @pytest.mark.parametrize("chunked_time", [True, False]) + def test_upsample_interpolate_dask(self, chunked_time: bool) -> None: + from scipy.interpolate import interp1d + + xs = np.arange(6) + ys = np.arange(3) + times = pd.date_range("2000-01-01", freq="6h", periods=5) + + z = np.arange(5) ** 2 + data = np.tile(z, (6, 3, 1)) + array = DataArray(data, {"time": times, "x": xs, "y": ys}, ("x", "y", "time")) + chunks = {"x": 2, "y": 1} + if chunked_time: + chunks["time"] = 3 + + expected_times = times.to_series().resample("1h").asfreq().index + # Split the times into equal sub-intervals to simulate the 6 hour + # to 1 hour up-sampling + new_times_idx = np.linspace(0, len(times) - 1, len(times) * 5) + kinds: list[InterpOptions] = [ + "linear", + "nearest", + "zero", + "slinear", + "quadratic", + "cubic", + "polynomial", + ] + for kind in kinds: + kwargs = {} + if kind == "polynomial": + kwargs["order"] = 1 + actual = array.chunk(chunks).resample(time="1h").interpolate(kind, **kwargs) + actual = actual.compute() + # using interp1d, polynomial order is to set directly in kind using int + f = interp1d( + np.arange(len(times)), + data, + kind=kwargs["order"] if kind == "polynomial" else kind, + axis=-1, + bounds_error=True, + assume_sorted=True, + ) + expected_data = f(new_times_idx) + expected = DataArray( + expected_data, + {"time": expected_times, "x": xs, "y": ys}, + ("x", "y", "time"), + ) + # Use AllClose because there are some small differences in how + # we upsample timeseries versus the integer indexing as I've + # done here due to floating point arithmetic + assert_allclose(expected, actual, rtol=1e-16) + + def test_resample_base(self) -> None: + times = pd.date_range("2000-01-01T02:03:01", freq="6h", periods=10) + array = DataArray(np.arange(10), [("time", times)]) + + base = 11 + + with pytest.warns(FutureWarning, match="the `base` parameter to resample"): + actual = array.resample(time="24h", base=base).mean() + expected = DataArray( + array.to_series().resample("24h", offset=f"{base}h").mean() + ) + assert_identical(expected, actual) + + def test_resample_offset(self) -> None: + times = pd.date_range("2000-01-01T02:03:01", freq="6h", periods=10) + array = DataArray(np.arange(10), [("time", times)]) + + offset = pd.Timedelta("11h") + actual = array.resample(time="24h", offset=offset).mean() + expected = DataArray(array.to_series().resample("24h", offset=offset).mean()) + assert_identical(expected, actual) + + def test_resample_origin(self) -> None: + times = pd.date_range("2000-01-01T02:03:01", freq="6h", periods=10) + array = DataArray(np.arange(10), [("time", times)]) + + origin = "start" + actual = array.resample(time="24h", origin=origin).mean() + expected = DataArray(array.to_series().resample("24h", origin=origin).mean()) + assert_identical(expected, actual) + + @pytest.mark.parametrize( + "loffset", + [ + "-12h", + datetime.timedelta(hours=-12), + pd.Timedelta(hours=-12), + pd.DateOffset(hours=-12), + ], + ) + def test_resample_loffset(self, loffset) -> None: + times = pd.date_range("2000-01-01", freq="6h", periods=10) + array = DataArray(np.arange(10), [("time", times)]) + + with pytest.warns(FutureWarning, match="`loffset` parameter"): + actual = array.resample(time="24h", loffset=loffset).mean() + series = array.to_series().resample("24h").mean() + if not isinstance(loffset, pd.DateOffset): + loffset = pd.Timedelta(loffset) + series.index = series.index + loffset + expected = DataArray(series) + assert_identical(actual, expected) + + def test_resample_invalid_loffset(self) -> None: + times = pd.date_range("2000-01-01", freq="6h", periods=10) + array = DataArray(np.arange(10), [("time", times)]) + + with pytest.warns( + FutureWarning, match="Following pandas, the `loffset` parameter" + ): + with pytest.raises(ValueError, match="`loffset` must be"): + array.resample(time="24h", loffset=1).mean() # type: ignore + + +class TestDatasetResample: + def test_resample_and_first(self) -> None: + times = pd.date_range("2000-01-01", freq="6h", periods=10) + ds = Dataset( + { + "foo": (["time", "x", "y"], np.random.randn(10, 5, 3)), + "bar": ("time", np.random.randn(10), {"meta": "data"}), + "time": times, + } + ) + + actual = ds.resample(time="1D").first(keep_attrs=True) + expected = ds.isel(time=[0, 4, 8]) + assert_identical(expected, actual) + + # upsampling + expected_time = pd.date_range("2000-01-01", freq="3h", periods=19) + expected = ds.reindex(time=expected_time) + actual = ds.resample(time="3h") + for how in ["mean", "sum", "first", "last"]: + method = getattr(actual, how) + result = method() + assert_equal(expected, result) + for method in [np.mean]: + result = actual.reduce(method) + assert_equal(expected, result) + + def test_resample_min_count(self) -> None: + times = pd.date_range("2000-01-01", freq="6h", periods=10) + ds = Dataset( + { + "foo": (["time", "x", "y"], np.random.randn(10, 5, 3)), + "bar": ("time", np.random.randn(10), {"meta": "data"}), + "time": times, + } + ) + # inject nan + ds["foo"] = xr.where(ds["foo"] > 2.0, np.nan, ds["foo"]) + + actual = ds.resample(time="1D").sum(min_count=1) + expected = xr.concat( + [ + ds.isel(time=slice(i * 4, (i + 1) * 4)).sum("time", min_count=1) + for i in range(3) + ], + dim=actual["time"], + ) + assert_allclose(expected, actual) + + def test_resample_by_mean_with_keep_attrs(self) -> None: + times = pd.date_range("2000-01-01", freq="6h", periods=10) + ds = Dataset( + { + "foo": (["time", "x", "y"], np.random.randn(10, 5, 3)), + "bar": ("time", np.random.randn(10), {"meta": "data"}), + "time": times, + } + ) + ds.attrs["dsmeta"] = "dsdata" + + resampled_ds = ds.resample(time="1D").mean(keep_attrs=True) + actual = resampled_ds["bar"].attrs + expected = ds["bar"].attrs + assert expected == actual + + actual = resampled_ds.attrs + expected = ds.attrs + assert expected == actual + + def test_resample_loffset(self) -> None: + times = pd.date_range("2000-01-01", freq="6h", periods=10) + ds = Dataset( + { + "foo": (["time", "x", "y"], np.random.randn(10, 5, 3)), + "bar": ("time", np.random.randn(10), {"meta": "data"}), + "time": times, + } + ) + ds.attrs["dsmeta"] = "dsdata" + + def test_resample_by_mean_discarding_attrs(self) -> None: + times = pd.date_range("2000-01-01", freq="6h", periods=10) + ds = Dataset( + { + "foo": (["time", "x", "y"], np.random.randn(10, 5, 3)), + "bar": ("time", np.random.randn(10), {"meta": "data"}), + "time": times, + } + ) + ds.attrs["dsmeta"] = "dsdata" + + resampled_ds = ds.resample(time="1D").mean(keep_attrs=False) + + assert resampled_ds["bar"].attrs == {} + assert resampled_ds.attrs == {} + + def test_resample_by_last_discarding_attrs(self) -> None: + times = pd.date_range("2000-01-01", freq="6h", periods=10) + ds = Dataset( + { + "foo": (["time", "x", "y"], np.random.randn(10, 5, 3)), + "bar": ("time", np.random.randn(10), {"meta": "data"}), + "time": times, + } + ) + ds.attrs["dsmeta"] = "dsdata" + + resampled_ds = ds.resample(time="1D").last(keep_attrs=False) + + assert resampled_ds["bar"].attrs == {} + assert resampled_ds.attrs == {} + + @requires_scipy + def test_resample_drop_nondim_coords(self) -> None: + xs = np.arange(6) + ys = np.arange(3) + times = pd.date_range("2000-01-01", freq="6h", periods=5) + data = np.tile(np.arange(5), (6, 3, 1)) + xx, yy = np.meshgrid(xs * 5, ys * 2.5) + tt = np.arange(len(times), dtype=int) + array = DataArray(data, {"time": times, "x": xs, "y": ys}, ("x", "y", "time")) + xcoord = DataArray(xx.T, {"x": xs, "y": ys}, ("x", "y")) + ycoord = DataArray(yy.T, {"x": xs, "y": ys}, ("x", "y")) + tcoord = DataArray(tt, {"time": times}, ("time",)) + ds = Dataset({"data": array, "xc": xcoord, "yc": ycoord, "tc": tcoord}) + ds = ds.set_coords(["xc", "yc", "tc"]) + + # Re-sample + actual = ds.resample(time="12h").mean("time") + assert "tc" not in actual.coords + + # Up-sample - filling + actual = ds.resample(time="1h").ffill() + assert "tc" not in actual.coords + + # Up-sample - interpolation + actual = ds.resample(time="1h").interpolate("linear") + assert "tc" not in actual.coords + + def test_resample_old_api(self) -> None: + times = pd.date_range("2000-01-01", freq="6h", periods=10) + ds = Dataset( + { + "foo": (["time", "x", "y"], np.random.randn(10, 5, 3)), + "bar": ("time", np.random.randn(10), {"meta": "data"}), + "time": times, + } + ) + + with pytest.raises(TypeError, match=r"resample\(\) no longer supports"): + ds.resample("1D", "time") # type: ignore[arg-type] + + with pytest.raises(TypeError, match=r"resample\(\) no longer supports"): + ds.resample("1D", dim="time", how="mean") # type: ignore[arg-type] + + with pytest.raises(TypeError, match=r"resample\(\) no longer supports"): + ds.resample("1D", dim="time") # type: ignore[arg-type] + + def test_resample_ds_da_are_the_same(self) -> None: + time = pd.date_range("2000-01-01", freq="6h", periods=365 * 4) + ds = xr.Dataset( + { + "foo": (("time", "x"), np.random.randn(365 * 4, 5)), + "time": time, + "x": np.arange(5), + } + ) + assert_allclose( + ds.resample(time="ME").mean()["foo"], ds.foo.resample(time="ME").mean() + ) + + def test_ds_resample_apply_func_args(self) -> None: + def func(arg1, arg2, arg3=0.0): + return arg1.mean("time") + arg2 + arg3 + + times = pd.date_range("2000", freq="D", periods=3) + ds = xr.Dataset({"foo": ("time", [1.0, 1.0, 1.0]), "time": times}) + expected = xr.Dataset({"foo": ("time", [3.0, 3.0, 3.0]), "time": times}) + actual = ds.resample(time="D").map(func, args=(1.0,), arg3=1.0) + assert_identical(expected, actual) + + +def test_groupby_cumsum() -> None: + ds = xr.Dataset( + {"foo": (("x",), [7, 3, 1, 1, 1, 1, 1])}, + coords={"x": [0, 1, 2, 3, 4, 5, 6], "group_id": ("x", [0, 0, 1, 1, 2, 2, 2])}, + ) + actual = ds.groupby("group_id").cumsum(dim="x") + expected = xr.Dataset( + { + "foo": (("x",), [7, 10, 1, 2, 1, 2, 3]), + }, + coords={ + "x": [0, 1, 2, 3, 4, 5, 6], + "group_id": ds.group_id, + }, + ) + # TODO: Remove drop_vars when GH6528 is fixed + # when Dataset.cumsum propagates indexes, and the group variable? + assert_identical(expected.drop_vars(["x", "group_id"]), actual) + + actual = ds.foo.groupby("group_id").cumsum(dim="x") + expected.coords["group_id"] = ds.group_id + expected.coords["x"] = np.arange(7) + assert_identical(expected.foo, actual) + + +def test_groupby_cumprod() -> None: + ds = xr.Dataset( + {"foo": (("x",), [7, 3, 0, 1, 1, 2, 1])}, + coords={"x": [0, 1, 2, 3, 4, 5, 6], "group_id": ("x", [0, 0, 1, 1, 2, 2, 2])}, + ) + actual = ds.groupby("group_id").cumprod(dim="x") + expected = xr.Dataset( + { + "foo": (("x",), [7, 21, 0, 0, 1, 2, 2]), + }, + coords={ + "x": [0, 1, 2, 3, 4, 5, 6], + "group_id": ds.group_id, + }, + ) + # TODO: Remove drop_vars when GH6528 is fixed + # when Dataset.cumsum propagates indexes, and the group variable? + assert_identical(expected.drop_vars(["x", "group_id"]), actual) + + actual = ds.foo.groupby("group_id").cumprod(dim="x") + expected.coords["group_id"] = ds.group_id + expected.coords["x"] = np.arange(7) + assert_identical(expected.foo, actual) + + +@pytest.mark.parametrize( + "method, expected_array", + [ + ("cumsum", [1.0, 2.0, 5.0, 6.0, 2.0, 2.0]), + ("cumprod", [1.0, 2.0, 6.0, 6.0, 2.0, 2.0]), + ], +) +def test_resample_cumsum(method: str, expected_array: list[float]) -> None: + ds = xr.Dataset( + {"foo": ("time", [1, 2, 3, 1, 2, np.nan])}, + coords={ + "time": xr.date_range("01-01-2001", freq="ME", periods=6, use_cftime=False), + }, + ) + actual = getattr(ds.resample(time="3ME"), method)(dim="time") + expected = xr.Dataset( + {"foo": (("time",), expected_array)}, + coords={ + "time": xr.date_range("01-01-2001", freq="ME", periods=6, use_cftime=False), + }, + ) + # TODO: Remove drop_vars when GH6528 is fixed + # when Dataset.cumsum propagates indexes, and the group variable? + assert_identical(expected.drop_vars(["time"]), actual) + + actual = getattr(ds.foo.resample(time="3ME"), method)(dim="time") + expected.coords["time"] = ds.time + assert_identical(expected.drop_vars(["time"]).foo, actual) + + +def test_groupby_binary_op_regression() -> None: + # regression test for #7797 + # monthly timeseries that should return "zero anomalies" everywhere + time = xr.date_range("2023-01-01", "2023-12-31", freq="MS") + data = np.linspace(-1, 1, 12) + x = xr.DataArray(data, coords={"time": time}) + clim = xr.DataArray(data, coords={"month": np.arange(1, 13, 1)}) + + # seems to give the correct result if we use the full x, but not with a slice + x_slice = x.sel(time=["2023-04-01"]) + + # two typical ways of computing anomalies + anom_gb = x_slice.groupby("time.month") - clim + + assert_identical(xr.zeros_like(anom_gb), anom_gb) + + +def test_groupby_multiindex_level() -> None: + # GH6836 + midx = pd.MultiIndex.from_product([list("abc"), [0, 1]], names=("one", "two")) + mda = xr.DataArray(np.random.rand(6, 3), [("x", midx), ("y", range(3))]) + groups = mda.groupby("one").groups + assert groups == {"a": [0, 1], "b": [2, 3], "c": [4, 5]} + + +@requires_flox +@pytest.mark.parametrize("func", ["sum", "prod"]) +@pytest.mark.parametrize("skipna", [True, False]) +@pytest.mark.parametrize("min_count", [None, 1]) +def test_min_count_vs_flox(func: str, min_count: int | None, skipna: bool) -> None: + da = DataArray( + data=np.array([np.nan, 1, 1, np.nan, 1, 1]), + dims="x", + coords={"labels": ("x", np.array([1, 2, 3, 1, 2, 3]))}, + ) + + gb = da.groupby("labels") + method = operator.methodcaller(func, min_count=min_count, skipna=skipna) + with xr.set_options(use_flox=True): + actual = method(gb) + with xr.set_options(use_flox=False): + expected = method(gb) + assert_identical(actual, expected) + + +@pytest.mark.parametrize("use_flox", [True, False]) +def test_min_count_error(use_flox: bool) -> None: + if use_flox and not has_flox: + pytest.skip() + da = DataArray( + data=np.array([np.nan, 1, 1, np.nan, 1, 1]), + dims="x", + coords={"labels": ("x", np.array([1, 2, 3, 1, 2, 3]))}, + ) + with xr.set_options(use_flox=use_flox): + with pytest.raises(TypeError): + da.groupby("labels").mean(min_count=1) + + +@requires_dask +def test_groupby_math_auto_chunk() -> None: + da = xr.DataArray( + [[1, 2, 3], [1, 2, 3], [1, 2, 3]], + dims=("y", "x"), + coords={"label": ("x", [2, 2, 1])}, + ) + sub = xr.DataArray( + InaccessibleArray(np.array([1, 2])), dims="label", coords={"label": [1, 2]} + ) + actual = da.chunk(x=1, y=2).groupby("label") - sub + assert actual.chunksizes == {"x": (1, 1, 1), "y": (2, 1)} + + +@pytest.mark.parametrize("use_flox", [True, False]) +def test_groupby_dim_no_dim_equal(use_flox: bool) -> None: + # https://github.com/pydata/xarray/issues/8263 + da = DataArray( + data=[1, 2, 3, 4], dims="lat", coords={"lat": np.linspace(0, 1.01, 4)} + ) + with xr.set_options(use_flox=use_flox): + actual1 = da.drop_vars("lat").groupby("lat", squeeze=False).sum() + actual2 = da.groupby("lat", squeeze=False).sum() + assert_identical(actual1, actual2.drop_vars("lat")) + + +@requires_flox +def test_default_flox_method() -> None: + import flox.xarray + + da = xr.DataArray([1, 2, 3], dims="x", coords={"label": ("x", [2, 2, 1])}) + + result = xr.DataArray([3, 3], dims="label", coords={"label": [1, 2]}) + with mock.patch("flox.xarray.xarray_reduce", return_value=result) as mocked_reduce: + da.groupby("label").sum() + + kwargs = mocked_reduce.call_args.kwargs + if Version(flox.__version__) < Version("0.9.0"): + assert kwargs["method"] == "cohorts" + else: + assert "method" not in kwargs diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_hashable.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_hashable.py new file mode 100644 index 0000000..9f92c60 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_hashable.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from enum import Enum +from typing import TYPE_CHECKING, Union + +import pytest + +from xarray import DataArray, Dataset, Variable + +if TYPE_CHECKING: + from xarray.core.types import TypeAlias + + DimT: TypeAlias = Union[int, tuple, "DEnum", "CustomHashable"] + + +class DEnum(Enum): + dim = "dim" + + +class CustomHashable: + def __init__(self, a: int) -> None: + self.a = a + + def __hash__(self) -> int: + return self.a + + +parametrize_dim = pytest.mark.parametrize( + "dim", + [ + pytest.param(5, id="int"), + pytest.param(("a", "b"), id="tuple"), + pytest.param(DEnum.dim, id="enum"), + pytest.param(CustomHashable(3), id="HashableObject"), + ], +) + + +@parametrize_dim +def test_hashable_dims(dim: DimT) -> None: + v = Variable([dim], [1, 2, 3]) + da = DataArray([1, 2, 3], dims=[dim]) + Dataset({"a": ([dim], [1, 2, 3])}) + + # alternative constructors + DataArray(v) + Dataset({"a": v}) + Dataset({"a": da}) + + +@parametrize_dim +def test_dataset_variable_hashable_names(dim: DimT) -> None: + Dataset({dim: ("x", [1, 2, 3])}) diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_indexes.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_indexes.py new file mode 100644 index 0000000..5ebdfd5 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_indexes.py @@ -0,0 +1,728 @@ +from __future__ import annotations + +import copy +from datetime import datetime +from typing import Any + +import numpy as np +import pandas as pd +import pytest + +import xarray as xr +from xarray.coding.cftimeindex import CFTimeIndex +from xarray.core.indexes import ( + Hashable, + Index, + Indexes, + PandasIndex, + PandasMultiIndex, + _asarray_tuplesafe, + safe_cast_to_index, +) +from xarray.core.variable import IndexVariable, Variable +from xarray.tests import assert_array_equal, assert_identical, requires_cftime +from xarray.tests.test_coding_times import _all_cftime_date_types + + +def test_asarray_tuplesafe() -> None: + res = _asarray_tuplesafe(("a", 1)) + assert isinstance(res, np.ndarray) + assert res.ndim == 0 + assert res.item() == ("a", 1) + + res = _asarray_tuplesafe([(0,), (1,)]) + assert res.shape == (2,) + assert res[0] == (0,) + assert res[1] == (1,) + + +class CustomIndex(Index): + def __init__(self, dims) -> None: + self.dims = dims + + +class TestIndex: + @pytest.fixture + def index(self) -> CustomIndex: + return CustomIndex({"x": 2}) + + def test_from_variables(self) -> None: + with pytest.raises(NotImplementedError): + Index.from_variables({}, options={}) + + def test_concat(self) -> None: + with pytest.raises(NotImplementedError): + Index.concat([], "x") + + def test_stack(self) -> None: + with pytest.raises(NotImplementedError): + Index.stack({}, "x") + + def test_unstack(self, index) -> None: + with pytest.raises(NotImplementedError): + index.unstack() + + def test_create_variables(self, index) -> None: + assert index.create_variables() == {} + assert index.create_variables({"x": "var"}) == {"x": "var"} + + def test_to_pandas_index(self, index) -> None: + with pytest.raises(TypeError): + index.to_pandas_index() + + def test_isel(self, index) -> None: + assert index.isel({}) is None + + def test_sel(self, index) -> None: + with pytest.raises(NotImplementedError): + index.sel({}) + + def test_join(self, index) -> None: + with pytest.raises(NotImplementedError): + index.join(CustomIndex({"y": 2})) + + def test_reindex_like(self, index) -> None: + with pytest.raises(NotImplementedError): + index.reindex_like(CustomIndex({"y": 2})) + + def test_equals(self, index) -> None: + with pytest.raises(NotImplementedError): + index.equals(CustomIndex({"y": 2})) + + def test_roll(self, index) -> None: + assert index.roll({}) is None + + def test_rename(self, index) -> None: + assert index.rename({}, {}) is index + + @pytest.mark.parametrize("deep", [True, False]) + def test_copy(self, index, deep) -> None: + copied = index.copy(deep=deep) + assert isinstance(copied, CustomIndex) + assert copied is not index + + copied.dims["x"] = 3 + if deep: + assert copied.dims != index.dims + assert copied.dims != copy.deepcopy(index).dims + else: + assert copied.dims is index.dims + assert copied.dims is copy.copy(index).dims + + def test_getitem(self, index) -> None: + with pytest.raises(NotImplementedError): + index[:] + + +class TestPandasIndex: + def test_constructor(self) -> None: + pd_idx = pd.Index([1, 2, 3]) + index = PandasIndex(pd_idx, "x") + + assert index.index.equals(pd_idx) + # makes a shallow copy + assert index.index is not pd_idx + assert index.dim == "x" + + # test no name set for pd.Index + pd_idx.name = None + index = PandasIndex(pd_idx, "x") + assert index.index.name == "x" + + def test_from_variables(self) -> None: + # pandas has only Float64Index but variable dtype should be preserved + data = np.array([1.1, 2.2, 3.3], dtype=np.float32) + var = xr.Variable( + "x", data, attrs={"unit": "m"}, encoding={"dtype": np.float64} + ) + + index = PandasIndex.from_variables({"x": var}, options={}) + assert index.dim == "x" + assert index.index.equals(pd.Index(data)) + assert index.coord_dtype == data.dtype + + var2 = xr.Variable(("x", "y"), [[1, 2, 3], [4, 5, 6]]) + with pytest.raises(ValueError, match=r".*only accepts one variable.*"): + PandasIndex.from_variables({"x": var, "foo": var2}, options={}) + + with pytest.raises( + ValueError, match=r".*cannot set a PandasIndex.*scalar variable.*" + ): + PandasIndex.from_variables({"foo": xr.Variable((), 1)}, options={}) + + with pytest.raises( + ValueError, match=r".*only accepts a 1-dimensional variable.*" + ): + PandasIndex.from_variables({"foo": var2}, options={}) + + def test_from_variables_index_adapter(self) -> None: + # test index type is preserved when variable wraps a pd.Index + data = pd.Series(["foo", "bar"], dtype="category") + pd_idx = pd.Index(data) + var = xr.Variable("x", pd_idx) + + index = PandasIndex.from_variables({"x": var}, options={}) + assert isinstance(index.index, pd.CategoricalIndex) + + def test_concat_periods(self): + periods = pd.period_range("2000-01-01", periods=10) + indexes = [PandasIndex(periods[:5], "t"), PandasIndex(periods[5:], "t")] + expected = PandasIndex(periods, "t") + actual = PandasIndex.concat(indexes, dim="t") + assert actual.equals(expected) + assert isinstance(actual.index, pd.PeriodIndex) + + positions = [list(range(5)), list(range(5, 10))] + actual = PandasIndex.concat(indexes, dim="t", positions=positions) + assert actual.equals(expected) + assert isinstance(actual.index, pd.PeriodIndex) + + @pytest.mark.parametrize("dtype", [str, bytes]) + def test_concat_str_dtype(self, dtype) -> None: + a = PandasIndex(np.array(["a"], dtype=dtype), "x", coord_dtype=dtype) + b = PandasIndex(np.array(["b"], dtype=dtype), "x", coord_dtype=dtype) + expected = PandasIndex( + np.array(["a", "b"], dtype=dtype), "x", coord_dtype=dtype + ) + + actual = PandasIndex.concat([a, b], "x") + assert actual.equals(expected) + assert np.issubdtype(actual.coord_dtype, dtype) + + def test_concat_empty(self) -> None: + idx = PandasIndex.concat([], "x") + assert idx.coord_dtype is np.dtype("O") + + def test_concat_dim_error(self) -> None: + indexes = [PandasIndex([0, 1], "x"), PandasIndex([2, 3], "y")] + + with pytest.raises(ValueError, match=r"Cannot concatenate.*dimensions.*"): + PandasIndex.concat(indexes, "x") + + def test_create_variables(self) -> None: + # pandas has only Float64Index but variable dtype should be preserved + data = np.array([1.1, 2.2, 3.3], dtype=np.float32) + pd_idx = pd.Index(data, name="foo") + index = PandasIndex(pd_idx, "x", coord_dtype=data.dtype) + index_vars = { + "foo": IndexVariable( + "x", data, attrs={"unit": "m"}, encoding={"fill_value": 0.0} + ) + } + + actual = index.create_variables(index_vars) + assert_identical(actual["foo"], index_vars["foo"]) + assert actual["foo"].dtype == index_vars["foo"].dtype + assert actual["foo"].dtype == index.coord_dtype + + def test_to_pandas_index(self) -> None: + pd_idx = pd.Index([1, 2, 3], name="foo") + index = PandasIndex(pd_idx, "x") + assert index.to_pandas_index() is index.index + + def test_sel(self) -> None: + # TODO: add tests that aren't just for edge cases + index = PandasIndex(pd.Index([1, 2, 3]), "x") + with pytest.raises(KeyError, match=r"not all values found"): + index.sel({"x": [0]}) + with pytest.raises(KeyError): + index.sel({"x": 0}) + with pytest.raises(ValueError, match=r"does not have a MultiIndex"): + index.sel({"x": {"one": 0}}) + + def test_sel_boolean(self) -> None: + # index should be ignored and indexer dtype should not be coerced + # see https://github.com/pydata/xarray/issues/5727 + index = PandasIndex(pd.Index([0.0, 2.0, 1.0, 3.0]), "x") + actual = index.sel({"x": [False, True, False, True]}) + expected_dim_indexers = {"x": [False, True, False, True]} + np.testing.assert_array_equal( + actual.dim_indexers["x"], expected_dim_indexers["x"] + ) + + def test_sel_datetime(self) -> None: + index = PandasIndex( + pd.to_datetime(["2000-01-01", "2001-01-01", "2002-01-01"]), "x" + ) + actual = index.sel({"x": "2001-01-01"}) + expected_dim_indexers = {"x": 1} + assert actual.dim_indexers == expected_dim_indexers + + actual = index.sel({"x": index.to_pandas_index().to_numpy()[1]}) + assert actual.dim_indexers == expected_dim_indexers + + def test_sel_unsorted_datetime_index_raises(self) -> None: + index = PandasIndex(pd.to_datetime(["2001", "2000", "2002"]), "x") + with pytest.raises(KeyError): + # pandas will try to convert this into an array indexer. We should + # raise instead, so we can be sure the result of indexing with a + # slice is always a view. + index.sel({"x": slice("2001", "2002")}) + + def test_equals(self) -> None: + index1 = PandasIndex([1, 2, 3], "x") + index2 = PandasIndex([1, 2, 3], "x") + assert index1.equals(index2) is True + + def test_join(self) -> None: + index1 = PandasIndex(["a", "aa", "aaa"], "x", coord_dtype=" None: + index1 = PandasIndex([0, 1, 2], "x") + index2 = PandasIndex([1, 2, 3, 4], "x") + + expected = {"x": [1, 2, -1, -1]} + actual = index1.reindex_like(index2) + assert actual.keys() == expected.keys() + np.testing.assert_array_equal(actual["x"], expected["x"]) + + index3 = PandasIndex([1, 1, 2], "x") + with pytest.raises(ValueError, match=r".*index has duplicate values"): + index3.reindex_like(index2) + + def test_rename(self) -> None: + index = PandasIndex(pd.Index([1, 2, 3], name="a"), "x", coord_dtype=np.int32) + + # shortcut + new_index = index.rename({}, {}) + assert new_index is index + + new_index = index.rename({"a": "b"}, {}) + assert new_index.index.name == "b" + assert new_index.dim == "x" + assert new_index.coord_dtype == np.int32 + + new_index = index.rename({}, {"x": "y"}) + assert new_index.index.name == "a" + assert new_index.dim == "y" + assert new_index.coord_dtype == np.int32 + + def test_copy(self) -> None: + expected = PandasIndex([1, 2, 3], "x", coord_dtype=np.int32) + actual = expected.copy() + + assert actual.index.equals(expected.index) + assert actual.index is not expected.index + assert actual.dim == expected.dim + assert actual.coord_dtype == expected.coord_dtype + + def test_getitem(self) -> None: + pd_idx = pd.Index([1, 2, 3]) + expected = PandasIndex(pd_idx, "x", coord_dtype=np.int32) + actual = expected[1:] + + assert actual.index.equals(pd_idx[1:]) + assert actual.dim == expected.dim + assert actual.coord_dtype == expected.coord_dtype + + +class TestPandasMultiIndex: + def test_constructor(self) -> None: + foo_data = np.array([0, 0, 1], dtype="int64") + bar_data = np.array([1.1, 1.2, 1.3], dtype="float64") + pd_idx = pd.MultiIndex.from_arrays([foo_data, bar_data], names=("foo", "bar")) + + index = PandasMultiIndex(pd_idx, "x") + + assert index.dim == "x" + assert index.index.equals(pd_idx) + assert index.index.names == ("foo", "bar") + assert index.index.name == "x" + assert index.level_coords_dtype == { + "foo": foo_data.dtype, + "bar": bar_data.dtype, + } + + with pytest.raises(ValueError, match=".*conflicting multi-index level name.*"): + PandasMultiIndex(pd_idx, "foo") + + # default level names + pd_idx = pd.MultiIndex.from_arrays([foo_data, bar_data]) + index = PandasMultiIndex(pd_idx, "x") + assert list(index.index.names) == ["x_level_0", "x_level_1"] + + def test_from_variables(self) -> None: + v_level1 = xr.Variable( + "x", [1, 2, 3], attrs={"unit": "m"}, encoding={"dtype": np.int32} + ) + v_level2 = xr.Variable( + "x", ["a", "b", "c"], attrs={"unit": "m"}, encoding={"dtype": "U"} + ) + + index = PandasMultiIndex.from_variables( + {"level1": v_level1, "level2": v_level2}, options={} + ) + + expected_idx = pd.MultiIndex.from_arrays([v_level1.data, v_level2.data]) + assert index.dim == "x" + assert index.index.equals(expected_idx) + assert index.index.name == "x" + assert list(index.index.names) == ["level1", "level2"] + + var = xr.Variable(("x", "y"), [[1, 2, 3], [4, 5, 6]]) + with pytest.raises( + ValueError, match=r".*only accepts 1-dimensional variables.*" + ): + PandasMultiIndex.from_variables({"var": var}, options={}) + + v_level3 = xr.Variable("y", [4, 5, 6]) + with pytest.raises( + ValueError, match=r"unmatched dimensions for multi-index variables.*" + ): + PandasMultiIndex.from_variables( + {"level1": v_level1, "level3": v_level3}, options={} + ) + + def test_concat(self) -> None: + pd_midx = pd.MultiIndex.from_product( + [[0, 1, 2], ["a", "b"]], names=("foo", "bar") + ) + level_coords_dtype = {"foo": np.int32, "bar": "=U1"} + + midx1 = PandasMultiIndex( + pd_midx[:2], "x", level_coords_dtype=level_coords_dtype + ) + midx2 = PandasMultiIndex( + pd_midx[2:], "x", level_coords_dtype=level_coords_dtype + ) + expected = PandasMultiIndex(pd_midx, "x", level_coords_dtype=level_coords_dtype) + + actual = PandasMultiIndex.concat([midx1, midx2], "x") + assert actual.equals(expected) + assert actual.level_coords_dtype == expected.level_coords_dtype + + def test_stack(self) -> None: + prod_vars = { + "x": xr.Variable("x", pd.Index(["b", "a"]), attrs={"foo": "bar"}), + "y": xr.Variable("y", pd.Index([1, 3, 2])), + } + + index = PandasMultiIndex.stack(prod_vars, "z") + + assert index.dim == "z" + # TODO: change to tuple when pandas 3 is minimum + assert list(index.index.names) == ["x", "y"] + np.testing.assert_array_equal( + index.index.codes, [[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]] + ) + + with pytest.raises( + ValueError, match=r"conflicting dimensions for multi-index product.*" + ): + PandasMultiIndex.stack( + {"x": xr.Variable("x", ["a", "b"]), "x2": xr.Variable("x", [1, 2])}, + "z", + ) + + def test_stack_non_unique(self) -> None: + prod_vars = { + "x": xr.Variable("x", pd.Index(["b", "a"]), attrs={"foo": "bar"}), + "y": xr.Variable("y", pd.Index([1, 1, 2])), + } + + index = PandasMultiIndex.stack(prod_vars, "z") + + np.testing.assert_array_equal( + index.index.codes, [[0, 0, 0, 1, 1, 1], [0, 0, 1, 0, 0, 1]] + ) + np.testing.assert_array_equal(index.index.levels[0], ["b", "a"]) + np.testing.assert_array_equal(index.index.levels[1], [1, 2]) + + def test_unstack(self) -> None: + pd_midx = pd.MultiIndex.from_product( + [["a", "b"], [1, 2, 3]], names=["one", "two"] + ) + index = PandasMultiIndex(pd_midx, "x") + + new_indexes, new_pd_idx = index.unstack() + assert list(new_indexes) == ["one", "two"] + assert new_indexes["one"].equals(PandasIndex(["a", "b"], "one")) + assert new_indexes["two"].equals(PandasIndex([1, 2, 3], "two")) + assert new_pd_idx.equals(pd_midx) + + def test_unstack_requires_unique(self) -> None: + pd_midx = pd.MultiIndex.from_product([["a", "a"], [1, 2]], names=["one", "two"]) + index = PandasMultiIndex(pd_midx, "x") + + with pytest.raises( + ValueError, match="Cannot unstack MultiIndex containing duplicates" + ): + index.unstack() + + def test_create_variables(self) -> None: + foo_data = np.array([0, 0, 1], dtype="int64") + bar_data = np.array([1.1, 1.2, 1.3], dtype="float64") + pd_idx = pd.MultiIndex.from_arrays([foo_data, bar_data], names=("foo", "bar")) + index_vars = { + "x": IndexVariable("x", pd_idx), + "foo": IndexVariable("x", foo_data, attrs={"unit": "m"}), + "bar": IndexVariable("x", bar_data, encoding={"fill_value": 0}), + } + + index = PandasMultiIndex(pd_idx, "x") + actual = index.create_variables(index_vars) + + for k, expected in index_vars.items(): + assert_identical(actual[k], expected) + assert actual[k].dtype == expected.dtype + if k != "x": + assert actual[k].dtype == index.level_coords_dtype[k] + + def test_sel(self) -> None: + index = PandasMultiIndex( + pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")), "x" + ) + + # test tuples inside slice are considered as scalar indexer values + actual = index.sel({"x": slice(("a", 1), ("b", 2))}) + expected_dim_indexers = {"x": slice(0, 4)} + assert actual.dim_indexers == expected_dim_indexers + + with pytest.raises(KeyError, match=r"not all values found"): + index.sel({"x": [0]}) + with pytest.raises(KeyError): + index.sel({"x": 0}) + with pytest.raises(ValueError, match=r"cannot provide labels for both.*"): + index.sel({"one": 0, "x": "a"}) + with pytest.raises( + ValueError, + match=r"multi-index level names \('three',\) not found in indexes", + ): + index.sel({"x": {"three": 0}}) + with pytest.raises(IndexError): + index.sel({"x": (slice(None), 1, "no_level")}) + + def test_join(self): + midx = pd.MultiIndex.from_product([["a", "aa"], [1, 2]], names=("one", "two")) + level_coords_dtype = {"one": "=U2", "two": "i"} + index1 = PandasMultiIndex(midx, "x", level_coords_dtype=level_coords_dtype) + index2 = PandasMultiIndex(midx[0:2], "x", level_coords_dtype=level_coords_dtype) + + actual = index1.join(index2) + assert actual.equals(index2) + assert actual.level_coords_dtype == level_coords_dtype + + actual = index1.join(index2, how="outer") + assert actual.equals(index1) + assert actual.level_coords_dtype == level_coords_dtype + + def test_rename(self) -> None: + level_coords_dtype = {"one": " None: + level_coords_dtype = {"one": "U<1", "two": np.int32} + expected = PandasMultiIndex( + pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")), + "x", + level_coords_dtype=level_coords_dtype, + ) + actual = expected.copy() + + assert actual.index.equals(expected.index) + assert actual.index is not expected.index + assert actual.dim == expected.dim + assert actual.level_coords_dtype == expected.level_coords_dtype + + +class TestIndexes: + @pytest.fixture + def indexes_and_vars(self) -> tuple[list[PandasIndex], dict[Hashable, Variable]]: + x_idx = PandasIndex(pd.Index([1, 2, 3], name="x"), "x") + y_idx = PandasIndex(pd.Index([4, 5, 6], name="y"), "y") + z_pd_midx = pd.MultiIndex.from_product( + [["a", "b"], [1, 2]], names=["one", "two"] + ) + z_midx = PandasMultiIndex(z_pd_midx, "z") + + indexes = [x_idx, y_idx, z_midx] + + variables = {} + for idx in indexes: + variables.update(idx.create_variables()) + + return indexes, variables + + @pytest.fixture(params=["pd_index", "xr_index"]) + def unique_indexes( + self, request, indexes_and_vars + ) -> list[PandasIndex] | list[pd.Index]: + xr_indexes, _ = indexes_and_vars + + if request.param == "pd_index": + return [idx.index for idx in xr_indexes] + else: + return xr_indexes + + @pytest.fixture + def indexes( + self, unique_indexes, indexes_and_vars + ) -> Indexes[Index] | Indexes[pd.Index]: + x_idx, y_idx, z_midx = unique_indexes + indexes: dict[Any, Index] = { + "x": x_idx, + "y": y_idx, + "z": z_midx, + "one": z_midx, + "two": z_midx, + } + + _, variables = indexes_and_vars + + if isinstance(x_idx, Index): + index_type = Index + else: + index_type = pd.Index + + return Indexes(indexes, variables, index_type=index_type) + + def test_interface(self, unique_indexes, indexes) -> None: + x_idx = unique_indexes[0] + assert list(indexes) == ["x", "y", "z", "one", "two"] + assert len(indexes) == 5 + assert "x" in indexes + assert indexes["x"] is x_idx + + def test_variables(self, indexes) -> None: + assert tuple(indexes.variables) == ("x", "y", "z", "one", "two") + + def test_dims(self, indexes) -> None: + assert indexes.dims == {"x": 3, "y": 3, "z": 4} + + def test_get_unique(self, unique_indexes, indexes) -> None: + assert indexes.get_unique() == unique_indexes + + def test_is_multi(self, indexes) -> None: + assert indexes.is_multi("one") is True + assert indexes.is_multi("x") is False + + def test_get_all_coords(self, indexes) -> None: + expected = { + "z": indexes.variables["z"], + "one": indexes.variables["one"], + "two": indexes.variables["two"], + } + assert indexes.get_all_coords("one") == expected + + with pytest.raises(ValueError, match="errors must be.*"): + indexes.get_all_coords("x", errors="invalid") + + with pytest.raises(ValueError, match="no index found.*"): + indexes.get_all_coords("no_coord") + + assert indexes.get_all_coords("no_coord", errors="ignore") == {} + + def test_get_all_dims(self, indexes) -> None: + expected = {"z": 4} + assert indexes.get_all_dims("one") == expected + + def test_group_by_index(self, unique_indexes, indexes): + expected = [ + (unique_indexes[0], {"x": indexes.variables["x"]}), + (unique_indexes[1], {"y": indexes.variables["y"]}), + ( + unique_indexes[2], + { + "z": indexes.variables["z"], + "one": indexes.variables["one"], + "two": indexes.variables["two"], + }, + ), + ] + + assert indexes.group_by_index() == expected + + def test_to_pandas_indexes(self, indexes) -> None: + pd_indexes = indexes.to_pandas_indexes() + assert isinstance(pd_indexes, Indexes) + assert all([isinstance(idx, pd.Index) for idx in pd_indexes.values()]) + assert indexes.variables == pd_indexes.variables + + def test_copy_indexes(self, indexes) -> None: + copied, index_vars = indexes.copy_indexes() + + assert copied.keys() == indexes.keys() + for new, original in zip(copied.values(), indexes.values()): + assert new.equals(original) + # check unique index objects preserved + assert copied["z"] is copied["one"] is copied["two"] + + assert index_vars.keys() == indexes.variables.keys() + for new, original in zip(index_vars.values(), indexes.variables.values()): + assert_identical(new, original) + + +def test_safe_cast_to_index(): + dates = pd.date_range("2000-01-01", periods=10) + x = np.arange(5) + td = x * np.timedelta64(1, "D") + for expected, array in [ + (dates, dates.values), + (pd.Index(x, dtype=object), x.astype(object)), + (pd.Index(td), td), + (pd.Index(td, dtype=object), td.astype(object)), + ]: + actual = safe_cast_to_index(array) + assert_array_equal(expected, actual) + assert expected.dtype == actual.dtype + + +@requires_cftime +def test_safe_cast_to_index_cftimeindex(): + date_types = _all_cftime_date_types() + for date_type in date_types.values(): + dates = [date_type(1, 1, day) for day in range(1, 20)] + expected = CFTimeIndex(dates) + actual = safe_cast_to_index(np.array(dates)) + assert_array_equal(expected, actual) + assert expected.dtype == actual.dtype + assert isinstance(actual, type(expected)) + + +# Test that datetime.datetime objects are never used in a CFTimeIndex +@requires_cftime +def test_safe_cast_to_index_datetime_datetime(): + dates = [datetime(1, 1, day) for day in range(1, 20)] + + expected = pd.Index(dates) + actual = safe_cast_to_index(np.array(dates)) + assert_array_equal(expected, actual) + assert isinstance(actual, pd.Index) + + +@pytest.mark.parametrize("dtype", ["int32", "float32"]) +def test_restore_dtype_on_multiindexes(dtype: str) -> None: + foo = xr.Dataset(coords={"bar": ("bar", np.array([0, 1], dtype=dtype))}) + foo = foo.stack(baz=("bar",)) + assert str(foo["bar"].values.dtype) == dtype diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_indexing.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_indexing.py new file mode 100644 index 0000000..f019d3c --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_indexing.py @@ -0,0 +1,975 @@ +from __future__ import annotations + +import itertools +from typing import Any + +import numpy as np +import pandas as pd +import pytest + +from xarray import DataArray, Dataset, Variable +from xarray.core import indexing, nputils +from xarray.core.indexes import PandasIndex, PandasMultiIndex +from xarray.core.types import T_Xarray +from xarray.tests import ( + IndexerMaker, + ReturnItem, + assert_array_equal, + assert_identical, + raise_if_dask_computes, + requires_dask, +) + +B = IndexerMaker(indexing.BasicIndexer) + + +class TestIndexCallable: + def test_getitem(self): + def getter(key): + return key * 2 + + indexer = indexing.IndexCallable(getter) + assert indexer[3] == 6 + assert indexer[0] == 0 + assert indexer[-1] == -2 + + def test_setitem(self): + def getter(key): + return key * 2 + + def setter(key, value): + raise NotImplementedError("Setter not implemented") + + indexer = indexing.IndexCallable(getter, setter) + with pytest.raises(NotImplementedError): + indexer[3] = 6 + + +class TestIndexers: + def set_to_zero(self, x, i): + x = x.copy() + x[i] = 0 + return x + + def test_expanded_indexer(self) -> None: + x = np.random.randn(10, 11, 12, 13, 14) + y = np.arange(5) + arr = ReturnItem() + for i in [ + arr[:], + arr[...], + arr[0, :, 10], + arr[..., 10], + arr[:5, ..., 0], + arr[..., 0, :], + arr[y], + arr[y, y], + arr[..., y, y], + arr[..., 0, 1, 2, 3, 4], + ]: + j = indexing.expanded_indexer(i, x.ndim) + assert_array_equal(x[i], x[j]) + assert_array_equal(self.set_to_zero(x, i), self.set_to_zero(x, j)) + with pytest.raises(IndexError, match=r"too many indices"): + indexing.expanded_indexer(arr[1, 2, 3], 2) + + def test_stacked_multiindex_min_max(self) -> None: + data = np.random.randn(3, 23, 4) + da = DataArray( + data, + name="value", + dims=["replicate", "rsample", "exp"], + coords=dict( + replicate=[0, 1, 2], exp=["a", "b", "c", "d"], rsample=list(range(23)) + ), + ) + da2 = da.stack(sample=("replicate", "rsample")) + s = da2.sample + assert_array_equal(da2.loc["a", s.max()], data[2, 22, 0]) + assert_array_equal(da2.loc["b", s.min()], data[0, 0, 1]) + + def test_group_indexers_by_index(self) -> None: + mindex = pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")) + data = DataArray( + np.zeros((4, 2, 2)), coords={"x": mindex, "y": [1, 2]}, dims=("x", "y", "z") + ) + data.coords["y2"] = ("y", [2.0, 3.0]) + + grouped_indexers = indexing.group_indexers_by_index( + data, {"z": 0, "one": "a", "two": 1, "y": 0}, {} + ) + + for idx, indexers in grouped_indexers: + if idx is None: + assert indexers == {"z": 0} + elif idx.equals(data.xindexes["x"]): + assert indexers == {"one": "a", "two": 1} + elif idx.equals(data.xindexes["y"]): + assert indexers == {"y": 0} + assert len(grouped_indexers) == 3 + + with pytest.raises(KeyError, match=r"no index found for coordinate 'y2'"): + indexing.group_indexers_by_index(data, {"y2": 2.0}, {}) + with pytest.raises( + KeyError, match=r"'w' is not a valid dimension or coordinate" + ): + indexing.group_indexers_by_index(data, {"w": "a"}, {}) + with pytest.raises(ValueError, match=r"cannot supply.*"): + indexing.group_indexers_by_index(data, {"z": 1}, {"method": "nearest"}) + + def test_map_index_queries(self) -> None: + def create_sel_results( + x_indexer, + x_index, + other_vars, + drop_coords, + drop_indexes, + rename_dims, + ): + dim_indexers = {"x": x_indexer} + index_vars = x_index.create_variables() + indexes = {k: x_index for k in index_vars} + variables = {} + variables.update(index_vars) + variables.update(other_vars) + + return indexing.IndexSelResult( + dim_indexers=dim_indexers, + indexes=indexes, + variables=variables, + drop_coords=drop_coords, + drop_indexes=drop_indexes, + rename_dims=rename_dims, + ) + + def test_indexer( + data: T_Xarray, + x: Any, + expected: indexing.IndexSelResult, + ) -> None: + results = indexing.map_index_queries(data, {"x": x}) + + assert results.dim_indexers.keys() == expected.dim_indexers.keys() + assert_array_equal(results.dim_indexers["x"], expected.dim_indexers["x"]) + + assert results.indexes.keys() == expected.indexes.keys() + for k in results.indexes: + assert results.indexes[k].equals(expected.indexes[k]) + + assert results.variables.keys() == expected.variables.keys() + for k in results.variables: + assert_array_equal(results.variables[k], expected.variables[k]) + + assert set(results.drop_coords) == set(expected.drop_coords) + assert set(results.drop_indexes) == set(expected.drop_indexes) + assert results.rename_dims == expected.rename_dims + + data = Dataset({"x": ("x", [1, 2, 3])}) + mindex = pd.MultiIndex.from_product( + [["a", "b"], [1, 2], [-1, -2]], names=("one", "two", "three") + ) + mdata = DataArray(range(8), [("x", mindex)]) + + test_indexer(data, 1, indexing.IndexSelResult({"x": 0})) + test_indexer(data, np.int32(1), indexing.IndexSelResult({"x": 0})) + test_indexer(data, Variable([], 1), indexing.IndexSelResult({"x": 0})) + test_indexer(mdata, ("a", 1, -1), indexing.IndexSelResult({"x": 0})) + + expected = create_sel_results( + [True, True, False, False, False, False, False, False], + PandasIndex(pd.Index([-1, -2]), "three"), + {"one": Variable((), "a"), "two": Variable((), 1)}, + ["x"], + ["one", "two"], + {"x": "three"}, + ) + test_indexer(mdata, ("a", 1), expected) + + expected = create_sel_results( + slice(0, 4, None), + PandasMultiIndex( + pd.MultiIndex.from_product([[1, 2], [-1, -2]], names=("two", "three")), + "x", + ), + {"one": Variable((), "a")}, + [], + ["one"], + {}, + ) + test_indexer(mdata, "a", expected) + + expected = create_sel_results( + [True, True, True, True, False, False, False, False], + PandasMultiIndex( + pd.MultiIndex.from_product([[1, 2], [-1, -2]], names=("two", "three")), + "x", + ), + {"one": Variable((), "a")}, + [], + ["one"], + {}, + ) + test_indexer(mdata, ("a",), expected) + + test_indexer( + mdata, [("a", 1, -1), ("b", 2, -2)], indexing.IndexSelResult({"x": [0, 7]}) + ) + test_indexer( + mdata, slice("a", "b"), indexing.IndexSelResult({"x": slice(0, 8, None)}) + ) + test_indexer( + mdata, + slice(("a", 1), ("b", 1)), + indexing.IndexSelResult({"x": slice(0, 6, None)}), + ) + test_indexer( + mdata, + {"one": "a", "two": 1, "three": -1}, + indexing.IndexSelResult({"x": 0}), + ) + + expected = create_sel_results( + [True, True, False, False, False, False, False, False], + PandasIndex(pd.Index([-1, -2]), "three"), + {"one": Variable((), "a"), "two": Variable((), 1)}, + ["x"], + ["one", "two"], + {"x": "three"}, + ) + test_indexer(mdata, {"one": "a", "two": 1}, expected) + + expected = create_sel_results( + [True, False, True, False, False, False, False, False], + PandasIndex(pd.Index([1, 2]), "two"), + {"one": Variable((), "a"), "three": Variable((), -1)}, + ["x"], + ["one", "three"], + {"x": "two"}, + ) + test_indexer(mdata, {"one": "a", "three": -1}, expected) + + expected = create_sel_results( + [True, True, True, True, False, False, False, False], + PandasMultiIndex( + pd.MultiIndex.from_product([[1, 2], [-1, -2]], names=("two", "three")), + "x", + ), + {"one": Variable((), "a")}, + [], + ["one"], + {}, + ) + test_indexer(mdata, {"one": "a"}, expected) + + def test_read_only_view(self) -> None: + arr = DataArray( + np.random.rand(3, 3), + coords={"x": np.arange(3), "y": np.arange(3)}, + dims=("x", "y"), + ) # Create a 2D DataArray + arr = arr.expand_dims({"z": 3}, -1) # New dimension 'z' + arr["z"] = np.arange(3) # New coords to dimension 'z' + with pytest.raises(ValueError, match="Do you want to .copy()"): + arr.loc[0, 0, 0] = 999 + + +class TestLazyArray: + def test_slice_slice(self) -> None: + arr = ReturnItem() + for size in [100, 99]: + # We test even/odd size cases + x = np.arange(size) + slices = [ + arr[:3], + arr[:4], + arr[2:4], + arr[:1], + arr[:-1], + arr[5:-1], + arr[-5:-1], + arr[::-1], + arr[5::-1], + arr[:3:-1], + arr[:30:-1], + arr[10:4:], + arr[::4], + arr[4:4:4], + arr[:4:-4], + arr[::-2], + ] + for i in slices: + for j in slices: + expected = x[i][j] + new_slice = indexing.slice_slice(i, j, size=size) + actual = x[new_slice] + assert_array_equal(expected, actual) + + def test_lazily_indexed_array(self) -> None: + original = np.random.rand(10, 20, 30) + x = indexing.NumpyIndexingAdapter(original) + v = Variable(["i", "j", "k"], original) + lazy = indexing.LazilyIndexedArray(x) + v_lazy = Variable(["i", "j", "k"], lazy) + arr = ReturnItem() + # test orthogonally applied indexers + indexers = [arr[:], 0, -2, arr[:3], [0, 1, 2, 3], [0], np.arange(10) < 5] + for i in indexers: + for j in indexers: + for k in indexers: + if isinstance(j, np.ndarray) and j.dtype.kind == "b": + j = np.arange(20) < 5 + if isinstance(k, np.ndarray) and k.dtype.kind == "b": + k = np.arange(30) < 5 + expected = np.asarray(v[i, j, k]) + for actual in [ + v_lazy[i, j, k], + v_lazy[:, j, k][i], + v_lazy[:, :, k][:, j][i], + ]: + assert expected.shape == actual.shape + assert_array_equal(expected, actual) + assert isinstance(actual._data, indexing.LazilyIndexedArray) + assert isinstance(v_lazy._data, indexing.LazilyIndexedArray) + + # make sure actual.key is appropriate type + if all( + isinstance(k, (int, slice)) for k in v_lazy._data.key.tuple + ): + assert isinstance(v_lazy._data.key, indexing.BasicIndexer) + else: + assert isinstance(v_lazy._data.key, indexing.OuterIndexer) + + # test sequentially applied indexers + indexers = [ + (3, 2), + (arr[:], 0), + (arr[:2], -1), + (arr[:4], [0]), + ([4, 5], 0), + ([0, 1, 2], [0, 1]), + ([0, 3, 5], arr[:2]), + ] + for i, j in indexers: + + expected_b = v[i][j] + actual = v_lazy[i][j] + assert expected_b.shape == actual.shape + assert_array_equal(expected_b, actual) + + # test transpose + if actual.ndim > 1: + order = np.random.choice(actual.ndim, actual.ndim) + order = np.array(actual.dims) + transposed = actual.transpose(*order) + assert_array_equal(expected_b.transpose(*order), transposed) + assert isinstance( + actual._data, + ( + indexing.LazilyVectorizedIndexedArray, + indexing.LazilyIndexedArray, + ), + ) + + assert isinstance(actual._data, indexing.LazilyIndexedArray) + assert isinstance(actual._data.array, indexing.NumpyIndexingAdapter) + + def test_vectorized_lazily_indexed_array(self) -> None: + original = np.random.rand(10, 20, 30) + x = indexing.NumpyIndexingAdapter(original) + v_eager = Variable(["i", "j", "k"], x) + lazy = indexing.LazilyIndexedArray(x) + v_lazy = Variable(["i", "j", "k"], lazy) + arr = ReturnItem() + + def check_indexing(v_eager, v_lazy, indexers): + for indexer in indexers: + actual = v_lazy[indexer] + expected = v_eager[indexer] + assert expected.shape == actual.shape + assert isinstance( + actual._data, + ( + indexing.LazilyVectorizedIndexedArray, + indexing.LazilyIndexedArray, + ), + ) + assert_array_equal(expected, actual) + v_eager = expected + v_lazy = actual + + # test orthogonal indexing + indexers = [(arr[:], 0, 1), (Variable("i", [0, 1]),)] + check_indexing(v_eager, v_lazy, indexers) + + # vectorized indexing + indexers = [ + (Variable("i", [0, 1]), Variable("i", [0, 1]), slice(None)), + (slice(1, 3, 2), 0), + ] + check_indexing(v_eager, v_lazy, indexers) + + indexers = [ + (slice(None, None, 2), 0, slice(None, 10)), + (Variable("i", [3, 2, 4, 3]), Variable("i", [3, 2, 1, 0])), + (Variable(["i", "j"], [[0, 1], [1, 2]]),), + ] + check_indexing(v_eager, v_lazy, indexers) + + indexers = [ + (Variable("i", [3, 2, 4, 3]), Variable("i", [3, 2, 1, 0])), + (Variable(["i", "j"], [[0, 1], [1, 2]]),), + ] + check_indexing(v_eager, v_lazy, indexers) + + def test_lazily_indexed_array_vindex_setitem(self) -> None: + + lazy = indexing.LazilyIndexedArray(np.random.rand(10, 20, 30)) + + # vectorized indexing + indexer = indexing.VectorizedIndexer( + (np.array([0, 1]), np.array([0, 1]), slice(None, None, None)) + ) + with pytest.raises( + NotImplementedError, + match=r"Lazy item assignment with the vectorized indexer is not yet", + ): + lazy.vindex[indexer] = 0 + + @pytest.mark.parametrize( + "indexer_class, key, value", + [ + (indexing.OuterIndexer, (0, 1, slice(None, None, None)), 10), + (indexing.BasicIndexer, (0, 1, slice(None, None, None)), 10), + ], + ) + def test_lazily_indexed_array_setitem(self, indexer_class, key, value) -> None: + original = np.random.rand(10, 20, 30) + x = indexing.NumpyIndexingAdapter(original) + lazy = indexing.LazilyIndexedArray(x) + + if indexer_class is indexing.BasicIndexer: + indexer = indexer_class(key) + lazy[indexer] = value + elif indexer_class is indexing.OuterIndexer: + indexer = indexer_class(key) + lazy.oindex[indexer] = value + + assert_array_equal(original[key], value) + + +class TestCopyOnWriteArray: + def test_setitem(self) -> None: + original = np.arange(10) + wrapped = indexing.CopyOnWriteArray(original) + wrapped[B[:]] = 0 + assert_array_equal(original, np.arange(10)) + assert_array_equal(wrapped, np.zeros(10)) + + def test_sub_array(self) -> None: + original = np.arange(10) + wrapped = indexing.CopyOnWriteArray(original) + child = wrapped[B[:5]] + assert isinstance(child, indexing.CopyOnWriteArray) + child[B[:]] = 0 + assert_array_equal(original, np.arange(10)) + assert_array_equal(wrapped, np.arange(10)) + assert_array_equal(child, np.zeros(5)) + + def test_index_scalar(self) -> None: + # regression test for GH1374 + x = indexing.CopyOnWriteArray(np.array(["foo", "bar"])) + assert np.array(x[B[0]][B[()]]) == "foo" + + +class TestMemoryCachedArray: + def test_wrapper(self) -> None: + original = indexing.LazilyIndexedArray(np.arange(10)) + wrapped = indexing.MemoryCachedArray(original) + assert_array_equal(wrapped, np.arange(10)) + assert isinstance(wrapped.array, indexing.NumpyIndexingAdapter) + + def test_sub_array(self) -> None: + original = indexing.LazilyIndexedArray(np.arange(10)) + wrapped = indexing.MemoryCachedArray(original) + child = wrapped[B[:5]] + assert isinstance(child, indexing.MemoryCachedArray) + assert_array_equal(child, np.arange(5)) + assert isinstance(child.array, indexing.NumpyIndexingAdapter) + assert isinstance(wrapped.array, indexing.LazilyIndexedArray) + + def test_setitem(self) -> None: + original = np.arange(10) + wrapped = indexing.MemoryCachedArray(original) + wrapped[B[:]] = 0 + assert_array_equal(original, np.zeros(10)) + + def test_index_scalar(self) -> None: + # regression test for GH1374 + x = indexing.MemoryCachedArray(np.array(["foo", "bar"])) + assert np.array(x[B[0]][B[()]]) == "foo" + + +def test_base_explicit_indexer() -> None: + with pytest.raises(TypeError): + indexing.ExplicitIndexer(()) + + class Subclass(indexing.ExplicitIndexer): + pass + + value = Subclass((1, 2, 3)) + assert value.tuple == (1, 2, 3) + assert repr(value) == "Subclass((1, 2, 3))" + + +@pytest.mark.parametrize( + "indexer_cls", + [indexing.BasicIndexer, indexing.OuterIndexer, indexing.VectorizedIndexer], +) +def test_invalid_for_all(indexer_cls) -> None: + with pytest.raises(TypeError): + indexer_cls(None) + with pytest.raises(TypeError): + indexer_cls(([],)) + with pytest.raises(TypeError): + indexer_cls((None,)) + with pytest.raises(TypeError): + indexer_cls(("foo",)) + with pytest.raises(TypeError): + indexer_cls((1.0,)) + with pytest.raises(TypeError): + indexer_cls((slice("foo"),)) + with pytest.raises(TypeError): + indexer_cls((np.array(["foo"]),)) + + +def check_integer(indexer_cls): + value = indexer_cls((1, np.uint64(2))).tuple + assert all(isinstance(v, int) for v in value) + assert value == (1, 2) + + +def check_slice(indexer_cls): + (value,) = indexer_cls((slice(1, None, np.int64(2)),)).tuple + assert value == slice(1, None, 2) + assert isinstance(value.step, int) + + +def check_array1d(indexer_cls): + (value,) = indexer_cls((np.arange(3, dtype=np.int32),)).tuple + assert value.dtype == np.int64 + np.testing.assert_array_equal(value, [0, 1, 2]) + + +def check_array2d(indexer_cls): + array = np.array([[1, 2], [3, 4]], dtype=np.int64) + (value,) = indexer_cls((array,)).tuple + assert value.dtype == np.int64 + np.testing.assert_array_equal(value, array) + + +def test_basic_indexer() -> None: + check_integer(indexing.BasicIndexer) + check_slice(indexing.BasicIndexer) + with pytest.raises(TypeError): + check_array1d(indexing.BasicIndexer) + with pytest.raises(TypeError): + check_array2d(indexing.BasicIndexer) + + +def test_outer_indexer() -> None: + check_integer(indexing.OuterIndexer) + check_slice(indexing.OuterIndexer) + check_array1d(indexing.OuterIndexer) + with pytest.raises(TypeError): + check_array2d(indexing.OuterIndexer) + + +def test_vectorized_indexer() -> None: + with pytest.raises(TypeError): + check_integer(indexing.VectorizedIndexer) + check_slice(indexing.VectorizedIndexer) + check_array1d(indexing.VectorizedIndexer) + check_array2d(indexing.VectorizedIndexer) + with pytest.raises(ValueError, match=r"numbers of dimensions"): + indexing.VectorizedIndexer( + (np.array(1, dtype=np.int64), np.arange(5, dtype=np.int64)) + ) + + +class Test_vectorized_indexer: + @pytest.fixture(autouse=True) + def setup(self): + self.data = indexing.NumpyIndexingAdapter(np.random.randn(10, 12, 13)) + self.indexers = [ + np.array([[0, 3, 2]]), + np.array([[0, 3, 3], [4, 6, 7]]), + slice(2, -2, 2), + slice(2, -2, 3), + slice(None), + ] + + def test_arrayize_vectorized_indexer(self) -> None: + for i, j, k in itertools.product(self.indexers, repeat=3): + vindex = indexing.VectorizedIndexer((i, j, k)) + vindex_array = indexing._arrayize_vectorized_indexer( + vindex, self.data.shape + ) + np.testing.assert_array_equal( + self.data.vindex[vindex], self.data.vindex[vindex_array] + ) + + actual = indexing._arrayize_vectorized_indexer( + indexing.VectorizedIndexer((slice(None),)), shape=(5,) + ) + np.testing.assert_array_equal(actual.tuple, [np.arange(5)]) + + actual = indexing._arrayize_vectorized_indexer( + indexing.VectorizedIndexer((np.arange(5),) * 3), shape=(8, 10, 12) + ) + expected = np.stack([np.arange(5)] * 3) + np.testing.assert_array_equal(np.stack(actual.tuple), expected) + + actual = indexing._arrayize_vectorized_indexer( + indexing.VectorizedIndexer((np.arange(5), slice(None))), shape=(8, 10) + ) + a, b = actual.tuple + np.testing.assert_array_equal(a, np.arange(5)[:, np.newaxis]) + np.testing.assert_array_equal(b, np.arange(10)[np.newaxis, :]) + + actual = indexing._arrayize_vectorized_indexer( + indexing.VectorizedIndexer((slice(None), np.arange(5))), shape=(8, 10) + ) + a, b = actual.tuple + np.testing.assert_array_equal(a, np.arange(8)[np.newaxis, :]) + np.testing.assert_array_equal(b, np.arange(5)[:, np.newaxis]) + + +def get_indexers(shape, mode): + if mode == "vectorized": + indexed_shape = (3, 4) + indexer = tuple(np.random.randint(0, s, size=indexed_shape) for s in shape) + return indexing.VectorizedIndexer(indexer) + + elif mode == "outer": + indexer = tuple(np.random.randint(0, s, s + 2) for s in shape) + return indexing.OuterIndexer(indexer) + + elif mode == "outer_scalar": + indexer = (np.random.randint(0, 3, 4), 0, slice(None, None, 2)) + return indexing.OuterIndexer(indexer[: len(shape)]) + + elif mode == "outer_scalar2": + indexer = (np.random.randint(0, 3, 4), -2, slice(None, None, 2)) + return indexing.OuterIndexer(indexer[: len(shape)]) + + elif mode == "outer1vec": + indexer = [slice(2, -3) for s in shape] + indexer[1] = np.random.randint(0, shape[1], shape[1] + 2) + return indexing.OuterIndexer(tuple(indexer)) + + elif mode == "basic": # basic indexer + indexer = [slice(2, -3) for s in shape] + indexer[0] = 3 + return indexing.BasicIndexer(tuple(indexer)) + + elif mode == "basic1": # basic indexer + return indexing.BasicIndexer((3,)) + + elif mode == "basic2": # basic indexer + indexer = [0, 2, 4] + return indexing.BasicIndexer(tuple(indexer[: len(shape)])) + + elif mode == "basic3": # basic indexer + indexer = [slice(None) for s in shape] + indexer[0] = slice(-2, 2, -2) + indexer[1] = slice(1, -1, 2) + return indexing.BasicIndexer(tuple(indexer[: len(shape)])) + + +@pytest.mark.parametrize("size", [100, 99]) +@pytest.mark.parametrize( + "sl", [slice(1, -1, 1), slice(None, -1, 2), slice(-1, 1, -1), slice(-1, 1, -2)] +) +def test_decompose_slice(size, sl) -> None: + x = np.arange(size) + slice1, slice2 = indexing._decompose_slice(sl, size) + expected = x[sl] + actual = x[slice1][slice2] + assert_array_equal(expected, actual) + + +@pytest.mark.parametrize("shape", [(10, 5, 8), (10, 3)]) +@pytest.mark.parametrize( + "indexer_mode", + [ + "vectorized", + "outer", + "outer_scalar", + "outer_scalar2", + "outer1vec", + "basic", + "basic1", + "basic2", + "basic3", + ], +) +@pytest.mark.parametrize( + "indexing_support", + [ + indexing.IndexingSupport.BASIC, + indexing.IndexingSupport.OUTER, + indexing.IndexingSupport.OUTER_1VECTOR, + indexing.IndexingSupport.VECTORIZED, + ], +) +def test_decompose_indexers(shape, indexer_mode, indexing_support) -> None: + data = np.random.randn(*shape) + indexer = get_indexers(shape, indexer_mode) + + backend_ind, np_ind = indexing.decompose_indexer(indexer, shape, indexing_support) + indexing_adapter = indexing.NumpyIndexingAdapter(data) + + # Dispatch to appropriate indexing method + if indexer_mode.startswith("vectorized"): + expected = indexing_adapter.vindex[indexer] + + elif indexer_mode.startswith("outer"): + expected = indexing_adapter.oindex[indexer] + + else: + expected = indexing_adapter[indexer] # Basic indexing + + if isinstance(backend_ind, indexing.VectorizedIndexer): + array = indexing_adapter.vindex[backend_ind] + elif isinstance(backend_ind, indexing.OuterIndexer): + array = indexing_adapter.oindex[backend_ind] + else: + array = indexing_adapter[backend_ind] + + if len(np_ind.tuple) > 0: + array_indexing_adapter = indexing.NumpyIndexingAdapter(array) + if isinstance(np_ind, indexing.VectorizedIndexer): + array = array_indexing_adapter.vindex[np_ind] + elif isinstance(np_ind, indexing.OuterIndexer): + array = array_indexing_adapter.oindex[np_ind] + else: + array = array_indexing_adapter[np_ind] + np.testing.assert_array_equal(expected, array) + + if not all(isinstance(k, indexing.integer_types) for k in np_ind.tuple): + combined_ind = indexing._combine_indexers(backend_ind, shape, np_ind) + assert isinstance(combined_ind, indexing.VectorizedIndexer) + array = indexing_adapter.vindex[combined_ind] + np.testing.assert_array_equal(expected, array) + + +def test_implicit_indexing_adapter() -> None: + array = np.arange(10, dtype=np.int64) + implicit = indexing.ImplicitToExplicitIndexingAdapter( + indexing.NumpyIndexingAdapter(array), indexing.BasicIndexer + ) + np.testing.assert_array_equal(array, np.asarray(implicit)) + np.testing.assert_array_equal(array, implicit[:]) + + +def test_implicit_indexing_adapter_copy_on_write() -> None: + array = np.arange(10, dtype=np.int64) + implicit = indexing.ImplicitToExplicitIndexingAdapter( + indexing.CopyOnWriteArray(array) + ) + assert isinstance(implicit[:], indexing.ImplicitToExplicitIndexingAdapter) + + +def test_outer_indexer_consistency_with_broadcast_indexes_vectorized() -> None: + def nonzero(x): + if isinstance(x, np.ndarray) and x.dtype.kind == "b": + x = x.nonzero()[0] + return x + + original = np.random.rand(10, 20, 30) + v = Variable(["i", "j", "k"], original) + arr = ReturnItem() + # test orthogonally applied indexers + indexers = [ + arr[:], + 0, + -2, + arr[:3], + np.array([0, 1, 2, 3]), + np.array([0]), + np.arange(10) < 5, + ] + for i, j, k in itertools.product(indexers, repeat=3): + if isinstance(j, np.ndarray) and j.dtype.kind == "b": # match size + j = np.arange(20) < 4 + if isinstance(k, np.ndarray) and k.dtype.kind == "b": + k = np.arange(30) < 8 + + _, expected, new_order = v._broadcast_indexes_vectorized((i, j, k)) + expected_data = nputils.NumpyVIndexAdapter(v.data)[expected.tuple] + if new_order: + old_order = range(len(new_order)) + expected_data = np.moveaxis(expected_data, old_order, new_order) + + outer_index = indexing.OuterIndexer((nonzero(i), nonzero(j), nonzero(k))) + actual = indexing._outer_to_numpy_indexer(outer_index, v.shape) + actual_data = v.data[actual] + np.testing.assert_array_equal(actual_data, expected_data) + + +def test_create_mask_outer_indexer() -> None: + indexer = indexing.OuterIndexer((np.array([0, -1, 2]),)) + expected = np.array([False, True, False]) + actual = indexing.create_mask(indexer, (5,)) + np.testing.assert_array_equal(expected, actual) + + indexer = indexing.OuterIndexer((1, slice(2), np.array([0, -1, 2]))) + expected = np.array(2 * [[False, True, False]]) + actual = indexing.create_mask(indexer, (5, 5, 5)) + np.testing.assert_array_equal(expected, actual) + + +def test_create_mask_vectorized_indexer() -> None: + indexer = indexing.VectorizedIndexer((np.array([0, -1, 2]), np.array([0, 1, -1]))) + expected = np.array([False, True, True]) + actual = indexing.create_mask(indexer, (5,)) + np.testing.assert_array_equal(expected, actual) + + indexer = indexing.VectorizedIndexer( + (np.array([0, -1, 2]), slice(None), np.array([0, 1, -1])) + ) + expected = np.array([[False, True, True]] * 2).T + actual = indexing.create_mask(indexer, (5, 2)) + np.testing.assert_array_equal(expected, actual) + + +def test_create_mask_basic_indexer() -> None: + indexer = indexing.BasicIndexer((-1,)) + actual = indexing.create_mask(indexer, (3,)) + np.testing.assert_array_equal(True, actual) + + indexer = indexing.BasicIndexer((0,)) + actual = indexing.create_mask(indexer, (3,)) + np.testing.assert_array_equal(False, actual) + + +def test_create_mask_dask() -> None: + da = pytest.importorskip("dask.array") + + indexer = indexing.OuterIndexer((1, slice(2), np.array([0, -1, 2]))) + expected = np.array(2 * [[False, True, False]]) + actual = indexing.create_mask( + indexer, (5, 5, 5), da.empty((2, 3), chunks=((1, 1), (2, 1))) + ) + assert actual.chunks == ((1, 1), (2, 1)) + np.testing.assert_array_equal(expected, actual) + + indexer_vec = indexing.VectorizedIndexer( + (np.array([0, -1, 2]), slice(None), np.array([0, 1, -1])) + ) + expected = np.array([[False, True, True]] * 2).T + actual = indexing.create_mask( + indexer_vec, (5, 2), da.empty((3, 2), chunks=((3,), (2,))) + ) + assert isinstance(actual, da.Array) + np.testing.assert_array_equal(expected, actual) + + with pytest.raises(ValueError): + indexing.create_mask(indexer_vec, (5, 2), da.empty((5,), chunks=(1,))) + + +def test_create_mask_error() -> None: + with pytest.raises(TypeError, match=r"unexpected key type"): + indexing.create_mask((1, 2), (3, 4)) # type: ignore[arg-type] + + +@pytest.mark.parametrize( + "indices, expected", + [ + (np.arange(5), np.arange(5)), + (np.array([0, -1, -1]), np.array([0, 0, 0])), + (np.array([-1, 1, -1]), np.array([1, 1, 1])), + (np.array([-1, -1, 2]), np.array([2, 2, 2])), + (np.array([-1]), np.array([0])), + (np.array([0, -1, 1, -1, -1]), np.array([0, 0, 1, 1, 1])), + (np.array([0, -1, -1, -1, 1]), np.array([0, 0, 0, 0, 1])), + ], +) +def test_posify_mask_subindexer(indices, expected) -> None: + actual = indexing._posify_mask_subindexer(indices) + np.testing.assert_array_equal(expected, actual) + + +def test_indexing_1d_object_array() -> None: + items = (np.arange(3), np.arange(6)) + arr = DataArray(np.array(items, dtype=object)) + + actual = arr[0] + + expected_data = np.empty((), dtype=object) + expected_data[()] = items[0] + expected = DataArray(expected_data) + + assert [actual.data.item()] == [expected.data.item()] + + +@requires_dask +def test_indexing_dask_array(): + import dask.array + + da = DataArray( + np.ones(10 * 3 * 3).reshape((10, 3, 3)), + dims=("time", "x", "y"), + ).chunk(dict(time=-1, x=1, y=1)) + with raise_if_dask_computes(): + actual = da.isel(time=dask.array.from_array([9], chunks=(1,))) + expected = da.isel(time=[9]) + assert_identical(actual, expected) + + +@requires_dask +def test_indexing_dask_array_scalar(): + # GH4276 + import dask.array + + a = dask.array.from_array(np.linspace(0.0, 1.0)) + da = DataArray(a, dims="x") + x_selector = da.argmax(dim=...) + with raise_if_dask_computes(): + actual = da.isel(x_selector) + expected = da.isel(x=-1) + assert_identical(actual, expected) + + +@requires_dask +def test_vectorized_indexing_dask_array(): + # https://github.com/pydata/xarray/issues/2511#issuecomment-563330352 + darr = DataArray(data=[0.2, 0.4, 0.6], coords={"z": range(3)}, dims=("z",)) + indexer = DataArray( + data=np.random.randint(0, 3, 8).reshape(4, 2).astype(int), + coords={"y": range(4), "x": range(2)}, + dims=("y", "x"), + ) + with pytest.raises(ValueError, match="Vectorized indexing with Dask arrays"): + darr[indexer.chunk({"y": 2})] + + +@requires_dask +def test_advanced_indexing_dask_array(): + # GH4663 + import dask.array as da + + ds = Dataset( + dict( + a=("x", da.from_array(np.random.randint(0, 100, 100))), + b=(("x", "y"), da.random.random((100, 10))), + ) + ) + expected = ds.b.sel(x=ds.a.compute()) + with raise_if_dask_computes(): + actual = ds.b.sel(x=ds.a) + assert_identical(expected, actual) + + with raise_if_dask_computes(): + actual = ds.b.sel(x=ds.a.data) + assert_identical(expected, actual) diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_interp.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_interp.py new file mode 100644 index 0000000..7151c66 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_interp.py @@ -0,0 +1,946 @@ +from __future__ import annotations + +from itertools import combinations, permutations +from typing import cast + +import numpy as np +import pandas as pd +import pytest + +import xarray as xr +from xarray.coding.cftimeindex import _parse_array_of_cftime_strings +from xarray.core.types import InterpOptions +from xarray.tests import ( + assert_allclose, + assert_equal, + assert_identical, + has_dask, + has_scipy, + requires_cftime, + requires_dask, + requires_scipy, +) +from xarray.tests.test_dataset import create_test_data + +try: + import scipy +except ImportError: + pass + + +def get_example_data(case: int) -> xr.DataArray: + if case == 0: + # 2D + x = np.linspace(0, 1, 100) + y = np.linspace(0, 0.1, 30) + return xr.DataArray( + np.sin(x[:, np.newaxis]) * np.cos(y), + dims=["x", "y"], + coords={"x": x, "y": y, "x2": ("x", x**2)}, + ) + elif case == 1: + # 2D chunked single dim + return get_example_data(0).chunk({"y": 3}) + elif case == 2: + # 2D chunked both dims + return get_example_data(0).chunk({"x": 25, "y": 3}) + elif case == 3: + # 3D + x = np.linspace(0, 1, 100) + y = np.linspace(0, 0.1, 30) + z = np.linspace(0.1, 0.2, 10) + return xr.DataArray( + np.sin(x[:, np.newaxis, np.newaxis]) * np.cos(y[:, np.newaxis]) * z, + dims=["x", "y", "z"], + coords={"x": x, "y": y, "x2": ("x", x**2), "z": z}, + ) + elif case == 4: + # 3D chunked single dim + return get_example_data(3).chunk({"z": 5}) + else: + raise ValueError("case must be 1-4") + + +def test_keywargs(): + if not has_scipy: + pytest.skip("scipy is not installed.") + + da = get_example_data(0) + assert_equal(da.interp(x=[0.5, 0.8]), da.interp({"x": [0.5, 0.8]})) + + +@pytest.mark.parametrize("method", ["linear", "cubic"]) +@pytest.mark.parametrize("dim", ["x", "y"]) +@pytest.mark.parametrize( + "case", [pytest.param(0, id="no_chunk"), pytest.param(1, id="chunk_y")] +) +def test_interpolate_1d(method: InterpOptions, dim: str, case: int) -> None: + if not has_scipy: + pytest.skip("scipy is not installed.") + + if not has_dask and case in [1]: + pytest.skip("dask is not installed in the environment.") + + da = get_example_data(case) + xdest = np.linspace(0.0, 0.9, 80) + actual = da.interp(method=method, coords={dim: xdest}) + + # scipy interpolation for the reference + def func(obj, new_x): + return scipy.interpolate.interp1d( + da[dim], + obj.data, + axis=obj.get_axis_num(dim), + bounds_error=False, + fill_value=np.nan, + kind=method, + )(new_x) + + if dim == "x": + coords = {"x": xdest, "y": da["y"], "x2": ("x", func(da["x2"], xdest))} + else: # y + coords = {"x": da["x"], "y": xdest, "x2": da["x2"]} + + expected = xr.DataArray(func(da, xdest), dims=["x", "y"], coords=coords) + assert_allclose(actual, expected) + + +@pytest.mark.parametrize("method", ["cubic", "zero"]) +def test_interpolate_1d_methods(method: InterpOptions) -> None: + if not has_scipy: + pytest.skip("scipy is not installed.") + + da = get_example_data(0) + dim = "x" + xdest = np.linspace(0.0, 0.9, 80) + + actual = da.interp(method=method, coords={dim: xdest}) + + # scipy interpolation for the reference + def func(obj, new_x): + return scipy.interpolate.interp1d( + da[dim], + obj.data, + axis=obj.get_axis_num(dim), + bounds_error=False, + fill_value=np.nan, + kind=method, + )(new_x) + + coords = {"x": xdest, "y": da["y"], "x2": ("x", func(da["x2"], xdest))} + expected = xr.DataArray(func(da, xdest), dims=["x", "y"], coords=coords) + assert_allclose(actual, expected) + + +@pytest.mark.parametrize("use_dask", [False, True]) +def test_interpolate_vectorize(use_dask: bool) -> None: + if not has_scipy: + pytest.skip("scipy is not installed.") + + if not has_dask and use_dask: + pytest.skip("dask is not installed in the environment.") + + # scipy interpolation for the reference + def func(obj, dim, new_x): + shape = [s for i, s in enumerate(obj.shape) if i != obj.get_axis_num(dim)] + for s in new_x.shape[::-1]: + shape.insert(obj.get_axis_num(dim), s) + + return scipy.interpolate.interp1d( + da[dim], + obj.data, + axis=obj.get_axis_num(dim), + bounds_error=False, + fill_value=np.nan, + )(new_x).reshape(shape) + + da = get_example_data(0) + if use_dask: + da = da.chunk({"y": 5}) + + # xdest is 1d but has different dimension + xdest = xr.DataArray( + np.linspace(0.1, 0.9, 30), + dims="z", + coords={"z": np.random.randn(30), "z2": ("z", np.random.randn(30))}, + ) + + actual = da.interp(x=xdest, method="linear") + + expected = xr.DataArray( + func(da, "x", xdest), + dims=["z", "y"], + coords={ + "z": xdest["z"], + "z2": xdest["z2"], + "y": da["y"], + "x": ("z", xdest.values), + "x2": ("z", func(da["x2"], "x", xdest)), + }, + ) + assert_allclose(actual, expected.transpose("z", "y", transpose_coords=True)) + + # xdest is 2d + xdest = xr.DataArray( + np.linspace(0.1, 0.9, 30).reshape(6, 5), + dims=["z", "w"], + coords={ + "z": np.random.randn(6), + "w": np.random.randn(5), + "z2": ("z", np.random.randn(6)), + }, + ) + + actual = da.interp(x=xdest, method="linear") + + expected = xr.DataArray( + func(da, "x", xdest), + dims=["z", "w", "y"], + coords={ + "z": xdest["z"], + "w": xdest["w"], + "z2": xdest["z2"], + "y": da["y"], + "x": (("z", "w"), xdest.data), + "x2": (("z", "w"), func(da["x2"], "x", xdest)), + }, + ) + assert_allclose(actual, expected.transpose("z", "w", "y", transpose_coords=True)) + + +@pytest.mark.parametrize( + "case", [pytest.param(3, id="no_chunk"), pytest.param(4, id="chunked")] +) +def test_interpolate_nd(case: int) -> None: + if not has_scipy: + pytest.skip("scipy is not installed.") + + if not has_dask and case == 4: + pytest.skip("dask is not installed in the environment.") + + da = get_example_data(case) + + # grid -> grid + xdestnp = np.linspace(0.1, 1.0, 11) + ydestnp = np.linspace(0.0, 0.2, 10) + actual = da.interp(x=xdestnp, y=ydestnp, method="linear") + + # linear interpolation is separateable + expected = da.interp(x=xdestnp, method="linear") + expected = expected.interp(y=ydestnp, method="linear") + assert_allclose(actual.transpose("x", "y", "z"), expected.transpose("x", "y", "z")) + + # grid -> 1d-sample + xdest = xr.DataArray(np.linspace(0.1, 1.0, 11), dims="y") + ydest = xr.DataArray(np.linspace(0.0, 0.2, 11), dims="y") + actual = da.interp(x=xdest, y=ydest, method="linear") + + # linear interpolation is separateable + expected_data = scipy.interpolate.RegularGridInterpolator( + (da["x"], da["y"]), + da.transpose("x", "y", "z").values, + method="linear", + bounds_error=False, + fill_value=np.nan, + )(np.stack([xdest, ydest], axis=-1)) + expected = xr.DataArray( + expected_data, + dims=["y", "z"], + coords={ + "z": da["z"], + "y": ydest, + "x": ("y", xdest.values), + "x2": da["x2"].interp(x=xdest), + }, + ) + assert_allclose(actual.transpose("y", "z"), expected) + + # reversed order + actual = da.interp(y=ydest, x=xdest, method="linear") + assert_allclose(actual.transpose("y", "z"), expected) + + +@requires_scipy +def test_interpolate_nd_nd() -> None: + """Interpolate nd array with an nd indexer sharing coordinates.""" + # Create original array + a = [0, 2] + x = [0, 1, 2] + da = xr.DataArray( + np.arange(6).reshape(2, 3), dims=("a", "x"), coords={"a": a, "x": x} + ) + + # Create indexer into `a` with dimensions (y, x) + y = [10] + c = {"x": x, "y": y} + ia = xr.DataArray([[1, 2, 2]], dims=("y", "x"), coords=c) + out = da.interp(a=ia) + expected = xr.DataArray([[1.5, 4, 5]], dims=("y", "x"), coords=c) + xr.testing.assert_allclose(out.drop_vars("a"), expected) + + # If the *shared* indexing coordinates do not match, interp should fail. + with pytest.raises(ValueError): + c = {"x": [1], "y": y} + ia = xr.DataArray([[1]], dims=("y", "x"), coords=c) + da.interp(a=ia) + + with pytest.raises(ValueError): + c = {"x": [5, 6, 7], "y": y} + ia = xr.DataArray([[1]], dims=("y", "x"), coords=c) + da.interp(a=ia) + + +@requires_scipy +def test_interpolate_nd_with_nan() -> None: + """Interpolate an array with an nd indexer and `NaN` values.""" + + # Create indexer into `a` with dimensions (y, x) + x = [0, 1, 2] + y = [10, 20] + c = {"x": x, "y": y} + a = np.arange(6, dtype=float).reshape(2, 3) + a[0, 1] = np.nan + ia = xr.DataArray(a, dims=("y", "x"), coords=c) + + da = xr.DataArray([1, 2, 2], dims=("a"), coords={"a": [0, 2, 4]}) + out = da.interp(a=ia) + expected = xr.DataArray( + [[1.0, np.nan, 2.0], [2.0, 2.0, np.nan]], dims=("y", "x"), coords=c + ) + xr.testing.assert_allclose(out.drop_vars("a"), expected) + + db = 2 * da + ds = xr.Dataset({"da": da, "db": db}) + out2 = ds.interp(a=ia) + expected_ds = xr.Dataset({"da": expected, "db": 2 * expected}) + xr.testing.assert_allclose(out2.drop_vars("a"), expected_ds) + + +@pytest.mark.parametrize("method", ["linear"]) +@pytest.mark.parametrize( + "case", [pytest.param(0, id="no_chunk"), pytest.param(1, id="chunk_y")] +) +def test_interpolate_scalar(method: InterpOptions, case: int) -> None: + if not has_scipy: + pytest.skip("scipy is not installed.") + + if not has_dask and case in [1]: + pytest.skip("dask is not installed in the environment.") + + da = get_example_data(case) + xdest = 0.4 + + actual = da.interp(x=xdest, method=method) + + # scipy interpolation for the reference + def func(obj, new_x): + return scipy.interpolate.interp1d( + da["x"], + obj.data, + axis=obj.get_axis_num("x"), + bounds_error=False, + fill_value=np.nan, + )(new_x) + + coords = {"x": xdest, "y": da["y"], "x2": func(da["x2"], xdest)} + expected = xr.DataArray(func(da, xdest), dims=["y"], coords=coords) + assert_allclose(actual, expected) + + +@pytest.mark.parametrize("method", ["linear"]) +@pytest.mark.parametrize( + "case", [pytest.param(3, id="no_chunk"), pytest.param(4, id="chunked")] +) +def test_interpolate_nd_scalar(method: InterpOptions, case: int) -> None: + if not has_scipy: + pytest.skip("scipy is not installed.") + + if not has_dask and case in [4]: + pytest.skip("dask is not installed in the environment.") + + da = get_example_data(case) + xdest = 0.4 + ydest = 0.05 + + actual = da.interp(x=xdest, y=ydest, method=method) + # scipy interpolation for the reference + expected_data = scipy.interpolate.RegularGridInterpolator( + (da["x"], da["y"]), + da.transpose("x", "y", "z").values, + method="linear", + bounds_error=False, + fill_value=np.nan, + )(np.stack([xdest, ydest], axis=-1)) + + coords = {"x": xdest, "y": ydest, "x2": da["x2"].interp(x=xdest), "z": da["z"]} + expected = xr.DataArray(expected_data[0], dims=["z"], coords=coords) + assert_allclose(actual, expected) + + +@pytest.mark.parametrize("use_dask", [True, False]) +def test_nans(use_dask: bool) -> None: + if not has_scipy: + pytest.skip("scipy is not installed.") + + da = xr.DataArray([0, 1, np.nan, 2], dims="x", coords={"x": range(4)}) + + if not has_dask and use_dask: + pytest.skip("dask is not installed in the environment.") + da = da.chunk() + + actual = da.interp(x=[0.5, 1.5]) + # not all values are nan + assert actual.count() > 0 + + +@pytest.mark.parametrize("use_dask", [True, False]) +def test_errors(use_dask: bool) -> None: + if not has_scipy: + pytest.skip("scipy is not installed.") + + # akima and spline are unavailable + da = xr.DataArray([0, 1, np.nan, 2], dims="x", coords={"x": range(4)}) + if not has_dask and use_dask: + pytest.skip("dask is not installed in the environment.") + da = da.chunk() + + for method in ["akima", "spline"]: + with pytest.raises(ValueError): + da.interp(x=[0.5, 1.5], method=method) # type: ignore + + # not sorted + if use_dask: + da = get_example_data(3) + else: + da = get_example_data(0) + + result = da.interp(x=[-1, 1, 3], kwargs={"fill_value": 0.0}) + assert not np.isnan(result.values).any() + result = da.interp(x=[-1, 1, 3]) + assert np.isnan(result.values).any() + + # invalid method + with pytest.raises(ValueError): + da.interp(x=[2, 0], method="boo") # type: ignore + with pytest.raises(ValueError): + da.interp(y=[2, 0], method="boo") # type: ignore + + # object-type DataArray cannot be interpolated + da = xr.DataArray(["a", "b", "c"], dims="x", coords={"x": [0, 1, 2]}) + with pytest.raises(TypeError): + da.interp(x=0) + + +@requires_scipy +def test_dtype() -> None: + data_vars = dict( + a=("time", np.array([1, 1.25, 2])), + b=("time", np.array([True, True, False], dtype=bool)), + c=("time", np.array(["start", "start", "end"], dtype=str)), + ) + time = np.array([0, 0.25, 1], dtype=float) + expected = xr.Dataset(data_vars, coords=dict(time=time)) + actual = xr.Dataset( + {k: (dim, arr[[0, -1]]) for k, (dim, arr) in data_vars.items()}, + coords=dict(time=time[[0, -1]]), + ) + actual = actual.interp(time=time, method="linear") + assert_identical(expected, actual) + + +@requires_scipy +def test_sorted() -> None: + # unsorted non-uniform gridded data + x = np.random.randn(100) + y = np.random.randn(30) + z = np.linspace(0.1, 0.2, 10) * 3.0 + da = xr.DataArray( + np.cos(x[:, np.newaxis, np.newaxis]) * np.cos(y[:, np.newaxis]) * z, + dims=["x", "y", "z"], + coords={"x": x, "y": y, "x2": ("x", x**2), "z": z}, + ) + + x_new = np.linspace(0, 1, 30) + y_new = np.linspace(0, 1, 20) + + da_sorted = da.sortby("x") + assert_allclose(da.interp(x=x_new), da_sorted.interp(x=x_new, assume_sorted=True)) + da_sorted = da.sortby(["x", "y"]) + assert_allclose( + da.interp(x=x_new, y=y_new), + da_sorted.interp(x=x_new, y=y_new, assume_sorted=True), + ) + + with pytest.raises(ValueError): + da.interp(x=[0, 1, 2], assume_sorted=True) + + +@requires_scipy +def test_dimension_wo_coords() -> None: + da = xr.DataArray( + np.arange(12).reshape(3, 4), dims=["x", "y"], coords={"y": [0, 1, 2, 3]} + ) + da_w_coord = da.copy() + da_w_coord["x"] = np.arange(3) + + assert_equal(da.interp(x=[0.1, 0.2, 0.3]), da_w_coord.interp(x=[0.1, 0.2, 0.3])) + assert_equal( + da.interp(x=[0.1, 0.2, 0.3], y=[0.5]), + da_w_coord.interp(x=[0.1, 0.2, 0.3], y=[0.5]), + ) + + +@requires_scipy +def test_dataset() -> None: + ds = create_test_data() + ds.attrs["foo"] = "var" + ds["var1"].attrs["buz"] = "var2" + new_dim2 = xr.DataArray([0.11, 0.21, 0.31], dims="z") + interpolated = ds.interp(dim2=new_dim2) + + assert_allclose(interpolated["var1"], ds["var1"].interp(dim2=new_dim2)) + assert interpolated["var3"].equals(ds["var3"]) + + # make sure modifying interpolated does not affect the original dataset + interpolated["var1"][:, 1] = 1.0 + interpolated["var2"][:, 1] = 1.0 + interpolated["var3"][:, 1] = 1.0 + + assert not interpolated["var1"].equals(ds["var1"]) + assert not interpolated["var2"].equals(ds["var2"]) + assert not interpolated["var3"].equals(ds["var3"]) + # attrs should be kept + assert interpolated.attrs["foo"] == "var" + assert interpolated["var1"].attrs["buz"] == "var2" + + +@pytest.mark.parametrize("case", [pytest.param(0, id="2D"), pytest.param(3, id="3D")]) +def test_interpolate_dimorder(case: int) -> None: + """Make sure the resultant dimension order is consistent with .sel()""" + if not has_scipy: + pytest.skip("scipy is not installed.") + + da = get_example_data(case) + + new_x = xr.DataArray([0, 1, 2], dims="x") + assert da.interp(x=new_x).dims == da.sel(x=new_x, method="nearest").dims + + new_y = xr.DataArray([0, 1, 2], dims="y") + actual = da.interp(x=new_x, y=new_y).dims + expected = da.sel(x=new_x, y=new_y, method="nearest").dims + assert actual == expected + # reversed order + actual = da.interp(y=new_y, x=new_x).dims + expected = da.sel(y=new_y, x=new_x, method="nearest").dims + assert actual == expected + + new_x = xr.DataArray([0, 1, 2], dims="a") + assert da.interp(x=new_x).dims == da.sel(x=new_x, method="nearest").dims + assert da.interp(y=new_x).dims == da.sel(y=new_x, method="nearest").dims + new_y = xr.DataArray([0, 1, 2], dims="a") + actual = da.interp(x=new_x, y=new_y).dims + expected = da.sel(x=new_x, y=new_y, method="nearest").dims + assert actual == expected + + new_x = xr.DataArray([[0], [1], [2]], dims=["a", "b"]) + assert da.interp(x=new_x).dims == da.sel(x=new_x, method="nearest").dims + assert da.interp(y=new_x).dims == da.sel(y=new_x, method="nearest").dims + + if case == 3: + new_x = xr.DataArray([[0], [1], [2]], dims=["a", "b"]) + new_z = xr.DataArray([[0], [1], [2]], dims=["a", "b"]) + actual = da.interp(x=new_x, z=new_z).dims + expected = da.sel(x=new_x, z=new_z, method="nearest").dims + assert actual == expected + + actual = da.interp(z=new_z, x=new_x).dims + expected = da.sel(z=new_z, x=new_x, method="nearest").dims + assert actual == expected + + actual = da.interp(x=0.5, z=new_z).dims + expected = da.sel(x=0.5, z=new_z, method="nearest").dims + assert actual == expected + + +@requires_scipy +def test_interp_like() -> None: + ds = create_test_data() + ds.attrs["foo"] = "var" + ds["var1"].attrs["buz"] = "var2" + + other = xr.DataArray(np.random.randn(3), dims=["dim2"], coords={"dim2": [0, 1, 2]}) + interpolated = ds.interp_like(other) + + assert_allclose(interpolated["var1"], ds["var1"].interp(dim2=other["dim2"])) + assert_allclose(interpolated["var1"], ds["var1"].interp_like(other)) + assert interpolated["var3"].equals(ds["var3"]) + + # attrs should be kept + assert interpolated.attrs["foo"] == "var" + assert interpolated["var1"].attrs["buz"] == "var2" + + other = xr.DataArray( + np.random.randn(3), dims=["dim3"], coords={"dim3": ["a", "b", "c"]} + ) + + actual = ds.interp_like(other) + expected = ds.reindex_like(other) + assert_allclose(actual, expected) + + +@requires_scipy +@pytest.mark.parametrize( + "x_new, expected", + [ + (pd.date_range("2000-01-02", periods=3), [1, 2, 3]), + ( + np.array( + [np.datetime64("2000-01-01T12:00"), np.datetime64("2000-01-02T12:00")] + ), + [0.5, 1.5], + ), + (["2000-01-01T12:00", "2000-01-02T12:00"], [0.5, 1.5]), + (["2000-01-01T12:00", "2000-01-02T12:00", "NaT"], [0.5, 1.5, np.nan]), + (["2000-01-01T12:00"], 0.5), + pytest.param("2000-01-01T12:00", 0.5, marks=pytest.mark.xfail), + ], +) +@pytest.mark.filterwarnings("ignore:Converting non-nanosecond") +def test_datetime(x_new, expected) -> None: + da = xr.DataArray( + np.arange(24), + dims="time", + coords={"time": pd.date_range("2000-01-01", periods=24)}, + ) + + actual = da.interp(time=x_new) + expected_da = xr.DataArray( + np.atleast_1d(expected), + dims=["time"], + coords={"time": (np.atleast_1d(x_new).astype("datetime64[ns]"))}, + ) + + assert_allclose(actual, expected_da) + + +@requires_scipy +def test_datetime_single_string() -> None: + da = xr.DataArray( + np.arange(24), + dims="time", + coords={"time": pd.date_range("2000-01-01", periods=24)}, + ) + actual = da.interp(time="2000-01-01T12:00") + expected = xr.DataArray(0.5) + + assert_allclose(actual.drop_vars("time"), expected) + + +@requires_cftime +@requires_scipy +def test_cftime() -> None: + times = xr.cftime_range("2000", periods=24, freq="D") + da = xr.DataArray(np.arange(24), coords=[times], dims="time") + + times_new = xr.cftime_range("2000-01-01T12:00:00", periods=3, freq="D") + actual = da.interp(time=times_new) + expected = xr.DataArray([0.5, 1.5, 2.5], coords=[times_new], dims=["time"]) + + assert_allclose(actual, expected) + + +@requires_cftime +@requires_scipy +def test_cftime_type_error() -> None: + times = xr.cftime_range("2000", periods=24, freq="D") + da = xr.DataArray(np.arange(24), coords=[times], dims="time") + + times_new = xr.cftime_range( + "2000-01-01T12:00:00", periods=3, freq="D", calendar="noleap" + ) + with pytest.raises(TypeError): + da.interp(time=times_new) + + +@requires_cftime +@requires_scipy +def test_cftime_list_of_strings() -> None: + from cftime import DatetimeProlepticGregorian + + times = xr.cftime_range( + "2000", periods=24, freq="D", calendar="proleptic_gregorian" + ) + da = xr.DataArray(np.arange(24), coords=[times], dims="time") + + times_new = ["2000-01-01T12:00", "2000-01-02T12:00", "2000-01-03T12:00"] + actual = da.interp(time=times_new) + + times_new_array = _parse_array_of_cftime_strings( + np.array(times_new), DatetimeProlepticGregorian + ) + expected = xr.DataArray([0.5, 1.5, 2.5], coords=[times_new_array], dims=["time"]) + + assert_allclose(actual, expected) + + +@requires_cftime +@requires_scipy +def test_cftime_single_string() -> None: + from cftime import DatetimeProlepticGregorian + + times = xr.cftime_range( + "2000", periods=24, freq="D", calendar="proleptic_gregorian" + ) + da = xr.DataArray(np.arange(24), coords=[times], dims="time") + + times_new = "2000-01-01T12:00" + actual = da.interp(time=times_new) + + times_new_array = _parse_array_of_cftime_strings( + np.array(times_new), DatetimeProlepticGregorian + ) + expected = xr.DataArray(0.5, coords={"time": times_new_array}) + + assert_allclose(actual, expected) + + +@requires_scipy +def test_datetime_to_non_datetime_error() -> None: + da = xr.DataArray( + np.arange(24), + dims="time", + coords={"time": pd.date_range("2000-01-01", periods=24)}, + ) + with pytest.raises(TypeError): + da.interp(time=0.5) + + +@requires_cftime +@requires_scipy +def test_cftime_to_non_cftime_error() -> None: + times = xr.cftime_range("2000", periods=24, freq="D") + da = xr.DataArray(np.arange(24), coords=[times], dims="time") + + with pytest.raises(TypeError): + da.interp(time=0.5) + + +@requires_scipy +def test_datetime_interp_noerror() -> None: + # GH:2667 + a = xr.DataArray( + np.arange(21).reshape(3, 7), + dims=["x", "time"], + coords={ + "x": [1, 2, 3], + "time": pd.date_range("01-01-2001", periods=7, freq="D"), + }, + ) + xi = xr.DataArray( + np.linspace(1, 3, 50), + dims=["time"], + coords={"time": pd.date_range("01-01-2001", periods=50, freq="h")}, + ) + a.interp(x=xi, time=xi.time) # should not raise an error + + +@requires_cftime +@requires_scipy +def test_3641() -> None: + times = xr.cftime_range("0001", periods=3, freq="500YE") + da = xr.DataArray(range(3), dims=["time"], coords=[times]) + da.interp(time=["0002-05-01"]) + + +@requires_scipy +@pytest.mark.parametrize("method", ["nearest", "linear"]) +def test_decompose(method: InterpOptions) -> None: + da = xr.DataArray( + np.arange(6).reshape(3, 2), + dims=["x", "y"], + coords={"x": [0, 1, 2], "y": [-0.1, -0.3]}, + ) + x_new = xr.DataArray([0.5, 1.5, 2.5], dims=["x1"]) + y_new = xr.DataArray([-0.15, -0.25], dims=["y1"]) + x_broadcast, y_broadcast = xr.broadcast(x_new, y_new) + assert x_broadcast.ndim == 2 + + actual = da.interp(x=x_new, y=y_new, method=method).drop_vars(("x", "y")) + expected = da.interp(x=x_broadcast, y=y_broadcast, method=method).drop_vars( + ("x", "y") + ) + assert_allclose(actual, expected) + + +@requires_scipy +@requires_dask +@pytest.mark.parametrize( + "method", ["linear", "nearest", "zero", "slinear", "quadratic", "cubic"] +) +@pytest.mark.parametrize("chunked", [True, False]) +@pytest.mark.parametrize( + "data_ndim,interp_ndim,nscalar", + [ + (data_ndim, interp_ndim, nscalar) + for data_ndim in range(1, 4) + for interp_ndim in range(1, data_ndim + 1) + for nscalar in range(0, interp_ndim + 1) + ], +) +def test_interpolate_chunk_1d( + method: InterpOptions, data_ndim, interp_ndim, nscalar, chunked: bool +) -> None: + """Interpolate nd array with multiple independent indexers + + It should do a series of 1d interpolation + """ + + # 3d non chunked data + x = np.linspace(0, 1, 5) + y = np.linspace(2, 4, 7) + z = np.linspace(-0.5, 0.5, 11) + da = xr.DataArray( + data=np.sin(x[:, np.newaxis, np.newaxis]) + * np.cos(y[:, np.newaxis]) + * np.exp(z), + coords=[("x", x), ("y", y), ("z", z)], + ) + kwargs = {"fill_value": "extrapolate"} + + # choose the data dimensions + for data_dims in permutations(da.dims, data_ndim): + # select only data_ndim dim + da = da.isel( # take the middle line + {dim: len(da.coords[dim]) // 2 for dim in da.dims if dim not in data_dims} + ) + + # chunk data + da = da.chunk(chunks={dim: i + 1 for i, dim in enumerate(da.dims)}) + + # choose the interpolation dimensions + for interp_dims in permutations(da.dims, interp_ndim): + # choose the scalar interpolation dimensions + for scalar_dims in combinations(interp_dims, nscalar): + dest = {} + for dim in interp_dims: + if dim in scalar_dims: + # take the middle point + dest[dim] = 0.5 * (da.coords[dim][0] + da.coords[dim][-1]) + else: + # pick some points, including outside the domain + before = 2 * da.coords[dim][0] - da.coords[dim][1] + after = 2 * da.coords[dim][-1] - da.coords[dim][-2] + + dest[dim] = cast( + xr.DataArray, + np.linspace( + before.item(), after.item(), len(da.coords[dim]) * 13 + ), + ) + if chunked: + dest[dim] = xr.DataArray(data=dest[dim], dims=[dim]) + dest[dim] = dest[dim].chunk(2) + actual = da.interp(method=method, **dest, kwargs=kwargs) + expected = da.compute().interp(method=method, **dest, kwargs=kwargs) + + assert_identical(actual, expected) + + # all the combinations are usually not necessary + break + break + break + + +@requires_scipy +@requires_dask +@pytest.mark.parametrize("method", ["linear", "nearest"]) +@pytest.mark.filterwarnings("ignore:Increasing number of chunks") +def test_interpolate_chunk_advanced(method: InterpOptions) -> None: + """Interpolate nd array with an nd indexer sharing coordinates.""" + # Create original array + x = np.linspace(-1, 1, 5) + y = np.linspace(-1, 1, 7) + z = np.linspace(-1, 1, 11) + t = np.linspace(0, 1, 13) + q = np.linspace(0, 1, 17) + da = xr.DataArray( + data=np.sin(x[:, np.newaxis, np.newaxis, np.newaxis, np.newaxis]) + * np.cos(y[:, np.newaxis, np.newaxis, np.newaxis]) + * np.exp(z[:, np.newaxis, np.newaxis]) + * t[:, np.newaxis] + + q, + dims=("x", "y", "z", "t", "q"), + coords={"x": x, "y": y, "z": z, "t": t, "q": q, "label": "dummy_attr"}, + ) + + # Create indexer into `da` with shared coordinate ("full-twist" Möbius strip) + theta = np.linspace(0, 2 * np.pi, 5) + w = np.linspace(-0.25, 0.25, 7) + r = xr.DataArray( + data=1 + w[:, np.newaxis] * np.cos(theta), + coords=[("w", w), ("theta", theta)], + ) + + xda = r * np.cos(theta) + yda = r * np.sin(theta) + zda = xr.DataArray( + data=w[:, np.newaxis] * np.sin(theta), + coords=[("w", w), ("theta", theta)], + ) + + kwargs = {"fill_value": None} + expected = da.interp(t=0.5, x=xda, y=yda, z=zda, kwargs=kwargs, method=method) + + da = da.chunk(2) + xda = xda.chunk(1) + zda = zda.chunk(3) + actual = da.interp(t=0.5, x=xda, y=yda, z=zda, kwargs=kwargs, method=method) + assert_identical(actual, expected) + + +@requires_scipy +def test_interp1d_bounds_error() -> None: + """Ensure exception on bounds error is raised if requested""" + da = xr.DataArray( + np.sin(0.3 * np.arange(4)), + [("time", np.arange(4))], + ) + + with pytest.raises(ValueError): + da.interp(time=3.5, kwargs=dict(bounds_error=True)) + + # default is to fill with nans, so this should pass + da.interp(time=3.5) + + +@requires_scipy +@pytest.mark.parametrize( + "x, expect_same_attrs", + [ + (2.5, True), + (np.array([2.5, 5]), True), + (("x", np.array([0, 0.5, 1, 2]), dict(unit="s")), False), + ], +) +def test_coord_attrs(x, expect_same_attrs: bool) -> None: + base_attrs = dict(foo="bar") + ds = xr.Dataset( + data_vars=dict(a=2 * np.arange(5)), + coords={"x": ("x", np.arange(5), base_attrs)}, + ) + + has_same_attrs = ds.interp(x=x).x.attrs == base_attrs + assert expect_same_attrs == has_same_attrs + + +@requires_scipy +def test_interp1d_complex_out_of_bounds() -> None: + """Ensure complex nans are used by default""" + da = xr.DataArray( + np.exp(0.3j * np.arange(4)), + [("time", np.arange(4))], + ) + + expected = da.interp(time=3.5, kwargs=dict(fill_value=np.nan + np.nan * 1j)) + actual = da.interp(time=3.5) + assert_identical(actual, expected) diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_merge.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_merge.py new file mode 100644 index 0000000..52935e9 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_merge.py @@ -0,0 +1,492 @@ +from __future__ import annotations + +import numpy as np +import pytest + +import xarray as xr +from xarray.core import dtypes, merge +from xarray.core.merge import MergeError +from xarray.testing import assert_equal, assert_identical +from xarray.tests.test_dataset import create_test_data + + +class TestMergeInternals: + def test_broadcast_dimension_size(self): + actual = merge.broadcast_dimension_size( + [xr.Variable("x", [1]), xr.Variable("y", [2, 1])] + ) + assert actual == {"x": 1, "y": 2} + + actual = merge.broadcast_dimension_size( + [xr.Variable(("x", "y"), [[1, 2]]), xr.Variable("y", [2, 1])] + ) + assert actual == {"x": 1, "y": 2} + + with pytest.raises(ValueError): + merge.broadcast_dimension_size( + [xr.Variable(("x", "y"), [[1, 2]]), xr.Variable("y", [2])] + ) + + +class TestMergeFunction: + def test_merge_arrays(self): + data = create_test_data(add_attrs=False) + + actual = xr.merge([data.var1, data.var2]) + expected = data[["var1", "var2"]] + assert_identical(actual, expected) + + def test_merge_datasets(self): + data = create_test_data(add_attrs=False, use_extension_array=True) + + actual = xr.merge([data[["var1"]], data[["var2"]]]) + expected = data[["var1", "var2"]] + assert_identical(actual, expected) + + actual = xr.merge([data, data]) + assert_identical(actual, data) + + def test_merge_dataarray_unnamed(self): + data = xr.DataArray([1, 2], dims="x") + with pytest.raises(ValueError, match=r"without providing an explicit name"): + xr.merge([data]) + + def test_merge_arrays_attrs_default(self): + var1_attrs = {"a": 1, "b": 2} + var2_attrs = {"a": 1, "c": 3} + expected_attrs = {"a": 1, "b": 2} + + data = create_test_data(add_attrs=False) + expected = data[["var1", "var2"]].copy() + expected.var1.attrs = var1_attrs + expected.var2.attrs = var2_attrs + expected.attrs = expected_attrs + + data.var1.attrs = var1_attrs + data.var2.attrs = var2_attrs + actual = xr.merge([data.var1, data.var2]) + assert_identical(actual, expected) + + @pytest.mark.parametrize( + "combine_attrs, var1_attrs, var2_attrs, expected_attrs, expect_exception", + [ + ( + "no_conflicts", + {"a": 1, "b": 2}, + {"a": 1, "c": 3}, + {"a": 1, "b": 2, "c": 3}, + False, + ), + ("no_conflicts", {"a": 1, "b": 2}, {}, {"a": 1, "b": 2}, False), + ("no_conflicts", {}, {"a": 1, "c": 3}, {"a": 1, "c": 3}, False), + ( + "no_conflicts", + {"a": 1, "b": 2}, + {"a": 4, "c": 3}, + {"a": 1, "b": 2, "c": 3}, + True, + ), + ("drop", {"a": 1, "b": 2}, {"a": 1, "c": 3}, {}, False), + ("identical", {"a": 1, "b": 2}, {"a": 1, "b": 2}, {"a": 1, "b": 2}, False), + ("identical", {"a": 1, "b": 2}, {"a": 1, "c": 3}, {"a": 1, "b": 2}, True), + ( + "override", + {"a": 1, "b": 2}, + {"a": 4, "b": 5, "c": 3}, + {"a": 1, "b": 2}, + False, + ), + ( + "drop_conflicts", + {"a": 1, "b": 2, "c": 3}, + {"b": 1, "c": 3, "d": 4}, + {"a": 1, "c": 3, "d": 4}, + False, + ), + ( + "drop_conflicts", + {"a": 1, "b": np.array([2]), "c": np.array([3])}, + {"b": 1, "c": np.array([3]), "d": 4}, + {"a": 1, "c": np.array([3]), "d": 4}, + False, + ), + ( + lambda attrs, context: attrs[1], + {"a": 1, "b": 2, "c": 3}, + {"a": 4, "b": 3, "c": 1}, + {"a": 4, "b": 3, "c": 1}, + False, + ), + ], + ) + def test_merge_arrays_attrs( + self, combine_attrs, var1_attrs, var2_attrs, expected_attrs, expect_exception + ): + data1 = xr.Dataset(attrs=var1_attrs) + data2 = xr.Dataset(attrs=var2_attrs) + if expect_exception: + with pytest.raises(MergeError, match="combine_attrs"): + actual = xr.merge([data1, data2], combine_attrs=combine_attrs) + else: + actual = xr.merge([data1, data2], combine_attrs=combine_attrs) + expected = xr.Dataset(attrs=expected_attrs) + + assert_identical(actual, expected) + + @pytest.mark.parametrize( + "combine_attrs, attrs1, attrs2, expected_attrs, expect_exception", + [ + ( + "no_conflicts", + {"a": 1, "b": 2}, + {"a": 1, "c": 3}, + {"a": 1, "b": 2, "c": 3}, + False, + ), + ("no_conflicts", {"a": 1, "b": 2}, {}, {"a": 1, "b": 2}, False), + ("no_conflicts", {}, {"a": 1, "c": 3}, {"a": 1, "c": 3}, False), + ( + "no_conflicts", + {"a": 1, "b": 2}, + {"a": 4, "c": 3}, + {"a": 1, "b": 2, "c": 3}, + True, + ), + ("drop", {"a": 1, "b": 2}, {"a": 1, "c": 3}, {}, False), + ("identical", {"a": 1, "b": 2}, {"a": 1, "b": 2}, {"a": 1, "b": 2}, False), + ("identical", {"a": 1, "b": 2}, {"a": 1, "c": 3}, {"a": 1, "b": 2}, True), + ( + "override", + {"a": 1, "b": 2}, + {"a": 4, "b": 5, "c": 3}, + {"a": 1, "b": 2}, + False, + ), + ( + "drop_conflicts", + {"a": 1, "b": 2, "c": 3}, + {"b": 1, "c": 3, "d": 4}, + {"a": 1, "c": 3, "d": 4}, + False, + ), + ( + lambda attrs, context: attrs[1], + {"a": 1, "b": 2, "c": 3}, + {"a": 4, "b": 3, "c": 1}, + {"a": 4, "b": 3, "c": 1}, + False, + ), + ], + ) + def test_merge_arrays_attrs_variables( + self, combine_attrs, attrs1, attrs2, expected_attrs, expect_exception + ): + """check that combine_attrs is used on data variables and coords""" + data1 = xr.Dataset( + {"var1": ("dim1", [], attrs1)}, coords={"dim1": ("dim1", [], attrs1)} + ) + data2 = xr.Dataset( + {"var1": ("dim1", [], attrs2)}, coords={"dim1": ("dim1", [], attrs2)} + ) + + if expect_exception: + with pytest.raises(MergeError, match="combine_attrs"): + actual = xr.merge([data1, data2], combine_attrs=combine_attrs) + else: + actual = xr.merge([data1, data2], combine_attrs=combine_attrs) + expected = xr.Dataset( + {"var1": ("dim1", [], expected_attrs)}, + coords={"dim1": ("dim1", [], expected_attrs)}, + ) + + assert_identical(actual, expected) + + def test_merge_attrs_override_copy(self): + ds1 = xr.Dataset(attrs={"x": 0}) + ds2 = xr.Dataset(attrs={"x": 1}) + ds3 = xr.merge([ds1, ds2], combine_attrs="override") + ds3.attrs["x"] = 2 + assert ds1.x == 0 + + def test_merge_attrs_drop_conflicts(self): + ds1 = xr.Dataset(attrs={"a": 0, "b": 0, "c": 0}) + ds2 = xr.Dataset(attrs={"b": 0, "c": 1, "d": 0}) + ds3 = xr.Dataset(attrs={"a": 0, "b": 1, "c": 0, "e": 0}) + + actual = xr.merge([ds1, ds2, ds3], combine_attrs="drop_conflicts") + expected = xr.Dataset(attrs={"a": 0, "d": 0, "e": 0}) + assert_identical(actual, expected) + + def test_merge_attrs_no_conflicts_compat_minimal(self): + """make sure compat="minimal" does not silence errors""" + ds1 = xr.Dataset({"a": ("x", [], {"a": 0})}) + ds2 = xr.Dataset({"a": ("x", [], {"a": 1})}) + + with pytest.raises(xr.MergeError, match="combine_attrs"): + xr.merge([ds1, ds2], combine_attrs="no_conflicts", compat="minimal") + + def test_merge_dicts_simple(self): + actual = xr.merge([{"foo": 0}, {"bar": "one"}, {"baz": 3.5}]) + expected = xr.Dataset({"foo": 0, "bar": "one", "baz": 3.5}) + assert_identical(actual, expected) + + def test_merge_dicts_dims(self): + actual = xr.merge([{"y": ("x", [13])}, {"x": [12]}]) + expected = xr.Dataset({"x": [12], "y": ("x", [13])}) + assert_identical(actual, expected) + + def test_merge_coordinates(self): + coords1 = xr.Coordinates({"x": ("x", [0, 1, 2])}) + coords2 = xr.Coordinates({"y": ("y", [3, 4, 5])}) + expected = xr.Dataset(coords={"x": [0, 1, 2], "y": [3, 4, 5]}) + actual = xr.merge([coords1, coords2]) + assert_identical(actual, expected) + + def test_merge_error(self): + ds = xr.Dataset({"x": 0}) + with pytest.raises(xr.MergeError): + xr.merge([ds, ds + 1]) + + def test_merge_alignment_error(self): + ds = xr.Dataset(coords={"x": [1, 2]}) + other = xr.Dataset(coords={"x": [2, 3]}) + with pytest.raises(ValueError, match=r"cannot align.*join.*exact.*not equal.*"): + xr.merge([ds, other], join="exact") + + def test_merge_wrong_input_error(self): + with pytest.raises(TypeError, match=r"objects must be an iterable"): + xr.merge([1]) + ds = xr.Dataset(coords={"x": [1, 2]}) + with pytest.raises(TypeError, match=r"objects must be an iterable"): + xr.merge({"a": ds}) + with pytest.raises(TypeError, match=r"objects must be an iterable"): + xr.merge([ds, 1]) + + def test_merge_no_conflicts_single_var(self): + ds1 = xr.Dataset({"a": ("x", [1, 2]), "x": [0, 1]}) + ds2 = xr.Dataset({"a": ("x", [2, 3]), "x": [1, 2]}) + expected = xr.Dataset({"a": ("x", [1, 2, 3]), "x": [0, 1, 2]}) + assert expected.identical(xr.merge([ds1, ds2], compat="no_conflicts")) + assert expected.identical(xr.merge([ds2, ds1], compat="no_conflicts")) + assert ds1.identical(xr.merge([ds1, ds2], compat="no_conflicts", join="left")) + assert ds2.identical(xr.merge([ds1, ds2], compat="no_conflicts", join="right")) + expected = xr.Dataset({"a": ("x", [2]), "x": [1]}) + assert expected.identical( + xr.merge([ds1, ds2], compat="no_conflicts", join="inner") + ) + + with pytest.raises(xr.MergeError): + ds3 = xr.Dataset({"a": ("x", [99, 3]), "x": [1, 2]}) + xr.merge([ds1, ds3], compat="no_conflicts") + + with pytest.raises(xr.MergeError): + ds3 = xr.Dataset({"a": ("y", [2, 3]), "y": [1, 2]}) + xr.merge([ds1, ds3], compat="no_conflicts") + + def test_merge_no_conflicts_multi_var(self): + data = create_test_data(add_attrs=False) + data1 = data.copy(deep=True) + data2 = data.copy(deep=True) + + expected = data[["var1", "var2"]] + actual = xr.merge([data1.var1, data2.var2], compat="no_conflicts") + assert_identical(expected, actual) + + data1["var1"][:, :5] = np.nan + data2["var1"][:, 5:] = np.nan + data1["var2"][:4, :] = np.nan + data2["var2"][4:, :] = np.nan + del data2["var3"] + + actual = xr.merge([data1, data2], compat="no_conflicts") + assert_equal(data, actual) + + def test_merge_no_conflicts_preserve_attrs(self): + data = xr.Dataset({"x": ([], 0, {"foo": "bar"})}) + actual = xr.merge([data, data], combine_attrs="no_conflicts") + assert_identical(data, actual) + + def test_merge_no_conflicts_broadcast(self): + datasets = [xr.Dataset({"x": ("y", [0])}), xr.Dataset({"x": np.nan})] + actual = xr.merge(datasets) + expected = xr.Dataset({"x": ("y", [0])}) + assert_identical(expected, actual) + + datasets = [xr.Dataset({"x": ("y", [np.nan])}), xr.Dataset({"x": 0})] + actual = xr.merge(datasets) + assert_identical(expected, actual) + + +class TestMergeMethod: + def test_merge(self): + data = create_test_data() + ds1 = data[["var1"]] + ds2 = data[["var3"]] + expected = data[["var1", "var3"]] + actual = ds1.merge(ds2) + assert_identical(expected, actual) + + actual = ds2.merge(ds1) + assert_identical(expected, actual) + + actual = data.merge(data) + assert_identical(data, actual) + actual = data.reset_coords(drop=True).merge(data) + assert_identical(data, actual) + actual = data.merge(data.reset_coords(drop=True)) + assert_identical(data, actual) + + with pytest.raises(ValueError): + ds1.merge(ds2.rename({"var3": "var1"})) + with pytest.raises(ValueError, match=r"should be coordinates or not"): + data.reset_coords().merge(data) + with pytest.raises(ValueError, match=r"should be coordinates or not"): + data.merge(data.reset_coords()) + + def test_merge_broadcast_equals(self): + ds1 = xr.Dataset({"x": 0}) + ds2 = xr.Dataset({"x": ("y", [0, 0])}) + actual = ds1.merge(ds2) + assert_identical(ds2, actual) + + actual = ds2.merge(ds1) + assert_identical(ds2, actual) + + actual = ds1.copy() + actual.update(ds2) + assert_identical(ds2, actual) + + ds1 = xr.Dataset({"x": np.nan}) + ds2 = xr.Dataset({"x": ("y", [np.nan, np.nan])}) + actual = ds1.merge(ds2) + assert_identical(ds2, actual) + + def test_merge_compat(self): + ds1 = xr.Dataset({"x": 0}) + ds2 = xr.Dataset({"x": 1}) + for compat in ["broadcast_equals", "equals", "identical", "no_conflicts"]: + with pytest.raises(xr.MergeError): + ds1.merge(ds2, compat=compat) + + ds2 = xr.Dataset({"x": [0, 0]}) + for compat in ["equals", "identical"]: + with pytest.raises(ValueError, match=r"should be coordinates or not"): + ds1.merge(ds2, compat=compat) + + ds2 = xr.Dataset({"x": ((), 0, {"foo": "bar"})}) + with pytest.raises(xr.MergeError): + ds1.merge(ds2, compat="identical") + + with pytest.raises(ValueError, match=r"compat=.* invalid"): + ds1.merge(ds2, compat="foobar") + + assert ds1.identical(ds1.merge(ds2, compat="override")) + + def test_merge_compat_minimal(self) -> None: + # https://github.com/pydata/xarray/issues/7405 + # https://github.com/pydata/xarray/issues/7588 + ds1 = xr.Dataset(coords={"foo": [1, 2, 3], "bar": 4}) + ds2 = xr.Dataset(coords={"foo": [1, 2, 3], "bar": 5}) + + actual = xr.merge([ds1, ds2], compat="minimal") + expected = xr.Dataset(coords={"foo": [1, 2, 3]}) + assert_identical(actual, expected) + + def test_merge_auto_align(self): + ds1 = xr.Dataset({"a": ("x", [1, 2]), "x": [0, 1]}) + ds2 = xr.Dataset({"b": ("x", [3, 4]), "x": [1, 2]}) + expected = xr.Dataset( + {"a": ("x", [1, 2, np.nan]), "b": ("x", [np.nan, 3, 4])}, {"x": [0, 1, 2]} + ) + assert expected.identical(ds1.merge(ds2)) + assert expected.identical(ds2.merge(ds1)) + + expected = expected.isel(x=slice(2)) + assert expected.identical(ds1.merge(ds2, join="left")) + assert expected.identical(ds2.merge(ds1, join="right")) + + expected = expected.isel(x=slice(1, 2)) + assert expected.identical(ds1.merge(ds2, join="inner")) + assert expected.identical(ds2.merge(ds1, join="inner")) + + @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0, {"a": 2, "b": 1}]) + def test_merge_fill_value(self, fill_value): + ds1 = xr.Dataset({"a": ("x", [1, 2]), "x": [0, 1]}) + ds2 = xr.Dataset({"b": ("x", [3, 4]), "x": [1, 2]}) + if fill_value == dtypes.NA: + # if we supply the default, we expect the missing value for a + # float array + fill_value_a = fill_value_b = np.nan + elif isinstance(fill_value, dict): + fill_value_a = fill_value["a"] + fill_value_b = fill_value["b"] + else: + fill_value_a = fill_value_b = fill_value + + expected = xr.Dataset( + {"a": ("x", [1, 2, fill_value_a]), "b": ("x", [fill_value_b, 3, 4])}, + {"x": [0, 1, 2]}, + ) + assert expected.identical(ds1.merge(ds2, fill_value=fill_value)) + assert expected.identical(ds2.merge(ds1, fill_value=fill_value)) + assert expected.identical(xr.merge([ds1, ds2], fill_value=fill_value)) + + def test_merge_no_conflicts(self): + ds1 = xr.Dataset({"a": ("x", [1, 2]), "x": [0, 1]}) + ds2 = xr.Dataset({"a": ("x", [2, 3]), "x": [1, 2]}) + expected = xr.Dataset({"a": ("x", [1, 2, 3]), "x": [0, 1, 2]}) + + assert expected.identical(ds1.merge(ds2, compat="no_conflicts")) + assert expected.identical(ds2.merge(ds1, compat="no_conflicts")) + + assert ds1.identical(ds1.merge(ds2, compat="no_conflicts", join="left")) + + assert ds2.identical(ds1.merge(ds2, compat="no_conflicts", join="right")) + + expected2 = xr.Dataset({"a": ("x", [2]), "x": [1]}) + assert expected2.identical(ds1.merge(ds2, compat="no_conflicts", join="inner")) + + with pytest.raises(xr.MergeError): + ds3 = xr.Dataset({"a": ("x", [99, 3]), "x": [1, 2]}) + ds1.merge(ds3, compat="no_conflicts") + + with pytest.raises(xr.MergeError): + ds3 = xr.Dataset({"a": ("y", [2, 3]), "y": [1, 2]}) + ds1.merge(ds3, compat="no_conflicts") + + def test_merge_dataarray(self): + ds = xr.Dataset({"a": 0}) + da = xr.DataArray(data=1, name="b") + + assert_identical(ds.merge(da), xr.merge([ds, da])) + + @pytest.mark.parametrize( + ["combine_attrs", "attrs1", "attrs2", "expected_attrs", "expect_error"], + # don't need to test thoroughly + ( + ("drop", {"a": 0, "b": 1, "c": 2}, {"a": 1, "b": 2, "c": 3}, {}, False), + ( + "drop_conflicts", + {"a": 0, "b": 1, "c": 2}, + {"b": 2, "c": 2, "d": 3}, + {"a": 0, "c": 2, "d": 3}, + False, + ), + ("override", {"a": 0, "b": 1}, {"a": 1, "b": 2}, {"a": 0, "b": 1}, False), + ("no_conflicts", {"a": 0, "b": 1}, {"a": 0, "b": 2}, None, True), + ("identical", {"a": 0, "b": 1}, {"a": 0, "b": 2}, None, True), + ), + ) + def test_merge_combine_attrs( + self, combine_attrs, attrs1, attrs2, expected_attrs, expect_error + ): + ds1 = xr.Dataset(attrs=attrs1) + ds2 = xr.Dataset(attrs=attrs2) + + if expect_error: + with pytest.raises(xr.MergeError): + ds1.merge(ds2, combine_attrs=combine_attrs) + else: + actual = ds1.merge(ds2, combine_attrs=combine_attrs) + expected = xr.Dataset(attrs=expected_attrs) + assert_identical(actual, expected) diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_missing.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_missing.py new file mode 100644 index 0000000..3adcc13 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_missing.py @@ -0,0 +1,770 @@ +from __future__ import annotations + +import itertools + +import numpy as np +import pandas as pd +import pytest + +import xarray as xr +from xarray.core.missing import ( + NumpyInterpolator, + ScipyInterpolator, + SplineInterpolator, + _get_nan_block_lengths, + get_clean_interp_index, +) +from xarray.namedarray.pycompat import array_type +from xarray.tests import ( + _CFTIME_CALENDARS, + assert_allclose, + assert_array_equal, + assert_equal, + raise_if_dask_computes, + requires_bottleneck, + requires_cftime, + requires_dask, + requires_numbagg, + requires_numbagg_or_bottleneck, + requires_scipy, +) + +dask_array_type = array_type("dask") + + +@pytest.fixture +def da(): + return xr.DataArray([0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time") + + +@pytest.fixture +def cf_da(): + def _cf_da(calendar, freq="1D"): + times = xr.cftime_range( + start="1970-01-01", freq=freq, periods=10, calendar=calendar + ) + values = np.arange(10) + return xr.DataArray(values, dims=("time",), coords={"time": times}) + + return _cf_da + + +@pytest.fixture +def ds(): + ds = xr.Dataset() + ds["var1"] = xr.DataArray( + [0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time" + ) + ds["var2"] = xr.DataArray( + [10, np.nan, 11, 12, np.nan, 13, 14, 15, np.nan, 16, 17], dims="x" + ) + return ds + + +def make_interpolate_example_data(shape, frac_nan, seed=12345, non_uniform=False): + rs = np.random.RandomState(seed) + vals = rs.normal(size=shape) + if frac_nan == 1: + vals[:] = np.nan + elif frac_nan == 0: + pass + else: + n_missing = int(vals.size * frac_nan) + + ys = np.arange(shape[0]) + xs = np.arange(shape[1]) + if n_missing: + np.random.shuffle(ys) + ys = ys[:n_missing] + + np.random.shuffle(xs) + xs = xs[:n_missing] + + vals[ys, xs] = np.nan + + if non_uniform: + # construct a datetime index that has irregular spacing + deltas = pd.to_timedelta(rs.normal(size=shape[0], scale=10), unit="d") + coords = {"time": (pd.Timestamp("2000-01-01") + deltas).sort_values()} + else: + coords = {"time": pd.date_range("2000-01-01", freq="D", periods=shape[0])} + da = xr.DataArray(vals, dims=("time", "x"), coords=coords) + df = da.to_pandas() + + return da, df + + +@pytest.mark.parametrize("fill_value", [None, np.nan, 47.11]) +@pytest.mark.parametrize( + "method", ["linear", "nearest", "zero", "slinear", "quadratic", "cubic"] +) +@requires_scipy +def test_interpolate_pd_compat(method, fill_value) -> None: + shapes = [(8, 8), (1, 20), (20, 1), (100, 100)] + frac_nans = [0, 0.5, 1] + + for shape, frac_nan in itertools.product(shapes, frac_nans): + da, df = make_interpolate_example_data(shape, frac_nan) + + for dim in ["time", "x"]: + actual = da.interpolate_na(method=method, dim=dim, fill_value=fill_value) + # need limit_direction="both" here, to let pandas fill + # in both directions instead of default forward direction only + expected = df.interpolate( + method=method, + axis=da.get_axis_num(dim), + limit_direction="both", + fill_value=fill_value, + ) + + if method == "linear": + # Note, Pandas does not take left/right fill_value into account + # for the numpy linear methods. + # see https://github.com/pandas-dev/pandas/issues/55144 + # This aligns the pandas output with the xarray output + fixed = expected.values.copy() + fixed[pd.isnull(actual.values)] = np.nan + fixed[actual.values == fill_value] = fill_value + else: + fixed = expected.values + + np.testing.assert_allclose(actual.values, fixed) + + +@requires_scipy +@pytest.mark.parametrize("method", ["barycentric", "krogh", "pchip", "spline", "akima"]) +def test_scipy_methods_function(method) -> None: + # Note: Pandas does some wacky things with these methods and the full + # integration tests won't work. + da, _ = make_interpolate_example_data((25, 25), 0.4, non_uniform=True) + actual = da.interpolate_na(method=method, dim="time") + assert (da.count("time") <= actual.count("time")).all() + + +@requires_scipy +def test_interpolate_pd_compat_non_uniform_index(): + shapes = [(8, 8), (1, 20), (20, 1), (100, 100)] + frac_nans = [0, 0.5, 1] + methods = ["time", "index", "values"] + + for shape, frac_nan, method in itertools.product(shapes, frac_nans, methods): + da, df = make_interpolate_example_data(shape, frac_nan, non_uniform=True) + for dim in ["time", "x"]: + if method == "time" and dim != "time": + continue + actual = da.interpolate_na( + method="linear", dim=dim, use_coordinate=True, fill_value=np.nan + ) + expected = df.interpolate( + method=method, + axis=da.get_axis_num(dim), + ) + + # Note, Pandas does some odd things with the left/right fill_value + # for the linear methods. This next line inforces the xarray + # fill_value convention on the pandas output. Therefore, this test + # only checks that interpolated values are the same (not nans) + expected_values = expected.values.copy() + expected_values[pd.isnull(actual.values)] = np.nan + + np.testing.assert_allclose(actual.values, expected_values) + + +@requires_scipy +def test_interpolate_pd_compat_polynomial(): + shapes = [(8, 8), (1, 20), (20, 1), (100, 100)] + frac_nans = [0, 0.5, 1] + orders = [1, 2, 3] + + for shape, frac_nan, order in itertools.product(shapes, frac_nans, orders): + da, df = make_interpolate_example_data(shape, frac_nan) + + for dim in ["time", "x"]: + actual = da.interpolate_na( + method="polynomial", order=order, dim=dim, use_coordinate=False + ) + expected = df.interpolate( + method="polynomial", order=order, axis=da.get_axis_num(dim) + ) + np.testing.assert_allclose(actual.values, expected.values) + + +@requires_scipy +def test_interpolate_unsorted_index_raises(): + vals = np.array([1, 2, 3], dtype=np.float64) + expected = xr.DataArray(vals, dims="x", coords={"x": [2, 1, 3]}) + with pytest.raises(ValueError, match=r"Index 'x' must be monotonically increasing"): + expected.interpolate_na(dim="x", method="index") + + +def test_interpolate_no_dim_raises(): + da = xr.DataArray(np.array([1, 2, np.nan, 5], dtype=np.float64), dims="x") + with pytest.raises(NotImplementedError, match=r"dim is a required argument"): + da.interpolate_na(method="linear") + + +def test_interpolate_invalid_interpolator_raises(): + da = xr.DataArray(np.array([1, 2, np.nan, 5], dtype=np.float64), dims="x") + with pytest.raises(ValueError, match=r"not a valid"): + da.interpolate_na(dim="x", method="foo") + + +def test_interpolate_duplicate_values_raises(): + data = np.random.randn(2, 3) + da = xr.DataArray(data, coords=[("x", ["a", "a"]), ("y", [0, 1, 2])]) + with pytest.raises(ValueError, match=r"Index 'x' has duplicate values"): + da.interpolate_na(dim="x", method="foo") + + +def test_interpolate_multiindex_raises(): + data = np.random.randn(2, 3) + data[1, 1] = np.nan + da = xr.DataArray(data, coords=[("x", ["a", "b"]), ("y", [0, 1, 2])]) + das = da.stack(z=("x", "y")) + with pytest.raises(TypeError, match=r"Index 'z' must be castable to float64"): + das.interpolate_na(dim="z") + + +def test_interpolate_2d_coord_raises(): + coords = { + "x": xr.Variable(("a", "b"), np.arange(6).reshape(2, 3)), + "y": xr.Variable(("a", "b"), np.arange(6).reshape(2, 3)) * 2, + } + + data = np.random.randn(2, 3) + data[1, 1] = np.nan + da = xr.DataArray(data, dims=("a", "b"), coords=coords) + with pytest.raises(ValueError, match=r"interpolation must be 1D"): + da.interpolate_na(dim="a", use_coordinate="x") + + +@requires_scipy +def test_interpolate_kwargs(): + da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x") + expected = xr.DataArray(np.array([4, 5, 6], dtype=np.float64), dims="x") + actual = da.interpolate_na(dim="x", fill_value="extrapolate") + assert_equal(actual, expected) + + expected = xr.DataArray(np.array([4, 5, -999], dtype=np.float64), dims="x") + actual = da.interpolate_na(dim="x", fill_value=-999) + assert_equal(actual, expected) + + +def test_interpolate_keep_attrs(): + vals = np.array([1, 2, 3, 4, 5, 6], dtype=np.float64) + mvals = vals.copy() + mvals[2] = np.nan + missing = xr.DataArray(mvals, dims="x") + missing.attrs = {"test": "value"} + + actual = missing.interpolate_na(dim="x", keep_attrs=True) + assert actual.attrs == {"test": "value"} + + +def test_interpolate(): + vals = np.array([1, 2, 3, 4, 5, 6], dtype=np.float64) + expected = xr.DataArray(vals, dims="x") + mvals = vals.copy() + mvals[2] = np.nan + missing = xr.DataArray(mvals, dims="x") + + actual = missing.interpolate_na(dim="x") + + assert_equal(actual, expected) + + +@requires_scipy +@pytest.mark.parametrize( + "method,vals", + [ + pytest.param(method, vals, id=f"{desc}:{method}") + for method in [ + "linear", + "nearest", + "zero", + "slinear", + "quadratic", + "cubic", + "polynomial", + ] + for (desc, vals) in [ + ("no nans", np.array([1, 2, 3, 4, 5, 6], dtype=np.float64)), + ("one nan", np.array([1, np.nan, np.nan], dtype=np.float64)), + ("all nans", np.full(6, np.nan, dtype=np.float64)), + ] + ], +) +def test_interp1d_fastrack(method, vals): + expected = xr.DataArray(vals, dims="x") + actual = expected.interpolate_na(dim="x", method=method) + + assert_equal(actual, expected) + + +@requires_bottleneck +def test_interpolate_limits(): + da = xr.DataArray( + np.array([1, 2, np.nan, np.nan, np.nan, 6], dtype=np.float64), dims="x" + ) + + actual = da.interpolate_na(dim="x", limit=None) + assert actual.isnull().sum() == 0 + + actual = da.interpolate_na(dim="x", limit=2) + expected = xr.DataArray( + np.array([1, 2, 3, 4, np.nan, 6], dtype=np.float64), dims="x" + ) + + assert_equal(actual, expected) + + +@requires_scipy +def test_interpolate_methods(): + for method in ["linear", "nearest", "zero", "slinear", "quadratic", "cubic"]: + kwargs = {} + da = xr.DataArray( + np.array([0, 1, 2, np.nan, np.nan, np.nan, 6, 7, 8], dtype=np.float64), + dims="x", + ) + actual = da.interpolate_na("x", method=method, **kwargs) + assert actual.isnull().sum() == 0 + + actual = da.interpolate_na("x", method=method, limit=2, **kwargs) + assert actual.isnull().sum() == 1 + + +@requires_scipy +def test_interpolators(): + for method, interpolator in [ + ("linear", NumpyInterpolator), + ("linear", ScipyInterpolator), + ("spline", SplineInterpolator), + ]: + xi = np.array([-1, 0, 1, 2, 5], dtype=np.float64) + yi = np.array([-10, 0, 10, 20, 50], dtype=np.float64) + x = np.array([3, 4], dtype=np.float64) + + f = interpolator(xi, yi, method=method) + out = f(x) + assert pd.isnull(out).sum() == 0 + + +def test_interpolate_use_coordinate(): + xc = xr.Variable("x", [100, 200, 300, 400, 500, 600]) + da = xr.DataArray( + np.array([1, 2, np.nan, np.nan, np.nan, 6], dtype=np.float64), + dims="x", + coords={"xc": xc}, + ) + + # use_coordinate == False is same as using the default index + actual = da.interpolate_na(dim="x", use_coordinate=False) + expected = da.interpolate_na(dim="x") + assert_equal(actual, expected) + + # possible to specify non index coordinate + actual = da.interpolate_na(dim="x", use_coordinate="xc") + expected = da.interpolate_na(dim="x") + assert_equal(actual, expected) + + # possible to specify index coordinate by name + actual = da.interpolate_na(dim="x", use_coordinate="x") + expected = da.interpolate_na(dim="x") + assert_equal(actual, expected) + + +@requires_dask +def test_interpolate_dask(): + da, _ = make_interpolate_example_data((40, 40), 0.5) + da = da.chunk({"x": 5}) + actual = da.interpolate_na("time") + expected = da.load().interpolate_na("time") + assert isinstance(actual.data, dask_array_type) + assert_equal(actual.compute(), expected) + + # with limit + da = da.chunk({"x": 5}) + actual = da.interpolate_na("time", limit=3) + expected = da.load().interpolate_na("time", limit=3) + assert isinstance(actual.data, dask_array_type) + assert_equal(actual, expected) + + +@requires_dask +def test_interpolate_dask_raises_for_invalid_chunk_dim(): + da, _ = make_interpolate_example_data((40, 40), 0.5) + da = da.chunk({"time": 5}) + # this checks for ValueError in dask.array.apply_gufunc + with pytest.raises(ValueError, match=r"consists of multiple chunks"): + da.interpolate_na("time") + + +@requires_dask +@requires_scipy +@pytest.mark.parametrize("dtype, method", [(int, "linear"), (int, "nearest")]) +def test_interpolate_dask_expected_dtype(dtype, method): + da = xr.DataArray( + data=np.array([0, 1], dtype=dtype), + dims=["time"], + coords=dict(time=np.array([0, 1])), + ).chunk(dict(time=2)) + da = da.interp(time=np.array([0, 0.5, 1, 2]), method=method) + + assert da.dtype == da.compute().dtype + + +@requires_numbagg_or_bottleneck +def test_ffill(): + da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x") + expected = xr.DataArray(np.array([4, 5, 5], dtype=np.float64), dims="x") + actual = da.ffill("x") + assert_equal(actual, expected) + + +def test_ffill_use_bottleneck_numbagg(): + da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x") + with xr.set_options(use_bottleneck=False, use_numbagg=False): + with pytest.raises(RuntimeError): + da.ffill("x") + + +@requires_dask +def test_ffill_use_bottleneck_dask(): + da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x") + da = da.chunk({"x": 1}) + with xr.set_options(use_bottleneck=False, use_numbagg=False): + with pytest.raises(RuntimeError): + da.ffill("x") + + +@requires_numbagg +@requires_dask +def test_ffill_use_numbagg_dask(): + with xr.set_options(use_bottleneck=False): + da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x") + da = da.chunk(x=-1) + # Succeeds with a single chunk: + _ = da.ffill("x").compute() + + +def test_bfill_use_bottleneck(): + da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x") + with xr.set_options(use_bottleneck=False, use_numbagg=False): + with pytest.raises(RuntimeError): + da.bfill("x") + + +@requires_dask +def test_bfill_use_bottleneck_dask(): + da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x") + da = da.chunk({"x": 1}) + with xr.set_options(use_bottleneck=False, use_numbagg=False): + with pytest.raises(RuntimeError): + da.bfill("x") + + +@requires_bottleneck +@requires_dask +@pytest.mark.parametrize("method", ["ffill", "bfill"]) +def test_ffill_bfill_dask(method): + da, _ = make_interpolate_example_data((40, 40), 0.5) + da = da.chunk({"x": 5}) + + dask_method = getattr(da, method) + numpy_method = getattr(da.compute(), method) + # unchunked axis + with raise_if_dask_computes(): + actual = dask_method("time") + expected = numpy_method("time") + assert_equal(actual, expected) + + # chunked axis + with raise_if_dask_computes(): + actual = dask_method("x") + expected = numpy_method("x") + assert_equal(actual, expected) + + # with limit + with raise_if_dask_computes(): + actual = dask_method("time", limit=3) + expected = numpy_method("time", limit=3) + assert_equal(actual, expected) + + # limit < axis size + with raise_if_dask_computes(): + actual = dask_method("x", limit=2) + expected = numpy_method("x", limit=2) + assert_equal(actual, expected) + + # limit > axis size + with raise_if_dask_computes(): + actual = dask_method("x", limit=41) + expected = numpy_method("x", limit=41) + assert_equal(actual, expected) + + +@requires_bottleneck +def test_ffill_bfill_nonans(): + vals = np.array([1, 2, 3, 4, 5, 6], dtype=np.float64) + expected = xr.DataArray(vals, dims="x") + + actual = expected.ffill(dim="x") + assert_equal(actual, expected) + + actual = expected.bfill(dim="x") + assert_equal(actual, expected) + + +@requires_bottleneck +def test_ffill_bfill_allnans(): + vals = np.full(6, np.nan, dtype=np.float64) + expected = xr.DataArray(vals, dims="x") + + actual = expected.ffill(dim="x") + assert_equal(actual, expected) + + actual = expected.bfill(dim="x") + assert_equal(actual, expected) + + +@requires_bottleneck +def test_ffill_functions(da): + result = da.ffill("time") + assert result.isnull().sum() == 0 + + +@requires_bottleneck +def test_ffill_limit(): + da = xr.DataArray( + [0, np.nan, np.nan, np.nan, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time" + ) + result = da.ffill("time") + expected = xr.DataArray([0, 0, 0, 0, 0, 3, 4, 5, 5, 6, 7], dims="time") + assert_array_equal(result, expected) + + result = da.ffill("time", limit=1) + expected = xr.DataArray( + [0, 0, np.nan, np.nan, np.nan, 3, 4, 5, 5, 6, 7], dims="time" + ) + assert_array_equal(result, expected) + + +def test_interpolate_dataset(ds): + actual = ds.interpolate_na(dim="time") + # no missing values in var1 + assert actual["var1"].count("time") == actual.sizes["time"] + + # var2 should be the same as it was + assert_array_equal(actual["var2"], ds["var2"]) + + +@requires_bottleneck +def test_ffill_dataset(ds): + ds.ffill(dim="time") + + +@requires_bottleneck +def test_bfill_dataset(ds): + ds.ffill(dim="time") + + +@requires_bottleneck +@pytest.mark.parametrize( + "y, lengths_expected", + [ + [np.arange(9), [[1, 0, 7, 7, 7, 7, 7, 7, 0], [3, 3, 3, 0, 3, 3, 0, 2, 2]]], + [ + np.arange(9) * 3, + [[3, 0, 21, 21, 21, 21, 21, 21, 0], [9, 9, 9, 0, 9, 9, 0, 6, 6]], + ], + [ + [0, 2, 5, 6, 7, 8, 10, 12, 14], + [[2, 0, 12, 12, 12, 12, 12, 12, 0], [6, 6, 6, 0, 4, 4, 0, 4, 4]], + ], + ], +) +def test_interpolate_na_nan_block_lengths(y, lengths_expected): + arr = [ + [np.nan, 1, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, 4], + [np.nan, np.nan, np.nan, 1, np.nan, np.nan, 4, np.nan, np.nan], + ] + da = xr.DataArray(arr, dims=["x", "y"], coords={"x": [0, 1], "y": y}) + index = get_clean_interp_index(da, dim="y", use_coordinate=True) + actual = _get_nan_block_lengths(da, dim="y", index=index) + expected = da.copy(data=lengths_expected) + assert_equal(actual, expected) + + +@requires_cftime +@pytest.mark.parametrize("calendar", _CFTIME_CALENDARS) +def test_get_clean_interp_index_cf_calendar(cf_da, calendar): + """The index for CFTimeIndex is in units of days. This means that if two series using a 360 and 365 days + calendar each have a trend of .01C/year, the linear regression coefficients will be different because they + have different number of days. + + Another option would be to have an index in units of years, but this would likely create other difficulties. + """ + i = get_clean_interp_index(cf_da(calendar), dim="time") + np.testing.assert_array_equal(i, np.arange(10) * 1e9 * 86400) + + +@requires_cftime +@pytest.mark.parametrize( + ("calendar", "freq"), zip(["gregorian", "proleptic_gregorian"], ["1D", "1ME", "1Y"]) +) +def test_get_clean_interp_index_dt(cf_da, calendar, freq): + """In the gregorian case, the index should be proportional to normal datetimes.""" + g = cf_da(calendar, freq=freq) + g["stime"] = xr.Variable(data=g.time.to_index().to_datetimeindex(), dims=("time",)) + + gi = get_clean_interp_index(g, "time") + si = get_clean_interp_index(g, "time", use_coordinate="stime") + np.testing.assert_array_equal(gi, si) + + +@requires_cftime +def test_get_clean_interp_index_potential_overflow(): + da = xr.DataArray( + [0, 1, 2], + dims=("time",), + coords={"time": xr.cftime_range("0000-01-01", periods=3, calendar="360_day")}, + ) + get_clean_interp_index(da, "time") + + +@pytest.mark.parametrize("index", ([0, 2, 1], [0, 1, 1])) +def test_get_clean_interp_index_strict(index): + da = xr.DataArray([0, 1, 2], dims=("x",), coords={"x": index}) + + with pytest.raises(ValueError): + get_clean_interp_index(da, "x") + + clean = get_clean_interp_index(da, "x", strict=False) + np.testing.assert_array_equal(index, clean) + assert clean.dtype == np.float64 + + +@pytest.fixture +def da_time(): + return xr.DataArray( + [np.nan, 1, 2, np.nan, np.nan, 5, np.nan, np.nan, np.nan, np.nan, 10], + dims=["t"], + ) + + +def test_interpolate_na_max_gap_errors(da_time): + with pytest.raises( + NotImplementedError, match=r"max_gap not implemented for unlabeled coordinates" + ): + da_time.interpolate_na("t", max_gap=1) + + with pytest.raises(ValueError, match=r"max_gap must be a scalar."): + da_time.interpolate_na("t", max_gap=(1,)) + + da_time["t"] = pd.date_range("2001-01-01", freq="h", periods=11) + with pytest.raises(TypeError, match=r"Expected value of type str"): + da_time.interpolate_na("t", max_gap=1) + + with pytest.raises(TypeError, match=r"Expected integer or floating point"): + da_time.interpolate_na("t", max_gap="1h", use_coordinate=False) + + with pytest.raises(ValueError, match=r"Could not convert 'huh' to timedelta64"): + da_time.interpolate_na("t", max_gap="huh") + + +@requires_bottleneck +@pytest.mark.parametrize( + "time_range_func", + [pd.date_range, pytest.param(xr.cftime_range, marks=requires_cftime)], +) +@pytest.mark.parametrize("transform", [lambda x: x, lambda x: x.to_dataset(name="a")]) +@pytest.mark.parametrize( + "max_gap", ["3h", np.timedelta64(3, "h"), pd.to_timedelta("3h")] +) +def test_interpolate_na_max_gap_time_specifier( + da_time, max_gap, transform, time_range_func +): + da_time["t"] = time_range_func("2001-01-01", freq="h", periods=11) + expected = transform( + da_time.copy(data=[np.nan, 1, 2, 3, 4, 5, np.nan, np.nan, np.nan, np.nan, 10]) + ) + actual = transform(da_time).interpolate_na("t", max_gap=max_gap) + assert_allclose(actual, expected) + + +@requires_bottleneck +@pytest.mark.parametrize( + "coords", + [ + pytest.param(None, marks=pytest.mark.xfail()), + {"x": np.arange(4), "y": np.arange(12)}, + ], +) +def test_interpolate_na_2d(coords): + n = np.nan + da = xr.DataArray( + [ + [1, 2, 3, 4, n, 6, n, n, n, 10, 11, n], + [n, n, 3, n, n, 6, n, n, n, 10, n, n], + [n, n, 3, n, n, 6, n, n, n, 10, n, n], + [n, 2, 3, 4, n, 6, n, n, n, 10, 11, n], + ], + dims=["x", "y"], + coords=coords, + ) + + actual = da.interpolate_na("y", max_gap=2) + expected_y = da.copy( + data=[ + [1, 2, 3, 4, 5, 6, n, n, n, 10, 11, n], + [n, n, 3, n, n, 6, n, n, n, 10, n, n], + [n, n, 3, n, n, 6, n, n, n, 10, n, n], + [n, 2, 3, 4, 5, 6, n, n, n, 10, 11, n], + ] + ) + assert_equal(actual, expected_y) + + actual = da.interpolate_na("y", max_gap=1, fill_value="extrapolate") + expected_y_extra = da.copy( + data=[ + [1, 2, 3, 4, n, 6, n, n, n, 10, 11, 12], + [n, n, 3, n, n, 6, n, n, n, 10, n, n], + [n, n, 3, n, n, 6, n, n, n, 10, n, n], + [1, 2, 3, 4, n, 6, n, n, n, 10, 11, 12], + ] + ) + assert_equal(actual, expected_y_extra) + + actual = da.interpolate_na("x", max_gap=3) + expected_x = xr.DataArray( + [ + [1, 2, 3, 4, n, 6, n, n, n, 10, 11, n], + [n, 2, 3, 4, n, 6, n, n, n, 10, 11, n], + [n, 2, 3, 4, n, 6, n, n, n, 10, 11, n], + [n, 2, 3, 4, n, 6, n, n, n, 10, 11, n], + ], + dims=["x", "y"], + coords=coords, + ) + assert_equal(actual, expected_x) + + +@requires_scipy +def test_interpolators_complex_out_of_bounds(): + """Ensure complex nans are used for complex data""" + + xi = np.array([-1, 0, 1, 2, 5], dtype=np.float64) + yi = np.exp(1j * xi) + x = np.array([-2, 1, 6], dtype=np.float64) + + expected = np.array( + [np.nan + np.nan * 1j, np.exp(1j), np.nan + np.nan * 1j], dtype=yi.dtype + ) + + for method, interpolator in [ + ("linear", NumpyInterpolator), + ("linear", ScipyInterpolator), + ]: + f = interpolator(xi, yi, method=method) + actual = f(x) + assert_array_equal(actual, expected) diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_namedarray.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_namedarray.py new file mode 100644 index 0000000..3d35844 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_namedarray.py @@ -0,0 +1,563 @@ +from __future__ import annotations + +import copy +from abc import abstractmethod +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, Generic, cast, overload + +import numpy as np +import pytest + +from xarray.core.indexing import ExplicitlyIndexed +from xarray.namedarray._typing import ( + _arrayfunction_or_api, + _default, + _DType_co, + _ShapeType_co, +) +from xarray.namedarray.core import NamedArray, from_array + +if TYPE_CHECKING: + from types import ModuleType + + from numpy.typing import ArrayLike, DTypeLike, NDArray + + from xarray.namedarray._typing import ( + Default, + _AttrsLike, + _Dim, + _DimsLike, + _DType, + _IndexKeyLike, + _IntOrUnknown, + _Shape, + _ShapeLike, + duckarray, + ) + + +class CustomArrayBase(Generic[_ShapeType_co, _DType_co]): + def __init__(self, array: duckarray[Any, _DType_co]) -> None: + self.array: duckarray[Any, _DType_co] = array + + @property + def dtype(self) -> _DType_co: + return self.array.dtype + + @property + def shape(self) -> _Shape: + return self.array.shape + + +class CustomArray( + CustomArrayBase[_ShapeType_co, _DType_co], Generic[_ShapeType_co, _DType_co] +): + def __array__(self) -> np.ndarray[Any, np.dtype[np.generic]]: + return np.array(self.array) + + +class CustomArrayIndexable( + CustomArrayBase[_ShapeType_co, _DType_co], + ExplicitlyIndexed, + Generic[_ShapeType_co, _DType_co], +): + def __getitem__( + self, key: _IndexKeyLike | CustomArrayIndexable[Any, Any], / + ) -> CustomArrayIndexable[Any, _DType_co]: + if isinstance(key, CustomArrayIndexable): + if isinstance(key.array, type(self.array)): + # TODO: key.array is duckarray here, can it be narrowed down further? + # an _arrayapi cannot be used on a _arrayfunction for example. + return type(self)(array=self.array[key.array]) # type: ignore[index] + else: + raise TypeError("key must have the same array type as self") + else: + return type(self)(array=self.array[key]) + + def __array_namespace__(self) -> ModuleType: + return np + + +def check_duck_array_typevar(a: duckarray[Any, _DType]) -> duckarray[Any, _DType]: + # Mypy checks a is valid: + b: duckarray[Any, _DType] = a + + # Runtime check if valid: + if isinstance(b, _arrayfunction_or_api): + return b + else: + raise TypeError(f"a ({type(a)}) is not a valid _arrayfunction or _arrayapi") + + +class NamedArraySubclassobjects: + @pytest.fixture + def target(self, data: np.ndarray[Any, Any]) -> Any: + """Fixture that needs to be overridden""" + raise NotImplementedError + + @abstractmethod + def cls(self, *args: Any, **kwargs: Any) -> Any: + """Method that needs to be overridden""" + raise NotImplementedError + + @pytest.fixture + def data(self) -> np.ndarray[Any, np.dtype[Any]]: + return 0.5 * np.arange(10).reshape(2, 5) + + @pytest.fixture + def random_inputs(self) -> np.ndarray[Any, np.dtype[np.float32]]: + return np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5)) + + def test_properties(self, target: Any, data: Any) -> None: + assert target.dims == ("x", "y") + assert np.array_equal(target.data, data) + assert target.dtype == float + assert target.shape == (2, 5) + assert target.ndim == 2 + assert target.sizes == {"x": 2, "y": 5} + assert target.size == 10 + assert target.nbytes == 80 + assert len(target) == 2 + + def test_attrs(self, target: Any) -> None: + assert target.attrs == {} + attrs = {"foo": "bar"} + target.attrs = attrs + assert target.attrs == attrs + assert isinstance(target.attrs, dict) + target.attrs["foo"] = "baz" + assert target.attrs["foo"] == "baz" + + @pytest.mark.parametrize( + "expected", [np.array([1, 2], dtype=np.dtype(np.int8)), [1, 2]] + ) + def test_init(self, expected: Any) -> None: + actual = self.cls(("x",), expected) + assert np.array_equal(np.asarray(actual.data), expected) + + actual = self.cls(("x",), expected) + assert np.array_equal(np.asarray(actual.data), expected) + + def test_data(self, random_inputs: Any) -> None: + expected = self.cls(["x", "y", "z"], random_inputs) + assert np.array_equal(np.asarray(expected.data), random_inputs) + with pytest.raises(ValueError): + expected.data = np.random.random((3, 4)).astype(np.float64) + d2 = np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5)) + expected.data = d2 + assert np.array_equal(np.asarray(expected.data), d2) + + +class TestNamedArray(NamedArraySubclassobjects): + def cls(self, *args: Any, **kwargs: Any) -> NamedArray[Any, Any]: + return NamedArray(*args, **kwargs) + + @pytest.fixture + def target(self, data: np.ndarray[Any, Any]) -> NamedArray[Any, Any]: + return NamedArray(["x", "y"], data) + + @pytest.mark.parametrize( + "expected", + [ + np.array([1, 2], dtype=np.dtype(np.int8)), + pytest.param( + [1, 2], + marks=pytest.mark.xfail( + reason="NamedArray only supports array-like objects" + ), + ), + ], + ) + def test_init(self, expected: Any) -> None: + super().test_init(expected) + + @pytest.mark.parametrize( + "dims, data, expected, raise_error", + [ + (("x",), [1, 2, 3], np.array([1, 2, 3]), False), + ((1,), np.array([4, 5, 6]), np.array([4, 5, 6]), False), + ((), 2, np.array(2), False), + # Fail: + ( + ("x",), + NamedArray("time", np.array([1, 2, 3])), + np.array([1, 2, 3]), + True, + ), + ], + ) + def test_from_array( + self, + dims: _DimsLike, + data: ArrayLike, + expected: np.ndarray[Any, Any], + raise_error: bool, + ) -> None: + actual: NamedArray[Any, Any] + if raise_error: + with pytest.raises(TypeError, match="already a Named array"): + actual = from_array(dims, data) + + # Named arrays are not allowed: + from_array(actual) # type: ignore[call-overload] + else: + actual = from_array(dims, data) + + assert np.array_equal(np.asarray(actual.data), expected) + + def test_from_array_with_masked_array(self) -> None: + masked_array: np.ndarray[Any, np.dtype[np.generic]] + masked_array = np.ma.array([1, 2, 3], mask=[False, True, False]) # type: ignore[no-untyped-call] + with pytest.raises(NotImplementedError): + from_array(("x",), masked_array) + + def test_from_array_with_0d_object(self) -> None: + data = np.empty((), dtype=object) + data[()] = (10, 12, 12) + narr = from_array((), data) + np.array_equal(np.asarray(narr.data), data) + + # TODO: Make xr.core.indexing.ExplicitlyIndexed pass as a subclass of_arrayfunction_or_api + # and remove this test. + def test_from_array_with_explicitly_indexed( + self, random_inputs: np.ndarray[Any, Any] + ) -> None: + array: CustomArray[Any, Any] + array = CustomArray(random_inputs) + output: NamedArray[Any, Any] + output = from_array(("x", "y", "z"), array) + assert isinstance(output.data, np.ndarray) + + array2: CustomArrayIndexable[Any, Any] + array2 = CustomArrayIndexable(random_inputs) + output2: NamedArray[Any, Any] + output2 = from_array(("x", "y", "z"), array2) + assert isinstance(output2.data, CustomArrayIndexable) + + def test_real_and_imag(self) -> None: + expected_real: np.ndarray[Any, np.dtype[np.float64]] + expected_real = np.arange(3, dtype=np.float64) + + expected_imag: np.ndarray[Any, np.dtype[np.float64]] + expected_imag = -np.arange(3, dtype=np.float64) + + arr: np.ndarray[Any, np.dtype[np.complex128]] + arr = expected_real + 1j * expected_imag + + named_array: NamedArray[Any, np.dtype[np.complex128]] + named_array = NamedArray(["x"], arr) + + actual_real: duckarray[Any, np.dtype[np.float64]] = named_array.real.data + assert np.array_equal(np.asarray(actual_real), expected_real) + assert actual_real.dtype == expected_real.dtype + + actual_imag: duckarray[Any, np.dtype[np.float64]] = named_array.imag.data + assert np.array_equal(np.asarray(actual_imag), expected_imag) + assert actual_imag.dtype == expected_imag.dtype + + # Additional tests as per your original class-based code + @pytest.mark.parametrize( + "data, dtype", + [ + ("foo", np.dtype("U3")), + (b"foo", np.dtype("S3")), + ], + ) + def test_from_array_0d_string(self, data: Any, dtype: DTypeLike) -> None: + named_array: NamedArray[Any, Any] + named_array = from_array([], data) + assert named_array.data == data + assert named_array.dims == () + assert named_array.sizes == {} + assert named_array.attrs == {} + assert named_array.ndim == 0 + assert named_array.size == 1 + assert named_array.dtype == dtype + + def test_from_array_0d_object(self) -> None: + named_array: NamedArray[Any, Any] + named_array = from_array([], (10, 12, 12)) + expected_data = np.empty((), dtype=object) + expected_data[()] = (10, 12, 12) + assert np.array_equal(np.asarray(named_array.data), expected_data) + + assert named_array.dims == () + assert named_array.sizes == {} + assert named_array.attrs == {} + assert named_array.ndim == 0 + assert named_array.size == 1 + assert named_array.dtype == np.dtype("O") + + def test_from_array_0d_datetime(self) -> None: + named_array: NamedArray[Any, Any] + named_array = from_array([], np.datetime64("2000-01-01")) + assert named_array.dtype == np.dtype("datetime64[D]") + + @pytest.mark.parametrize( + "timedelta, expected_dtype", + [ + (np.timedelta64(1, "D"), np.dtype("timedelta64[D]")), + (np.timedelta64(1, "s"), np.dtype("timedelta64[s]")), + (np.timedelta64(1, "m"), np.dtype("timedelta64[m]")), + (np.timedelta64(1, "h"), np.dtype("timedelta64[h]")), + (np.timedelta64(1, "us"), np.dtype("timedelta64[us]")), + (np.timedelta64(1, "ns"), np.dtype("timedelta64[ns]")), + (np.timedelta64(1, "ps"), np.dtype("timedelta64[ps]")), + (np.timedelta64(1, "fs"), np.dtype("timedelta64[fs]")), + (np.timedelta64(1, "as"), np.dtype("timedelta64[as]")), + ], + ) + def test_from_array_0d_timedelta( + self, timedelta: np.timedelta64, expected_dtype: np.dtype[np.timedelta64] + ) -> None: + named_array: NamedArray[Any, Any] + named_array = from_array([], timedelta) + assert named_array.dtype == expected_dtype + assert named_array.data == timedelta + + @pytest.mark.parametrize( + "dims, data_shape, new_dims, raises", + [ + (["x", "y", "z"], (2, 3, 4), ["a", "b", "c"], False), + (["x", "y", "z"], (2, 3, 4), ["a", "b"], True), + (["x", "y", "z"], (2, 4, 5), ["a", "b", "c", "d"], True), + ([], [], (), False), + ([], [], ("x",), True), + ], + ) + def test_dims_setter( + self, dims: Any, data_shape: Any, new_dims: Any, raises: bool + ) -> None: + named_array: NamedArray[Any, Any] + named_array = NamedArray(dims, np.asarray(np.random.random(data_shape))) + assert named_array.dims == tuple(dims) + if raises: + with pytest.raises(ValueError): + named_array.dims = new_dims + else: + named_array.dims = new_dims + assert named_array.dims == tuple(new_dims) + + def test_duck_array_class(self) -> None: + numpy_a: NDArray[np.int64] + numpy_a = np.array([2.1, 4], dtype=np.dtype(np.int64)) + check_duck_array_typevar(numpy_a) + + masked_a: np.ma.MaskedArray[Any, np.dtype[np.int64]] + masked_a = np.ma.asarray([2.1, 4], dtype=np.dtype(np.int64)) # type: ignore[no-untyped-call] + check_duck_array_typevar(masked_a) + + custom_a: CustomArrayIndexable[Any, np.dtype[np.int64]] + custom_a = CustomArrayIndexable(numpy_a) + check_duck_array_typevar(custom_a) + + def test_duck_array_class_array_api(self) -> None: + # Test numpy's array api: + nxp = pytest.importorskip("array_api_strict", minversion="1.0") + + # TODO: nxp doesn't use dtype typevars, so can only use Any for the moment: + arrayapi_a: duckarray[Any, Any] # duckarray[Any, np.dtype[np.int64]] + arrayapi_a = nxp.asarray([2.1, 4], dtype=nxp.int64) + check_duck_array_typevar(arrayapi_a) + + def test_new_namedarray(self) -> None: + dtype_float = np.dtype(np.float32) + narr_float: NamedArray[Any, np.dtype[np.float32]] + narr_float = NamedArray(("x",), np.array([1.5, 3.2], dtype=dtype_float)) + assert narr_float.dtype == dtype_float + + dtype_int = np.dtype(np.int8) + narr_int: NamedArray[Any, np.dtype[np.int8]] + narr_int = narr_float._new(("x",), np.array([1, 3], dtype=dtype_int)) + assert narr_int.dtype == dtype_int + + class Variable( + NamedArray[_ShapeType_co, _DType_co], Generic[_ShapeType_co, _DType_co] + ): + @overload + def _new( + self, + dims: _DimsLike | Default = ..., + data: duckarray[Any, _DType] = ..., + attrs: _AttrsLike | Default = ..., + ) -> Variable[Any, _DType]: ... + + @overload + def _new( + self, + dims: _DimsLike | Default = ..., + data: Default = ..., + attrs: _AttrsLike | Default = ..., + ) -> Variable[_ShapeType_co, _DType_co]: ... + + def _new( + self, + dims: _DimsLike | Default = _default, + data: duckarray[Any, _DType] | Default = _default, + attrs: _AttrsLike | Default = _default, + ) -> Variable[Any, _DType] | Variable[_ShapeType_co, _DType_co]: + dims_ = copy.copy(self._dims) if dims is _default else dims + + attrs_: Mapping[Any, Any] | None + if attrs is _default: + attrs_ = None if self._attrs is None else self._attrs.copy() + else: + attrs_ = attrs + + if data is _default: + return type(self)(dims_, copy.copy(self._data), attrs_) + cls_ = cast("type[Variable[Any, _DType]]", type(self)) + return cls_(dims_, data, attrs_) + + var_float: Variable[Any, np.dtype[np.float32]] + var_float = Variable(("x",), np.array([1.5, 3.2], dtype=dtype_float)) + assert var_float.dtype == dtype_float + + var_int: Variable[Any, np.dtype[np.int8]] + var_int = var_float._new(("x",), np.array([1, 3], dtype=dtype_int)) + assert var_int.dtype == dtype_int + + def test_replace_namedarray(self) -> None: + dtype_float = np.dtype(np.float32) + np_val: np.ndarray[Any, np.dtype[np.float32]] + np_val = np.array([1.5, 3.2], dtype=dtype_float) + np_val2: np.ndarray[Any, np.dtype[np.float32]] + np_val2 = 2 * np_val + + narr_float: NamedArray[Any, np.dtype[np.float32]] + narr_float = NamedArray(("x",), np_val) + assert narr_float.dtype == dtype_float + + narr_float2: NamedArray[Any, np.dtype[np.float32]] + narr_float2 = NamedArray(("x",), np_val2) + assert narr_float2.dtype == dtype_float + + class Variable( + NamedArray[_ShapeType_co, _DType_co], Generic[_ShapeType_co, _DType_co] + ): + @overload + def _new( + self, + dims: _DimsLike | Default = ..., + data: duckarray[Any, _DType] = ..., + attrs: _AttrsLike | Default = ..., + ) -> Variable[Any, _DType]: ... + + @overload + def _new( + self, + dims: _DimsLike | Default = ..., + data: Default = ..., + attrs: _AttrsLike | Default = ..., + ) -> Variable[_ShapeType_co, _DType_co]: ... + + def _new( + self, + dims: _DimsLike | Default = _default, + data: duckarray[Any, _DType] | Default = _default, + attrs: _AttrsLike | Default = _default, + ) -> Variable[Any, _DType] | Variable[_ShapeType_co, _DType_co]: + dims_ = copy.copy(self._dims) if dims is _default else dims + + attrs_: Mapping[Any, Any] | None + if attrs is _default: + attrs_ = None if self._attrs is None else self._attrs.copy() + else: + attrs_ = attrs + + if data is _default: + return type(self)(dims_, copy.copy(self._data), attrs_) + cls_ = cast("type[Variable[Any, _DType]]", type(self)) + return cls_(dims_, data, attrs_) + + var_float: Variable[Any, np.dtype[np.float32]] + var_float = Variable(("x",), np_val) + assert var_float.dtype == dtype_float + + var_float2: Variable[Any, np.dtype[np.float32]] + var_float2 = var_float._replace(("x",), np_val2) + assert var_float2.dtype == dtype_float + + @pytest.mark.parametrize( + "dim,expected_ndim,expected_shape,expected_dims", + [ + (None, 3, (1, 2, 5), (None, "x", "y")), + (_default, 3, (1, 2, 5), ("dim_2", "x", "y")), + ("z", 3, (1, 2, 5), ("z", "x", "y")), + ], + ) + def test_expand_dims( + self, + target: NamedArray[Any, np.dtype[np.float32]], + dim: _Dim | Default, + expected_ndim: int, + expected_shape: _ShapeLike, + expected_dims: _DimsLike, + ) -> None: + result = target.expand_dims(dim=dim) + assert result.ndim == expected_ndim + assert result.shape == expected_shape + assert result.dims == expected_dims + + @pytest.mark.parametrize( + "dims, expected_sizes", + [ + ((), {"y": 5, "x": 2}), + (["y", "x"], {"y": 5, "x": 2}), + (["y", ...], {"y": 5, "x": 2}), + ], + ) + def test_permute_dims( + self, + target: NamedArray[Any, np.dtype[np.float32]], + dims: _DimsLike, + expected_sizes: dict[_Dim, _IntOrUnknown], + ) -> None: + actual = target.permute_dims(*dims) + assert actual.sizes == expected_sizes + + def test_permute_dims_errors( + self, + target: NamedArray[Any, np.dtype[np.float32]], + ) -> None: + with pytest.raises(ValueError, match=r"'y'.*permuted list"): + dims = ["y"] + target.permute_dims(*dims) + + @pytest.mark.parametrize( + "broadcast_dims,expected_ndim", + [ + ({"x": 2, "y": 5}, 2), + ({"x": 2, "y": 5, "z": 2}, 3), + ({"w": 1, "x": 2, "y": 5}, 3), + ], + ) + def test_broadcast_to( + self, + target: NamedArray[Any, np.dtype[np.float32]], + broadcast_dims: Mapping[_Dim, int], + expected_ndim: int, + ) -> None: + expand_dims = set(broadcast_dims.keys()) - set(target.dims) + # loop over expand_dims and call .expand_dims(dim=dim) in a loop + for dim in expand_dims: + target = target.expand_dims(dim=dim) + result = target.broadcast_to(broadcast_dims) + assert result.ndim == expected_ndim + assert result.sizes == broadcast_dims + + def test_broadcast_to_errors( + self, target: NamedArray[Any, np.dtype[np.float32]] + ) -> None: + with pytest.raises( + ValueError, + match=r"operands could not be broadcast together with remapped shapes", + ): + target.broadcast_to({"x": 2, "y": 2}) + + with pytest.raises(ValueError, match=r"Cannot add new dimensions"): + target.broadcast_to({"x": 2, "y": 2, "z": 2}) + + def test_warn_on_repeated_dimension_names(self) -> None: + with pytest.warns(UserWarning, match="Duplicate dimension names"): + NamedArray(("x", "x"), np.arange(4).reshape(2, 2)) diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_nputils.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_nputils.py new file mode 100644 index 0000000..dbe8ee3 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_nputils.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import numpy as np +from numpy.testing import assert_array_equal + +from xarray.core.nputils import NumpyVIndexAdapter, _is_contiguous + + +def test_is_contiguous() -> None: + assert _is_contiguous([1]) + assert _is_contiguous([1, 2, 3]) + assert not _is_contiguous([1, 3]) + + +def test_vindex() -> None: + x = np.arange(3 * 4 * 5).reshape((3, 4, 5)) + vindex = NumpyVIndexAdapter(x) + + # getitem + assert_array_equal(vindex[0], x[0]) + assert_array_equal(vindex[[1, 2], [1, 2]], x[[1, 2], [1, 2]]) + assert vindex[[0, 1], [0, 1], :].shape == (2, 5) + assert vindex[[0, 1], :, [0, 1]].shape == (2, 4) + assert vindex[:, [0, 1], [0, 1]].shape == (2, 3) + + # setitem + vindex[:] = 0 + assert_array_equal(x, np.zeros_like(x)) + # assignment should not raise + vindex[[0, 1], [0, 1], :] = vindex[[0, 1], [0, 1], :] + vindex[[0, 1], :, [0, 1]] = vindex[[0, 1], :, [0, 1]] + vindex[:, [0, 1], [0, 1]] = vindex[:, [0, 1], [0, 1]] diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_options.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_options.py new file mode 100644 index 0000000..8ad1cbe --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_options.py @@ -0,0 +1,217 @@ +from __future__ import annotations + +import pytest + +import xarray +from xarray import concat, merge +from xarray.backends.file_manager import FILE_CACHE +from xarray.core.options import OPTIONS, _get_keep_attrs +from xarray.tests.test_dataset import create_test_data + + +def test_invalid_option_raises() -> None: + with pytest.raises(ValueError): + xarray.set_options(not_a_valid_options=True) + + +def test_display_width() -> None: + with pytest.raises(ValueError): + xarray.set_options(display_width=0) + with pytest.raises(ValueError): + xarray.set_options(display_width=-10) + with pytest.raises(ValueError): + xarray.set_options(display_width=3.5) + + +def test_arithmetic_join() -> None: + with pytest.raises(ValueError): + xarray.set_options(arithmetic_join="invalid") + with xarray.set_options(arithmetic_join="exact"): + assert OPTIONS["arithmetic_join"] == "exact" + + +def test_enable_cftimeindex() -> None: + with pytest.raises(ValueError): + xarray.set_options(enable_cftimeindex=None) + with pytest.warns(FutureWarning, match="no-op"): + with xarray.set_options(enable_cftimeindex=True): + assert OPTIONS["enable_cftimeindex"] + + +def test_file_cache_maxsize() -> None: + with pytest.raises(ValueError): + xarray.set_options(file_cache_maxsize=0) + original_size = FILE_CACHE.maxsize + with xarray.set_options(file_cache_maxsize=123): + assert FILE_CACHE.maxsize == 123 + assert FILE_CACHE.maxsize == original_size + + +def test_keep_attrs() -> None: + with pytest.raises(ValueError): + xarray.set_options(keep_attrs="invalid_str") + with xarray.set_options(keep_attrs=True): + assert OPTIONS["keep_attrs"] + with xarray.set_options(keep_attrs=False): + assert not OPTIONS["keep_attrs"] + with xarray.set_options(keep_attrs="default"): + assert _get_keep_attrs(default=True) + assert not _get_keep_attrs(default=False) + + +def test_nested_options() -> None: + original = OPTIONS["display_width"] + with xarray.set_options(display_width=1): + assert OPTIONS["display_width"] == 1 + with xarray.set_options(display_width=2): + assert OPTIONS["display_width"] == 2 + assert OPTIONS["display_width"] == 1 + assert OPTIONS["display_width"] == original + + +def test_display_style() -> None: + original = "html" + assert OPTIONS["display_style"] == original + with pytest.raises(ValueError): + xarray.set_options(display_style="invalid_str") + with xarray.set_options(display_style="text"): + assert OPTIONS["display_style"] == "text" + assert OPTIONS["display_style"] == original + + +def create_test_dataset_attrs(seed=0): + ds = create_test_data(seed) + ds.attrs = {"attr1": 5, "attr2": "history", "attr3": {"nested": "more_info"}} + return ds + + +def create_test_dataarray_attrs(seed=0, var="var1"): + da = create_test_data(seed)[var] + da.attrs = {"attr1": 5, "attr2": "history", "attr3": {"nested": "more_info"}} + return da + + +class TestAttrRetention: + def test_dataset_attr_retention(self) -> None: + # Use .mean() for all tests: a typical reduction operation + ds = create_test_dataset_attrs() + original_attrs = ds.attrs + + # Test default behaviour + result = ds.mean() + assert result.attrs == {} + with xarray.set_options(keep_attrs="default"): + result = ds.mean() + assert result.attrs == {} + + with xarray.set_options(keep_attrs=True): + result = ds.mean() + assert result.attrs == original_attrs + + with xarray.set_options(keep_attrs=False): + result = ds.mean() + assert result.attrs == {} + + def test_dataarray_attr_retention(self) -> None: + # Use .mean() for all tests: a typical reduction operation + da = create_test_dataarray_attrs() + original_attrs = da.attrs + + # Test default behaviour + result = da.mean() + assert result.attrs == {} + with xarray.set_options(keep_attrs="default"): + result = da.mean() + assert result.attrs == {} + + with xarray.set_options(keep_attrs=True): + result = da.mean() + assert result.attrs == original_attrs + + with xarray.set_options(keep_attrs=False): + result = da.mean() + assert result.attrs == {} + + def test_groupby_attr_retention(self) -> None: + da = xarray.DataArray([1, 2, 3], [("x", [1, 1, 2])]) + da.attrs = {"attr1": 5, "attr2": "history", "attr3": {"nested": "more_info"}} + original_attrs = da.attrs + + # Test default behaviour + result = da.groupby("x").sum(keep_attrs=True) + assert result.attrs == original_attrs + with xarray.set_options(keep_attrs="default"): + result = da.groupby("x").sum(keep_attrs=True) + assert result.attrs == original_attrs + + with xarray.set_options(keep_attrs=True): + result1 = da.groupby("x") + result = result1.sum() + assert result.attrs == original_attrs + + with xarray.set_options(keep_attrs=False): + result = da.groupby("x").sum() + assert result.attrs == {} + + def test_concat_attr_retention(self) -> None: + ds1 = create_test_dataset_attrs() + ds2 = create_test_dataset_attrs() + ds2.attrs = {"wrong": "attributes"} + original_attrs = ds1.attrs + + # Test default behaviour of keeping the attrs of the first + # dataset in the supplied list + # global keep_attrs option current doesn't affect concat + result = concat([ds1, ds2], dim="dim1") + assert result.attrs == original_attrs + + def test_merge_attr_retention(self) -> None: + da1 = create_test_dataarray_attrs(var="var1") + da2 = create_test_dataarray_attrs(var="var2") + da2.attrs = {"wrong": "attributes"} + original_attrs = da1.attrs + + # merge currently discards attrs, and the global keep_attrs + # option doesn't affect this + result = merge([da1, da2]) + assert result.attrs == original_attrs + + def test_display_style_text(self) -> None: + ds = create_test_dataset_attrs() + with xarray.set_options(display_style="text"): + text = ds._repr_html_() + assert text.startswith("
    ")
    +            assert "'nested'" in text
    +            assert "<xarray.Dataset>" in text
    +
    +    def test_display_style_html(self) -> None:
    +        ds = create_test_dataset_attrs()
    +        with xarray.set_options(display_style="html"):
    +            html = ds._repr_html_()
    +            assert html.startswith("
    ") + assert "'nested'" in html + + def test_display_dataarray_style_text(self) -> None: + da = create_test_dataarray_attrs() + with xarray.set_options(display_style="text"): + text = da._repr_html_() + assert text.startswith("
    ")
    +            assert "<xarray.DataArray 'var1'" in text
    +
    +    def test_display_dataarray_style_html(self) -> None:
    +        da = create_test_dataarray_attrs()
    +        with xarray.set_options(display_style="html"):
    +            html = da._repr_html_()
    +            assert html.startswith("
    ") + assert "#x27;nested'" in html + + +@pytest.mark.parametrize( + "set_value", + [("left"), ("exact")], +) +def test_get_options_retention(set_value): + """Test to check if get_options will return changes made by set_options""" + with xarray.set_options(arithmetic_join=set_value): + get_options = xarray.get_options() + assert get_options["arithmetic_join"] == set_value diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_parallelcompat.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_parallelcompat.py new file mode 100644 index 0000000..dbe40be --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_parallelcompat.py @@ -0,0 +1,232 @@ +from __future__ import annotations + +from importlib.metadata import EntryPoint +from typing import Any + +import numpy as np +import pytest + +from xarray.core.types import T_Chunks, T_DuckArray, T_NormalizedChunks +from xarray.namedarray._typing import _Chunks +from xarray.namedarray.daskmanager import DaskManager +from xarray.namedarray.parallelcompat import ( + ChunkManagerEntrypoint, + get_chunked_array_type, + guess_chunkmanager, + list_chunkmanagers, + load_chunkmanagers, +) +from xarray.tests import has_dask, requires_dask + + +class DummyChunkedArray(np.ndarray): + """ + Mock-up of a chunked array class. + + Adds a (non-functional) .chunks attribute by following this example in the numpy docs + https://numpy.org/doc/stable/user/basics.subclassing.html#simple-example-adding-an-extra-attribute-to-ndarray + """ + + chunks: T_NormalizedChunks + + def __new__( + cls, + shape, + dtype=float, + buffer=None, + offset=0, + strides=None, + order=None, + chunks=None, + ): + obj = super().__new__(cls, shape, dtype, buffer, offset, strides, order) + obj.chunks = chunks + return obj + + def __array_finalize__(self, obj): + if obj is None: + return + self.chunks = getattr(obj, "chunks", None) + + def rechunk(self, chunks, **kwargs): + copied = self.copy() + copied.chunks = chunks + return copied + + +class DummyChunkManager(ChunkManagerEntrypoint): + """Mock-up of ChunkManager class for DummyChunkedArray""" + + def __init__(self): + self.array_cls = DummyChunkedArray + + def is_chunked_array(self, data: Any) -> bool: + return isinstance(data, DummyChunkedArray) + + def chunks(self, data: DummyChunkedArray) -> T_NormalizedChunks: + return data.chunks + + def normalize_chunks( + self, + chunks: T_Chunks | T_NormalizedChunks, + shape: tuple[int, ...] | None = None, + limit: int | None = None, + dtype: np.dtype | None = None, + previous_chunks: T_NormalizedChunks | None = None, + ) -> T_NormalizedChunks: + from dask.array.core import normalize_chunks + + return normalize_chunks(chunks, shape, limit, dtype, previous_chunks) + + def from_array( + self, data: T_DuckArray | np.typing.ArrayLike, chunks: _Chunks, **kwargs + ) -> DummyChunkedArray: + from dask import array as da + + return da.from_array(data, chunks, **kwargs) + + def rechunk(self, data: DummyChunkedArray, chunks, **kwargs) -> DummyChunkedArray: + return data.rechunk(chunks, **kwargs) + + def compute(self, *data: DummyChunkedArray, **kwargs) -> tuple[np.ndarray, ...]: + from dask.array import compute + + return compute(*data, **kwargs) + + def apply_gufunc( + self, + func, + signature, + *args, + axes=None, + axis=None, + keepdims=False, + output_dtypes=None, + output_sizes=None, + vectorize=None, + allow_rechunk=False, + meta=None, + **kwargs, + ): + from dask.array.gufunc import apply_gufunc + + return apply_gufunc( + func, + signature, + *args, + axes=axes, + axis=axis, + keepdims=keepdims, + output_dtypes=output_dtypes, + output_sizes=output_sizes, + vectorize=vectorize, + allow_rechunk=allow_rechunk, + meta=meta, + **kwargs, + ) + + +@pytest.fixture +def register_dummy_chunkmanager(monkeypatch): + """ + Mocks the registering of an additional ChunkManagerEntrypoint. + + This preserves the presence of the existing DaskManager, so a test that relies on this and DaskManager both being + returned from list_chunkmanagers() at once would still work. + + The monkeypatching changes the behavior of list_chunkmanagers when called inside xarray.namedarray.parallelcompat, + but not when called from this tests file. + """ + # Should include DaskManager iff dask is available to be imported + preregistered_chunkmanagers = list_chunkmanagers() + + monkeypatch.setattr( + "xarray.namedarray.parallelcompat.list_chunkmanagers", + lambda: {"dummy": DummyChunkManager()} | preregistered_chunkmanagers, + ) + yield + + +class TestGetChunkManager: + def test_get_chunkmanger(self, register_dummy_chunkmanager) -> None: + chunkmanager = guess_chunkmanager("dummy") + assert isinstance(chunkmanager, DummyChunkManager) + + def test_fail_on_nonexistent_chunkmanager(self) -> None: + with pytest.raises(ValueError, match="unrecognized chunk manager foo"): + guess_chunkmanager("foo") + + @requires_dask + def test_get_dask_if_installed(self) -> None: + chunkmanager = guess_chunkmanager(None) + assert isinstance(chunkmanager, DaskManager) + + @pytest.mark.skipif(has_dask, reason="requires dask not to be installed") + def test_dont_get_dask_if_not_installed(self) -> None: + with pytest.raises(ValueError, match="unrecognized chunk manager dask"): + guess_chunkmanager("dask") + + @requires_dask + def test_choose_dask_over_other_chunkmanagers( + self, register_dummy_chunkmanager + ) -> None: + chunk_manager = guess_chunkmanager(None) + assert isinstance(chunk_manager, DaskManager) + + +class TestGetChunkedArrayType: + def test_detect_chunked_arrays(self, register_dummy_chunkmanager) -> None: + dummy_arr = DummyChunkedArray([1, 2, 3]) + + chunk_manager = get_chunked_array_type(dummy_arr) + assert isinstance(chunk_manager, DummyChunkManager) + + def test_ignore_inmemory_arrays(self, register_dummy_chunkmanager) -> None: + dummy_arr = DummyChunkedArray([1, 2, 3]) + + chunk_manager = get_chunked_array_type(*[dummy_arr, 1.0, np.array([5, 6])]) + assert isinstance(chunk_manager, DummyChunkManager) + + with pytest.raises(TypeError, match="Expected a chunked array"): + get_chunked_array_type(5.0) + + def test_raise_if_no_arrays_chunked(self, register_dummy_chunkmanager) -> None: + with pytest.raises(TypeError, match="Expected a chunked array "): + get_chunked_array_type(*[1.0, np.array([5, 6])]) + + def test_raise_if_no_matching_chunkmanagers(self) -> None: + dummy_arr = DummyChunkedArray([1, 2, 3]) + + with pytest.raises( + TypeError, match="Could not find a Chunk Manager which recognises" + ): + get_chunked_array_type(dummy_arr) + + @requires_dask + def test_detect_dask_if_installed(self) -> None: + import dask.array as da + + dask_arr = da.from_array([1, 2, 3], chunks=(1,)) + + chunk_manager = get_chunked_array_type(dask_arr) + assert isinstance(chunk_manager, DaskManager) + + @requires_dask + def test_raise_on_mixed_array_types(self, register_dummy_chunkmanager) -> None: + import dask.array as da + + dummy_arr = DummyChunkedArray([1, 2, 3]) + dask_arr = da.from_array([1, 2, 3], chunks=(1,)) + + with pytest.raises(TypeError, match="received multiple types"): + get_chunked_array_type(*[dask_arr, dummy_arr]) + + +def test_bogus_entrypoint() -> None: + # Create a bogus entry-point as if the user broke their setup.cfg + # or is actively developing their new chunk manager + entry_point = EntryPoint( + "bogus", "xarray.bogus.doesnotwork", "xarray.chunkmanagers" + ) + with pytest.warns(UserWarning, match="Failed to load chunk manager"): + assert len(load_chunkmanagers([entry_point])) == 0 diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_plot.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_plot.py new file mode 100644 index 0000000..a44b621 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_plot.py @@ -0,0 +1,3408 @@ +from __future__ import annotations + +import contextlib +import inspect +import math +from collections.abc import Generator, Hashable +from copy import copy +from datetime import date, timedelta +from typing import Any, Callable, Literal + +import numpy as np +import pandas as pd +import pytest + +import xarray as xr +import xarray.plot as xplt +from xarray import DataArray, Dataset +from xarray.namedarray.utils import module_available +from xarray.plot.dataarray_plot import _infer_interval_breaks +from xarray.plot.dataset_plot import _infer_meta_data +from xarray.plot.utils import ( + _assert_valid_xy, + _build_discrete_cmap, + _color_palette, + _determine_cmap_params, + _maybe_gca, + get_axis, + label_from_attrs, +) +from xarray.tests import ( + assert_array_equal, + assert_equal, + assert_no_warnings, + requires_cartopy, + requires_cftime, + requires_matplotlib, + requires_seaborn, +) + +# this should not be imported to test if the automatic lazy import works +has_nc_time_axis = module_available("nc_time_axis") + +# import mpl and change the backend before other mpl imports +try: + import matplotlib as mpl + import matplotlib.dates + import matplotlib.pyplot as plt + import mpl_toolkits +except ImportError: + pass + +try: + import cartopy +except ImportError: + pass + + +@contextlib.contextmanager +def figure_context(*args, **kwargs): + """context manager which autocloses a figure (even if the test failed)""" + + try: + yield None + finally: + plt.close("all") + + +@pytest.fixture(scope="function", autouse=True) +def test_all_figures_closed(): + """meta-test to ensure all figures are closed at the end of a test + + Notes: Scope is kept to module (only invoke this function once per test + module) else tests cannot be run in parallel (locally). Disadvantage: only + catches one open figure per run. May still give a false positive if tests + are run in parallel. + """ + yield None + + open_figs = len(plt.get_fignums()) + if open_figs: + raise RuntimeError( + f"tests did not close all figures ({open_figs} figures open)" + ) + + +@pytest.mark.flaky +@pytest.mark.skip(reason="maybe flaky") +def text_in_fig() -> set[str]: + """ + Return the set of all text in the figure + """ + return {t.get_text() for t in plt.gcf().findobj(mpl.text.Text)} # type: ignore[attr-defined] # mpl error? + + +def find_possible_colorbars() -> list[mpl.collections.QuadMesh]: + # nb. this function also matches meshes from pcolormesh + return plt.gcf().findobj(mpl.collections.QuadMesh) # type: ignore[return-value] # mpl error? + + +def substring_in_axes(substring: str, ax: mpl.axes.Axes) -> bool: + """ + Return True if a substring is found anywhere in an axes + """ + alltxt: set[str] = {t.get_text() for t in ax.findobj(mpl.text.Text)} # type: ignore[attr-defined] # mpl error? + for txt in alltxt: + if substring in txt: + return True + return False + + +def substring_not_in_axes(substring: str, ax: mpl.axes.Axes) -> bool: + """ + Return True if a substring is not found anywhere in an axes + """ + alltxt: set[str] = {t.get_text() for t in ax.findobj(mpl.text.Text)} # type: ignore[attr-defined] # mpl error? + check = [(substring not in txt) for txt in alltxt] + return all(check) + + +def property_in_axes_text( + property, property_str, target_txt, ax: mpl.axes.Axes +) -> bool: + """ + Return True if the specified text in an axes + has the property assigned to property_str + """ + alltxt: list[mpl.text.Text] = ax.findobj(mpl.text.Text) # type: ignore[assignment] + check = [] + for t in alltxt: + if t.get_text() == target_txt: + check.append(plt.getp(t, property) == property_str) + return all(check) + + +def easy_array(shape: tuple[int, ...], start: float = 0, stop: float = 1) -> np.ndarray: + """ + Make an array with desired shape using np.linspace + + shape is a tuple like (2, 3) + """ + a = np.linspace(start, stop, num=math.prod(shape)) + return a.reshape(shape) + + +def get_colorbar_label(colorbar) -> str: + if colorbar.orientation == "vertical": + return colorbar.ax.get_ylabel() + else: + return colorbar.ax.get_xlabel() + + +@requires_matplotlib +class PlotTestCase: + @pytest.fixture(autouse=True) + def setup(self) -> Generator: + yield + # Remove all matplotlib figures + plt.close("all") + + def pass_in_axis(self, plotmethod, subplot_kw=None) -> None: + fig, axs = plt.subplots(ncols=2, subplot_kw=subplot_kw) + plotmethod(ax=axs[0]) + assert axs[0].has_data() + + @pytest.mark.slow + def imshow_called(self, plotmethod) -> bool: + plotmethod() + images = plt.gca().findobj(mpl.image.AxesImage) + return len(images) > 0 + + def contourf_called(self, plotmethod) -> bool: + plotmethod() + + # Compatible with mpl before (PathCollection) and after (QuadContourSet) 3.8 + def matchfunc(x) -> bool: + return isinstance( + x, (mpl.collections.PathCollection, mpl.contour.QuadContourSet) + ) + + paths = plt.gca().findobj(matchfunc) + return len(paths) > 0 + + +class TestPlot(PlotTestCase): + @pytest.fixture(autouse=True) + def setup_array(self) -> None: + self.darray = DataArray(easy_array((2, 3, 4))) + + def test_accessor(self) -> None: + from xarray.plot.accessor import DataArrayPlotAccessor + + assert DataArray.plot is DataArrayPlotAccessor + assert isinstance(self.darray.plot, DataArrayPlotAccessor) + + def test_label_from_attrs(self) -> None: + da = self.darray.copy() + assert "" == label_from_attrs(da) + + da.name = 0 + assert "0" == label_from_attrs(da) + + da.name = "a" + da.attrs["units"] = "a_units" + da.attrs["long_name"] = "a_long_name" + da.attrs["standard_name"] = "a_standard_name" + assert "a_long_name [a_units]" == label_from_attrs(da) + + da.attrs.pop("long_name") + assert "a_standard_name [a_units]" == label_from_attrs(da) + da.attrs.pop("units") + assert "a_standard_name" == label_from_attrs(da) + + da.attrs["units"] = "a_units" + da.attrs.pop("standard_name") + assert "a [a_units]" == label_from_attrs(da) + + da.attrs.pop("units") + assert "a" == label_from_attrs(da) + + # Latex strings can be longer without needing a new line: + long_latex_name = r"$Ra_s = \mathrm{mean}(\epsilon_k) / \mu M^2_\infty$" + da.attrs = dict(long_name=long_latex_name) + assert label_from_attrs(da) == long_latex_name + + def test1d(self) -> None: + self.darray[:, 0, 0].plot() + + with pytest.raises(ValueError, match=r"x must be one of None, 'dim_0'"): + self.darray[:, 0, 0].plot(x="dim_1") + + with pytest.raises(TypeError, match=r"complex128"): + (self.darray[:, 0, 0] + 1j).plot() + + def test_1d_bool(self) -> None: + xr.ones_like(self.darray[:, 0, 0], dtype=bool).plot() + + def test_1d_x_y_kw(self) -> None: + z = np.arange(10) + da = DataArray(np.cos(z), dims=["z"], coords=[z], name="f") + + xy: list[list[None | str]] = [[None, None], [None, "z"], ["z", None]] + + f, ax = plt.subplots(3, 1) + for aa, (x, y) in enumerate(xy): + da.plot(x=x, y=y, ax=ax.flat[aa]) + + with pytest.raises(ValueError, match=r"Cannot specify both"): + da.plot(x="z", y="z") + + error_msg = "must be one of None, 'z'" + with pytest.raises(ValueError, match=rf"x {error_msg}"): + da.plot(x="f") + + with pytest.raises(ValueError, match=rf"y {error_msg}"): + da.plot(y="f") + + def test_multiindex_level_as_coord(self) -> None: + da = xr.DataArray( + np.arange(5), + dims="x", + coords=dict(a=("x", np.arange(5)), b=("x", np.arange(5, 10))), + ) + da = da.set_index(x=["a", "b"]) + + for x in ["a", "b"]: + h = da.plot(x=x)[0] + assert_array_equal(h.get_xdata(), da[x].values) + + for y in ["a", "b"]: + h = da.plot(y=y)[0] + assert_array_equal(h.get_ydata(), da[y].values) + + # Test for bug in GH issue #2725 + def test_infer_line_data(self) -> None: + current = DataArray( + name="I", + data=np.array([5, 8]), + dims=["t"], + coords={ + "t": (["t"], np.array([0.1, 0.2])), + "V": (["t"], np.array([100, 200])), + }, + ) + + # Plot current against voltage + line = current.plot.line(x="V")[0] + assert_array_equal(line.get_xdata(), current.coords["V"].values) + + # Plot current against time + line = current.plot.line()[0] + assert_array_equal(line.get_xdata(), current.coords["t"].values) + + def test_line_plot_along_1d_coord(self) -> None: + # Test for bug in GH #3334 + x_coord = xr.DataArray(data=[0.1, 0.2], dims=["x"]) + t_coord = xr.DataArray(data=[10, 20], dims=["t"]) + + da = xr.DataArray( + data=np.array([[0, 1], [5, 9]]), + dims=["x", "t"], + coords={"x": x_coord, "time": t_coord}, + ) + + line = da.plot(x="time", hue="x")[0] + assert_array_equal(line.get_xdata(), da.coords["time"].values) + + line = da.plot(y="time", hue="x")[0] + assert_array_equal(line.get_ydata(), da.coords["time"].values) + + def test_line_plot_wrong_hue(self) -> None: + da = xr.DataArray( + data=np.array([[0, 1], [5, 9]]), + dims=["x", "t"], + ) + + with pytest.raises(ValueError, match="hue must be one of"): + da.plot(x="t", hue="wrong_coord") + + def test_2d_line(self) -> None: + with pytest.raises(ValueError, match=r"hue"): + self.darray[:, :, 0].plot.line() + + self.darray[:, :, 0].plot.line(hue="dim_1") + self.darray[:, :, 0].plot.line(x="dim_1") + self.darray[:, :, 0].plot.line(y="dim_1") + self.darray[:, :, 0].plot.line(x="dim_0", hue="dim_1") + self.darray[:, :, 0].plot.line(y="dim_0", hue="dim_1") + + with pytest.raises(ValueError, match=r"Cannot"): + self.darray[:, :, 0].plot.line(x="dim_1", y="dim_0", hue="dim_1") + + def test_2d_line_accepts_legend_kw(self) -> None: + self.darray[:, :, 0].plot.line(x="dim_0", add_legend=False) + assert not plt.gca().get_legend() + plt.cla() + self.darray[:, :, 0].plot.line(x="dim_0", add_legend=True) + assert plt.gca().get_legend() + # check whether legend title is set + assert plt.gca().get_legend().get_title().get_text() == "dim_1" + + def test_2d_line_accepts_x_kw(self) -> None: + self.darray[:, :, 0].plot.line(x="dim_0") + assert plt.gca().get_xlabel() == "dim_0" + plt.cla() + self.darray[:, :, 0].plot.line(x="dim_1") + assert plt.gca().get_xlabel() == "dim_1" + + def test_2d_line_accepts_hue_kw(self) -> None: + self.darray[:, :, 0].plot.line(hue="dim_0") + assert plt.gca().get_legend().get_title().get_text() == "dim_0" + plt.cla() + self.darray[:, :, 0].plot.line(hue="dim_1") + assert plt.gca().get_legend().get_title().get_text() == "dim_1" + + def test_2d_coords_line_plot(self) -> None: + lon, lat = np.meshgrid(np.linspace(-20, 20, 5), np.linspace(0, 30, 4)) + lon += lat / 10 + lat += lon / 10 + da = xr.DataArray( + np.arange(20).reshape(4, 5), + dims=["y", "x"], + coords={"lat": (("y", "x"), lat), "lon": (("y", "x"), lon)}, + ) + + with figure_context(): + hdl = da.plot.line(x="lon", hue="x") + assert len(hdl) == 5 + + with figure_context(): + hdl = da.plot.line(x="lon", hue="y") + assert len(hdl) == 4 + + with pytest.raises(ValueError, match="For 2D inputs, hue must be a dimension"): + da.plot.line(x="lon", hue="lat") + + def test_2d_coord_line_plot_coords_transpose_invariant(self) -> None: + # checks for bug reported in GH #3933 + x = np.arange(10) + y = np.arange(20) + ds = xr.Dataset(coords={"x": x, "y": y}) + + for z in [ds.y + ds.x, ds.x + ds.y]: + ds = ds.assign_coords(z=z) + ds["v"] = ds.x + ds.y + ds["v"].plot.line(y="z", hue="x") + + def test_2d_before_squeeze(self) -> None: + a = DataArray(easy_array((1, 5))) + a.plot() + + def test2d_uniform_calls_imshow(self) -> None: + assert self.imshow_called(self.darray[:, :, 0].plot.imshow) + + @pytest.mark.slow + def test2d_nonuniform_calls_contourf(self) -> None: + a = self.darray[:, :, 0] + a.coords["dim_1"] = [2, 1, 89] + assert self.contourf_called(a.plot.contourf) + + def test2d_1d_2d_coordinates_contourf(self) -> None: + sz = (20, 10) + depth = easy_array(sz) + a = DataArray( + easy_array(sz), + dims=["z", "time"], + coords={"depth": (["z", "time"], depth), "time": np.linspace(0, 1, sz[1])}, + ) + + a.plot.contourf(x="time", y="depth") + a.plot.contourf(x="depth", y="time") + + def test2d_1d_2d_coordinates_pcolormesh(self) -> None: + # Test with equal coordinates to catch bug from #5097 + sz = 10 + y2d, x2d = np.meshgrid(np.arange(sz), np.arange(sz)) + a = DataArray( + easy_array((sz, sz)), + dims=["x", "y"], + coords={"x2d": (["x", "y"], x2d), "y2d": (["x", "y"], y2d)}, + ) + + for x, y in [ + ("x", "y"), + ("y", "x"), + ("x2d", "y"), + ("y", "x2d"), + ("x", "y2d"), + ("y2d", "x"), + ("x2d", "y2d"), + ("y2d", "x2d"), + ]: + p = a.plot.pcolormesh(x=x, y=y) + v = p.get_paths()[0].vertices + assert isinstance(v, np.ndarray) + + # Check all vertices are different, except last vertex which should be the + # same as the first + _, unique_counts = np.unique(v[:-1], axis=0, return_counts=True) + assert np.all(unique_counts == 1) + + def test_str_coordinates_pcolormesh(self) -> None: + # test for #6775 + x = DataArray( + [[1, 2, 3], [4, 5, 6]], + dims=("a", "b"), + coords={"a": [1, 2], "b": ["a", "b", "c"]}, + ) + x.plot.pcolormesh() + x.T.plot.pcolormesh() + + def test_contourf_cmap_set(self) -> None: + a = DataArray(easy_array((4, 4)), dims=["z", "time"]) + + cmap_expected = mpl.colormaps["viridis"] + + # use copy to ensure cmap is not changed by contourf() + # Set vmin and vmax so that _build_discrete_colormap is called with + # extend='both'. extend is passed to + # mpl.colors.from_levels_and_colors(), which returns a result with + # sensible under and over values if extend='both', but not if + # extend='neither' (but if extend='neither' the under and over values + # would not be used because the data would all be within the plotted + # range) + pl = a.plot.contourf(cmap=copy(cmap_expected), vmin=0.1, vmax=0.9) + + # check the set_bad color + cmap = pl.cmap + assert cmap is not None + assert_array_equal( + cmap(np.ma.masked_invalid([np.nan]))[0], + cmap_expected(np.ma.masked_invalid([np.nan]))[0], + ) + + # check the set_under color + assert cmap(-np.inf) == cmap_expected(-np.inf) + + # check the set_over color + assert cmap(np.inf) == cmap_expected(np.inf) + + def test_contourf_cmap_set_with_bad_under_over(self) -> None: + a = DataArray(easy_array((4, 4)), dims=["z", "time"]) + + # make a copy here because we want a local cmap that we will modify. + cmap_expected = copy(mpl.colormaps["viridis"]) + + cmap_expected.set_bad("w") + # check we actually changed the set_bad color + assert np.all( + cmap_expected(np.ma.masked_invalid([np.nan]))[0] + != mpl.colormaps["viridis"](np.ma.masked_invalid([np.nan]))[0] + ) + + cmap_expected.set_under("r") + # check we actually changed the set_under color + assert cmap_expected(-np.inf) != mpl.colormaps["viridis"](-np.inf) + + cmap_expected.set_over("g") + # check we actually changed the set_over color + assert cmap_expected(np.inf) != mpl.colormaps["viridis"](-np.inf) + + # copy to ensure cmap is not changed by contourf() + pl = a.plot.contourf(cmap=copy(cmap_expected)) + cmap = pl.cmap + assert cmap is not None + + # check the set_bad color has been kept + assert_array_equal( + cmap(np.ma.masked_invalid([np.nan]))[0], + cmap_expected(np.ma.masked_invalid([np.nan]))[0], + ) + + # check the set_under color has been kept + assert cmap(-np.inf) == cmap_expected(-np.inf) + + # check the set_over color has been kept + assert cmap(np.inf) == cmap_expected(np.inf) + + def test3d(self) -> None: + self.darray.plot() + + def test_can_pass_in_axis(self) -> None: + self.pass_in_axis(self.darray.plot) + + def test__infer_interval_breaks(self) -> None: + assert_array_equal([-0.5, 0.5, 1.5], _infer_interval_breaks([0, 1])) + assert_array_equal( + [-0.5, 0.5, 5.0, 9.5, 10.5], _infer_interval_breaks([0, 1, 9, 10]) + ) + assert_array_equal( + pd.date_range("20000101", periods=4) - np.timedelta64(12, "h"), + _infer_interval_breaks(pd.date_range("20000101", periods=3)), + ) + + # make a bounded 2D array that we will center and re-infer + xref, yref = np.meshgrid(np.arange(6), np.arange(5)) + cx = (xref[1:, 1:] + xref[:-1, :-1]) / 2 + cy = (yref[1:, 1:] + yref[:-1, :-1]) / 2 + x = _infer_interval_breaks(cx, axis=1) + x = _infer_interval_breaks(x, axis=0) + y = _infer_interval_breaks(cy, axis=1) + y = _infer_interval_breaks(y, axis=0) + np.testing.assert_allclose(xref, x) + np.testing.assert_allclose(yref, y) + + # test that ValueError is raised for non-monotonic 1D inputs + with pytest.raises(ValueError): + _infer_interval_breaks(np.array([0, 2, 1]), check_monotonic=True) + + def test__infer_interval_breaks_logscale(self) -> None: + """ + Check if interval breaks are defined in the logspace if scale="log" + """ + # Check for 1d arrays + x = np.logspace(-4, 3, 8) + expected_interval_breaks = 10 ** np.linspace(-4.5, 3.5, 9) + np.testing.assert_allclose( + _infer_interval_breaks(x, scale="log"), expected_interval_breaks + ) + + # Check for 2d arrays + x = np.logspace(-4, 3, 8) + y = np.linspace(-5, 5, 11) + x, y = np.meshgrid(x, y) + expected_interval_breaks = np.vstack([10 ** np.linspace(-4.5, 3.5, 9)] * 12) + x = _infer_interval_breaks(x, axis=1, scale="log") + x = _infer_interval_breaks(x, axis=0, scale="log") + np.testing.assert_allclose(x, expected_interval_breaks) + + def test__infer_interval_breaks_logscale_invalid_coords(self) -> None: + """ + Check error is raised when passing non-positive coordinates with logscale + """ + # Check if error is raised after a zero value in the array + x = np.linspace(0, 5, 6) + with pytest.raises(ValueError): + _infer_interval_breaks(x, scale="log") + # Check if error is raised after negative values in the array + x = np.linspace(-5, 5, 11) + with pytest.raises(ValueError): + _infer_interval_breaks(x, scale="log") + + def test_geo_data(self) -> None: + # Regression test for gh2250 + # Realistic coordinates taken from the example dataset + lat = np.array( + [ + [16.28, 18.48, 19.58, 19.54, 18.35], + [28.07, 30.52, 31.73, 31.68, 30.37], + [39.65, 42.27, 43.56, 43.51, 42.11], + [50.52, 53.22, 54.55, 54.50, 53.06], + ] + ) + lon = np.array( + [ + [-126.13, -113.69, -100.92, -88.04, -75.29], + [-129.27, -115.62, -101.54, -87.32, -73.26], + [-133.10, -118.00, -102.31, -86.42, -70.76], + [-137.85, -120.99, -103.28, -85.28, -67.62], + ] + ) + data = np.sqrt(lon**2 + lat**2) + da = DataArray( + data, + dims=("y", "x"), + coords={"lon": (("y", "x"), lon), "lat": (("y", "x"), lat)}, + ) + da.plot(x="lon", y="lat") + ax = plt.gca() + assert ax.has_data() + da.plot(x="lat", y="lon") + ax = plt.gca() + assert ax.has_data() + + def test_datetime_dimension(self) -> None: + nrow = 3 + ncol = 4 + time = pd.date_range("2000-01-01", periods=nrow) + a = DataArray( + easy_array((nrow, ncol)), coords=[("time", time), ("y", range(ncol))] + ) + a.plot() + ax = plt.gca() + assert ax.has_data() + + def test_date_dimension(self) -> None: + nrow = 3 + ncol = 4 + start = date(2000, 1, 1) + time = [start + timedelta(days=i) for i in range(nrow)] + a = DataArray( + easy_array((nrow, ncol)), coords=[("time", time), ("y", range(ncol))] + ) + a.plot() + ax = plt.gca() + assert ax.has_data() + + @pytest.mark.slow + @pytest.mark.filterwarnings("ignore:tight_layout cannot") + def test_convenient_facetgrid(self) -> None: + a = easy_array((10, 15, 4)) + d = DataArray(a, dims=["y", "x", "z"]) + d.coords["z"] = list("abcd") + g = d.plot(x="x", y="y", col="z", col_wrap=2, cmap="cool") + + assert_array_equal(g.axs.shape, [2, 2]) + for ax in g.axs.flat: + assert ax.has_data() + + with pytest.raises(ValueError, match=r"[Ff]acet"): + d.plot(x="x", y="y", col="z", ax=plt.gca()) + + with pytest.raises(ValueError, match=r"[Ff]acet"): + d[0].plot(x="x", y="y", col="z", ax=plt.gca()) + + @pytest.mark.slow + def test_subplot_kws(self) -> None: + a = easy_array((10, 15, 4)) + d = DataArray(a, dims=["y", "x", "z"]) + d.coords["z"] = list("abcd") + g = d.plot( + x="x", + y="y", + col="z", + col_wrap=2, + cmap="cool", + subplot_kws=dict(facecolor="r"), + ) + for ax in g.axs.flat: + # mpl V2 + assert ax.get_facecolor()[0:3] == mpl.colors.to_rgb("r") + + @pytest.mark.slow + def test_plot_size(self) -> None: + self.darray[:, 0, 0].plot(figsize=(13, 5)) + assert tuple(plt.gcf().get_size_inches()) == (13, 5) + + self.darray.plot(figsize=(13, 5)) + assert tuple(plt.gcf().get_size_inches()) == (13, 5) + + self.darray.plot(size=5) + assert plt.gcf().get_size_inches()[1] == 5 + + self.darray.plot(size=5, aspect=2) + assert tuple(plt.gcf().get_size_inches()) == (10, 5) + + with pytest.raises(ValueError, match=r"cannot provide both"): + self.darray.plot(ax=plt.gca(), figsize=(3, 4)) + + with pytest.raises(ValueError, match=r"cannot provide both"): + self.darray.plot(size=5, figsize=(3, 4)) + + with pytest.raises(ValueError, match=r"cannot provide both"): + self.darray.plot(size=5, ax=plt.gca()) + + with pytest.raises(ValueError, match=r"cannot provide `aspect`"): + self.darray.plot(aspect=1) + + @pytest.mark.slow + @pytest.mark.filterwarnings("ignore:tight_layout cannot") + def test_convenient_facetgrid_4d(self) -> None: + a = easy_array((10, 15, 2, 3)) + d = DataArray(a, dims=["y", "x", "columns", "rows"]) + g = d.plot(x="x", y="y", col="columns", row="rows") + + assert_array_equal(g.axs.shape, [3, 2]) + for ax in g.axs.flat: + assert ax.has_data() + + with pytest.raises(ValueError, match=r"[Ff]acet"): + d.plot(x="x", y="y", col="columns", ax=plt.gca()) + + def test_coord_with_interval(self) -> None: + """Test line plot with intervals.""" + bins = [-1, 0, 1, 2] + self.darray.groupby_bins("dim_0", bins).mean(...).plot() + + def test_coord_with_interval_x(self) -> None: + """Test line plot with intervals explicitly on x axis.""" + bins = [-1, 0, 1, 2] + self.darray.groupby_bins("dim_0", bins).mean(...).plot(x="dim_0_bins") + + def test_coord_with_interval_y(self) -> None: + """Test line plot with intervals explicitly on y axis.""" + bins = [-1, 0, 1, 2] + self.darray.groupby_bins("dim_0", bins).mean(...).plot(y="dim_0_bins") + + def test_coord_with_interval_xy(self) -> None: + """Test line plot with intervals on both x and y axes.""" + bins = [-1, 0, 1, 2] + self.darray.groupby_bins("dim_0", bins).mean(...).dim_0_bins.plot() + + @pytest.mark.parametrize("dim", ("x", "y")) + def test_labels_with_units_with_interval(self, dim) -> None: + """Test line plot with intervals and a units attribute.""" + bins = [-1, 0, 1, 2] + arr = self.darray.groupby_bins("dim_0", bins).mean(...) + arr.dim_0_bins.attrs["units"] = "m" + + (mappable,) = arr.plot(**{dim: "dim_0_bins"}) + ax = mappable.figure.gca() + actual = getattr(ax, f"get_{dim}label")() + + expected = "dim_0_bins_center [m]" + assert actual == expected + + def test_multiplot_over_length_one_dim(self) -> None: + a = easy_array((3, 1, 1, 1)) + d = DataArray(a, dims=("x", "col", "row", "hue")) + d.plot(col="col") + d.plot(row="row") + d.plot(hue="hue") + + +class TestPlot1D(PlotTestCase): + @pytest.fixture(autouse=True) + def setUp(self) -> None: + d = [0, 1.1, 0, 2] + self.darray = DataArray(d, coords={"period": range(len(d))}, dims="period") + self.darray.period.attrs["units"] = "s" + + def test_xlabel_is_index_name(self) -> None: + self.darray.plot() + assert "period [s]" == plt.gca().get_xlabel() + + def test_no_label_name_on_x_axis(self) -> None: + self.darray.plot(y="period") + assert "" == plt.gca().get_xlabel() + + def test_no_label_name_on_y_axis(self) -> None: + self.darray.plot() + assert "" == plt.gca().get_ylabel() + + def test_ylabel_is_data_name(self) -> None: + self.darray.name = "temperature" + self.darray.attrs["units"] = "degrees_Celsius" + self.darray.plot() + assert "temperature [degrees_Celsius]" == plt.gca().get_ylabel() + + def test_xlabel_is_data_name(self) -> None: + self.darray.name = "temperature" + self.darray.attrs["units"] = "degrees_Celsius" + self.darray.plot(y="period") + assert "temperature [degrees_Celsius]" == plt.gca().get_xlabel() + + def test_format_string(self) -> None: + self.darray.plot.line("ro") + + def test_can_pass_in_axis(self) -> None: + self.pass_in_axis(self.darray.plot.line) + + def test_nonnumeric_index(self) -> None: + a = DataArray([1, 2, 3], {"letter": ["a", "b", "c"]}, dims="letter") + a.plot.line() + + def test_primitive_returned(self) -> None: + p = self.darray.plot.line() + assert isinstance(p[0], mpl.lines.Line2D) + + @pytest.mark.slow + def test_plot_nans(self) -> None: + self.darray[1] = np.nan + self.darray.plot.line() + + def test_dates_are_concise(self) -> None: + import matplotlib.dates as mdates + + time = pd.date_range("2000-01-01", "2000-01-10") + a = DataArray(np.arange(len(time)), [("t", time)]) + a.plot.line() + + ax = plt.gca() + + assert isinstance(ax.xaxis.get_major_locator(), mdates.AutoDateLocator) + assert isinstance(ax.xaxis.get_major_formatter(), mdates.ConciseDateFormatter) + + def test_xyincrease_false_changes_axes(self) -> None: + self.darray.plot.line(xincrease=False, yincrease=False) + xlim = plt.gca().get_xlim() + ylim = plt.gca().get_ylim() + diffs = xlim[1] - xlim[0], ylim[1] - ylim[0] + assert all(x < 0 for x in diffs) + + def test_slice_in_title(self) -> None: + self.darray.coords["d"] = 10.009 + self.darray.plot.line() + title = plt.gca().get_title() + assert "d = 10.01" == title + + def test_slice_in_title_single_item_array(self) -> None: + """Edge case for data of shape (1, N) or (N, 1).""" + darray = self.darray.expand_dims({"d": np.array([10.009])}) + darray.plot.line(x="period") + title = plt.gca().get_title() + assert "d = 10.01" == title + + +class TestPlotStep(PlotTestCase): + @pytest.fixture(autouse=True) + def setUp(self) -> None: + self.darray = DataArray(easy_array((2, 3, 4))) + + def test_step(self) -> None: + hdl = self.darray[0, 0].plot.step() + assert "steps" in hdl[0].get_drawstyle() + + @pytest.mark.parametrize("where", ["pre", "post", "mid"]) + def test_step_with_where(self, where) -> None: + hdl = self.darray[0, 0].plot.step(where=where) + assert hdl[0].get_drawstyle() == f"steps-{where}" + + def test_step_with_hue(self) -> None: + hdl = self.darray[0].plot.step(hue="dim_2") + assert hdl[0].get_drawstyle() == "steps-pre" + + @pytest.mark.parametrize("where", ["pre", "post", "mid"]) + def test_step_with_hue_and_where(self, where) -> None: + hdl = self.darray[0].plot.step(hue="dim_2", where=where) + assert hdl[0].get_drawstyle() == f"steps-{where}" + + def test_drawstyle_steps(self) -> None: + hdl = self.darray[0].plot(hue="dim_2", drawstyle="steps") + assert hdl[0].get_drawstyle() == "steps" + + @pytest.mark.parametrize("where", ["pre", "post", "mid"]) + def test_drawstyle_steps_with_where(self, where) -> None: + hdl = self.darray[0].plot(hue="dim_2", drawstyle=f"steps-{where}") + assert hdl[0].get_drawstyle() == f"steps-{where}" + + def test_coord_with_interval_step(self) -> None: + """Test step plot with intervals.""" + bins = [-1, 0, 1, 2] + self.darray.groupby_bins("dim_0", bins).mean(...).plot.step() + line = plt.gca().lines[0] + assert isinstance(line, mpl.lines.Line2D) + assert len(np.asarray(line.get_xdata())) == ((len(bins) - 1) * 2) + + def test_coord_with_interval_step_x(self) -> None: + """Test step plot with intervals explicitly on x axis.""" + bins = [-1, 0, 1, 2] + self.darray.groupby_bins("dim_0", bins).mean(...).plot.step(x="dim_0_bins") + line = plt.gca().lines[0] + assert isinstance(line, mpl.lines.Line2D) + assert len(np.asarray(line.get_xdata())) == ((len(bins) - 1) * 2) + + def test_coord_with_interval_step_y(self) -> None: + """Test step plot with intervals explicitly on y axis.""" + bins = [-1, 0, 1, 2] + self.darray.groupby_bins("dim_0", bins).mean(...).plot.step(y="dim_0_bins") + line = plt.gca().lines[0] + assert isinstance(line, mpl.lines.Line2D) + assert len(np.asarray(line.get_xdata())) == ((len(bins) - 1) * 2) + + def test_coord_with_interval_step_x_and_y_raises_valueeerror(self) -> None: + """Test that step plot with intervals both on x and y axes raises an error.""" + arr = xr.DataArray( + [pd.Interval(0, 1), pd.Interval(1, 2)], + coords=[("x", [pd.Interval(0, 1), pd.Interval(1, 2)])], + ) + with pytest.raises(TypeError, match="intervals against intervals"): + arr.plot.step() + + +class TestPlotHistogram(PlotTestCase): + @pytest.fixture(autouse=True) + def setUp(self) -> None: + self.darray = DataArray(easy_array((2, 3, 4))) + + def test_3d_array(self) -> None: + self.darray.plot.hist() + + def test_xlabel_uses_name(self) -> None: + self.darray.name = "testpoints" + self.darray.attrs["units"] = "testunits" + self.darray.plot.hist() + assert "testpoints [testunits]" == plt.gca().get_xlabel() + + def test_title_is_histogram(self) -> None: + self.darray.coords["d"] = 10 + self.darray.plot.hist() + assert "d = 10" == plt.gca().get_title() + + def test_can_pass_in_kwargs(self) -> None: + nbins = 5 + self.darray.plot.hist(bins=nbins) + assert nbins == len(plt.gca().patches) + + def test_can_pass_in_axis(self) -> None: + self.pass_in_axis(self.darray.plot.hist) + + def test_primitive_returned(self) -> None: + n, bins, patches = self.darray.plot.hist() + assert isinstance(n, np.ndarray) + assert isinstance(bins, np.ndarray) + assert isinstance(patches, mpl.container.BarContainer) + assert isinstance(patches[0], mpl.patches.Rectangle) + + @pytest.mark.slow + def test_plot_nans(self) -> None: + self.darray[0, 0, 0] = np.nan + self.darray.plot.hist() + + def test_hist_coord_with_interval(self) -> None: + ( + self.darray.groupby_bins("dim_0", [-1, 0, 1, 2]) + .mean(...) + .plot.hist(range=(-1, 2)) + ) + + +@requires_matplotlib +class TestDetermineCmapParams: + @pytest.fixture(autouse=True) + def setUp(self) -> None: + self.data = np.linspace(0, 1, num=100) + + def test_robust(self) -> None: + cmap_params = _determine_cmap_params(self.data, robust=True) + assert cmap_params["vmin"] == np.percentile(self.data, 2) + assert cmap_params["vmax"] == np.percentile(self.data, 98) + assert cmap_params["cmap"] == "viridis" + assert cmap_params["extend"] == "both" + assert cmap_params["levels"] is None + assert cmap_params["norm"] is None + + def test_center(self) -> None: + cmap_params = _determine_cmap_params(self.data, center=0.5) + assert cmap_params["vmax"] - 0.5 == 0.5 - cmap_params["vmin"] + assert cmap_params["cmap"] == "RdBu_r" + assert cmap_params["extend"] == "neither" + assert cmap_params["levels"] is None + assert cmap_params["norm"] is None + + def test_cmap_sequential_option(self) -> None: + with xr.set_options(cmap_sequential="magma"): + cmap_params = _determine_cmap_params(self.data) + assert cmap_params["cmap"] == "magma" + + def test_cmap_sequential_explicit_option(self) -> None: + with xr.set_options(cmap_sequential=mpl.colormaps["magma"]): + cmap_params = _determine_cmap_params(self.data) + assert cmap_params["cmap"] == mpl.colormaps["magma"] + + def test_cmap_divergent_option(self) -> None: + with xr.set_options(cmap_divergent="magma"): + cmap_params = _determine_cmap_params(self.data, center=0.5) + assert cmap_params["cmap"] == "magma" + + def test_nan_inf_are_ignored(self) -> None: + cmap_params1 = _determine_cmap_params(self.data) + data = self.data + data[50:55] = np.nan + data[56:60] = np.inf + cmap_params2 = _determine_cmap_params(data) + assert cmap_params1["vmin"] == cmap_params2["vmin"] + assert cmap_params1["vmax"] == cmap_params2["vmax"] + + @pytest.mark.slow + def test_integer_levels(self) -> None: + data = self.data + 1 + + # default is to cover full data range but with no guarantee on Nlevels + for level in np.arange(2, 10, dtype=int): + cmap_params = _determine_cmap_params(data, levels=level) + assert cmap_params["vmin"] is None + assert cmap_params["vmax"] is None + assert cmap_params["norm"].vmin == cmap_params["levels"][0] + assert cmap_params["norm"].vmax == cmap_params["levels"][-1] + assert cmap_params["extend"] == "neither" + + # with min max we are more strict + cmap_params = _determine_cmap_params( + data, levels=5, vmin=0, vmax=5, cmap="Blues" + ) + assert cmap_params["vmin"] is None + assert cmap_params["vmax"] is None + assert cmap_params["norm"].vmin == 0 + assert cmap_params["norm"].vmax == 5 + assert cmap_params["norm"].vmin == cmap_params["levels"][0] + assert cmap_params["norm"].vmax == cmap_params["levels"][-1] + assert cmap_params["cmap"].name == "Blues" + assert cmap_params["extend"] == "neither" + assert cmap_params["cmap"].N == 4 + assert cmap_params["norm"].N == 5 + + cmap_params = _determine_cmap_params(data, levels=5, vmin=0.5, vmax=1.5) + assert cmap_params["cmap"].name == "viridis" + assert cmap_params["extend"] == "max" + + cmap_params = _determine_cmap_params(data, levels=5, vmin=1.5) + assert cmap_params["cmap"].name == "viridis" + assert cmap_params["extend"] == "min" + + cmap_params = _determine_cmap_params(data, levels=5, vmin=1.3, vmax=1.5) + assert cmap_params["cmap"].name == "viridis" + assert cmap_params["extend"] == "both" + + def test_list_levels(self) -> None: + data = self.data + 1 + + orig_levels = [0, 1, 2, 3, 4, 5] + # vmin and vmax should be ignored if levels are explicitly provided + cmap_params = _determine_cmap_params(data, levels=orig_levels, vmin=0, vmax=3) + assert cmap_params["vmin"] is None + assert cmap_params["vmax"] is None + assert cmap_params["norm"].vmin == 0 + assert cmap_params["norm"].vmax == 5 + assert cmap_params["cmap"].N == 5 + assert cmap_params["norm"].N == 6 + + for wrap_levels in [list, np.array, pd.Index, DataArray]: + cmap_params = _determine_cmap_params(data, levels=wrap_levels(orig_levels)) + assert_array_equal(cmap_params["levels"], orig_levels) + + def test_divergentcontrol(self) -> None: + neg = self.data - 0.1 + pos = self.data + + # Default with positive data will be a normal cmap + cmap_params = _determine_cmap_params(pos) + assert cmap_params["vmin"] == 0 + assert cmap_params["vmax"] == 1 + assert cmap_params["cmap"] == "viridis" + + # Default with negative data will be a divergent cmap + cmap_params = _determine_cmap_params(neg) + assert cmap_params["vmin"] == -0.9 + assert cmap_params["vmax"] == 0.9 + assert cmap_params["cmap"] == "RdBu_r" + + # Setting vmin or vmax should prevent this only if center is false + cmap_params = _determine_cmap_params(neg, vmin=-0.1, center=False) + assert cmap_params["vmin"] == -0.1 + assert cmap_params["vmax"] == 0.9 + assert cmap_params["cmap"] == "viridis" + cmap_params = _determine_cmap_params(neg, vmax=0.5, center=False) + assert cmap_params["vmin"] == -0.1 + assert cmap_params["vmax"] == 0.5 + assert cmap_params["cmap"] == "viridis" + + # Setting center=False too + cmap_params = _determine_cmap_params(neg, center=False) + assert cmap_params["vmin"] == -0.1 + assert cmap_params["vmax"] == 0.9 + assert cmap_params["cmap"] == "viridis" + + # However, I should still be able to set center and have a div cmap + cmap_params = _determine_cmap_params(neg, center=0) + assert cmap_params["vmin"] == -0.9 + assert cmap_params["vmax"] == 0.9 + assert cmap_params["cmap"] == "RdBu_r" + + # Setting vmin or vmax alone will force symmetric bounds around center + cmap_params = _determine_cmap_params(neg, vmin=-0.1) + assert cmap_params["vmin"] == -0.1 + assert cmap_params["vmax"] == 0.1 + assert cmap_params["cmap"] == "RdBu_r" + cmap_params = _determine_cmap_params(neg, vmax=0.5) + assert cmap_params["vmin"] == -0.5 + assert cmap_params["vmax"] == 0.5 + assert cmap_params["cmap"] == "RdBu_r" + cmap_params = _determine_cmap_params(neg, vmax=0.6, center=0.1) + assert cmap_params["vmin"] == -0.4 + assert cmap_params["vmax"] == 0.6 + assert cmap_params["cmap"] == "RdBu_r" + + # But this is only true if vmin or vmax are negative + cmap_params = _determine_cmap_params(pos, vmin=-0.1) + assert cmap_params["vmin"] == -0.1 + assert cmap_params["vmax"] == 0.1 + assert cmap_params["cmap"] == "RdBu_r" + cmap_params = _determine_cmap_params(pos, vmin=0.1) + assert cmap_params["vmin"] == 0.1 + assert cmap_params["vmax"] == 1 + assert cmap_params["cmap"] == "viridis" + cmap_params = _determine_cmap_params(pos, vmax=0.5) + assert cmap_params["vmin"] == 0 + assert cmap_params["vmax"] == 0.5 + assert cmap_params["cmap"] == "viridis" + + # If both vmin and vmax are provided, output is non-divergent + cmap_params = _determine_cmap_params(neg, vmin=-0.2, vmax=0.6) + assert cmap_params["vmin"] == -0.2 + assert cmap_params["vmax"] == 0.6 + assert cmap_params["cmap"] == "viridis" + + # regression test for GH3524 + # infer diverging colormap from divergent levels + cmap_params = _determine_cmap_params(pos, levels=[-0.1, 0, 1]) + # specifying levels makes cmap a Colormap object + assert cmap_params["cmap"].name == "RdBu_r" + + def test_norm_sets_vmin_vmax(self) -> None: + vmin = self.data.min() + vmax = self.data.max() + + for norm, extend, levels in zip( + [ + mpl.colors.Normalize(), + mpl.colors.Normalize(), + mpl.colors.Normalize(vmin + 0.1, vmax - 0.1), + mpl.colors.Normalize(None, vmax - 0.1), + mpl.colors.Normalize(vmin + 0.1, None), + ], + ["neither", "neither", "both", "max", "min"], + [7, None, None, None, None], + ): + test_min = vmin if norm.vmin is None else norm.vmin + test_max = vmax if norm.vmax is None else norm.vmax + + cmap_params = _determine_cmap_params(self.data, norm=norm, levels=levels) + assert cmap_params["vmin"] is None + assert cmap_params["vmax"] is None + assert cmap_params["norm"].vmin == test_min + assert cmap_params["norm"].vmax == test_max + assert cmap_params["extend"] == extend + assert cmap_params["norm"] == norm + + +@requires_matplotlib +class TestDiscreteColorMap: + @pytest.fixture(autouse=True) + def setUp(self): + x = np.arange(start=0, stop=10, step=2) + y = np.arange(start=9, stop=-7, step=-3) + xy = np.dstack(np.meshgrid(x, y)) + distance = np.linalg.norm(xy, axis=2) + self.darray = DataArray(distance, list(zip(("y", "x"), (y, x)))) + self.data_min = distance.min() + self.data_max = distance.max() + yield + # Remove all matplotlib figures + plt.close("all") + + @pytest.mark.slow + def test_recover_from_seaborn_jet_exception(self) -> None: + pal = _color_palette("jet", 4) + assert type(pal) == np.ndarray + assert len(pal) == 4 + + @pytest.mark.slow + def test_build_discrete_cmap(self) -> None: + for cmap, levels, extend, filled in [ + ("jet", [0, 1], "both", False), + ("hot", [-4, 4], "max", True), + ]: + ncmap, cnorm = _build_discrete_cmap(cmap, levels, extend, filled) + assert ncmap.N == len(levels) - 1 + assert len(ncmap.colors) == len(levels) - 1 + assert cnorm.N == len(levels) + assert_array_equal(cnorm.boundaries, levels) + assert max(levels) == cnorm.vmax + assert min(levels) == cnorm.vmin + if filled: + assert ncmap.colorbar_extend == extend + else: + assert ncmap.colorbar_extend == "max" + + @pytest.mark.slow + def test_discrete_colormap_list_of_levels(self) -> None: + for extend, levels in [ + ("max", [-1, 2, 4, 8, 10]), + ("both", [2, 5, 10, 11]), + ("neither", [0, 5, 10, 15]), + ("min", [2, 5, 10, 15]), + ]: + for kind in ["imshow", "pcolormesh", "contourf", "contour"]: + primitive = getattr(self.darray.plot, kind)(levels=levels) + assert_array_equal(levels, primitive.norm.boundaries) + assert max(levels) == primitive.norm.vmax + assert min(levels) == primitive.norm.vmin + if kind != "contour": + assert extend == primitive.cmap.colorbar_extend + else: + assert "max" == primitive.cmap.colorbar_extend + assert len(levels) - 1 == len(primitive.cmap.colors) + + @pytest.mark.slow + def test_discrete_colormap_int_levels(self) -> None: + for extend, levels, vmin, vmax, cmap in [ + ("neither", 7, None, None, None), + ("neither", 7, None, 20, mpl.colormaps["RdBu"]), + ("both", 7, 4, 8, None), + ("min", 10, 4, 15, None), + ]: + for kind in ["imshow", "pcolormesh", "contourf", "contour"]: + primitive = getattr(self.darray.plot, kind)( + levels=levels, vmin=vmin, vmax=vmax, cmap=cmap + ) + assert levels >= len(primitive.norm.boundaries) - 1 + if vmax is None: + assert primitive.norm.vmax >= self.data_max + else: + assert primitive.norm.vmax >= vmax + if vmin is None: + assert primitive.norm.vmin <= self.data_min + else: + assert primitive.norm.vmin <= vmin + if kind != "contour": + assert extend == primitive.cmap.colorbar_extend + else: + assert "max" == primitive.cmap.colorbar_extend + assert levels >= len(primitive.cmap.colors) + + def test_discrete_colormap_list_levels_and_vmin_or_vmax(self) -> None: + levels = [0, 5, 10, 15] + primitive = self.darray.plot(levels=levels, vmin=-3, vmax=20) + assert primitive.norm.vmax == max(levels) + assert primitive.norm.vmin == min(levels) + + def test_discrete_colormap_provided_boundary_norm(self) -> None: + norm = mpl.colors.BoundaryNorm([0, 5, 10, 15], 4) + primitive = self.darray.plot.contourf(norm=norm) + np.testing.assert_allclose(list(primitive.levels), norm.boundaries) + + def test_discrete_colormap_provided_boundary_norm_matching_cmap_levels( + self, + ) -> None: + norm = mpl.colors.BoundaryNorm([0, 5, 10, 15], 4) + primitive = self.darray.plot.contourf(norm=norm) + cbar = primitive.colorbar + assert cbar is not None + assert cbar.norm.Ncmap == cbar.norm.N # type: ignore[attr-defined] # Exists, debatable if public though. + + +class Common2dMixin: + """ + Common tests for 2d plotting go here. + + These tests assume that a staticmethod for `self.plotfunc` exists. + Should have the same name as the method. + """ + + darray: DataArray + plotfunc: staticmethod + pass_in_axis: Callable + + # Needs to be overridden in TestSurface for facet grid plots + subplot_kws: dict[Any, Any] | None = None + + @pytest.fixture(autouse=True) + def setUp(self) -> None: + da = DataArray( + easy_array((10, 15), start=-1), + dims=["y", "x"], + coords={"y": np.arange(10), "x": np.arange(15)}, + ) + # add 2d coords + ds = da.to_dataset(name="testvar") + x, y = np.meshgrid(da.x.values, da.y.values) + ds["x2d"] = DataArray(x, dims=["y", "x"]) + ds["y2d"] = DataArray(y, dims=["y", "x"]) + ds = ds.set_coords(["x2d", "y2d"]) + # set darray and plot method + self.darray: DataArray = ds.testvar + + # Add CF-compliant metadata + self.darray.attrs["long_name"] = "a_long_name" + self.darray.attrs["units"] = "a_units" + self.darray.x.attrs["long_name"] = "x_long_name" + self.darray.x.attrs["units"] = "x_units" + self.darray.y.attrs["long_name"] = "y_long_name" + self.darray.y.attrs["units"] = "y_units" + + self.plotmethod = getattr(self.darray.plot, self.plotfunc.__name__) + + def test_label_names(self) -> None: + self.plotmethod() + assert "x_long_name [x_units]" == plt.gca().get_xlabel() + assert "y_long_name [y_units]" == plt.gca().get_ylabel() + + def test_1d_raises_valueerror(self) -> None: + with pytest.raises(ValueError, match=r"DataArray must be 2d"): + self.plotfunc(self.darray[0, :]) + + def test_bool(self) -> None: + xr.ones_like(self.darray, dtype=bool).plot() + + def test_complex_raises_typeerror(self) -> None: + with pytest.raises(TypeError, match=r"complex128"): + (self.darray + 1j).plot() + + def test_3d_raises_valueerror(self) -> None: + a = DataArray(easy_array((2, 3, 4))) + if self.plotfunc.__name__ == "imshow": + pytest.skip() + with pytest.raises(ValueError, match=r"DataArray must be 2d"): + self.plotfunc(a) + + def test_nonnumeric_index(self) -> None: + a = DataArray(easy_array((3, 2)), coords=[["a", "b", "c"], ["d", "e"]]) + if self.plotfunc.__name__ == "surface": + # ax.plot_surface errors with nonnumerics: + with pytest.raises(Exception): + self.plotfunc(a) + else: + self.plotfunc(a) + + def test_multiindex_raises_typeerror(self) -> None: + a = DataArray( + easy_array((3, 2)), + dims=("x", "y"), + coords=dict(x=("x", [0, 1, 2]), a=("y", [0, 1]), b=("y", [2, 3])), + ) + a = a.set_index(y=("a", "b")) + with pytest.raises(TypeError, match=r"[Pp]lot"): + self.plotfunc(a) + + def test_can_pass_in_axis(self) -> None: + self.pass_in_axis(self.plotmethod) + + def test_xyincrease_defaults(self) -> None: + # With default settings the axis must be ordered regardless + # of the coords order. + self.plotfunc(DataArray(easy_array((3, 2)), coords=[[1, 2, 3], [1, 2]])) + bounds = plt.gca().get_ylim() + assert bounds[0] < bounds[1] + bounds = plt.gca().get_xlim() + assert bounds[0] < bounds[1] + # Inverted coords + self.plotfunc(DataArray(easy_array((3, 2)), coords=[[3, 2, 1], [2, 1]])) + bounds = plt.gca().get_ylim() + assert bounds[0] < bounds[1] + bounds = plt.gca().get_xlim() + assert bounds[0] < bounds[1] + + def test_xyincrease_false_changes_axes(self) -> None: + self.plotmethod(xincrease=False, yincrease=False) + xlim = plt.gca().get_xlim() + ylim = plt.gca().get_ylim() + diffs = xlim[0] - 14, xlim[1] - 0, ylim[0] - 9, ylim[1] - 0 + assert all(abs(x) < 1 for x in diffs) + + def test_xyincrease_true_changes_axes(self) -> None: + self.plotmethod(xincrease=True, yincrease=True) + xlim = plt.gca().get_xlim() + ylim = plt.gca().get_ylim() + diffs = xlim[0] - 0, xlim[1] - 14, ylim[0] - 0, ylim[1] - 9 + assert all(abs(x) < 1 for x in diffs) + + def test_dates_are_concise(self) -> None: + import matplotlib.dates as mdates + + time = pd.date_range("2000-01-01", "2000-01-10") + a = DataArray(np.random.randn(2, len(time)), [("xx", [1, 2]), ("t", time)]) + self.plotfunc(a, x="t") + + ax = plt.gca() + + assert isinstance(ax.xaxis.get_major_locator(), mdates.AutoDateLocator) + assert isinstance(ax.xaxis.get_major_formatter(), mdates.ConciseDateFormatter) + + def test_plot_nans(self) -> None: + x1 = self.darray[:5] + x2 = self.darray.copy() + x2[5:] = np.nan + + clim1 = self.plotfunc(x1).get_clim() + clim2 = self.plotfunc(x2).get_clim() + assert clim1 == clim2 + + @pytest.mark.filterwarnings("ignore::UserWarning") + @pytest.mark.filterwarnings("ignore:invalid value encountered") + def test_can_plot_all_nans(self) -> None: + # regression test for issue #1780 + self.plotfunc(DataArray(np.full((2, 2), np.nan))) + + @pytest.mark.filterwarnings("ignore: Attempting to set") + def test_can_plot_axis_size_one(self) -> None: + if self.plotfunc.__name__ not in ("contour", "contourf"): + self.plotfunc(DataArray(np.ones((1, 1)))) + + def test_disallows_rgb_arg(self) -> None: + with pytest.raises(ValueError): + # Always invalid for most plots. Invalid for imshow with 2D data. + self.plotfunc(DataArray(np.ones((2, 2))), rgb="not None") + + def test_viridis_cmap(self) -> None: + cmap_name = self.plotmethod(cmap="viridis").get_cmap().name + assert "viridis" == cmap_name + + def test_default_cmap(self) -> None: + cmap_name = self.plotmethod().get_cmap().name + assert "RdBu_r" == cmap_name + + cmap_name = self.plotfunc(abs(self.darray)).get_cmap().name + assert "viridis" == cmap_name + + @requires_seaborn + def test_seaborn_palette_as_cmap(self) -> None: + cmap_name = self.plotmethod(levels=2, cmap="husl").get_cmap().name + assert "husl" == cmap_name + + def test_can_change_default_cmap(self) -> None: + cmap_name = self.plotmethod(cmap="Blues").get_cmap().name + assert "Blues" == cmap_name + + def test_diverging_color_limits(self) -> None: + artist = self.plotmethod() + vmin, vmax = artist.get_clim() + assert round(abs(-vmin - vmax), 7) == 0 + + def test_xy_strings(self) -> None: + self.plotmethod(x="y", y="x") + ax = plt.gca() + assert "y_long_name [y_units]" == ax.get_xlabel() + assert "x_long_name [x_units]" == ax.get_ylabel() + + def test_positional_coord_string(self) -> None: + self.plotmethod(y="x") + ax = plt.gca() + assert "x_long_name [x_units]" == ax.get_ylabel() + assert "y_long_name [y_units]" == ax.get_xlabel() + + self.plotmethod(x="x") + ax = plt.gca() + assert "x_long_name [x_units]" == ax.get_xlabel() + assert "y_long_name [y_units]" == ax.get_ylabel() + + def test_bad_x_string_exception(self) -> None: + with pytest.raises(ValueError, match=r"x and y cannot be equal."): + self.plotmethod(x="y", y="y") + + error_msg = "must be one of None, 'x', 'x2d', 'y', 'y2d'" + with pytest.raises(ValueError, match=rf"x {error_msg}"): + self.plotmethod(x="not_a_real_dim", y="y") + with pytest.raises(ValueError, match=rf"x {error_msg}"): + self.plotmethod(x="not_a_real_dim") + with pytest.raises(ValueError, match=rf"y {error_msg}"): + self.plotmethod(y="not_a_real_dim") + self.darray.coords["z"] = 100 + + def test_coord_strings(self) -> None: + # 1d coords (same as dims) + assert {"x", "y"} == set(self.darray.dims) + self.plotmethod(y="y", x="x") + + def test_non_linked_coords(self) -> None: + # plot with coordinate names that are not dimensions + self.darray.coords["newy"] = self.darray.y + 150 + # Normal case, without transpose + self.plotfunc(self.darray, x="x", y="newy") + ax = plt.gca() + assert "x_long_name [x_units]" == ax.get_xlabel() + assert "newy" == ax.get_ylabel() + # ax limits might change between plotfuncs + # simply ensure that these high coords were passed over + assert np.min(ax.get_ylim()) > 100.0 + + def test_non_linked_coords_transpose(self) -> None: + # plot with coordinate names that are not dimensions, + # and with transposed y and x axes + # This used to raise an error with pcolormesh and contour + # https://github.com/pydata/xarray/issues/788 + self.darray.coords["newy"] = self.darray.y + 150 + self.plotfunc(self.darray, x="newy", y="x") + ax = plt.gca() + assert "newy" == ax.get_xlabel() + assert "x_long_name [x_units]" == ax.get_ylabel() + # ax limits might change between plotfuncs + # simply ensure that these high coords were passed over + assert np.min(ax.get_xlim()) > 100.0 + + def test_multiindex_level_as_coord(self) -> None: + da = DataArray( + easy_array((3, 2)), + dims=("x", "y"), + coords=dict(x=("x", [0, 1, 2]), a=("y", [0, 1]), b=("y", [2, 3])), + ) + da = da.set_index(y=["a", "b"]) + + for x, y in (("a", "x"), ("b", "x"), ("x", "a"), ("x", "b")): + self.plotfunc(da, x=x, y=y) + + ax = plt.gca() + assert x == ax.get_xlabel() + assert y == ax.get_ylabel() + + with pytest.raises(ValueError, match=r"levels of the same MultiIndex"): + self.plotfunc(da, x="a", y="b") + + with pytest.raises(ValueError, match=r"y must be one of None, 'a', 'b', 'x'"): + self.plotfunc(da, x="a", y="y") + + def test_default_title(self) -> None: + a = DataArray(easy_array((4, 3, 2)), dims=["a", "b", "c"]) + a.coords["c"] = [0, 1] + a.coords["d"] = "foo" + self.plotfunc(a.isel(c=1)) + title = plt.gca().get_title() + assert "c = 1, d = foo" == title or "d = foo, c = 1" == title + + def test_colorbar_default_label(self) -> None: + self.plotmethod(add_colorbar=True) + assert "a_long_name [a_units]" in text_in_fig() + + def test_no_labels(self) -> None: + self.darray.name = "testvar" + self.darray.attrs["units"] = "test_units" + self.plotmethod(add_labels=False) + alltxt = text_in_fig() + for string in [ + "x_long_name [x_units]", + "y_long_name [y_units]", + "testvar [test_units]", + ]: + assert string not in alltxt + + def test_colorbar_kwargs(self) -> None: + # replace label + self.darray.attrs.pop("long_name") + self.darray.attrs["units"] = "test_units" + # check default colorbar label + self.plotmethod(add_colorbar=True) + alltxt = text_in_fig() + assert "testvar [test_units]" in alltxt + self.darray.attrs.pop("units") + + self.darray.name = "testvar" + self.plotmethod(add_colorbar=True, cbar_kwargs={"label": "MyLabel"}) + alltxt = text_in_fig() + assert "MyLabel" in alltxt + assert "testvar" not in alltxt + # you can use anything accepted by the dict constructor as well + self.plotmethod(add_colorbar=True, cbar_kwargs=(("label", "MyLabel"),)) + alltxt = text_in_fig() + assert "MyLabel" in alltxt + assert "testvar" not in alltxt + # change cbar ax + fig, (ax, cax) = plt.subplots(1, 2) + self.plotmethod( + ax=ax, cbar_ax=cax, add_colorbar=True, cbar_kwargs={"label": "MyBar"} + ) + assert ax.has_data() + assert cax.has_data() + alltxt = text_in_fig() + assert "MyBar" in alltxt + assert "testvar" not in alltxt + # note that there are two ways to achieve this + fig, (ax, cax) = plt.subplots(1, 2) + self.plotmethod( + ax=ax, add_colorbar=True, cbar_kwargs={"label": "MyBar", "cax": cax} + ) + assert ax.has_data() + assert cax.has_data() + alltxt = text_in_fig() + assert "MyBar" in alltxt + assert "testvar" not in alltxt + # see that no colorbar is respected + self.plotmethod(add_colorbar=False) + assert "testvar" not in text_in_fig() + # check that error is raised + pytest.raises( + ValueError, + self.plotmethod, + add_colorbar=False, + cbar_kwargs={"label": "label"}, + ) + + def test_verbose_facetgrid(self) -> None: + a = easy_array((10, 15, 3)) + d = DataArray(a, dims=["y", "x", "z"]) + g = xplt.FacetGrid(d, col="z", subplot_kws=self.subplot_kws) + g.map_dataarray(self.plotfunc, "x", "y") + for ax in g.axs.flat: + assert ax.has_data() + + def test_2d_function_and_method_signature_same(self) -> None: + func_sig = inspect.signature(self.plotfunc) + method_sig = inspect.signature(self.plotmethod) + for argname, param in method_sig.parameters.items(): + assert func_sig.parameters[argname] == param + + @pytest.mark.filterwarnings("ignore:tight_layout cannot") + def test_convenient_facetgrid(self) -> None: + a = easy_array((10, 15, 4)) + d = DataArray(a, dims=["y", "x", "z"]) + g = self.plotfunc(d, x="x", y="y", col="z", col_wrap=2) + + assert_array_equal(g.axs.shape, [2, 2]) + for (y, x), ax in np.ndenumerate(g.axs): + assert ax.has_data() + if x == 0: + assert "y" == ax.get_ylabel() + else: + assert "" == ax.get_ylabel() + if y == 1: + assert "x" == ax.get_xlabel() + else: + assert "" == ax.get_xlabel() + + # Inferring labels + g = self.plotfunc(d, col="z", col_wrap=2) + assert_array_equal(g.axs.shape, [2, 2]) + for (y, x), ax in np.ndenumerate(g.axs): + assert ax.has_data() + if x == 0: + assert "y" == ax.get_ylabel() + else: + assert "" == ax.get_ylabel() + if y == 1: + assert "x" == ax.get_xlabel() + else: + assert "" == ax.get_xlabel() + + @pytest.mark.filterwarnings("ignore:tight_layout cannot") + def test_convenient_facetgrid_4d(self) -> None: + a = easy_array((10, 15, 2, 3)) + d = DataArray(a, dims=["y", "x", "columns", "rows"]) + g = self.plotfunc(d, x="x", y="y", col="columns", row="rows") + + assert_array_equal(g.axs.shape, [3, 2]) + for ax in g.axs.flat: + assert ax.has_data() + + @pytest.mark.filterwarnings("ignore:This figure includes") + def test_facetgrid_map_only_appends_mappables(self) -> None: + a = easy_array((10, 15, 2, 3)) + d = DataArray(a, dims=["y", "x", "columns", "rows"]) + g = self.plotfunc(d, x="x", y="y", col="columns", row="rows") + + expected = g._mappables + + g.map(lambda: plt.plot(1, 1)) + actual = g._mappables + + assert expected == actual + + def test_facetgrid_cmap(self) -> None: + # Regression test for GH592 + data = np.random.random(size=(20, 25, 12)) + np.linspace(-3, 3, 12) + d = DataArray(data, dims=["x", "y", "time"]) + fg = d.plot.pcolormesh(col="time") + # check that all color limits are the same + assert len({m.get_clim() for m in fg._mappables}) == 1 + # check that all colormaps are the same + assert len({m.get_cmap().name for m in fg._mappables}) == 1 + + def test_facetgrid_cbar_kwargs(self) -> None: + a = easy_array((10, 15, 2, 3)) + d = DataArray(a, dims=["y", "x", "columns", "rows"]) + g = self.plotfunc( + d, + x="x", + y="y", + col="columns", + row="rows", + cbar_kwargs={"label": "test_label"}, + ) + + # catch contour case + if g.cbar is not None: + assert get_colorbar_label(g.cbar) == "test_label" + + def test_facetgrid_no_cbar_ax(self) -> None: + a = easy_array((10, 15, 2, 3)) + d = DataArray(a, dims=["y", "x", "columns", "rows"]) + with pytest.raises(ValueError): + self.plotfunc(d, x="x", y="y", col="columns", row="rows", cbar_ax=1) + + def test_cmap_and_color_both(self) -> None: + with pytest.raises(ValueError): + self.plotmethod(colors="k", cmap="RdBu") + + def test_2d_coord_with_interval(self) -> None: + for dim in self.darray.dims: + gp = self.darray.groupby_bins(dim, range(15), restore_coord_dims=True).mean( + [dim] + ) + for kind in ["imshow", "pcolormesh", "contourf", "contour"]: + getattr(gp.plot, kind)() + + def test_colormap_error_norm_and_vmin_vmax(self) -> None: + norm = mpl.colors.LogNorm(0.1, 1e1) + + with pytest.raises(ValueError): + self.darray.plot(norm=norm, vmin=2) + + with pytest.raises(ValueError): + self.darray.plot(norm=norm, vmax=2) + + +@pytest.mark.slow +class TestContourf(Common2dMixin, PlotTestCase): + plotfunc = staticmethod(xplt.contourf) + + @pytest.mark.slow + def test_contourf_called(self) -> None: + # Having both statements ensures the test works properly + assert not self.contourf_called(self.darray.plot.imshow) + assert self.contourf_called(self.darray.plot.contourf) + + def test_primitive_artist_returned(self) -> None: + artist = self.plotmethod() + assert isinstance(artist, mpl.contour.QuadContourSet) + + @pytest.mark.slow + def test_extend(self) -> None: + artist = self.plotmethod() + assert artist.extend == "neither" + + self.darray[0, 0] = -100 + self.darray[-1, -1] = 100 + artist = self.plotmethod(robust=True) + assert artist.extend == "both" + + self.darray[0, 0] = 0 + self.darray[-1, -1] = 0 + artist = self.plotmethod(vmin=-0, vmax=10) + assert artist.extend == "min" + + artist = self.plotmethod(vmin=-10, vmax=0) + assert artist.extend == "max" + + @pytest.mark.slow + def test_2d_coord_names(self) -> None: + self.plotmethod(x="x2d", y="y2d") + # make sure labels came out ok + ax = plt.gca() + assert "x2d" == ax.get_xlabel() + assert "y2d" == ax.get_ylabel() + + @pytest.mark.slow + def test_levels(self) -> None: + artist = self.plotmethod(levels=[-0.5, -0.4, 0.1]) + assert artist.extend == "both" + + artist = self.plotmethod(levels=3) + assert artist.extend == "neither" + + +@pytest.mark.slow +class TestContour(Common2dMixin, PlotTestCase): + plotfunc = staticmethod(xplt.contour) + + # matplotlib cmap.colors gives an rgbA ndarray + # when seaborn is used, instead we get an rgb tuple + @staticmethod + def _color_as_tuple(c: Any) -> tuple[Any, Any, Any]: + return c[0], c[1], c[2] + + def test_colors(self) -> None: + # with single color, we don't want rgb array + artist = self.plotmethod(colors="k") + assert artist.cmap.colors[0] == "k" + + artist = self.plotmethod(colors=["k", "b"]) + assert self._color_as_tuple(artist.cmap.colors[1]) == (0.0, 0.0, 1.0) + + artist = self.darray.plot.contour( + levels=[-0.5, 0.0, 0.5, 1.0], colors=["k", "r", "w", "b"] + ) + assert self._color_as_tuple(artist.cmap.colors[1]) == (1.0, 0.0, 0.0) + assert self._color_as_tuple(artist.cmap.colors[2]) == (1.0, 1.0, 1.0) + # the last color is now under "over" + assert self._color_as_tuple(artist.cmap._rgba_over) == (0.0, 0.0, 1.0) + + def test_colors_np_levels(self) -> None: + # https://github.com/pydata/xarray/issues/3284 + levels = np.array([-0.5, 0.0, 0.5, 1.0]) + artist = self.darray.plot.contour(levels=levels, colors=["k", "r", "w", "b"]) + cmap = artist.cmap + assert isinstance(cmap, mpl.colors.ListedColormap) + colors = cmap.colors + assert isinstance(colors, list) + + assert self._color_as_tuple(colors[1]) == (1.0, 0.0, 0.0) + assert self._color_as_tuple(colors[2]) == (1.0, 1.0, 1.0) + # the last color is now under "over" + assert hasattr(cmap, "_rgba_over") + assert self._color_as_tuple(cmap._rgba_over) == (0.0, 0.0, 1.0) + + def test_cmap_and_color_both(self) -> None: + with pytest.raises(ValueError): + self.plotmethod(colors="k", cmap="RdBu") + + def list_of_colors_in_cmap_raises_error(self) -> None: + with pytest.raises(ValueError, match=r"list of colors"): + self.plotmethod(cmap=["k", "b"]) + + @pytest.mark.slow + def test_2d_coord_names(self) -> None: + self.plotmethod(x="x2d", y="y2d") + # make sure labels came out ok + ax = plt.gca() + assert "x2d" == ax.get_xlabel() + assert "y2d" == ax.get_ylabel() + + def test_single_level(self) -> None: + # this used to raise an error, but not anymore since + # add_colorbar defaults to false + self.plotmethod(levels=[0.1]) + self.plotmethod(levels=1) + + +class TestPcolormesh(Common2dMixin, PlotTestCase): + plotfunc = staticmethod(xplt.pcolormesh) + + def test_primitive_artist_returned(self) -> None: + artist = self.plotmethod() + assert isinstance(artist, mpl.collections.QuadMesh) + + def test_everything_plotted(self) -> None: + artist = self.plotmethod() + assert artist.get_array().size == self.darray.size + + @pytest.mark.slow + def test_2d_coord_names(self) -> None: + self.plotmethod(x="x2d", y="y2d") + # make sure labels came out ok + ax = plt.gca() + assert "x2d" == ax.get_xlabel() + assert "y2d" == ax.get_ylabel() + + def test_dont_infer_interval_breaks_for_cartopy(self) -> None: + # Regression for GH 781 + ax = plt.gca() + # Simulate a Cartopy Axis + setattr(ax, "projection", True) + artist = self.plotmethod(x="x2d", y="y2d", ax=ax) + assert isinstance(artist, mpl.collections.QuadMesh) + # Let cartopy handle the axis limits and artist size + arr = artist.get_array() + assert arr is not None + assert arr.size <= self.darray.size + + +class TestPcolormeshLogscale(PlotTestCase): + """ + Test pcolormesh axes when x and y are in logscale + """ + + plotfunc = staticmethod(xplt.pcolormesh) + + @pytest.fixture(autouse=True) + def setUp(self) -> None: + self.boundaries = (-1, 9, -4, 3) + shape = (8, 11) + x = np.logspace(self.boundaries[0], self.boundaries[1], shape[1]) + y = np.logspace(self.boundaries[2], self.boundaries[3], shape[0]) + da = DataArray( + easy_array(shape, start=-1), + dims=["y", "x"], + coords={"y": y, "x": x}, + name="testvar", + ) + self.darray = da + + def test_interval_breaks_logspace(self) -> None: + """ + Check if the outer vertices of the pcolormesh are the expected values + + Checks bugfix for #5333 + """ + artist = self.darray.plot.pcolormesh(xscale="log", yscale="log") + + # Grab the coordinates of the vertices of the Patches + x_vertices = [p.vertices[:, 0] for p in artist.properties()["paths"]] + y_vertices = [p.vertices[:, 1] for p in artist.properties()["paths"]] + + # Get the maximum and minimum values for each set of vertices + xmin, xmax = np.min(x_vertices), np.max(x_vertices) + ymin, ymax = np.min(y_vertices), np.max(y_vertices) + + # Check if they are equal to 10 to the power of the outer value of its + # corresponding axis plus or minus the interval in the logspace + log_interval = 0.5 + np.testing.assert_allclose(xmin, 10 ** (self.boundaries[0] - log_interval)) + np.testing.assert_allclose(xmax, 10 ** (self.boundaries[1] + log_interval)) + np.testing.assert_allclose(ymin, 10 ** (self.boundaries[2] - log_interval)) + np.testing.assert_allclose(ymax, 10 ** (self.boundaries[3] + log_interval)) + + +@pytest.mark.slow +class TestImshow(Common2dMixin, PlotTestCase): + plotfunc = staticmethod(xplt.imshow) + + @pytest.mark.xfail( + reason=( + "Failing inside matplotlib. Should probably be fixed upstream because " + "other plot functions can handle it. " + "Remove this test when it works, already in Common2dMixin" + ) + ) + def test_dates_are_concise(self) -> None: + import matplotlib.dates as mdates + + time = pd.date_range("2000-01-01", "2000-01-10") + a = DataArray(np.random.randn(2, len(time)), [("xx", [1, 2]), ("t", time)]) + self.plotfunc(a, x="t") + + ax = plt.gca() + + assert isinstance(ax.xaxis.get_major_locator(), mdates.AutoDateLocator) + assert isinstance(ax.xaxis.get_major_formatter(), mdates.ConciseDateFormatter) + + @pytest.mark.slow + def test_imshow_called(self) -> None: + # Having both statements ensures the test works properly + assert not self.imshow_called(self.darray.plot.contourf) + assert self.imshow_called(self.darray.plot.imshow) + + def test_xy_pixel_centered(self) -> None: + self.darray.plot.imshow(yincrease=False) + assert np.allclose([-0.5, 14.5], plt.gca().get_xlim()) + assert np.allclose([9.5, -0.5], plt.gca().get_ylim()) + + def test_default_aspect_is_auto(self) -> None: + self.darray.plot.imshow() + assert "auto" == plt.gca().get_aspect() + + @pytest.mark.slow + def test_cannot_change_mpl_aspect(self) -> None: + with pytest.raises(ValueError, match=r"not available in xarray"): + self.darray.plot.imshow(aspect="equal") + + # with numbers we fall back to fig control + self.darray.plot.imshow(size=5, aspect=2) + assert "auto" == plt.gca().get_aspect() + assert tuple(plt.gcf().get_size_inches()) == (10, 5) + + @pytest.mark.slow + def test_primitive_artist_returned(self) -> None: + artist = self.plotmethod() + assert isinstance(artist, mpl.image.AxesImage) + + @pytest.mark.slow + @requires_seaborn + def test_seaborn_palette_needs_levels(self) -> None: + with pytest.raises(ValueError): + self.plotmethod(cmap="husl") + + def test_2d_coord_names(self) -> None: + with pytest.raises(ValueError, match=r"requires 1D coordinates"): + self.plotmethod(x="x2d", y="y2d") + + def test_plot_rgb_image(self) -> None: + DataArray( + easy_array((10, 15, 3), start=0), dims=["y", "x", "band"] + ).plot.imshow() + assert 0 == len(find_possible_colorbars()) + + def test_plot_rgb_image_explicit(self) -> None: + DataArray( + easy_array((10, 15, 3), start=0), dims=["y", "x", "band"] + ).plot.imshow(y="y", x="x", rgb="band") + assert 0 == len(find_possible_colorbars()) + + def test_plot_rgb_faceted(self) -> None: + DataArray( + easy_array((2, 2, 10, 15, 3), start=0), dims=["a", "b", "y", "x", "band"] + ).plot.imshow(row="a", col="b") + assert 0 == len(find_possible_colorbars()) + + def test_plot_rgba_image_transposed(self) -> None: + # We can handle the color axis being in any position + DataArray( + easy_array((4, 10, 15), start=0), dims=["band", "y", "x"] + ).plot.imshow() + + def test_warns_ambigious_dim(self) -> None: + arr = DataArray(easy_array((3, 3, 3)), dims=["y", "x", "band"]) + with pytest.warns(UserWarning): + arr.plot.imshow() + # but doesn't warn if dimensions specified + arr.plot.imshow(rgb="band") + arr.plot.imshow(x="x", y="y") + + def test_rgb_errors_too_many_dims(self) -> None: + arr = DataArray(easy_array((3, 3, 3, 3)), dims=["y", "x", "z", "band"]) + with pytest.raises(ValueError): + arr.plot.imshow(rgb="band") + + def test_rgb_errors_bad_dim_sizes(self) -> None: + arr = DataArray(easy_array((5, 5, 5)), dims=["y", "x", "band"]) + with pytest.raises(ValueError): + arr.plot.imshow(rgb="band") + + @pytest.mark.parametrize( + ["vmin", "vmax", "robust"], + [ + (-1, None, False), + (None, 2, False), + (-1, 1, False), + (0, 0, False), + (0, None, True), + (None, -1, True), + ], + ) + def test_normalize_rgb_imshow( + self, vmin: float | None, vmax: float | None, robust: bool + ) -> None: + da = DataArray(easy_array((5, 5, 3), start=-0.6, stop=1.4)) + arr = da.plot.imshow(vmin=vmin, vmax=vmax, robust=robust).get_array() + assert arr is not None + assert 0 <= arr.min() <= arr.max() <= 1 + + def test_normalize_rgb_one_arg_error(self) -> None: + da = DataArray(easy_array((5, 5, 3), start=-0.6, stop=1.4)) + # If passed one bound that implies all out of range, error: + for vmin, vmax in ((None, -1), (2, None)): + with pytest.raises(ValueError): + da.plot.imshow(vmin=vmin, vmax=vmax) + # If passed two that's just moving the range, *not* an error: + for vmin2, vmax2 in ((-1.2, -1), (2, 2.1)): + da.plot.imshow(vmin=vmin2, vmax=vmax2) + + @pytest.mark.parametrize("dtype", [np.uint8, np.int8, np.int16]) + def test_imshow_rgb_values_in_valid_range(self, dtype) -> None: + da = DataArray(np.arange(75, dtype=dtype).reshape((5, 5, 3))) + _, ax = plt.subplots() + out = da.plot.imshow(ax=ax).get_array() + assert out is not None + actual_dtype = out.dtype + assert actual_dtype is not None + assert actual_dtype == np.uint8 + assert (out[..., :3] == da.values).all() # Compare without added alpha + assert (out[..., -1] == 255).all() # Compare alpha + + @pytest.mark.filterwarnings("ignore:Several dimensions of this array") + def test_regression_rgb_imshow_dim_size_one(self) -> None: + # Regression: https://github.com/pydata/xarray/issues/1966 + da = DataArray(easy_array((1, 3, 3), start=0.0, stop=1.0)) + da.plot.imshow() + + def test_origin_overrides_xyincrease(self) -> None: + da = DataArray(easy_array((3, 2)), coords=[[-2, 0, 2], [-1, 1]]) + with figure_context(): + da.plot.imshow(origin="upper") + assert plt.xlim()[0] < 0 + assert plt.ylim()[1] < 0 + + with figure_context(): + da.plot.imshow(origin="lower") + assert plt.xlim()[0] < 0 + assert plt.ylim()[0] < 0 + + +class TestSurface(Common2dMixin, PlotTestCase): + plotfunc = staticmethod(xplt.surface) + subplot_kws = {"projection": "3d"} + + @pytest.mark.xfail( + reason=( + "Failing inside matplotlib. Should probably be fixed upstream because " + "other plot functions can handle it. " + "Remove this test when it works, already in Common2dMixin" + ) + ) + def test_dates_are_concise(self) -> None: + import matplotlib.dates as mdates + + time = pd.date_range("2000-01-01", "2000-01-10") + a = DataArray(np.random.randn(2, len(time)), [("xx", [1, 2]), ("t", time)]) + self.plotfunc(a, x="t") + + ax = plt.gca() + + assert isinstance(ax.xaxis.get_major_locator(), mdates.AutoDateLocator) + assert isinstance(ax.xaxis.get_major_formatter(), mdates.ConciseDateFormatter) + + def test_primitive_artist_returned(self) -> None: + artist = self.plotmethod() + assert isinstance(artist, mpl_toolkits.mplot3d.art3d.Poly3DCollection) + + @pytest.mark.slow + def test_2d_coord_names(self) -> None: + self.plotmethod(x="x2d", y="y2d") + # make sure labels came out ok + ax = plt.gca() + assert isinstance(ax, mpl_toolkits.mplot3d.axes3d.Axes3D) + assert "x2d" == ax.get_xlabel() + assert "y2d" == ax.get_ylabel() + assert f"{self.darray.long_name} [{self.darray.units}]" == ax.get_zlabel() + + def test_xyincrease_false_changes_axes(self) -> None: + # Does not make sense for surface plots + pytest.skip("does not make sense for surface plots") + + def test_xyincrease_true_changes_axes(self) -> None: + # Does not make sense for surface plots + pytest.skip("does not make sense for surface plots") + + def test_can_pass_in_axis(self) -> None: + self.pass_in_axis(self.plotmethod, subplot_kw={"projection": "3d"}) + + def test_default_cmap(self) -> None: + # Does not make sense for surface plots with default arguments + pytest.skip("does not make sense for surface plots") + + def test_diverging_color_limits(self) -> None: + # Does not make sense for surface plots with default arguments + pytest.skip("does not make sense for surface plots") + + def test_colorbar_kwargs(self) -> None: + # Does not make sense for surface plots with default arguments + pytest.skip("does not make sense for surface plots") + + def test_cmap_and_color_both(self) -> None: + # Does not make sense for surface plots with default arguments + pytest.skip("does not make sense for surface plots") + + def test_seaborn_palette_as_cmap(self) -> None: + # seaborn does not work with mpl_toolkits.mplot3d + with pytest.raises(ValueError): + super().test_seaborn_palette_as_cmap() + + # Need to modify this test for surface(), because all subplots should have labels, + # not just left and bottom + @pytest.mark.filterwarnings("ignore:tight_layout cannot") + def test_convenient_facetgrid(self) -> None: + a = easy_array((10, 15, 4)) + d = DataArray(a, dims=["y", "x", "z"]) + g = self.plotfunc(d, x="x", y="y", col="z", col_wrap=2) # type: ignore[arg-type] # https://github.com/python/mypy/issues/15015 + + assert_array_equal(g.axs.shape, [2, 2]) + for (y, x), ax in np.ndenumerate(g.axs): + assert ax.has_data() + assert "y" == ax.get_ylabel() + assert "x" == ax.get_xlabel() + + # Inferring labels + g = self.plotfunc(d, col="z", col_wrap=2) # type: ignore[arg-type] # https://github.com/python/mypy/issues/15015 + assert_array_equal(g.axs.shape, [2, 2]) + for (y, x), ax in np.ndenumerate(g.axs): + assert ax.has_data() + assert "y" == ax.get_ylabel() + assert "x" == ax.get_xlabel() + + def test_viridis_cmap(self) -> None: + return super().test_viridis_cmap() + + def test_can_change_default_cmap(self) -> None: + return super().test_can_change_default_cmap() + + def test_colorbar_default_label(self) -> None: + return super().test_colorbar_default_label() + + def test_facetgrid_map_only_appends_mappables(self) -> None: + return super().test_facetgrid_map_only_appends_mappables() + + +class TestFacetGrid(PlotTestCase): + @pytest.fixture(autouse=True) + def setUp(self) -> None: + d = easy_array((10, 15, 3)) + self.darray = DataArray(d, dims=["y", "x", "z"], coords={"z": ["a", "b", "c"]}) + self.g = xplt.FacetGrid(self.darray, col="z") + + @pytest.mark.slow + def test_no_args(self) -> None: + self.g.map_dataarray(xplt.contourf, "x", "y") + + # Don't want colorbar labeled with 'None' + alltxt = text_in_fig() + assert "None" not in alltxt + + for ax in self.g.axs.flat: + assert ax.has_data() + + @pytest.mark.slow + def test_names_appear_somewhere(self) -> None: + self.darray.name = "testvar" + self.g.map_dataarray(xplt.contourf, "x", "y") + for k, ax in zip("abc", self.g.axs.flat): + assert f"z = {k}" == ax.get_title() + + alltxt = text_in_fig() + assert self.darray.name in alltxt + for label in ["x", "y"]: + assert label in alltxt + + @pytest.mark.slow + def test_text_not_super_long(self) -> None: + self.darray.coords["z"] = [100 * letter for letter in "abc"] + g = xplt.FacetGrid(self.darray, col="z") + g.map_dataarray(xplt.contour, "x", "y") + alltxt = text_in_fig() + maxlen = max(len(txt) for txt in alltxt) + assert maxlen < 50 + + t0 = g.axs[0, 0].get_title() + assert t0.endswith("...") + + @pytest.mark.slow + def test_colorbar(self) -> None: + vmin = self.darray.values.min() + vmax = self.darray.values.max() + expected = np.array((vmin, vmax)) + + self.g.map_dataarray(xplt.imshow, "x", "y") + + for image in plt.gcf().findobj(mpl.image.AxesImage): + assert isinstance(image, mpl.image.AxesImage) + clim = np.array(image.get_clim()) + assert np.allclose(expected, clim) + + assert 1 == len(find_possible_colorbars()) + + def test_colorbar_scatter(self) -> None: + ds = Dataset({"a": (("x", "y"), np.arange(4).reshape(2, 2))}) + fg: xplt.FacetGrid = ds.plot.scatter(x="a", y="a", row="x", hue="a") + cbar = fg.cbar + assert cbar is not None + assert hasattr(cbar, "vmin") + assert cbar.vmin == 0 + assert hasattr(cbar, "vmax") + assert cbar.vmax == 3 + + @pytest.mark.slow + def test_empty_cell(self) -> None: + g = xplt.FacetGrid(self.darray, col="z", col_wrap=2) + g.map_dataarray(xplt.imshow, "x", "y") + + bottomright = g.axs[-1, -1] + assert not bottomright.has_data() + assert not bottomright.get_visible() + + @pytest.mark.slow + def test_norow_nocol_error(self) -> None: + with pytest.raises(ValueError, match=r"[Rr]ow"): + xplt.FacetGrid(self.darray) + + @pytest.mark.slow + def test_groups(self) -> None: + self.g.map_dataarray(xplt.imshow, "x", "y") + upperleft_dict = self.g.name_dicts[0, 0] + upperleft_array = self.darray.loc[upperleft_dict] + z0 = self.darray.isel(z=0) + + assert_equal(upperleft_array, z0) + + @pytest.mark.slow + def test_float_index(self) -> None: + self.darray.coords["z"] = [0.1, 0.2, 0.4] + g = xplt.FacetGrid(self.darray, col="z") + g.map_dataarray(xplt.imshow, "x", "y") + + @pytest.mark.slow + def test_nonunique_index_error(self) -> None: + self.darray.coords["z"] = [0.1, 0.2, 0.2] + with pytest.raises(ValueError, match=r"[Uu]nique"): + xplt.FacetGrid(self.darray, col="z") + + @pytest.mark.slow + def test_robust(self) -> None: + z = np.zeros((20, 20, 2)) + darray = DataArray(z, dims=["y", "x", "z"]) + darray[:, :, 1] = 1 + darray[2, 0, 0] = -1000 + darray[3, 0, 0] = 1000 + g = xplt.FacetGrid(darray, col="z") + g.map_dataarray(xplt.imshow, "x", "y", robust=True) + + # Color limits should be 0, 1 + # The largest number displayed in the figure should be less than 21 + numbers = set() + alltxt = text_in_fig() + for txt in alltxt: + try: + numbers.add(float(txt)) + except ValueError: + pass + largest = max(abs(x) for x in numbers) + assert largest < 21 + + @pytest.mark.slow + def test_can_set_vmin_vmax(self) -> None: + vmin, vmax = 50.0, 1000.0 + expected = np.array((vmin, vmax)) + self.g.map_dataarray(xplt.imshow, "x", "y", vmin=vmin, vmax=vmax) + + for image in plt.gcf().findobj(mpl.image.AxesImage): + assert isinstance(image, mpl.image.AxesImage) + clim = np.array(image.get_clim()) + assert np.allclose(expected, clim) + + @pytest.mark.slow + def test_vmin_vmax_equal(self) -> None: + # regression test for GH3734 + fg = self.g.map_dataarray(xplt.imshow, "x", "y", vmin=50, vmax=50) + for mappable in fg._mappables: + assert mappable.norm.vmin != mappable.norm.vmax + + @pytest.mark.slow + @pytest.mark.filterwarnings("ignore") + def test_can_set_norm(self) -> None: + norm = mpl.colors.SymLogNorm(0.1) + self.g.map_dataarray(xplt.imshow, "x", "y", norm=norm) + for image in plt.gcf().findobj(mpl.image.AxesImage): + assert isinstance(image, mpl.image.AxesImage) + assert image.norm is norm + + @pytest.mark.slow + def test_figure_size(self) -> None: + assert_array_equal(self.g.fig.get_size_inches(), (10, 3)) + + g = xplt.FacetGrid(self.darray, col="z", size=6) + assert_array_equal(g.fig.get_size_inches(), (19, 6)) + + g = self.darray.plot.imshow(col="z", size=6) + assert_array_equal(g.fig.get_size_inches(), (19, 6)) + + g = xplt.FacetGrid(self.darray, col="z", size=4, aspect=0.5) + assert_array_equal(g.fig.get_size_inches(), (7, 4)) + + g = xplt.FacetGrid(self.darray, col="z", figsize=(9, 4)) + assert_array_equal(g.fig.get_size_inches(), (9, 4)) + + with pytest.raises(ValueError, match=r"cannot provide both"): + g = xplt.plot(self.darray, row=2, col="z", figsize=(6, 4), size=6) + + with pytest.raises(ValueError, match=r"Can't use"): + g = xplt.plot(self.darray, row=2, col="z", ax=plt.gca(), size=6) + + @pytest.mark.slow + def test_num_ticks(self) -> None: + nticks = 99 + maxticks = nticks + 1 + self.g.map_dataarray(xplt.imshow, "x", "y") + self.g.set_ticks(max_xticks=nticks, max_yticks=nticks) + + for ax in self.g.axs.flat: + xticks = len(ax.get_xticks()) + yticks = len(ax.get_yticks()) + assert xticks <= maxticks + assert yticks <= maxticks + assert xticks >= nticks / 2.0 + assert yticks >= nticks / 2.0 + + @pytest.mark.slow + def test_map(self) -> None: + assert self.g._finalized is False + self.g.map(plt.contourf, "x", "y", ...) + assert self.g._finalized is True + self.g.map(lambda: None) + + @pytest.mark.slow + def test_map_dataset(self) -> None: + g = xplt.FacetGrid(self.darray.to_dataset(name="foo"), col="z") + g.map(plt.contourf, "x", "y", "foo") + + alltxt = text_in_fig() + for label in ["x", "y"]: + assert label in alltxt + # everything has a label + assert "None" not in alltxt + + # colorbar can't be inferred automatically + assert "foo" not in alltxt + assert 0 == len(find_possible_colorbars()) + + g.add_colorbar(label="colors!") + assert "colors!" in text_in_fig() + assert 1 == len(find_possible_colorbars()) + + @pytest.mark.slow + def test_set_axis_labels(self) -> None: + g = self.g.map_dataarray(xplt.contourf, "x", "y") + g.set_axis_labels("longitude", "latitude") + alltxt = text_in_fig() + for label in ["longitude", "latitude"]: + assert label in alltxt + + @pytest.mark.slow + def test_facetgrid_colorbar(self) -> None: + a = easy_array((10, 15, 4)) + d = DataArray(a, dims=["y", "x", "z"], name="foo") + + d.plot.imshow(x="x", y="y", col="z") + assert 1 == len(find_possible_colorbars()) + + d.plot.imshow(x="x", y="y", col="z", add_colorbar=True) + assert 1 == len(find_possible_colorbars()) + + d.plot.imshow(x="x", y="y", col="z", add_colorbar=False) + assert 0 == len(find_possible_colorbars()) + + @pytest.mark.slow + def test_facetgrid_polar(self) -> None: + # test if polar projection in FacetGrid does not raise an exception + self.darray.plot.pcolormesh( + col="z", subplot_kws=dict(projection="polar"), sharex=False, sharey=False + ) + + +@pytest.mark.filterwarnings("ignore:tight_layout cannot") +class TestFacetGrid4d(PlotTestCase): + @pytest.fixture(autouse=True) + def setUp(self) -> None: + a = easy_array((10, 15, 3, 2)) + darray = DataArray(a, dims=["y", "x", "col", "row"]) + darray.coords["col"] = np.array( + ["col" + str(x) for x in darray.coords["col"].values] + ) + darray.coords["row"] = np.array( + ["row" + str(x) for x in darray.coords["row"].values] + ) + + self.darray = darray + + def test_title_kwargs(self) -> None: + g = xplt.FacetGrid(self.darray, col="col", row="row") + g.set_titles(template="{value}", weight="bold") + + # Rightmost column titles should be bold + for label, ax in zip(self.darray.coords["row"].values, g.axs[:, -1]): + assert property_in_axes_text("weight", "bold", label, ax) + + # Top row titles should be bold + for label, ax in zip(self.darray.coords["col"].values, g.axs[0, :]): + assert property_in_axes_text("weight", "bold", label, ax) + + @pytest.mark.slow + def test_default_labels(self) -> None: + g = xplt.FacetGrid(self.darray, col="col", row="row") + assert (2, 3) == g.axs.shape + + g.map_dataarray(xplt.imshow, "x", "y") + + # Rightmost column should be labeled + for label, ax in zip(self.darray.coords["row"].values, g.axs[:, -1]): + assert substring_in_axes(label, ax) + + # Top row should be labeled + for label, ax in zip(self.darray.coords["col"].values, g.axs[0, :]): + assert substring_in_axes(label, ax) + + # ensure that row & col labels can be changed + g.set_titles("abc={value}") + for label, ax in zip(self.darray.coords["row"].values, g.axs[:, -1]): + assert substring_in_axes(f"abc={label}", ax) + # previous labels were "row=row0" etc. + assert substring_not_in_axes("row=", ax) + + for label, ax in zip(self.darray.coords["col"].values, g.axs[0, :]): + assert substring_in_axes(f"abc={label}", ax) + # previous labels were "col=row0" etc. + assert substring_not_in_axes("col=", ax) + + +@pytest.mark.filterwarnings("ignore:tight_layout cannot") +class TestFacetedLinePlotsLegend(PlotTestCase): + @pytest.fixture(autouse=True) + def setUp(self) -> None: + self.darray = xr.tutorial.scatter_example_dataset() + + def test_legend_labels(self) -> None: + fg = self.darray.A.plot.line(col="x", row="w", hue="z") + all_legend_labels = [t.get_text() for t in fg.figlegend.texts] + # labels in legend should be ['0', '1', '2', '3'] + assert sorted(all_legend_labels) == ["0", "1", "2", "3"] + + +@pytest.mark.filterwarnings("ignore:tight_layout cannot") +class TestFacetedLinePlots(PlotTestCase): + @pytest.fixture(autouse=True) + def setUp(self) -> None: + self.darray = DataArray( + np.random.randn(10, 6, 3, 4), + dims=["hue", "x", "col", "row"], + coords=[range(10), range(6), range(3), ["A", "B", "C", "C++"]], + name="Cornelius Ortega the 1st", + ) + + self.darray.hue.name = "huename" + self.darray.hue.attrs["units"] = "hunits" + self.darray.x.attrs["units"] = "xunits" + self.darray.col.attrs["units"] = "colunits" + self.darray.row.attrs["units"] = "rowunits" + + def test_facetgrid_shape(self) -> None: + g = self.darray.plot(row="row", col="col", hue="hue") + assert g.axs.shape == (len(self.darray.row), len(self.darray.col)) + + g = self.darray.plot(row="col", col="row", hue="hue") + assert g.axs.shape == (len(self.darray.col), len(self.darray.row)) + + def test_unnamed_args(self) -> None: + g = self.darray.plot.line("o--", row="row", col="col", hue="hue") + lines = [ + q for q in g.axs.flat[0].get_children() if isinstance(q, mpl.lines.Line2D) + ] + # passing 'o--' as argument should set marker and linestyle + assert lines[0].get_marker() == "o" + assert lines[0].get_linestyle() == "--" + + def test_default_labels(self) -> None: + g = self.darray.plot(row="row", col="col", hue="hue") + # Rightmost column should be labeled + for label, ax in zip(self.darray.coords["row"].values, g.axs[:, -1]): + assert substring_in_axes(label, ax) + + # Top row should be labeled + for label, ax in zip(self.darray.coords["col"].values, g.axs[0, :]): + assert substring_in_axes(str(label), ax) + + # Leftmost column should have array name + for ax in g.axs[:, 0]: + assert substring_in_axes(str(self.darray.name), ax) + + def test_test_empty_cell(self) -> None: + g = ( + self.darray.isel(row=1) + .drop_vars("row") + .plot(col="col", hue="hue", col_wrap=2) + ) + bottomright = g.axs[-1, -1] + assert not bottomright.has_data() + assert not bottomright.get_visible() + + def test_set_axis_labels(self) -> None: + g = self.darray.plot(row="row", col="col", hue="hue") + g.set_axis_labels("longitude", "latitude") + alltxt = text_in_fig() + + assert "longitude" in alltxt + assert "latitude" in alltxt + + def test_axes_in_faceted_plot(self) -> None: + with pytest.raises(ValueError): + self.darray.plot.line(row="row", col="col", x="x", ax=plt.axes()) + + def test_figsize_and_size(self) -> None: + with pytest.raises(ValueError): + self.darray.plot.line(row="row", col="col", x="x", size=3, figsize=(4, 3)) + + def test_wrong_num_of_dimensions(self) -> None: + with pytest.raises(ValueError): + self.darray.plot(row="row", hue="hue") + self.darray.plot.line(row="row", hue="hue") + + +@requires_matplotlib +class TestDatasetQuiverPlots(PlotTestCase): + @pytest.fixture(autouse=True) + def setUp(self) -> None: + das = [ + DataArray( + np.random.randn(3, 3, 4, 4), + dims=["x", "y", "row", "col"], + coords=[range(k) for k in [3, 3, 4, 4]], + ) + for _ in [1, 2] + ] + ds = Dataset({"u": das[0], "v": das[1]}) + ds.x.attrs["units"] = "xunits" + ds.y.attrs["units"] = "yunits" + ds.col.attrs["units"] = "colunits" + ds.row.attrs["units"] = "rowunits" + ds.u.attrs["units"] = "uunits" + ds.v.attrs["units"] = "vunits" + ds["mag"] = np.hypot(ds.u, ds.v) + self.ds = ds + + def test_quiver(self) -> None: + with figure_context(): + hdl = self.ds.isel(row=0, col=0).plot.quiver(x="x", y="y", u="u", v="v") + assert isinstance(hdl, mpl.quiver.Quiver) + with pytest.raises(ValueError, match=r"specify x, y, u, v"): + self.ds.isel(row=0, col=0).plot.quiver(x="x", y="y", u="u") + + with pytest.raises(ValueError, match=r"hue_style"): + self.ds.isel(row=0, col=0).plot.quiver( + x="x", y="y", u="u", v="v", hue="mag", hue_style="discrete" + ) + + def test_facetgrid(self) -> None: + with figure_context(): + fg = self.ds.plot.quiver( + x="x", y="y", u="u", v="v", row="row", col="col", scale=1, hue="mag" + ) + for handle in fg._mappables: + assert isinstance(handle, mpl.quiver.Quiver) + assert fg.quiverkey is not None + assert "uunits" in fg.quiverkey.text.get_text() + + with figure_context(): + fg = self.ds.plot.quiver( + x="x", + y="y", + u="u", + v="v", + row="row", + col="col", + scale=1, + hue="mag", + add_guide=False, + ) + assert fg.quiverkey is None + with pytest.raises(ValueError, match=r"Please provide scale"): + self.ds.plot.quiver(x="x", y="y", u="u", v="v", row="row", col="col") + + @pytest.mark.parametrize( + "add_guide, hue_style, legend, colorbar", + [ + (None, None, False, True), + (False, None, False, False), + (True, None, False, True), + (True, "continuous", False, True), + ], + ) + def test_add_guide(self, add_guide, hue_style, legend, colorbar) -> None: + meta_data = _infer_meta_data( + self.ds, + x="x", + y="y", + hue="mag", + hue_style=hue_style, + add_guide=add_guide, + funcname="quiver", + ) + assert meta_data["add_legend"] is legend + assert meta_data["add_colorbar"] is colorbar + + +@requires_matplotlib +class TestDatasetStreamplotPlots(PlotTestCase): + @pytest.fixture(autouse=True) + def setUp(self) -> None: + das = [ + DataArray( + np.random.randn(3, 3, 2, 2), + dims=["x", "y", "row", "col"], + coords=[range(k) for k in [3, 3, 2, 2]], + ) + for _ in [1, 2] + ] + ds = Dataset({"u": das[0], "v": das[1]}) + ds.x.attrs["units"] = "xunits" + ds.y.attrs["units"] = "yunits" + ds.col.attrs["units"] = "colunits" + ds.row.attrs["units"] = "rowunits" + ds.u.attrs["units"] = "uunits" + ds.v.attrs["units"] = "vunits" + ds["mag"] = np.hypot(ds.u, ds.v) + self.ds = ds + + def test_streamline(self) -> None: + with figure_context(): + hdl = self.ds.isel(row=0, col=0).plot.streamplot(x="x", y="y", u="u", v="v") + assert isinstance(hdl, mpl.collections.LineCollection) + with pytest.raises(ValueError, match=r"specify x, y, u, v"): + self.ds.isel(row=0, col=0).plot.streamplot(x="x", y="y", u="u") + + with pytest.raises(ValueError, match=r"hue_style"): + self.ds.isel(row=0, col=0).plot.streamplot( + x="x", y="y", u="u", v="v", hue="mag", hue_style="discrete" + ) + + def test_facetgrid(self) -> None: + with figure_context(): + fg = self.ds.plot.streamplot( + x="x", y="y", u="u", v="v", row="row", col="col", hue="mag" + ) + for handle in fg._mappables: + assert isinstance(handle, mpl.collections.LineCollection) + + with figure_context(): + fg = self.ds.plot.streamplot( + x="x", + y="y", + u="u", + v="v", + row="row", + col="col", + hue="mag", + add_guide=False, + ) + + +@requires_matplotlib +class TestDatasetScatterPlots(PlotTestCase): + @pytest.fixture(autouse=True) + def setUp(self) -> None: + das = [ + DataArray( + np.random.randn(3, 3, 4, 4), + dims=["x", "row", "col", "hue"], + coords=[range(k) for k in [3, 3, 4, 4]], + ) + for _ in [1, 2] + ] + ds = Dataset({"A": das[0], "B": das[1]}) + ds.hue.name = "huename" + ds.hue.attrs["units"] = "hunits" + ds.x.attrs["units"] = "xunits" + ds.col.attrs["units"] = "colunits" + ds.row.attrs["units"] = "rowunits" + ds.A.attrs["units"] = "Aunits" + ds.B.attrs["units"] = "Bunits" + self.ds = ds + + def test_accessor(self) -> None: + from xarray.plot.accessor import DatasetPlotAccessor + + assert Dataset.plot is DatasetPlotAccessor + assert isinstance(self.ds.plot, DatasetPlotAccessor) + + @pytest.mark.parametrize( + "add_guide, hue_style, legend, colorbar", + [ + (None, None, False, True), + (False, None, False, False), + (True, None, False, True), + (True, "continuous", False, True), + (False, "discrete", False, False), + (True, "discrete", True, False), + ], + ) + def test_add_guide( + self, + add_guide: bool | None, + hue_style: Literal["continuous", "discrete", None], + legend: bool, + colorbar: bool, + ) -> None: + meta_data = _infer_meta_data( + self.ds, + x="A", + y="B", + hue="hue", + hue_style=hue_style, + add_guide=add_guide, + funcname="scatter", + ) + assert meta_data["add_legend"] is legend + assert meta_data["add_colorbar"] is colorbar + + def test_facetgrid_shape(self) -> None: + g = self.ds.plot.scatter(x="A", y="B", row="row", col="col") + assert g.axs.shape == (len(self.ds.row), len(self.ds.col)) + + g = self.ds.plot.scatter(x="A", y="B", row="col", col="row") + assert g.axs.shape == (len(self.ds.col), len(self.ds.row)) + + def test_default_labels(self) -> None: + g = self.ds.plot.scatter(x="A", y="B", row="row", col="col", hue="hue") + + # Top row should be labeled + for label, ax in zip(self.ds.coords["col"].values, g.axs[0, :]): + assert substring_in_axes(str(label), ax) + + # Bottom row should have name of x array name and units + for ax in g.axs[-1, :]: + assert ax.get_xlabel() == "A [Aunits]" + + # Leftmost column should have name of y array name and units + for ax in g.axs[:, 0]: + assert ax.get_ylabel() == "B [Bunits]" + + def test_axes_in_faceted_plot(self) -> None: + with pytest.raises(ValueError): + self.ds.plot.scatter(x="A", y="B", row="row", ax=plt.axes()) + + def test_figsize_and_size(self) -> None: + with pytest.raises(ValueError): + self.ds.plot.scatter(x="A", y="B", row="row", size=3, figsize=(4, 3)) + + @pytest.mark.parametrize( + "x, y, hue, add_legend, add_colorbar, error_type", + [ + pytest.param( + "A", "The Spanish Inquisition", None, None, None, KeyError, id="bad_y" + ), + pytest.param( + "The Spanish Inquisition", "B", None, None, True, ValueError, id="bad_x" + ), + ], + ) + def test_bad_args( + self, + x: Hashable, + y: Hashable, + hue: Hashable | None, + add_legend: bool | None, + add_colorbar: bool | None, + error_type: type[Exception], + ) -> None: + with pytest.raises(error_type): + self.ds.plot.scatter( + x=x, y=y, hue=hue, add_legend=add_legend, add_colorbar=add_colorbar + ) + + def test_datetime_hue(self) -> None: + ds2 = self.ds.copy() + + # TODO: Currently plots as categorical, should it behave as numerical? + ds2["hue"] = pd.date_range("2000-1-1", periods=4) + ds2.plot.scatter(x="A", y="B", hue="hue") + + ds2["hue"] = pd.timedelta_range("-1D", periods=4, freq="D") + ds2.plot.scatter(x="A", y="B", hue="hue") + + def test_facetgrid_hue_style(self) -> None: + ds2 = self.ds.copy() + + # Numbers plots as continuous: + g = ds2.plot.scatter(x="A", y="B", row="row", col="col", hue="hue") + assert isinstance(g._mappables[-1], mpl.collections.PathCollection) + + # Datetimes plots as categorical: + # TODO: Currently plots as categorical, should it behave as numerical? + ds2["hue"] = pd.date_range("2000-1-1", periods=4) + g = ds2.plot.scatter(x="A", y="B", row="row", col="col", hue="hue") + assert isinstance(g._mappables[-1], mpl.collections.PathCollection) + + # Strings plots as categorical: + ds2["hue"] = ["a", "a", "b", "b"] + g = ds2.plot.scatter(x="A", y="B", row="row", col="col", hue="hue") + assert isinstance(g._mappables[-1], mpl.collections.PathCollection) + + @pytest.mark.parametrize( + ["x", "y", "hue", "markersize"], + [("A", "B", "x", "col"), ("x", "row", "A", "B")], + ) + def test_scatter( + self, x: Hashable, y: Hashable, hue: Hashable, markersize: Hashable + ) -> None: + self.ds.plot.scatter(x=x, y=y, hue=hue, markersize=markersize) + + with pytest.raises(ValueError, match=r"u, v"): + self.ds.plot.scatter(x=x, y=y, u="col", v="row") + + def test_non_numeric_legend(self) -> None: + ds2 = self.ds.copy() + ds2["hue"] = ["a", "b", "c", "d"] + pc = ds2.plot.scatter(x="A", y="B", markersize="hue") + axes = pc.axes + assert axes is not None + # should make a discrete legend + assert hasattr(axes, "legend_") + assert axes.legend_ is not None + + def test_legend_labels(self) -> None: + # regression test for #4126: incorrect legend labels + ds2 = self.ds.copy() + ds2["hue"] = ["a", "a", "b", "b"] + pc = ds2.plot.scatter(x="A", y="B", markersize="hue") + axes = pc.axes + assert axes is not None + actual = [t.get_text() for t in axes.get_legend().texts] + expected = ["hue", "a", "b"] + assert actual == expected + + def test_legend_labels_facetgrid(self) -> None: + ds2 = self.ds.copy() + ds2["hue"] = ["d", "a", "c", "b"] + g = ds2.plot.scatter(x="A", y="B", hue="hue", markersize="x", col="col") + legend = g.figlegend + assert legend is not None + actual = tuple(t.get_text() for t in legend.texts) + expected = ( + "x [xunits]", + "$\\mathdefault{0}$", + "$\\mathdefault{1}$", + "$\\mathdefault{2}$", + ) + assert actual == expected + + def test_add_legend_by_default(self) -> None: + sc = self.ds.plot.scatter(x="A", y="B", hue="hue") + fig = sc.figure + assert fig is not None + assert len(fig.axes) == 2 + + +class TestDatetimePlot(PlotTestCase): + @pytest.fixture(autouse=True) + def setUp(self) -> None: + """ + Create a DataArray with a time-axis that contains datetime objects. + """ + month = np.arange(1, 13, 1) + data = np.sin(2 * np.pi * month / 12.0) + times = pd.date_range(start="2017-01-01", freq="MS", periods=12) + darray = DataArray(data, dims=["time"], coords=[times]) + + self.darray = darray + + def test_datetime_line_plot(self) -> None: + # test if line plot raises no Exception + self.darray.plot.line() + + def test_datetime_units(self) -> None: + # test that matplotlib-native datetime works: + fig, ax = plt.subplots() + ax.plot(self.darray["time"], self.darray) + + # Make sure only mpl converters are used, use type() so only + # mpl.dates.AutoDateLocator passes and no other subclasses: + assert type(ax.xaxis.get_major_locator()) is mpl.dates.AutoDateLocator + + def test_datetime_plot1d(self) -> None: + # Test that matplotlib-native datetime works: + p = self.darray.plot.line() + ax = p[0].axes + + # Make sure only mpl converters are used, use type() so only + # mpl.dates.AutoDateLocator passes and no other subclasses: + assert type(ax.xaxis.get_major_locator()) is mpl.dates.AutoDateLocator + + @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") + def test_datetime_plot2d(self) -> None: + # Test that matplotlib-native datetime works: + da = DataArray( + np.arange(3 * 4).reshape(3, 4), + dims=("x", "y"), + coords={ + "x": [1, 2, 3], + "y": [np.datetime64(f"2000-01-{x:02d}") for x in range(1, 5)], + }, + ) + + p = da.plot.pcolormesh() + ax = p.axes + assert ax is not None + + # Make sure only mpl converters are used, use type() so only + # mpl.dates.AutoDateLocator passes and no other subclasses: + assert type(ax.xaxis.get_major_locator()) is mpl.dates.AutoDateLocator + + +@pytest.mark.filterwarnings("ignore:setting an array element with a sequence") +@requires_cftime +@pytest.mark.skipif(not has_nc_time_axis, reason="nc_time_axis is not installed") +class TestCFDatetimePlot(PlotTestCase): + @pytest.fixture(autouse=True) + def setUp(self) -> None: + """ + Create a DataArray with a time-axis that contains cftime.datetime + objects. + """ + # case for 1d array + data = np.random.rand(4, 12) + time = xr.cftime_range(start="2017", periods=12, freq="1ME", calendar="noleap") + darray = DataArray(data, dims=["x", "time"]) + darray.coords["time"] = time + + self.darray = darray + + def test_cfdatetime_line_plot(self) -> None: + self.darray.isel(x=0).plot.line() + + def test_cfdatetime_pcolormesh_plot(self) -> None: + self.darray.plot.pcolormesh() + + def test_cfdatetime_contour_plot(self) -> None: + self.darray.plot.contour() + + +@requires_cftime +@pytest.mark.skipif(has_nc_time_axis, reason="nc_time_axis is installed") +class TestNcAxisNotInstalled(PlotTestCase): + @pytest.fixture(autouse=True) + def setUp(self) -> None: + """ + Create a DataArray with a time-axis that contains cftime.datetime + objects. + """ + month = np.arange(1, 13, 1) + data = np.sin(2 * np.pi * month / 12.0) + darray = DataArray(data, dims=["time"]) + darray.coords["time"] = xr.cftime_range( + start="2017", periods=12, freq="1ME", calendar="noleap" + ) + + self.darray = darray + + def test_ncaxis_notinstalled_line_plot(self) -> None: + with pytest.raises(ImportError, match=r"optional `nc-time-axis`"): + self.darray.plot.line() + + +@requires_matplotlib +class TestAxesKwargs: + @pytest.fixture(params=[1, 2, 3]) + def data_array(self, request) -> DataArray: + """ + Return a simple DataArray + """ + dims = request.param + if dims == 1: + return DataArray(easy_array((10,))) + elif dims == 2: + return DataArray(easy_array((10, 3))) + elif dims == 3: + return DataArray(easy_array((10, 3, 2))) + else: + raise ValueError(f"No DataArray implemented for {dims=}.") + + @pytest.fixture(params=[1, 2]) + def data_array_logspaced(self, request) -> DataArray: + """ + Return a simple DataArray with logspaced coordinates + """ + dims = request.param + if dims == 1: + return DataArray( + np.arange(7), dims=("x",), coords={"x": np.logspace(-3, 3, 7)} + ) + elif dims == 2: + return DataArray( + np.arange(16).reshape(4, 4), + dims=("y", "x"), + coords={"x": np.logspace(-1, 2, 4), "y": np.logspace(-5, -1, 4)}, + ) + else: + raise ValueError(f"No DataArray implemented for {dims=}.") + + @pytest.mark.parametrize("xincrease", [True, False]) + def test_xincrease_kwarg(self, data_array, xincrease) -> None: + with figure_context(): + data_array.plot(xincrease=xincrease) + assert plt.gca().xaxis_inverted() == (not xincrease) + + @pytest.mark.parametrize("yincrease", [True, False]) + def test_yincrease_kwarg(self, data_array, yincrease) -> None: + with figure_context(): + data_array.plot(yincrease=yincrease) + assert plt.gca().yaxis_inverted() == (not yincrease) + + @pytest.mark.parametrize("xscale", ["linear", "logit", "symlog"]) + def test_xscale_kwarg(self, data_array, xscale) -> None: + with figure_context(): + data_array.plot(xscale=xscale) + assert plt.gca().get_xscale() == xscale + + @pytest.mark.parametrize("yscale", ["linear", "logit", "symlog"]) + def test_yscale_kwarg(self, data_array, yscale) -> None: + with figure_context(): + data_array.plot(yscale=yscale) + assert plt.gca().get_yscale() == yscale + + def test_xscale_log_kwarg(self, data_array_logspaced) -> None: + xscale = "log" + with figure_context(): + data_array_logspaced.plot(xscale=xscale) + assert plt.gca().get_xscale() == xscale + + def test_yscale_log_kwarg(self, data_array_logspaced) -> None: + yscale = "log" + with figure_context(): + data_array_logspaced.plot(yscale=yscale) + assert plt.gca().get_yscale() == yscale + + def test_xlim_kwarg(self, data_array) -> None: + with figure_context(): + expected = (0.0, 1000.0) + data_array.plot(xlim=[0, 1000]) + assert plt.gca().get_xlim() == expected + + def test_ylim_kwarg(self, data_array) -> None: + with figure_context(): + data_array.plot(ylim=[0, 1000]) + expected = (0.0, 1000.0) + assert plt.gca().get_ylim() == expected + + def test_xticks_kwarg(self, data_array) -> None: + with figure_context(): + data_array.plot(xticks=np.arange(5)) + expected = np.arange(5).tolist() + assert_array_equal(plt.gca().get_xticks(), expected) + + def test_yticks_kwarg(self, data_array) -> None: + with figure_context(): + data_array.plot(yticks=np.arange(5)) + expected = np.arange(5) + assert_array_equal(plt.gca().get_yticks(), expected) + + +@requires_matplotlib +@pytest.mark.parametrize("plotfunc", ["pcolormesh", "contourf", "contour"]) +def test_plot_transposed_nondim_coord(plotfunc) -> None: + x = np.linspace(0, 10, 101) + h = np.linspace(3, 7, 101) + s = np.linspace(0, 1, 51) + z = s[:, np.newaxis] * h[np.newaxis, :] + da = xr.DataArray( + np.sin(x) * np.cos(z), + dims=["s", "x"], + coords={"x": x, "s": s, "z": (("s", "x"), z), "zt": (("x", "s"), z.T)}, + ) + with figure_context(): + getattr(da.plot, plotfunc)(x="x", y="zt") + with figure_context(): + getattr(da.plot, plotfunc)(x="zt", y="x") + + +@requires_matplotlib +@pytest.mark.parametrize("plotfunc", ["pcolormesh", "imshow"]) +def test_plot_transposes_properly(plotfunc) -> None: + # test that we aren't mistakenly transposing when the 2 dimensions have equal sizes. + da = xr.DataArray([np.sin(2 * np.pi / 10 * np.arange(10))] * 10, dims=("y", "x")) + with figure_context(): + hdl = getattr(da.plot, plotfunc)(x="x", y="y") + # get_array doesn't work for contour, contourf. It returns the colormap intervals. + # pcolormesh returns 1D array but imshow returns a 2D array so it is necessary + # to ravel() on the LHS + assert_array_equal(hdl.get_array().ravel(), da.to_masked_array().ravel()) + + +@requires_matplotlib +def test_facetgrid_single_contour() -> None: + # regression test for GH3569 + x, y = np.meshgrid(np.arange(12), np.arange(12)) + z = xr.DataArray(np.sqrt(x**2 + y**2)) + z2 = xr.DataArray(np.sqrt(x**2 + y**2) + 1) + ds = xr.concat([z, z2], dim="time") + ds["time"] = [0, 1] + + with figure_context(): + ds.plot.contour(col="time", levels=[4], colors=["k"]) + + +@requires_matplotlib +def test_get_axis_raises() -> None: + # test get_axis raises an error if trying to do invalid things + + # cannot provide both ax and figsize + with pytest.raises(ValueError, match="both `figsize` and `ax`"): + get_axis(figsize=[4, 4], size=None, aspect=None, ax="something") # type: ignore[arg-type] + + # cannot provide both ax and size + with pytest.raises(ValueError, match="both `size` and `ax`"): + get_axis(figsize=None, size=200, aspect=4 / 3, ax="something") # type: ignore[arg-type] + + # cannot provide both size and figsize + with pytest.raises(ValueError, match="both `figsize` and `size`"): + get_axis(figsize=[4, 4], size=200, aspect=None, ax=None) + + # cannot provide aspect and size + with pytest.raises(ValueError, match="`aspect` argument without `size`"): + get_axis(figsize=None, size=None, aspect=4 / 3, ax=None) + + # cannot provide axis and subplot_kws + with pytest.raises(ValueError, match="cannot use subplot_kws with existing ax"): + get_axis(figsize=None, size=None, aspect=None, ax=1, something_else=5) # type: ignore[arg-type] + + +@requires_matplotlib +@pytest.mark.parametrize( + ["figsize", "size", "aspect", "ax", "kwargs"], + [ + pytest.param((3, 2), None, None, False, {}, id="figsize"), + pytest.param( + (3.5, 2.5), None, None, False, {"label": "test"}, id="figsize_kwargs" + ), + pytest.param(None, 5, None, False, {}, id="size"), + pytest.param(None, 5.5, None, False, {"label": "test"}, id="size_kwargs"), + pytest.param(None, 5, 1, False, {}, id="size+aspect"), + pytest.param(None, 5, "auto", False, {}, id="auto_aspect"), + pytest.param(None, 5, "equal", False, {}, id="equal_aspect"), + pytest.param(None, None, None, True, {}, id="ax"), + pytest.param(None, None, None, False, {}, id="default"), + pytest.param(None, None, None, False, {"label": "test"}, id="default_kwargs"), + ], +) +def test_get_axis( + figsize: tuple[float, float] | None, + size: float | None, + aspect: float | None, + ax: bool, + kwargs: dict[str, Any], +) -> None: + with figure_context(): + inp_ax = plt.axes() if ax else None + out_ax = get_axis( + figsize=figsize, size=size, aspect=aspect, ax=inp_ax, **kwargs + ) + assert isinstance(out_ax, mpl.axes.Axes) + + +@requires_matplotlib +@requires_cartopy +@pytest.mark.parametrize( + ["figsize", "size", "aspect"], + [ + pytest.param((3, 2), None, None, id="figsize"), + pytest.param(None, 5, None, id="size"), + pytest.param(None, 5, 1, id="size+aspect"), + pytest.param(None, None, None, id="default"), + ], +) +def test_get_axis_cartopy( + figsize: tuple[float, float] | None, size: float | None, aspect: float | None +) -> None: + kwargs = {"projection": cartopy.crs.PlateCarree()} + with figure_context(): + out_ax = get_axis(figsize=figsize, size=size, aspect=aspect, **kwargs) + assert isinstance(out_ax, cartopy.mpl.geoaxes.GeoAxesSubplot) + + +@requires_matplotlib +def test_get_axis_current() -> None: + with figure_context(): + _, ax = plt.subplots() + out_ax = get_axis() + assert ax is out_ax + + +@requires_matplotlib +def test_maybe_gca() -> None: + with figure_context(): + ax = _maybe_gca(aspect=1) + + assert isinstance(ax, mpl.axes.Axes) + assert ax.get_aspect() == 1 + + with figure_context(): + # create figure without axes + plt.figure() + ax = _maybe_gca(aspect=1) + + assert isinstance(ax, mpl.axes.Axes) + assert ax.get_aspect() == 1 + + with figure_context(): + existing_axes = plt.axes() + ax = _maybe_gca(aspect=1) + + # re-uses the existing axes + assert existing_axes == ax + # kwargs are ignored when reusing axes + assert ax.get_aspect() == "auto" + + +@requires_matplotlib +@pytest.mark.parametrize( + "x, y, z, hue, markersize, row, col, add_legend, add_colorbar", + [ + ("A", "B", None, None, None, None, None, None, None), + ("B", "A", None, "w", None, None, None, True, None), + ("A", "B", None, "y", "x", None, None, True, True), + ("A", "B", "z", None, None, None, None, None, None), + ("B", "A", "z", "w", None, None, None, True, None), + ("A", "B", "z", "y", "x", None, None, True, True), + ("A", "B", "z", "y", "x", "w", None, True, True), + ], +) +def test_datarray_scatter( + x, y, z, hue, markersize, row, col, add_legend, add_colorbar +) -> None: + """Test datarray scatter. Merge with TestPlot1D eventually.""" + ds = xr.tutorial.scatter_example_dataset() + + extra_coords = [v for v in [x, hue, markersize] if v is not None] + + # Base coords: + coords = dict(ds.coords) + + # Add extra coords to the DataArray: + coords.update({v: ds[v] for v in extra_coords}) + + darray = xr.DataArray(ds[y], coords=coords) + + with figure_context(): + darray.plot.scatter( + x=x, + z=z, + hue=hue, + markersize=markersize, + add_legend=add_legend, + add_colorbar=add_colorbar, + ) + + +@requires_matplotlib +def test_assert_valid_xy() -> None: + ds = xr.tutorial.scatter_example_dataset() + darray = ds.A + + # x is valid and should not error: + _assert_valid_xy(darray=darray, xy="x", name="x") + + # None should be valid as well even though it isn't in the valid list: + _assert_valid_xy(darray=darray, xy=None, name="x") + + # A hashable that is not valid should error: + with pytest.raises(ValueError, match="x must be one of"): + _assert_valid_xy(darray=darray, xy="error_now", name="x") + + +@requires_matplotlib +@pytest.mark.parametrize( + "val", [pytest.param([], id="empty"), pytest.param(0, id="scalar")] +) +@pytest.mark.parametrize( + "method", + [ + "__call__", + "line", + "step", + "contour", + "contourf", + "hist", + "imshow", + "pcolormesh", + "scatter", + "surface", + ], +) +def test_plot_empty_raises(val: list | float, method: str) -> None: + da = xr.DataArray(val) + with pytest.raises(TypeError, match="No numeric data"): + getattr(da.plot, method)() + + +@requires_matplotlib +def test_facetgrid_axes_raises_deprecation_warning() -> None: + with pytest.warns( + DeprecationWarning, + match=( + "self.axes is deprecated since 2022.11 in order to align with " + "matplotlibs plt.subplots, use self.axs instead." + ), + ): + with figure_context(): + ds = xr.tutorial.scatter_example_dataset() + g = ds.plot.scatter(x="A", y="B", col="x") + g.axes + + +@requires_matplotlib +def test_plot1d_default_rcparams() -> None: + import matplotlib as mpl + + ds = xr.tutorial.scatter_example_dataset(seed=42) + + with figure_context(): + # scatter markers should by default have white edgecolor to better + # see overlapping markers: + fig, ax = plt.subplots(1, 1) + ds.plot.scatter(x="A", y="B", marker="o", ax=ax) + np.testing.assert_allclose( + ax.collections[0].get_edgecolor(), mpl.colors.to_rgba_array("w") + ) + + # Facetgrids should have the default value as well: + fg = ds.plot.scatter(x="A", y="B", col="x", marker="o") + ax = fg.axs.ravel()[0] + np.testing.assert_allclose( + ax.collections[0].get_edgecolor(), mpl.colors.to_rgba_array("w") + ) + + # scatter should not emit any warnings when using unfilled markers: + with assert_no_warnings(): + fig, ax = plt.subplots(1, 1) + ds.plot.scatter(x="A", y="B", ax=ax, marker="x") + + # Prioritize edgecolor argument over default plot1d values: + fig, ax = plt.subplots(1, 1) + ds.plot.scatter(x="A", y="B", marker="o", ax=ax, edgecolor="k") + np.testing.assert_allclose( + ax.collections[0].get_edgecolor(), mpl.colors.to_rgba_array("k") + ) + + +@requires_matplotlib +def test_plot1d_filtered_nulls() -> None: + ds = xr.tutorial.scatter_example_dataset(seed=42) + y = ds.y.where(ds.y > 0.2) + expected = y.notnull().sum().item() + + with figure_context(): + pc = y.plot.scatter() + actual = pc.get_offsets().shape[0] + + assert expected == actual diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_plugins.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_plugins.py new file mode 100644 index 0000000..46051f0 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_plugins.py @@ -0,0 +1,319 @@ +from __future__ import annotations + +import sys +from importlib.metadata import EntryPoint + +if sys.version_info >= (3, 10): + from importlib.metadata import EntryPoints +else: + EntryPoints = list[EntryPoint] +from unittest import mock + +import pytest + +from xarray.backends import common, plugins +from xarray.tests import ( + has_h5netcdf, + has_netCDF4, + has_pydap, + has_scipy, + has_zarr, +) + +# Do not import list_engines here, this will break the lazy tests + +importlib_metadata_mock = "importlib.metadata" + + +class DummyBackendEntrypointArgs(common.BackendEntrypoint): + def open_dataset(filename_or_obj, *args): + pass + + +class DummyBackendEntrypointKwargs(common.BackendEntrypoint): + def open_dataset(filename_or_obj, **kwargs): + pass + + +class DummyBackendEntrypoint1(common.BackendEntrypoint): + def open_dataset(self, filename_or_obj, *, decoder): + pass + + +class DummyBackendEntrypoint2(common.BackendEntrypoint): + def open_dataset(self, filename_or_obj, *, decoder): + pass + + +@pytest.fixture +def dummy_duplicated_entrypoints(): + specs = [ + ["engine1", "xarray.tests.test_plugins:backend_1", "xarray.backends"], + ["engine1", "xarray.tests.test_plugins:backend_2", "xarray.backends"], + ["engine2", "xarray.tests.test_plugins:backend_1", "xarray.backends"], + ["engine2", "xarray.tests.test_plugins:backend_2", "xarray.backends"], + ] + eps = [EntryPoint(name, value, group) for name, value, group in specs] + return eps + + +@pytest.mark.filterwarnings("ignore:Found") +def test_remove_duplicates(dummy_duplicated_entrypoints) -> None: + with pytest.warns(RuntimeWarning): + entrypoints = plugins.remove_duplicates(dummy_duplicated_entrypoints) + assert len(entrypoints) == 2 + + +def test_broken_plugin() -> None: + broken_backend = EntryPoint( + "broken_backend", + "xarray.tests.test_plugins:backend_1", + "xarray.backends", + ) + with pytest.warns(RuntimeWarning) as record: + _ = plugins.build_engines(EntryPoints([broken_backend])) + assert len(record) == 1 + message = str(record[0].message) + assert "Engine 'broken_backend'" in message + + +def test_remove_duplicates_warnings(dummy_duplicated_entrypoints) -> None: + with pytest.warns(RuntimeWarning) as record: + _ = plugins.remove_duplicates(dummy_duplicated_entrypoints) + + assert len(record) == 2 + message0 = str(record[0].message) + message1 = str(record[1].message) + assert "entrypoints" in message0 + assert "entrypoints" in message1 + + +@mock.patch( + f"{importlib_metadata_mock}.EntryPoint.load", mock.MagicMock(return_value=None) +) +def test_backends_dict_from_pkg() -> None: + specs = [ + ["engine1", "xarray.tests.test_plugins:backend_1", "xarray.backends"], + ["engine2", "xarray.tests.test_plugins:backend_2", "xarray.backends"], + ] + entrypoints = [EntryPoint(name, value, group) for name, value, group in specs] + engines = plugins.backends_dict_from_pkg(entrypoints) + assert len(engines) == 2 + assert engines.keys() == {"engine1", "engine2"} + + +def test_set_missing_parameters() -> None: + backend_1 = DummyBackendEntrypoint1 + backend_2 = DummyBackendEntrypoint2 + backend_2.open_dataset_parameters = ("filename_or_obj",) + engines = {"engine_1": backend_1, "engine_2": backend_2} + plugins.set_missing_parameters(engines) + + assert len(engines) == 2 + assert backend_1.open_dataset_parameters == ("filename_or_obj", "decoder") + assert backend_2.open_dataset_parameters == ("filename_or_obj",) + + backend_kwargs = DummyBackendEntrypointKwargs + backend_kwargs.open_dataset_parameters = ("filename_or_obj", "decoder") + plugins.set_missing_parameters({"engine": backend_kwargs}) + assert backend_kwargs.open_dataset_parameters == ("filename_or_obj", "decoder") + + backend_args = DummyBackendEntrypointArgs + backend_args.open_dataset_parameters = ("filename_or_obj", "decoder") + plugins.set_missing_parameters({"engine": backend_args}) + assert backend_args.open_dataset_parameters == ("filename_or_obj", "decoder") + + # reset + backend_1.open_dataset_parameters = None + backend_1.open_dataset_parameters = None + backend_kwargs.open_dataset_parameters = None + backend_args.open_dataset_parameters = None + + +def test_set_missing_parameters_raise_error() -> None: + backend = DummyBackendEntrypointKwargs + with pytest.raises(TypeError): + plugins.set_missing_parameters({"engine": backend}) + + backend_args = DummyBackendEntrypointArgs + with pytest.raises(TypeError): + plugins.set_missing_parameters({"engine": backend_args}) + + +@mock.patch( + f"{importlib_metadata_mock}.EntryPoint.load", + mock.MagicMock(return_value=DummyBackendEntrypoint1), +) +def test_build_engines() -> None: + dummy_pkg_entrypoint = EntryPoint( + "dummy", "xarray.tests.test_plugins:backend_1", "xarray_backends" + ) + backend_entrypoints = plugins.build_engines(EntryPoints([dummy_pkg_entrypoint])) + + assert isinstance(backend_entrypoints["dummy"], DummyBackendEntrypoint1) + assert backend_entrypoints["dummy"].open_dataset_parameters == ( + "filename_or_obj", + "decoder", + ) + + +@mock.patch( + f"{importlib_metadata_mock}.EntryPoint.load", + mock.MagicMock(return_value=DummyBackendEntrypoint1), +) +def test_build_engines_sorted() -> None: + dummy_pkg_entrypoints = EntryPoints( + [ + EntryPoint( + "dummy2", "xarray.tests.test_plugins:backend_1", "xarray.backends" + ), + EntryPoint( + "dummy1", "xarray.tests.test_plugins:backend_1", "xarray.backends" + ), + ] + ) + backend_entrypoints = list(plugins.build_engines(dummy_pkg_entrypoints)) + + indices = [] + for be in plugins.STANDARD_BACKENDS_ORDER: + try: + index = backend_entrypoints.index(be) + backend_entrypoints.pop(index) + indices.append(index) + except ValueError: + pass + + assert set(indices) < {0, -1} + assert list(backend_entrypoints) == sorted(backend_entrypoints) + + +@mock.patch( + "xarray.backends.plugins.list_engines", + mock.MagicMock(return_value={"dummy": DummyBackendEntrypointArgs()}), +) +def test_no_matching_engine_found() -> None: + with pytest.raises(ValueError, match=r"did not find a match in any"): + plugins.guess_engine("not-valid") + + with pytest.raises(ValueError, match=r"found the following matches with the input"): + plugins.guess_engine("foo.nc") + + +@mock.patch( + "xarray.backends.plugins.list_engines", + mock.MagicMock(return_value={}), +) +def test_engines_not_installed() -> None: + with pytest.raises(ValueError, match=r"xarray is unable to open"): + plugins.guess_engine("not-valid") + + with pytest.raises(ValueError, match=r"found the following matches with the input"): + plugins.guess_engine("foo.nc") + + +def test_lazy_import() -> None: + """Test that some modules are imported in a lazy manner. + + When importing xarray these should not be imported as well. + Only when running code for the first time that requires them. + """ + deny_list = [ + "cubed", + "cupy", + # "dask", # TODO: backends.locks is not lazy yet :( + "dask.array", + "dask.distributed", + "flox", + "h5netcdf", + "matplotlib", + "nc_time_axis", + "netCDF4", + "numbagg", + "pint", + "pydap", + "scipy", + "sparse", + "zarr", + ] + # ensure that none of the above modules has been imported before + modules_backup = {} + for pkg in list(sys.modules.keys()): + for mod in deny_list + ["xarray"]: + if pkg.startswith(mod): + modules_backup[pkg] = sys.modules[pkg] + del sys.modules[pkg] + break + + try: + import xarray # noqa: F401 + from xarray.backends import list_engines + + list_engines() + + # ensure that none of the modules that are supposed to be + # lazy loaded are loaded when importing xarray + is_imported = set() + for pkg in sys.modules: + for mod in deny_list: + if pkg.startswith(mod): + is_imported.add(mod) + break + assert ( + len(is_imported) == 0 + ), f"{is_imported} have been imported but should be lazy" + + finally: + # restore original + sys.modules.update(modules_backup) + + +def test_list_engines() -> None: + from xarray.backends import list_engines + + engines = list_engines() + assert list_engines.cache_info().currsize == 1 + + assert ("scipy" in engines) == has_scipy + assert ("h5netcdf" in engines) == has_h5netcdf + assert ("netcdf4" in engines) == has_netCDF4 + assert ("pydap" in engines) == has_pydap + assert ("zarr" in engines) == has_zarr + assert "store" in engines + + +def test_refresh_engines() -> None: + from xarray.backends import list_engines, refresh_engines + + EntryPointMock1 = mock.MagicMock() + EntryPointMock1.name = "test1" + EntryPointMock1.load.return_value = DummyBackendEntrypoint1 + + if sys.version_info >= (3, 10): + return_value = EntryPoints([EntryPointMock1]) + else: + return_value = {"xarray.backends": [EntryPointMock1]} + + with mock.patch("xarray.backends.plugins.entry_points", return_value=return_value): + list_engines.cache_clear() + engines = list_engines() + assert "test1" in engines + assert isinstance(engines["test1"], DummyBackendEntrypoint1) + + EntryPointMock2 = mock.MagicMock() + EntryPointMock2.name = "test2" + EntryPointMock2.load.return_value = DummyBackendEntrypoint2 + + if sys.version_info >= (3, 10): + return_value2 = EntryPoints([EntryPointMock2]) + else: + return_value2 = {"xarray.backends": [EntryPointMock2]} + + with mock.patch("xarray.backends.plugins.entry_points", return_value=return_value2): + refresh_engines() + engines = list_engines() + assert "test1" not in engines + assert "test2" in engines + assert isinstance(engines["test2"], DummyBackendEntrypoint2) + + # reset to original + refresh_engines() diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_print_versions.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_print_versions.py new file mode 100644 index 0000000..f964eb8 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_print_versions.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +import io + +import xarray + + +def test_show_versions() -> None: + f = io.StringIO() + xarray.show_versions(file=f) + assert "INSTALLED VERSIONS" in f.getvalue() diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_rolling.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_rolling.py new file mode 100644 index 0000000..89f6ebb --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_rolling.py @@ -0,0 +1,887 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np +import pandas as pd +import pytest + +import xarray as xr +from xarray import DataArray, Dataset, set_options +from xarray.tests import ( + assert_allclose, + assert_equal, + assert_identical, + has_dask, + requires_dask, + requires_numbagg, +) + +pytestmark = [ + pytest.mark.filterwarnings("error:Mean of empty slice"), + pytest.mark.filterwarnings("error:All-NaN (slice|axis) encountered"), +] + + +@pytest.mark.parametrize("func", ["mean", "sum"]) +@pytest.mark.parametrize("min_periods", [1, 10]) +def test_cumulative(d, func, min_periods) -> None: + # One dim + result = getattr(d.cumulative("z", min_periods=min_periods), func)() + expected = getattr(d.rolling(z=d["z"].size, min_periods=min_periods), func)() + assert_identical(result, expected) + + # Multiple dim + result = getattr(d.cumulative(["z", "x"], min_periods=min_periods), func)() + expected = getattr( + d.rolling(z=d["z"].size, x=d["x"].size, min_periods=min_periods), + func, + )() + assert_identical(result, expected) + + +def test_cumulative_vs_cum(d) -> None: + result = d.cumulative("z").sum() + expected = d.cumsum("z") + # cumsum drops the coord of the dimension; cumulative doesn't + expected = expected.assign_coords(z=result["z"]) + assert_identical(result, expected) + + +class TestDataArrayRolling: + @pytest.mark.parametrize("da", (1, 2), indirect=True) + @pytest.mark.parametrize("center", [True, False]) + @pytest.mark.parametrize("size", [1, 2, 3, 7]) + def test_rolling_iter(self, da: DataArray, center: bool, size: int) -> None: + rolling_obj = da.rolling(time=size, center=center) + rolling_obj_mean = rolling_obj.mean() + + assert len(rolling_obj.window_labels) == len(da["time"]) + assert_identical(rolling_obj.window_labels, da["time"]) + + for i, (label, window_da) in enumerate(rolling_obj): + assert label == da["time"].isel(time=i) + + actual = rolling_obj_mean.isel(time=i) + expected = window_da.mean("time") + + np.testing.assert_allclose(actual.values, expected.values) + + @pytest.mark.parametrize("da", (1,), indirect=True) + def test_rolling_repr(self, da) -> None: + rolling_obj = da.rolling(time=7) + assert repr(rolling_obj) == "DataArrayRolling [time->7]" + rolling_obj = da.rolling(time=7, center=True) + assert repr(rolling_obj) == "DataArrayRolling [time->7(center)]" + rolling_obj = da.rolling(time=7, x=3, center=True) + assert repr(rolling_obj) == "DataArrayRolling [time->7(center),x->3(center)]" + + @requires_dask + def test_repeated_rolling_rechunks(self) -> None: + # regression test for GH3277, GH2514 + dat = DataArray(np.random.rand(7653, 300), dims=("day", "item")) + dat_chunk = dat.chunk({"item": 20}) + dat_chunk.rolling(day=10).mean().rolling(day=250).std() + + def test_rolling_doc(self, da) -> None: + rolling_obj = da.rolling(time=7) + + # argument substitution worked + assert "`mean`" in rolling_obj.mean.__doc__ + + def test_rolling_properties(self, da) -> None: + rolling_obj = da.rolling(time=4) + + assert rolling_obj.obj.get_axis_num("time") == 1 + + # catching invalid args + with pytest.raises(ValueError, match="window must be > 0"): + da.rolling(time=-2) + + with pytest.raises(ValueError, match="min_periods must be greater than zero"): + da.rolling(time=2, min_periods=0) + + with pytest.raises( + KeyError, + match=r"\('foo',\) not found in DataArray dimensions", + ): + da.rolling(foo=2) + + @pytest.mark.parametrize( + "name", ("sum", "mean", "std", "min", "max", "median", "argmin", "argmax") + ) + @pytest.mark.parametrize("center", (True, False, None)) + @pytest.mark.parametrize("min_periods", (1, None)) + @pytest.mark.parametrize("backend", ["numpy"], indirect=True) + def test_rolling_wrapped_bottleneck( + self, da, name, center, min_periods, compute_backend + ) -> None: + bn = pytest.importorskip("bottleneck", minversion="1.1") + # Test all bottleneck functions + rolling_obj = da.rolling(time=7, min_periods=min_periods) + + func_name = f"move_{name}" + actual = getattr(rolling_obj, name)() + window = 7 + expected = getattr(bn, func_name)( + da.values, window=window, axis=1, min_count=min_periods + ) + # index 0 is at the rightmost edge of the window + # need to reverse index here + # see GH #8541 + if func_name in ["move_argmin", "move_argmax"]: + expected = window - 1 - expected + + # Using assert_allclose because we get tiny (1e-17) differences in numbagg. + np.testing.assert_allclose(actual.values, expected) + + with pytest.warns(DeprecationWarning, match="Reductions are applied"): + getattr(rolling_obj, name)(dim="time") + + # Test center + rolling_obj = da.rolling(time=7, center=center) + actual = getattr(rolling_obj, name)()["time"] + # Using assert_allclose because we get tiny (1e-17) differences in numbagg. + assert_allclose(actual, da["time"]) + + @requires_dask + @pytest.mark.parametrize("name", ("mean", "count")) + @pytest.mark.parametrize("center", (True, False, None)) + @pytest.mark.parametrize("min_periods", (1, None)) + @pytest.mark.parametrize("window", (7, 8)) + @pytest.mark.parametrize("backend", ["dask"], indirect=True) + def test_rolling_wrapped_dask(self, da, name, center, min_periods, window) -> None: + # dask version + rolling_obj = da.rolling(time=window, min_periods=min_periods, center=center) + actual = getattr(rolling_obj, name)().load() + if name != "count": + with pytest.warns(DeprecationWarning, match="Reductions are applied"): + getattr(rolling_obj, name)(dim="time") + # numpy version + rolling_obj = da.load().rolling( + time=window, min_periods=min_periods, center=center + ) + expected = getattr(rolling_obj, name)() + + # using all-close because rolling over ghost cells introduces some + # precision errors + assert_allclose(actual, expected) + + # with zero chunked array GH:2113 + rolling_obj = da.chunk().rolling( + time=window, min_periods=min_periods, center=center + ) + actual = getattr(rolling_obj, name)().load() + assert_allclose(actual, expected) + + @pytest.mark.parametrize("center", (True, None)) + def test_rolling_wrapped_dask_nochunk(self, center) -> None: + # GH:2113 + pytest.importorskip("dask.array") + + da_day_clim = xr.DataArray( + np.arange(1, 367), coords=[np.arange(1, 367)], dims="dayofyear" + ) + expected = da_day_clim.rolling(dayofyear=31, center=center).mean() + actual = da_day_clim.chunk().rolling(dayofyear=31, center=center).mean() + assert_allclose(actual, expected) + + @pytest.mark.parametrize("center", (True, False)) + @pytest.mark.parametrize("min_periods", (None, 1, 2, 3)) + @pytest.mark.parametrize("window", (1, 2, 3, 4)) + def test_rolling_pandas_compat( + self, center, window, min_periods, compute_backend + ) -> None: + s = pd.Series(np.arange(10)) + da = DataArray.from_series(s) + + if min_periods is not None and window < min_periods: + min_periods = window + + s_rolling = s.rolling(window, center=center, min_periods=min_periods).mean() + da_rolling = da.rolling( + index=window, center=center, min_periods=min_periods + ).mean() + da_rolling_np = da.rolling( + index=window, center=center, min_periods=min_periods + ).reduce(np.nanmean) + + np.testing.assert_allclose(s_rolling.values, da_rolling.values) + np.testing.assert_allclose(s_rolling.index, da_rolling["index"]) + np.testing.assert_allclose(s_rolling.values, da_rolling_np.values) + np.testing.assert_allclose(s_rolling.index, da_rolling_np["index"]) + + @pytest.mark.parametrize("center", (True, False)) + @pytest.mark.parametrize("window", (1, 2, 3, 4)) + def test_rolling_construct(self, center: bool, window: int) -> None: + s = pd.Series(np.arange(10)) + da = DataArray.from_series(s) + + s_rolling = s.rolling(window, center=center, min_periods=1).mean() + da_rolling = da.rolling(index=window, center=center, min_periods=1) + + da_rolling_mean = da_rolling.construct("window").mean("window") + np.testing.assert_allclose(s_rolling.values, da_rolling_mean.values) + np.testing.assert_allclose(s_rolling.index, da_rolling_mean["index"]) + + # with stride + da_rolling_mean = da_rolling.construct("window", stride=2).mean("window") + np.testing.assert_allclose(s_rolling.values[::2], da_rolling_mean.values) + np.testing.assert_allclose(s_rolling.index[::2], da_rolling_mean["index"]) + + # with fill_value + da_rolling_mean = da_rolling.construct("window", stride=2, fill_value=0.0).mean( + "window" + ) + assert da_rolling_mean.isnull().sum() == 0 + assert (da_rolling_mean == 0.0).sum() >= 0 + + @pytest.mark.parametrize("da", (1, 2), indirect=True) + @pytest.mark.parametrize("center", (True, False)) + @pytest.mark.parametrize("min_periods", (None, 1, 2, 3)) + @pytest.mark.parametrize("window", (1, 2, 3, 4)) + @pytest.mark.parametrize("name", ("sum", "mean", "std", "max")) + def test_rolling_reduce( + self, da, center, min_periods, window, name, compute_backend + ) -> None: + if min_periods is not None and window < min_periods: + min_periods = window + + if da.isnull().sum() > 1 and window == 1: + # this causes all nan slices + window = 2 + + rolling_obj = da.rolling(time=window, center=center, min_periods=min_periods) + + # add nan prefix to numpy methods to get similar # behavior as bottleneck + actual = rolling_obj.reduce(getattr(np, f"nan{name}")) + expected = getattr(rolling_obj, name)() + assert_allclose(actual, expected) + assert actual.sizes == expected.sizes + + @pytest.mark.parametrize("center", (True, False)) + @pytest.mark.parametrize("min_periods", (None, 1, 2, 3)) + @pytest.mark.parametrize("window", (1, 2, 3, 4)) + @pytest.mark.parametrize("name", ("sum", "max")) + def test_rolling_reduce_nonnumeric( + self, center, min_periods, window, name, compute_backend + ) -> None: + da = DataArray( + [0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time" + ).isnull() + + if min_periods is not None and window < min_periods: + min_periods = window + + rolling_obj = da.rolling(time=window, center=center, min_periods=min_periods) + + # add nan prefix to numpy methods to get similar behavior as bottleneck + actual = rolling_obj.reduce(getattr(np, f"nan{name}")) + expected = getattr(rolling_obj, name)() + assert_allclose(actual, expected) + assert actual.sizes == expected.sizes + + def test_rolling_count_correct(self, compute_backend) -> None: + da = DataArray([0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time") + + kwargs: list[dict[str, Any]] = [ + {"time": 11, "min_periods": 1}, + {"time": 11, "min_periods": None}, + {"time": 7, "min_periods": 2}, + ] + expecteds = [ + DataArray([1, 1, 2, 3, 3, 4, 5, 6, 6, 7, 8], dims="time"), + DataArray( + [ + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + ], + dims="time", + ), + DataArray([np.nan, np.nan, 2, 3, 3, 4, 5, 5, 5, 5, 5], dims="time"), + ] + + for kwarg, expected in zip(kwargs, expecteds): + result = da.rolling(**kwarg).count() + assert_equal(result, expected) + + result = da.to_dataset(name="var1").rolling(**kwarg).count()["var1"] + assert_equal(result, expected) + + @pytest.mark.parametrize("da", (1,), indirect=True) + @pytest.mark.parametrize("center", (True, False)) + @pytest.mark.parametrize("min_periods", (None, 1)) + @pytest.mark.parametrize("name", ("sum", "mean", "max")) + def test_ndrolling_reduce( + self, da, center, min_periods, name, compute_backend + ) -> None: + rolling_obj = da.rolling(time=3, x=2, center=center, min_periods=min_periods) + + actual = getattr(rolling_obj, name)() + expected = getattr( + getattr( + da.rolling(time=3, center=center, min_periods=min_periods), name + )().rolling(x=2, center=center, min_periods=min_periods), + name, + )() + + assert_allclose(actual, expected) + assert actual.sizes == expected.sizes + + if name in ["mean"]: + # test our reimplementation of nanmean using np.nanmean + expected = getattr(rolling_obj.construct({"time": "tw", "x": "xw"}), name)( + ["tw", "xw"] + ) + count = rolling_obj.count() + if min_periods is None: + min_periods = 1 + assert_allclose(actual, expected.where(count >= min_periods)) + + @pytest.mark.parametrize("center", (True, False, (True, False))) + @pytest.mark.parametrize("fill_value", (np.nan, 0.0)) + def test_ndrolling_construct(self, center, fill_value) -> None: + da = DataArray( + np.arange(5 * 6 * 7).reshape(5, 6, 7).astype(float), + dims=["x", "y", "z"], + coords={"x": ["a", "b", "c", "d", "e"], "y": np.arange(6)}, + ) + actual = da.rolling(x=3, z=2, center=center).construct( + x="x1", z="z1", fill_value=fill_value + ) + if not isinstance(center, tuple): + center = (center, center) + expected = ( + da.rolling(x=3, center=center[0]) + .construct(x="x1", fill_value=fill_value) + .rolling(z=2, center=center[1]) + .construct(z="z1", fill_value=fill_value) + ) + assert_allclose(actual, expected) + + @pytest.mark.parametrize( + "funcname, argument", + [ + ("reduce", (np.mean,)), + ("mean", ()), + ("construct", ("window_dim",)), + ("count", ()), + ], + ) + def test_rolling_keep_attrs(self, funcname, argument) -> None: + attrs_da = {"da_attr": "test"} + + data = np.linspace(10, 15, 100) + coords = np.linspace(1, 10, 100) + + da = DataArray( + data, dims=("coord"), coords={"coord": coords}, attrs=attrs_da, name="name" + ) + + # attrs are now kept per default + func = getattr(da.rolling(dim={"coord": 5}), funcname) + result = func(*argument) + assert result.attrs == attrs_da + assert result.name == "name" + + # discard attrs + func = getattr(da.rolling(dim={"coord": 5}), funcname) + result = func(*argument, keep_attrs=False) + assert result.attrs == {} + assert result.name == "name" + + # test discard attrs using global option + func = getattr(da.rolling(dim={"coord": 5}), funcname) + with set_options(keep_attrs=False): + result = func(*argument) + assert result.attrs == {} + assert result.name == "name" + + # keyword takes precedence over global option + func = getattr(da.rolling(dim={"coord": 5}), funcname) + with set_options(keep_attrs=False): + result = func(*argument, keep_attrs=True) + assert result.attrs == attrs_da + assert result.name == "name" + + func = getattr(da.rolling(dim={"coord": 5}), funcname) + with set_options(keep_attrs=True): + result = func(*argument, keep_attrs=False) + assert result.attrs == {} + assert result.name == "name" + + @requires_dask + @pytest.mark.parametrize("dtype", ["int", "float32", "float64"]) + def test_rolling_dask_dtype(self, dtype) -> None: + data = DataArray( + np.array([1, 2, 3], dtype=dtype), dims="x", coords={"x": [1, 2, 3]} + ) + unchunked_result = data.rolling(x=3, min_periods=1).mean() + chunked_result = data.chunk({"x": 1}).rolling(x=3, min_periods=1).mean() + assert chunked_result.dtype == unchunked_result.dtype + + +@requires_numbagg +class TestDataArrayRollingExp: + @pytest.mark.parametrize("dim", ["time", "x"]) + @pytest.mark.parametrize( + "window_type, window", + [["span", 5], ["alpha", 0.5], ["com", 0.5], ["halflife", 5]], + ) + @pytest.mark.parametrize("backend", ["numpy"], indirect=True) + @pytest.mark.parametrize("func", ["mean", "sum", "var", "std"]) + def test_rolling_exp_runs(self, da, dim, window_type, window, func) -> None: + da = da.where(da > 0.2) + + rolling_exp = da.rolling_exp(window_type=window_type, **{dim: window}) + result = getattr(rolling_exp, func)() + assert isinstance(result, DataArray) + + @pytest.mark.parametrize("dim", ["time", "x"]) + @pytest.mark.parametrize( + "window_type, window", + [["span", 5], ["alpha", 0.5], ["com", 0.5], ["halflife", 5]], + ) + @pytest.mark.parametrize("backend", ["numpy"], indirect=True) + def test_rolling_exp_mean_pandas(self, da, dim, window_type, window) -> None: + da = da.isel(a=0).where(lambda x: x > 0.2) + + result = da.rolling_exp(window_type=window_type, **{dim: window}).mean() + assert isinstance(result, DataArray) + + pandas_array = da.to_pandas() + assert pandas_array.index.name == "time" + if dim == "x": + pandas_array = pandas_array.T + expected = xr.DataArray( + pandas_array.ewm(**{window_type: window}).mean() + ).transpose(*da.dims) + + assert_allclose(expected.variable, result.variable) + + @pytest.mark.parametrize("backend", ["numpy"], indirect=True) + @pytest.mark.parametrize("func", ["mean", "sum"]) + def test_rolling_exp_keep_attrs(self, da, func) -> None: + attrs = {"attrs": "da"} + da.attrs = attrs + + # Equivalent of `da.rolling_exp(time=10).mean` + rolling_exp_func = getattr(da.rolling_exp(time=10), func) + + # attrs are kept per default + result = rolling_exp_func() + assert result.attrs == attrs + + # discard attrs + result = rolling_exp_func(keep_attrs=False) + assert result.attrs == {} + + # test discard attrs using global option + with set_options(keep_attrs=False): + result = rolling_exp_func() + assert result.attrs == {} + + # keyword takes precedence over global option + with set_options(keep_attrs=False): + result = rolling_exp_func(keep_attrs=True) + assert result.attrs == attrs + + with set_options(keep_attrs=True): + result = rolling_exp_func(keep_attrs=False) + assert result.attrs == {} + + with pytest.warns( + UserWarning, + match="Passing ``keep_attrs`` to ``rolling_exp`` has no effect.", + ): + da.rolling_exp(time=10, keep_attrs=True) + + +class TestDatasetRolling: + @pytest.mark.parametrize( + "funcname, argument", + [ + ("reduce", (np.mean,)), + ("mean", ()), + ("construct", ("window_dim",)), + ("count", ()), + ], + ) + def test_rolling_keep_attrs(self, funcname, argument) -> None: + global_attrs = {"units": "test", "long_name": "testing"} + da_attrs = {"da_attr": "test"} + da_not_rolled_attrs = {"da_not_rolled_attr": "test"} + + data = np.linspace(10, 15, 100) + coords = np.linspace(1, 10, 100) + + ds = Dataset( + data_vars={"da": ("coord", data), "da_not_rolled": ("no_coord", data)}, + coords={"coord": coords}, + attrs=global_attrs, + ) + ds.da.attrs = da_attrs + ds.da_not_rolled.attrs = da_not_rolled_attrs + + # attrs are now kept per default + func = getattr(ds.rolling(dim={"coord": 5}), funcname) + result = func(*argument) + assert result.attrs == global_attrs + assert result.da.attrs == da_attrs + assert result.da_not_rolled.attrs == da_not_rolled_attrs + assert result.da.name == "da" + assert result.da_not_rolled.name == "da_not_rolled" + + # discard attrs + func = getattr(ds.rolling(dim={"coord": 5}), funcname) + result = func(*argument, keep_attrs=False) + assert result.attrs == {} + assert result.da.attrs == {} + assert result.da_not_rolled.attrs == {} + assert result.da.name == "da" + assert result.da_not_rolled.name == "da_not_rolled" + + # test discard attrs using global option + func = getattr(ds.rolling(dim={"coord": 5}), funcname) + with set_options(keep_attrs=False): + result = func(*argument) + + assert result.attrs == {} + assert result.da.attrs == {} + assert result.da_not_rolled.attrs == {} + assert result.da.name == "da" + assert result.da_not_rolled.name == "da_not_rolled" + + # keyword takes precedence over global option + func = getattr(ds.rolling(dim={"coord": 5}), funcname) + with set_options(keep_attrs=False): + result = func(*argument, keep_attrs=True) + + assert result.attrs == global_attrs + assert result.da.attrs == da_attrs + assert result.da_not_rolled.attrs == da_not_rolled_attrs + assert result.da.name == "da" + assert result.da_not_rolled.name == "da_not_rolled" + + func = getattr(ds.rolling(dim={"coord": 5}), funcname) + with set_options(keep_attrs=True): + result = func(*argument, keep_attrs=False) + + assert result.attrs == {} + assert result.da.attrs == {} + assert result.da_not_rolled.attrs == {} + assert result.da.name == "da" + assert result.da_not_rolled.name == "da_not_rolled" + + def test_rolling_properties(self, ds) -> None: + # catching invalid args + with pytest.raises(ValueError, match="window must be > 0"): + ds.rolling(time=-2) + with pytest.raises(ValueError, match="min_periods must be greater than zero"): + ds.rolling(time=2, min_periods=0) + with pytest.raises(KeyError, match="time2"): + ds.rolling(time2=2) + with pytest.raises( + KeyError, + match=r"\('foo',\) not found in Dataset dimensions", + ): + ds.rolling(foo=2) + + @pytest.mark.parametrize( + "name", ("sum", "mean", "std", "var", "min", "max", "median") + ) + @pytest.mark.parametrize("center", (True, False, None)) + @pytest.mark.parametrize("min_periods", (1, None)) + @pytest.mark.parametrize("key", ("z1", "z2")) + @pytest.mark.parametrize("backend", ["numpy"], indirect=True) + def test_rolling_wrapped_bottleneck( + self, ds, name, center, min_periods, key, compute_backend + ) -> None: + bn = pytest.importorskip("bottleneck", minversion="1.1") + + # Test all bottleneck functions + rolling_obj = ds.rolling(time=7, min_periods=min_periods) + + func_name = f"move_{name}" + actual = getattr(rolling_obj, name)() + if key == "z1": # z1 does not depend on 'Time' axis. Stored as it is. + expected = ds[key] + elif key == "z2": + expected = getattr(bn, func_name)( + ds[key].values, window=7, axis=0, min_count=min_periods + ) + else: + raise ValueError + np.testing.assert_allclose(actual[key].values, expected) + + # Test center + rolling_obj = ds.rolling(time=7, center=center) + actual = getattr(rolling_obj, name)()["time"] + assert_allclose(actual, ds["time"]) + + @pytest.mark.parametrize("center", (True, False)) + @pytest.mark.parametrize("min_periods", (None, 1, 2, 3)) + @pytest.mark.parametrize("window", (1, 2, 3, 4)) + def test_rolling_pandas_compat(self, center, window, min_periods) -> None: + df = pd.DataFrame( + { + "x": np.random.randn(20), + "y": np.random.randn(20), + "time": np.linspace(0, 1, 20), + } + ) + ds = Dataset.from_dataframe(df) + + if min_periods is not None and window < min_periods: + min_periods = window + + df_rolling = df.rolling(window, center=center, min_periods=min_periods).mean() + ds_rolling = ds.rolling( + index=window, center=center, min_periods=min_periods + ).mean() + + np.testing.assert_allclose(df_rolling["x"].values, ds_rolling["x"].values) + np.testing.assert_allclose(df_rolling.index, ds_rolling["index"]) + + @pytest.mark.parametrize("center", (True, False)) + @pytest.mark.parametrize("window", (1, 2, 3, 4)) + def test_rolling_construct(self, center: bool, window: int) -> None: + df = pd.DataFrame( + { + "x": np.random.randn(20), + "y": np.random.randn(20), + "time": np.linspace(0, 1, 20), + } + ) + + ds = Dataset.from_dataframe(df) + df_rolling = df.rolling(window, center=center, min_periods=1).mean() + ds_rolling = ds.rolling(index=window, center=center) + + ds_rolling_mean = ds_rolling.construct("window").mean("window") + np.testing.assert_allclose(df_rolling["x"].values, ds_rolling_mean["x"].values) + np.testing.assert_allclose(df_rolling.index, ds_rolling_mean["index"]) + + # with fill_value + ds_rolling_mean = ds_rolling.construct("window", stride=2, fill_value=0.0).mean( + "window" + ) + assert (ds_rolling_mean.isnull().sum() == 0).to_dataarray(dim="vars").all() + assert (ds_rolling_mean["x"] == 0.0).sum() >= 0 + + @pytest.mark.parametrize("center", (True, False)) + @pytest.mark.parametrize("window", (1, 2, 3, 4)) + def test_rolling_construct_stride(self, center: bool, window: int) -> None: + df = pd.DataFrame( + { + "x": np.random.randn(20), + "y": np.random.randn(20), + "time": np.linspace(0, 1, 20), + } + ) + ds = Dataset.from_dataframe(df) + df_rolling_mean = df.rolling(window, center=center, min_periods=1).mean() + + # With an index (dimension coordinate) + ds_rolling = ds.rolling(index=window, center=center) + ds_rolling_mean = ds_rolling.construct("w", stride=2).mean("w") + np.testing.assert_allclose( + df_rolling_mean["x"][::2].values, ds_rolling_mean["x"].values + ) + np.testing.assert_allclose(df_rolling_mean.index[::2], ds_rolling_mean["index"]) + + # Without index (https://github.com/pydata/xarray/issues/7021) + ds2 = ds.drop_vars("index") + ds2_rolling = ds2.rolling(index=window, center=center) + ds2_rolling_mean = ds2_rolling.construct("w", stride=2).mean("w") + np.testing.assert_allclose( + df_rolling_mean["x"][::2].values, ds2_rolling_mean["x"].values + ) + + # Mixed coordinates, indexes and 2D coordinates + ds3 = xr.Dataset( + {"x": ("t", range(20)), "x2": ("y", range(5))}, + { + "t": range(20), + "y": ("y", range(5)), + "t2": ("t", range(20)), + "y2": ("y", range(5)), + "yt": (["t", "y"], np.ones((20, 5))), + }, + ) + ds3_rolling = ds3.rolling(t=window, center=center) + ds3_rolling_mean = ds3_rolling.construct("w", stride=2).mean("w") + for coord in ds3.coords: + assert coord in ds3_rolling_mean.coords + + @pytest.mark.slow + @pytest.mark.parametrize("ds", (1, 2), indirect=True) + @pytest.mark.parametrize("center", (True, False)) + @pytest.mark.parametrize("min_periods", (None, 1, 2, 3)) + @pytest.mark.parametrize("window", (1, 2, 3, 4)) + @pytest.mark.parametrize( + "name", ("sum", "mean", "std", "var", "min", "max", "median") + ) + def test_rolling_reduce(self, ds, center, min_periods, window, name) -> None: + if min_periods is not None and window < min_periods: + min_periods = window + + if name == "std" and window == 1: + pytest.skip("std with window == 1 is unstable in bottleneck") + + rolling_obj = ds.rolling(time=window, center=center, min_periods=min_periods) + + # add nan prefix to numpy methods to get similar behavior as bottleneck + actual = rolling_obj.reduce(getattr(np, f"nan{name}")) + expected = getattr(rolling_obj, name)() + assert_allclose(actual, expected) + assert ds.sizes == actual.sizes + # make sure the order of data_var are not changed. + assert list(ds.data_vars.keys()) == list(actual.data_vars.keys()) + + # Make sure the dimension order is restored + for key, src_var in ds.data_vars.items(): + assert src_var.dims == actual[key].dims + + @pytest.mark.parametrize("ds", (2,), indirect=True) + @pytest.mark.parametrize("center", (True, False)) + @pytest.mark.parametrize("min_periods", (None, 1)) + @pytest.mark.parametrize("name", ("sum", "max")) + @pytest.mark.parametrize("dask", (True, False)) + def test_ndrolling_reduce(self, ds, center, min_periods, name, dask) -> None: + if dask and has_dask: + ds = ds.chunk({"x": 4}) + + rolling_obj = ds.rolling(time=4, x=3, center=center, min_periods=min_periods) + + actual = getattr(rolling_obj, name)() + expected = getattr( + getattr( + ds.rolling(time=4, center=center, min_periods=min_periods), name + )().rolling(x=3, center=center, min_periods=min_periods), + name, + )() + assert_allclose(actual, expected) + assert actual.sizes == expected.sizes + + # Do it in the opposite order + expected = getattr( + getattr( + ds.rolling(x=3, center=center, min_periods=min_periods), name + )().rolling(time=4, center=center, min_periods=min_periods), + name, + )() + + assert_allclose(actual, expected) + assert actual.sizes == expected.sizes + + @pytest.mark.parametrize("center", (True, False, (True, False))) + @pytest.mark.parametrize("fill_value", (np.nan, 0.0)) + @pytest.mark.parametrize("dask", (True, False)) + def test_ndrolling_construct(self, center, fill_value, dask) -> None: + da = DataArray( + np.arange(5 * 6 * 7).reshape(5, 6, 7).astype(float), + dims=["x", "y", "z"], + coords={"x": ["a", "b", "c", "d", "e"], "y": np.arange(6)}, + ) + ds = xr.Dataset({"da": da}) + if dask and has_dask: + ds = ds.chunk({"x": 4}) + + actual = ds.rolling(x=3, z=2, center=center).construct( + x="x1", z="z1", fill_value=fill_value + ) + if not isinstance(center, tuple): + center = (center, center) + expected = ( + ds.rolling(x=3, center=center[0]) + .construct(x="x1", fill_value=fill_value) + .rolling(z=2, center=center[1]) + .construct(z="z1", fill_value=fill_value) + ) + assert_allclose(actual, expected) + + @requires_dask + @pytest.mark.filterwarnings("error") + @pytest.mark.parametrize("ds", (2,), indirect=True) + @pytest.mark.parametrize("name", ("mean", "max")) + def test_raise_no_warning_dask_rolling_assert_close(self, ds, name) -> None: + """ + This is a puzzle — I can't easily find the source of the warning. It + requires `assert_allclose` to be run, for the `ds` param to be 2, and is + different for `mean` and `max`. `sum` raises no warning. + """ + + ds = ds.chunk({"x": 4}) + + rolling_obj = ds.rolling(time=4, x=3) + + actual = getattr(rolling_obj, name)() + expected = getattr(getattr(ds.rolling(time=4), name)().rolling(x=3), name)() + assert_allclose(actual, expected) + + +@requires_numbagg +class TestDatasetRollingExp: + @pytest.mark.parametrize( + "backend", ["numpy", pytest.param("dask", marks=requires_dask)], indirect=True + ) + def test_rolling_exp(self, ds) -> None: + result = ds.rolling_exp(time=10, window_type="span").mean() + assert isinstance(result, Dataset) + + @pytest.mark.parametrize("backend", ["numpy"], indirect=True) + def test_rolling_exp_keep_attrs(self, ds) -> None: + attrs_global = {"attrs": "global"} + attrs_z1 = {"attr": "z1"} + + ds.attrs = attrs_global + ds.z1.attrs = attrs_z1 + + # attrs are kept per default + result = ds.rolling_exp(time=10).mean() + assert result.attrs == attrs_global + assert result.z1.attrs == attrs_z1 + + # discard attrs + result = ds.rolling_exp(time=10).mean(keep_attrs=False) + assert result.attrs == {} + # TODO: from #8114 — this arguably should be empty, but `apply_ufunc` doesn't do + # that at the moment. We should change in `apply_func` rather than + # special-case it here. + # + # assert result.z1.attrs == {} + + # test discard attrs using global option + with set_options(keep_attrs=False): + result = ds.rolling_exp(time=10).mean() + assert result.attrs == {} + # See above + # assert result.z1.attrs == {} + + # keyword takes precedence over global option + with set_options(keep_attrs=False): + result = ds.rolling_exp(time=10).mean(keep_attrs=True) + assert result.attrs == attrs_global + assert result.z1.attrs == attrs_z1 + + with set_options(keep_attrs=True): + result = ds.rolling_exp(time=10).mean(keep_attrs=False) + assert result.attrs == {} + # See above + # assert result.z1.attrs == {} + + with pytest.warns( + UserWarning, + match="Passing ``keep_attrs`` to ``rolling_exp`` has no effect.", + ): + ds.rolling_exp(time=10, keep_attrs=True) diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_sparse.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_sparse.py new file mode 100644 index 0000000..f0a97fc --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_sparse.py @@ -0,0 +1,905 @@ +from __future__ import annotations + +import math +import pickle +from textwrap import dedent + +import numpy as np +import pandas as pd +import pytest + +import xarray as xr +from xarray import DataArray, Variable +from xarray.namedarray.pycompat import array_type +from xarray.tests import assert_equal, assert_identical, requires_dask + +filterwarnings = pytest.mark.filterwarnings +param = pytest.param +xfail = pytest.mark.xfail + +sparse = pytest.importorskip("sparse") +sparse_array_type = array_type("sparse") + + +def assert_sparse_equal(a, b): + assert isinstance(a, sparse_array_type) + assert isinstance(b, sparse_array_type) + np.testing.assert_equal(a.todense(), b.todense()) + + +def make_ndarray(shape): + return np.arange(math.prod(shape)).reshape(shape) + + +def make_sparray(shape): + return sparse.random(shape, density=0.1, random_state=0) + + +def make_xrvar(dim_lengths): + return xr.Variable( + tuple(dim_lengths.keys()), make_sparray(shape=tuple(dim_lengths.values())) + ) + + +def make_xrarray(dim_lengths, coords=None, name="test"): + if coords is None: + coords = {d: np.arange(n) for d, n in dim_lengths.items()} + return xr.DataArray( + make_sparray(shape=tuple(dim_lengths.values())), + dims=tuple(coords.keys()), + coords=coords, + name=name, + ) + + +class do: + def __init__(self, meth, *args, **kwargs): + self.meth = meth + self.args = args + self.kwargs = kwargs + + def __call__(self, obj): + # cannot pass np.sum when using pytest-xdist + kwargs = self.kwargs.copy() + if "func" in self.kwargs: + kwargs["func"] = getattr(np, kwargs["func"]) + + return getattr(obj, self.meth)(*self.args, **kwargs) + + def __repr__(self): + return f"obj.{self.meth}(*{self.args}, **{self.kwargs})" + + +@pytest.mark.parametrize( + "prop", + [ + "chunks", + "data", + "dims", + "dtype", + "encoding", + "imag", + "nbytes", + "ndim", + param("values", marks=xfail(reason="Coercion to dense")), + ], +) +def test_variable_property(prop): + var = make_xrvar({"x": 10, "y": 5}) + getattr(var, prop) + + +@pytest.mark.parametrize( + "func,sparse_output", + [ + (do("all"), False), + (do("any"), False), + (do("astype", dtype=int), True), + (do("clip", min=0, max=1), True), + (do("coarsen", windows={"x": 2}, func="sum"), True), + (do("compute"), True), + (do("conj"), True), + (do("copy"), True), + (do("count"), False), + (do("get_axis_num", dim="x"), False), + (do("isel", x=slice(2, 4)), True), + (do("isnull"), True), + (do("load"), True), + (do("mean"), False), + (do("notnull"), True), + (do("roll"), True), + (do("round"), True), + (do("set_dims", dim=("x", "y", "z")), True), + (do("stack", dim={"flat": ("x", "y")}), True), + (do("to_base_variable"), True), + (do("transpose"), True), + (do("unstack", dim={"x": {"x1": 5, "x2": 2}}), True), + (do("broadcast_equals", make_xrvar({"x": 10, "y": 5})), False), + (do("equals", make_xrvar({"x": 10, "y": 5})), False), + (do("identical", make_xrvar({"x": 10, "y": 5})), False), + param( + do("argmax"), + True, + marks=[ + xfail(reason="Missing implementation for np.argmin"), + filterwarnings("ignore:Behaviour of argmin/argmax"), + ], + ), + param( + do("argmin"), + True, + marks=[ + xfail(reason="Missing implementation for np.argmax"), + filterwarnings("ignore:Behaviour of argmin/argmax"), + ], + ), + param( + do("argsort"), + True, + marks=xfail(reason="'COO' object has no attribute 'argsort'"), + ), + param( + do( + "concat", + variables=[ + make_xrvar({"x": 10, "y": 5}), + make_xrvar({"x": 10, "y": 5}), + ], + ), + True, + ), + param( + do("conjugate"), + True, + marks=xfail(reason="'COO' object has no attribute 'conjugate'"), + ), + param( + do("cumprod"), + True, + marks=xfail(reason="Missing implementation for np.nancumprod"), + ), + param( + do("cumsum"), + True, + marks=xfail(reason="Missing implementation for np.nancumsum"), + ), + (do("fillna", 0), True), + param( + do("item", (1, 1)), + False, + marks=xfail(reason="'COO' object has no attribute 'item'"), + ), + param( + do("median"), + False, + marks=xfail(reason="Missing implementation for np.nanmedian"), + ), + param(do("max"), False), + param(do("min"), False), + param( + do("no_conflicts", other=make_xrvar({"x": 10, "y": 5})), + True, + marks=xfail(reason="mixed sparse-dense operation"), + ), + param( + do("pad", mode="constant", pad_widths={"x": (1, 1)}, fill_value=5), + True, + marks=xfail(reason="Missing implementation for np.pad"), + ), + (do("prod"), False), + param( + do("quantile", q=0.5), + True, + marks=xfail(reason="Missing implementation for np.nanpercentile"), + ), + param( + do("rank", dim="x"), + False, + marks=xfail(reason="Only implemented for NumPy arrays (via bottleneck)"), + ), + param( + do("reduce", func="sum", dim="x"), + True, + ), + param( + do("rolling_window", dim="x", window=2, window_dim="x_win"), + True, + marks=xfail(reason="Missing implementation for np.pad"), + ), + param( + do("shift", x=2), True, marks=xfail(reason="mixed sparse-dense operation") + ), + param( + do("std"), False, marks=xfail(reason="Missing implementation for np.nanstd") + ), + (do("sum"), False), + param( + do("var"), False, marks=xfail(reason="Missing implementation for np.nanvar") + ), + param(do("to_dict"), False), + (do("where", cond=make_xrvar({"x": 10, "y": 5}) > 0.5), True), + ], + ids=repr, +) +def test_variable_method(func, sparse_output): + var_s = make_xrvar({"x": 10, "y": 5}) + var_d = xr.Variable(var_s.dims, var_s.data.todense()) + ret_s = func(var_s) + ret_d = func(var_d) + + # TODO: figure out how to verify the results of each method + if isinstance(ret_d, xr.Variable) and isinstance(ret_d.data, sparse.SparseArray): + ret_d = ret_d.copy(data=ret_d.data.todense()) + + if sparse_output: + assert isinstance(ret_s.data, sparse.SparseArray) + assert np.allclose(ret_s.data.todense(), ret_d.data, equal_nan=True) + else: + if func.meth != "to_dict": + assert np.allclose(ret_s, ret_d) + else: + # pop the arrays from the dict + arr_s, arr_d = ret_s.pop("data"), ret_d.pop("data") + + assert np.allclose(arr_s, arr_d) + assert ret_s == ret_d + + +@pytest.mark.parametrize( + "func,sparse_output", + [ + (do("squeeze"), True), + param(do("to_index"), False, marks=xfail(reason="Coercion to dense")), + param(do("to_index_variable"), False, marks=xfail(reason="Coercion to dense")), + param( + do("searchsorted", 0.5), + True, + marks=xfail(reason="'COO' object has no attribute 'searchsorted'"), + ), + ], +) +def test_1d_variable_method(func, sparse_output): + var_s = make_xrvar({"x": 10}) + var_d = xr.Variable(var_s.dims, var_s.data.todense()) + ret_s = func(var_s) + ret_d = func(var_d) + + if sparse_output: + assert isinstance(ret_s.data, sparse.SparseArray) + assert np.allclose(ret_s.data.todense(), ret_d.data) + else: + assert np.allclose(ret_s, ret_d) + + +class TestSparseVariable: + @pytest.fixture(autouse=True) + def setUp(self): + self.data = sparse.random((4, 6), random_state=0, density=0.5) + self.var = xr.Variable(("x", "y"), self.data) + + def test_nbytes(self): + assert self.var.nbytes == self.data.nbytes + + def test_unary_op(self): + assert_sparse_equal(-self.var.data, -self.data) + assert_sparse_equal(abs(self.var).data, abs(self.data)) + assert_sparse_equal(self.var.round().data, self.data.round()) + + @pytest.mark.filterwarnings("ignore::FutureWarning") + def test_univariate_ufunc(self): + assert_sparse_equal(np.sin(self.data), np.sin(self.var).data) + + @pytest.mark.filterwarnings("ignore::FutureWarning") + def test_bivariate_ufunc(self): + assert_sparse_equal(np.maximum(self.data, 0), np.maximum(self.var, 0).data) + assert_sparse_equal(np.maximum(self.data, 0), np.maximum(0, self.var).data) + + def test_repr(self): + expected = dedent( + """\ + Size: 288B + """ + ) + assert expected == repr(self.var) + + def test_pickle(self): + v1 = self.var + v2 = pickle.loads(pickle.dumps(v1)) + assert_sparse_equal(v1.data, v2.data) + + def test_missing_values(self): + a = np.array([0, 1, np.nan, 3]) + s = sparse.COO.from_numpy(a) + var_s = Variable("x", s) + assert np.all(var_s.fillna(2).data.todense() == np.arange(4)) + assert np.all(var_s.count() == 3) + + +@pytest.mark.parametrize( + "prop", + [ + "attrs", + "chunks", + "coords", + "data", + "dims", + "dtype", + "encoding", + "imag", + "indexes", + "loc", + "name", + "nbytes", + "ndim", + "plot", + "real", + "shape", + "size", + "sizes", + "str", + "variable", + ], +) +def test_dataarray_property(prop): + arr = make_xrarray({"x": 10, "y": 5}) + getattr(arr, prop) + + +@pytest.mark.parametrize( + "func,sparse_output", + [ + (do("all"), False), + (do("any"), False), + (do("assign_attrs", {"foo": "bar"}), True), + (do("assign_coords", x=make_xrarray({"x": 10}).x + 1), True), + (do("astype", int), True), + (do("clip", min=0, max=1), True), + (do("compute"), True), + (do("conj"), True), + (do("copy"), True), + (do("count"), False), + (do("diff", "x"), True), + (do("drop_vars", "x"), True), + (do("expand_dims", {"z": 2}, axis=2), True), + (do("get_axis_num", "x"), False), + (do("get_index", "x"), False), + (do("identical", make_xrarray({"x": 5, "y": 5})), False), + (do("integrate", "x"), True), + (do("isel", {"x": slice(0, 3), "y": slice(2, 4)}), True), + (do("isnull"), True), + (do("load"), True), + (do("mean"), False), + (do("persist"), True), + (do("reindex", {"x": [1, 2, 3]}), True), + (do("rename", "foo"), True), + (do("reorder_levels"), True), + (do("reset_coords", drop=True), True), + (do("reset_index", "x"), True), + (do("round"), True), + (do("sel", x=[0, 1, 2]), True), + (do("shift"), True), + (do("sortby", "x", ascending=False), True), + (do("stack", z=["x", "y"]), True), + (do("transpose"), True), + # TODO + # set_index + # swap_dims + (do("broadcast_equals", make_xrvar({"x": 10, "y": 5})), False), + (do("equals", make_xrvar({"x": 10, "y": 5})), False), + param( + do("argmax"), + True, + marks=[ + xfail(reason="Missing implementation for np.argmax"), + filterwarnings("ignore:Behaviour of argmin/argmax"), + ], + ), + param( + do("argmin"), + True, + marks=[ + xfail(reason="Missing implementation for np.argmin"), + filterwarnings("ignore:Behaviour of argmin/argmax"), + ], + ), + param( + do("argsort"), + True, + marks=xfail(reason="'COO' object has no attribute 'argsort'"), + ), + param( + do("bfill", dim="x"), + False, + marks=xfail(reason="Missing implementation for np.flip"), + ), + (do("combine_first", make_xrarray({"x": 10, "y": 5})), True), + param( + do("conjugate"), + False, + marks=xfail(reason="'COO' object has no attribute 'conjugate'"), + ), + param( + do("cumprod"), + True, + marks=xfail(reason="Missing implementation for np.nancumprod"), + ), + param( + do("cumsum"), + True, + marks=xfail(reason="Missing implementation for np.nancumsum"), + ), + param( + do("differentiate", "x"), + False, + marks=xfail(reason="Missing implementation for np.gradient"), + ), + param( + do("dot", make_xrarray({"x": 10, "y": 5})), + True, + marks=xfail(reason="Missing implementation for np.einsum"), + ), + param(do("dropna", "x"), False, marks=xfail(reason="Coercion to dense")), + param(do("ffill", "x"), False, marks=xfail(reason="Coercion to dense")), + (do("fillna", 0), True), + param( + do("interp", coords={"x": np.arange(10) + 0.5}), + True, + marks=xfail(reason="Coercion to dense"), + ), + param( + do( + "interp_like", + make_xrarray( + {"x": 10, "y": 5}, + coords={"x": np.arange(10) + 0.5, "y": np.arange(5) + 0.5}, + ), + ), + True, + marks=xfail(reason="Indexing COO with more than one iterable index"), + ), + param(do("interpolate_na", "x"), True, marks=xfail(reason="Coercion to dense")), + param( + do("isin", [1, 2, 3]), + False, + marks=xfail(reason="Missing implementation for np.isin"), + ), + param( + do("item", (1, 1)), + False, + marks=xfail(reason="'COO' object has no attribute 'item'"), + ), + param(do("max"), False), + param(do("min"), False), + param( + do("median"), + False, + marks=xfail(reason="Missing implementation for np.nanmedian"), + ), + (do("notnull"), True), + (do("pipe", func="sum", axis=1), True), + (do("prod"), False), + param( + do("quantile", q=0.5), + False, + marks=xfail(reason="Missing implementation for np.nanpercentile"), + ), + param( + do("rank", "x"), + False, + marks=xfail(reason="Only implemented for NumPy arrays (via bottleneck)"), + ), + param( + do("reduce", func="sum", dim="x"), + False, + marks=xfail(reason="Coercion to dense"), + ), + param( + do( + "reindex_like", + make_xrarray( + {"x": 10, "y": 5}, + coords={"x": np.arange(10) + 0.5, "y": np.arange(5) + 0.5}, + ), + ), + True, + marks=xfail(reason="Indexing COO with more than one iterable index"), + ), + (do("roll", x=2, roll_coords=True), True), + param( + do("sel", x=[0, 1, 2], y=[2, 3]), + True, + marks=xfail(reason="Indexing COO with more than one iterable index"), + ), + param( + do("std"), False, marks=xfail(reason="Missing implementation for np.nanstd") + ), + (do("sum"), False), + param( + do("var"), False, marks=xfail(reason="Missing implementation for np.nanvar") + ), + param( + do("where", make_xrarray({"x": 10, "y": 5}) > 0.5), + False, + marks=xfail(reason="Conversion of dense to sparse when using sparse mask"), + ), + ], + ids=repr, +) +def test_dataarray_method(func, sparse_output): + arr_s = make_xrarray( + {"x": 10, "y": 5}, coords={"x": np.arange(10), "y": np.arange(5)} + ) + arr_d = xr.DataArray(arr_s.data.todense(), coords=arr_s.coords, dims=arr_s.dims) + ret_s = func(arr_s) + ret_d = func(arr_d) + + if sparse_output: + assert isinstance(ret_s.data, sparse.SparseArray) + assert np.allclose(ret_s.data.todense(), ret_d.data, equal_nan=True) + else: + assert np.allclose(ret_s, ret_d, equal_nan=True) + + +@pytest.mark.parametrize( + "func,sparse_output", + [ + (do("squeeze"), True), + param( + do("searchsorted", [1, 2, 3]), + False, + marks=xfail(reason="'COO' object has no attribute 'searchsorted'"), + ), + ], +) +def test_datarray_1d_method(func, sparse_output): + arr_s = make_xrarray({"x": 10}, coords={"x": np.arange(10)}) + arr_d = xr.DataArray(arr_s.data.todense(), coords=arr_s.coords, dims=arr_s.dims) + ret_s = func(arr_s) + ret_d = func(arr_d) + + if sparse_output: + assert isinstance(ret_s.data, sparse.SparseArray) + assert np.allclose(ret_s.data.todense(), ret_d.data, equal_nan=True) + else: + assert np.allclose(ret_s, ret_d, equal_nan=True) + + +class TestSparseDataArrayAndDataset: + @pytest.fixture(autouse=True) + def setUp(self): + self.sp_ar = sparse.random((4, 6), random_state=0, density=0.5) + self.sp_xr = xr.DataArray( + self.sp_ar, coords={"x": range(4)}, dims=("x", "y"), name="foo" + ) + self.ds_ar = self.sp_ar.todense() + self.ds_xr = xr.DataArray( + self.ds_ar, coords={"x": range(4)}, dims=("x", "y"), name="foo" + ) + + def test_to_dataset_roundtrip(self): + x = self.sp_xr + assert_equal(x, x.to_dataset("x").to_dataarray("x")) + + def test_align(self): + a1 = xr.DataArray( + sparse.COO.from_numpy(np.arange(4)), + dims=["x"], + coords={"x": ["a", "b", "c", "d"]}, + ) + b1 = xr.DataArray( + sparse.COO.from_numpy(np.arange(4)), + dims=["x"], + coords={"x": ["a", "b", "d", "e"]}, + ) + a2, b2 = xr.align(a1, b1, join="inner") + assert isinstance(a2.data, sparse.SparseArray) + assert isinstance(b2.data, sparse.SparseArray) + assert np.all(a2.coords["x"].data == ["a", "b", "d"]) + assert np.all(b2.coords["x"].data == ["a", "b", "d"]) + + @pytest.mark.xfail( + reason="COO objects currently do not accept more than one " + "iterable index at a time" + ) + def test_align_2d(self): + A1 = xr.DataArray( + self.sp_ar, + dims=["x", "y"], + coords={ + "x": np.arange(self.sp_ar.shape[0]), + "y": np.arange(self.sp_ar.shape[1]), + }, + ) + + A2 = xr.DataArray( + self.sp_ar, + dims=["x", "y"], + coords={ + "x": np.arange(1, self.sp_ar.shape[0] + 1), + "y": np.arange(1, self.sp_ar.shape[1] + 1), + }, + ) + + B1, B2 = xr.align(A1, A2, join="inner") + assert np.all(B1.coords["x"] == np.arange(1, self.sp_ar.shape[0])) + assert np.all(B1.coords["y"] == np.arange(1, self.sp_ar.shape[0])) + assert np.all(B1.coords["x"] == B2.coords["x"]) + assert np.all(B1.coords["y"] == B2.coords["y"]) + + def test_align_outer(self): + a1 = xr.DataArray( + sparse.COO.from_numpy(np.arange(4)), + dims=["x"], + coords={"x": ["a", "b", "c", "d"]}, + ) + b1 = xr.DataArray( + sparse.COO.from_numpy(np.arange(4)), + dims=["x"], + coords={"x": ["a", "b", "d", "e"]}, + ) + a2, b2 = xr.align(a1, b1, join="outer") + assert isinstance(a2.data, sparse.SparseArray) + assert isinstance(b2.data, sparse.SparseArray) + assert np.all(a2.coords["x"].data == ["a", "b", "c", "d", "e"]) + assert np.all(b2.coords["x"].data == ["a", "b", "c", "d", "e"]) + + def test_concat(self): + ds1 = xr.Dataset(data_vars={"d": self.sp_xr}) + ds2 = xr.Dataset(data_vars={"d": self.sp_xr}) + ds3 = xr.Dataset(data_vars={"d": self.sp_xr}) + out = xr.concat([ds1, ds2, ds3], dim="x") + assert_sparse_equal( + out["d"].data, + sparse.concatenate([self.sp_ar, self.sp_ar, self.sp_ar], axis=0), + ) + + out = xr.concat([self.sp_xr, self.sp_xr, self.sp_xr], dim="y") + assert_sparse_equal( + out.data, sparse.concatenate([self.sp_ar, self.sp_ar, self.sp_ar], axis=1) + ) + + def test_stack(self): + arr = make_xrarray({"w": 2, "x": 3, "y": 4}) + stacked = arr.stack(z=("x", "y")) + + z = pd.MultiIndex.from_product([np.arange(3), np.arange(4)], names=["x", "y"]) + + expected = xr.DataArray( + arr.data.reshape((2, -1)), {"w": [0, 1], "z": z}, dims=["w", "z"] + ) + + assert_equal(expected, stacked) + + roundtripped = stacked.unstack() + assert_identical(arr, roundtripped) + + def test_dataarray_repr(self): + a = xr.DataArray( + sparse.COO.from_numpy(np.ones(4)), + dims=["x"], + coords={"y": ("x", sparse.COO.from_numpy(np.arange(4, dtype="i8")))}, + ) + expected = dedent( + """\ + Size: 64B + + Coordinates: + y (x) int64 48B + Dimensions without coordinates: x""" + ) + assert expected == repr(a) + + def test_dataset_repr(self): + ds = xr.Dataset( + data_vars={"a": ("x", sparse.COO.from_numpy(np.ones(4)))}, + coords={"y": ("x", sparse.COO.from_numpy(np.arange(4, dtype="i8")))}, + ) + expected = dedent( + """\ + Size: 112B + Dimensions: (x: 4) + Coordinates: + y (x) int64 48B + Dimensions without coordinates: x + Data variables: + a (x) float64 64B """ + ) + assert expected == repr(ds) + + @requires_dask + def test_sparse_dask_dataset_repr(self): + ds = xr.Dataset( + data_vars={"a": ("x", sparse.COO.from_numpy(np.ones(4)))} + ).chunk() + expected = dedent( + """\ + Size: 32B + Dimensions: (x: 4) + Dimensions without coordinates: x + Data variables: + a (x) float64 32B dask.array""" + ) + assert expected == repr(ds) + + def test_dataarray_pickle(self): + a1 = xr.DataArray( + sparse.COO.from_numpy(np.ones(4)), + dims=["x"], + coords={"y": ("x", sparse.COO.from_numpy(np.arange(4)))}, + ) + a2 = pickle.loads(pickle.dumps(a1)) + assert_identical(a1, a2) + + def test_dataset_pickle(self): + ds1 = xr.Dataset( + data_vars={"a": ("x", sparse.COO.from_numpy(np.ones(4)))}, + coords={"y": ("x", sparse.COO.from_numpy(np.arange(4)))}, + ) + ds2 = pickle.loads(pickle.dumps(ds1)) + assert_identical(ds1, ds2) + + def test_coarsen(self): + a1 = self.ds_xr + a2 = self.sp_xr + m1 = a1.coarsen(x=2, boundary="trim").mean() + m2 = a2.coarsen(x=2, boundary="trim").mean() + + assert isinstance(m2.data, sparse.SparseArray) + assert np.allclose(m1.data, m2.data.todense()) + + @pytest.mark.xfail(reason="No implementation of np.pad") + def test_rolling(self): + a1 = self.ds_xr + a2 = self.sp_xr + m1 = a1.rolling(x=2, center=True).mean() + m2 = a2.rolling(x=2, center=True).mean() + + assert isinstance(m2.data, sparse.SparseArray) + assert np.allclose(m1.data, m2.data.todense()) + + @pytest.mark.xfail(reason="Coercion to dense") + def test_rolling_exp(self): + a1 = self.ds_xr + a2 = self.sp_xr + m1 = a1.rolling_exp(x=2, center=True).mean() + m2 = a2.rolling_exp(x=2, center=True).mean() + + assert isinstance(m2.data, sparse.SparseArray) + assert np.allclose(m1.data, m2.data.todense()) + + @pytest.mark.xfail(reason="No implementation of np.einsum") + def test_dot(self): + a1 = self.xp_xr.dot(self.xp_xr[0]) + a2 = self.sp_ar.dot(self.sp_ar[0]) + assert_equal(a1, a2) + + @pytest.mark.xfail(reason="Groupby reductions produce dense output") + def test_groupby(self): + x1 = self.ds_xr + x2 = self.sp_xr + m1 = x1.groupby("x").mean(...) + m2 = x2.groupby("x").mean(...) + assert isinstance(m2.data, sparse.SparseArray) + assert np.allclose(m1.data, m2.data.todense()) + + @pytest.mark.xfail(reason="Groupby reductions produce dense output") + def test_groupby_first(self): + x = self.sp_xr.copy() + x.coords["ab"] = ("x", ["a", "a", "b", "b"]) + x.groupby("ab").first() + x.groupby("ab").first(skipna=False) + + @pytest.mark.xfail(reason="Groupby reductions produce dense output") + def test_groupby_bins(self): + x1 = self.ds_xr + x2 = self.sp_xr + m1 = x1.groupby_bins("x", bins=[0, 3, 7, 10]).sum(...) + m2 = x2.groupby_bins("x", bins=[0, 3, 7, 10]).sum(...) + assert isinstance(m2.data, sparse.SparseArray) + assert np.allclose(m1.data, m2.data.todense()) + + @pytest.mark.xfail(reason="Resample produces dense output") + def test_resample(self): + t1 = xr.DataArray( + np.linspace(0, 11, num=12), + coords=[ + pd.date_range("1999-12-15", periods=12, freq=pd.DateOffset(months=1)) + ], + dims="time", + ) + t2 = t1.copy() + t2.data = sparse.COO(t2.data) + m1 = t1.resample(time="QS-DEC").mean() + m2 = t2.resample(time="QS-DEC").mean() + assert isinstance(m2.data, sparse.SparseArray) + assert np.allclose(m1.data, m2.data.todense()) + + @pytest.mark.xfail + def test_reindex(self): + x1 = self.ds_xr + x2 = self.sp_xr + for kwargs in [ + {"x": [2, 3, 4]}, + {"x": [1, 100, 2, 101, 3]}, + {"x": [2.5, 3, 3.5], "y": [2, 2.5, 3]}, + ]: + m1 = x1.reindex(**kwargs) + m2 = x2.reindex(**kwargs) + assert np.allclose(m1, m2, equal_nan=True) + + @pytest.mark.xfail + def test_merge(self): + x = self.sp_xr + y = xr.merge([x, x.rename("bar")]).to_dataarray() + assert isinstance(y, sparse.SparseArray) + + @pytest.mark.xfail + def test_where(self): + a = np.arange(10) + cond = a > 3 + xr.DataArray(a).where(cond) + + s = sparse.COO.from_numpy(a) + cond = s > 3 + xr.DataArray(s).where(cond) + + x = xr.DataArray(s) + cond = x > 3 + x.where(cond) + + +class TestSparseCoords: + @pytest.mark.xfail(reason="Coercion of coords to dense") + def test_sparse_coords(self): + xr.DataArray( + sparse.COO.from_numpy(np.arange(4)), + dims=["x"], + coords={"x": sparse.COO.from_numpy([1, 2, 3, 4])}, + ) + + +@requires_dask +def test_chunk(): + s = sparse.COO.from_numpy(np.array([0, 0, 1, 2])) + a = DataArray(s) + ac = a.chunk(2) + assert ac.chunks == ((2, 2),) + assert isinstance(ac.data._meta, sparse.COO) + assert_identical(ac, a) + + ds = a.to_dataset(name="a") + dsc = ds.chunk(2) + assert dsc.chunks == {"dim_0": (2, 2)} + assert_identical(dsc, ds) + + +@requires_dask +def test_dask_token(): + import dask + + s = sparse.COO.from_numpy(np.array([0, 0, 1, 2])) + a = DataArray(s) + t1 = dask.base.tokenize(a) + t2 = dask.base.tokenize(a) + t3 = dask.base.tokenize(a + 1) + assert t1 == t2 + assert t3 != t2 + assert isinstance(a.data, sparse.COO) + + ac = a.chunk(2) + t4 = dask.base.tokenize(ac) + t5 = dask.base.tokenize(ac + 1) + assert t4 != t5 + assert isinstance(ac.data._meta, sparse.COO) + + +@requires_dask +def test_apply_ufunc_check_meta_coherence(): + s = sparse.COO.from_numpy(np.array([0, 0, 1, 2])) + a = DataArray(s) + ac = a.chunk(2) + sparse_meta = ac.data._meta + + result = xr.apply_ufunc(lambda x: x, ac, dask="parallelized").data._meta + + assert_sparse_equal(result, sparse_meta) diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_strategies.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_strategies.py new file mode 100644 index 0000000..47f5438 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_strategies.py @@ -0,0 +1,280 @@ +import warnings + +import numpy as np +import numpy.testing as npt +import pytest +from packaging.version import Version + +pytest.importorskip("hypothesis") +# isort: split + +import hypothesis.extra.numpy as npst +import hypothesis.strategies as st +from hypothesis import given +from hypothesis.extra.array_api import make_strategies_namespace + +from xarray.core.variable import Variable +from xarray.testing.strategies import ( + attrs, + dimension_names, + dimension_sizes, + supported_dtypes, + unique_subset_of, + variables, +) + +ALLOWED_ATTRS_VALUES_TYPES = (int, bool, str, np.ndarray) + + +class TestDimensionNamesStrategy: + @given(dimension_names()) + def test_types(self, dims): + assert isinstance(dims, list) + for d in dims: + assert isinstance(d, str) + + @given(dimension_names()) + def test_unique(self, dims): + assert len(set(dims)) == len(dims) + + @given(st.data(), st.tuples(st.integers(0, 10), st.integers(0, 10)).map(sorted)) + def test_number_of_dims(self, data, ndims): + min_dims, max_dims = ndims + dim_names = data.draw(dimension_names(min_dims=min_dims, max_dims=max_dims)) + assert isinstance(dim_names, list) + assert min_dims <= len(dim_names) <= max_dims + + +class TestDimensionSizesStrategy: + @given(dimension_sizes()) + def test_types(self, dims): + assert isinstance(dims, dict) + for d, n in dims.items(): + assert isinstance(d, str) + assert len(d) >= 1 + + assert isinstance(n, int) + assert n >= 0 + + @given(st.data(), st.tuples(st.integers(0, 10), st.integers(0, 10)).map(sorted)) + def test_number_of_dims(self, data, ndims): + min_dims, max_dims = ndims + dim_sizes = data.draw(dimension_sizes(min_dims=min_dims, max_dims=max_dims)) + assert isinstance(dim_sizes, dict) + assert min_dims <= len(dim_sizes) <= max_dims + + @given(st.data()) + def test_restrict_names(self, data): + capitalized_names = st.text(st.characters(), min_size=1).map(str.upper) + dim_sizes = data.draw(dimension_sizes(dim_names=capitalized_names)) + for dim in dim_sizes.keys(): + assert dim.upper() == dim + + +def check_dict_values(dictionary: dict, allowed_attrs_values_types) -> bool: + """Helper function to assert that all values in recursive dict match one of a set of types.""" + for key, value in dictionary.items(): + if isinstance(value, allowed_attrs_values_types) or value is None: + continue + elif isinstance(value, dict): + # If the value is a dictionary, recursively check it + if not check_dict_values(value, allowed_attrs_values_types): + return False + else: + # If the value is not an integer or a dictionary, it's not valid + return False + return True + + +class TestAttrsStrategy: + @given(attrs()) + def test_type(self, attrs): + assert isinstance(attrs, dict) + check_dict_values(attrs, ALLOWED_ATTRS_VALUES_TYPES) + + +class TestVariablesStrategy: + @given(variables()) + def test_given_nothing(self, var): + assert isinstance(var, Variable) + + @given(st.data()) + def test_given_incorrect_types(self, data): + with pytest.raises(TypeError, match="dims must be provided as a"): + data.draw(variables(dims=["x", "y"])) # type: ignore[arg-type] + + with pytest.raises(TypeError, match="dtype must be provided as a"): + data.draw(variables(dtype=np.dtype("int32"))) # type: ignore[arg-type] + + with pytest.raises(TypeError, match="attrs must be provided as a"): + data.draw(variables(attrs=dict())) # type: ignore[arg-type] + + with pytest.raises(TypeError, match="Callable"): + data.draw(variables(array_strategy_fn=np.array([0]))) # type: ignore[arg-type] + + @given(st.data(), dimension_names()) + def test_given_fixed_dim_names(self, data, fixed_dim_names): + var = data.draw(variables(dims=st.just(fixed_dim_names))) + + assert list(var.dims) == fixed_dim_names + + @given(st.data(), dimension_sizes()) + def test_given_fixed_dim_sizes(self, data, dim_sizes): + var = data.draw(variables(dims=st.just(dim_sizes))) + + assert var.dims == tuple(dim_sizes.keys()) + assert var.shape == tuple(dim_sizes.values()) + + @given(st.data(), supported_dtypes()) + def test_given_fixed_dtype(self, data, dtype): + var = data.draw(variables(dtype=st.just(dtype))) + + assert var.dtype == dtype + + @given(st.data(), npst.arrays(shape=npst.array_shapes(), dtype=supported_dtypes())) + def test_given_fixed_data_dims_and_dtype(self, data, arr): + def fixed_array_strategy_fn(*, shape=None, dtype=None): + """The fact this ignores shape and dtype is only okay because compatible shape & dtype will be passed separately.""" + return st.just(arr) + + dim_names = data.draw(dimension_names(min_dims=arr.ndim, max_dims=arr.ndim)) + dim_sizes = {name: size for name, size in zip(dim_names, arr.shape)} + + var = data.draw( + variables( + array_strategy_fn=fixed_array_strategy_fn, + dims=st.just(dim_sizes), + dtype=st.just(arr.dtype), + ) + ) + + npt.assert_equal(var.data, arr) + assert var.dtype == arr.dtype + + @given(st.data(), st.integers(0, 3)) + def test_given_array_strat_arbitrary_size_and_arbitrary_data(self, data, ndims): + dim_names = data.draw(dimension_names(min_dims=ndims, max_dims=ndims)) + + def array_strategy_fn(*, shape=None, dtype=None): + return npst.arrays(shape=shape, dtype=dtype) + + var = data.draw( + variables( + array_strategy_fn=array_strategy_fn, + dims=st.just(dim_names), + dtype=supported_dtypes(), + ) + ) + + assert var.ndim == ndims + + @given(st.data()) + def test_catch_unruly_dtype_from_custom_array_strategy_fn(self, data): + def dodgy_array_strategy_fn(*, shape=None, dtype=None): + """Dodgy function which ignores the dtype it was passed""" + return npst.arrays(shape=shape, dtype=npst.floating_dtypes()) + + with pytest.raises( + ValueError, match="returned an array object with a different dtype" + ): + data.draw( + variables( + array_strategy_fn=dodgy_array_strategy_fn, + dtype=st.just(np.dtype("int32")), + ) + ) + + @given(st.data()) + def test_catch_unruly_shape_from_custom_array_strategy_fn(self, data): + def dodgy_array_strategy_fn(*, shape=None, dtype=None): + """Dodgy function which ignores the shape it was passed""" + return npst.arrays(shape=(3, 2), dtype=dtype) + + with pytest.raises( + ValueError, match="returned an array object with a different shape" + ): + data.draw( + variables( + array_strategy_fn=dodgy_array_strategy_fn, + dims=st.just({"a": 2, "b": 1}), + dtype=supported_dtypes(), + ) + ) + + @given(st.data()) + def test_make_strategies_namespace(self, data): + """ + Test not causing a hypothesis.InvalidArgument by generating a dtype that's not in the array API. + + We still want to generate dtypes not in the array API by default, but this checks we don't accidentally override + the user's choice of dtypes with non-API-compliant ones. + """ + if Version(np.__version__) >= Version("2.0.0.dev0"): + nxp = np + else: + # requires numpy>=1.26.0, and we expect a UserWarning to be raised + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", category=UserWarning, message=".+See NEP 47." + ) + from numpy import ( # type: ignore[no-redef,unused-ignore] + array_api as nxp, + ) + + nxp_st = make_strategies_namespace(nxp) + + data.draw( + variables( + array_strategy_fn=nxp_st.arrays, + dtype=nxp_st.scalar_dtypes(), + ) + ) + + +class TestUniqueSubsetOf: + @given(st.data()) + def test_invalid(self, data): + with pytest.raises(TypeError, match="must be an Iterable or a Mapping"): + data.draw(unique_subset_of(0)) # type: ignore[call-overload] + + with pytest.raises(ValueError, match="length-zero object"): + data.draw(unique_subset_of({})) + + @given(st.data(), dimension_sizes(min_dims=1)) + def test_mapping(self, data, dim_sizes): + subset_of_dim_sizes = data.draw(unique_subset_of(dim_sizes)) + + for dim, length in subset_of_dim_sizes.items(): + assert dim in dim_sizes + assert dim_sizes[dim] == length + + @given(st.data(), dimension_names(min_dims=1)) + def test_iterable(self, data, dim_names): + subset_of_dim_names = data.draw(unique_subset_of(dim_names)) + + for dim in subset_of_dim_names: + assert dim in dim_names + + +class TestReduction: + """ + These tests are for checking that the examples given in the docs page on testing actually work. + """ + + @given(st.data(), variables(dims=dimension_names(min_dims=1))) + def test_mean(self, data, var): + """ + Test that given a Variable of at least one dimension, + the mean of the Variable is always equal to the mean of the underlying array. + """ + + # specify arbitrary reduction along at least one dimension + reduction_dims = data.draw(unique_subset_of(var.dims, min_size=1)) + + # create expected result (using nanmean because arrays with Nans will be generated) + reduction_axes = tuple(var.get_axis_num(dim) for dim in reduction_dims) + expected = np.nanmean(var.data, axis=reduction_axes) + + # assert property is always satisfied + result = var.mean(dim=reduction_dims).data + npt.assert_equal(expected, result) diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_treenode.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_treenode.py new file mode 100644 index 0000000..a7de2e3 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_treenode.py @@ -0,0 +1,388 @@ +from __future__ import annotations + +from collections.abc import Iterator +from typing import cast + +import pytest + +from xarray.core.iterators import LevelOrderIter +from xarray.core.treenode import InvalidTreeError, NamedNode, NodePath, TreeNode + + +class TestFamilyTree: + def test_lonely(self): + root: TreeNode = TreeNode() + assert root.parent is None + assert root.children == {} + + def test_parenting(self): + john: TreeNode = TreeNode() + mary: TreeNode = TreeNode() + mary._set_parent(john, "Mary") + + assert mary.parent == john + assert john.children["Mary"] is mary + + def test_no_time_traveller_loops(self): + john: TreeNode = TreeNode() + + with pytest.raises(InvalidTreeError, match="cannot be a parent of itself"): + john._set_parent(john, "John") + + with pytest.raises(InvalidTreeError, match="cannot be a parent of itself"): + john.children = {"John": john} + + mary: TreeNode = TreeNode() + rose: TreeNode = TreeNode() + mary._set_parent(john, "Mary") + rose._set_parent(mary, "Rose") + + with pytest.raises(InvalidTreeError, match="is already a descendant"): + john._set_parent(rose, "John") + + with pytest.raises(InvalidTreeError, match="is already a descendant"): + rose.children = {"John": john} + + def test_parent_swap(self): + john: TreeNode = TreeNode() + mary: TreeNode = TreeNode() + mary._set_parent(john, "Mary") + + steve: TreeNode = TreeNode() + mary._set_parent(steve, "Mary") + + assert mary.parent == steve + assert steve.children["Mary"] is mary + assert "Mary" not in john.children + + def test_multi_child_family(self): + mary: TreeNode = TreeNode() + kate: TreeNode = TreeNode() + john: TreeNode = TreeNode(children={"Mary": mary, "Kate": kate}) + assert john.children["Mary"] is mary + assert john.children["Kate"] is kate + assert mary.parent is john + assert kate.parent is john + + def test_disown_child(self): + mary: TreeNode = TreeNode() + john: TreeNode = TreeNode(children={"Mary": mary}) + mary.orphan() + assert mary.parent is None + assert "Mary" not in john.children + + def test_doppelganger_child(self): + kate: TreeNode = TreeNode() + john: TreeNode = TreeNode() + + with pytest.raises(TypeError): + john.children = {"Kate": 666} + + with pytest.raises(InvalidTreeError, match="Cannot add same node"): + john.children = {"Kate": kate, "Evil_Kate": kate} + + john = TreeNode(children={"Kate": kate}) + evil_kate: TreeNode = TreeNode() + evil_kate._set_parent(john, "Kate") + assert john.children["Kate"] is evil_kate + + def test_sibling_relationships(self): + mary: TreeNode = TreeNode() + kate: TreeNode = TreeNode() + ashley: TreeNode = TreeNode() + TreeNode(children={"Mary": mary, "Kate": kate, "Ashley": ashley}) + assert kate.siblings["Mary"] is mary + assert kate.siblings["Ashley"] is ashley + assert "Kate" not in kate.siblings + + def test_ancestors(self): + tony: TreeNode = TreeNode() + michael: TreeNode = TreeNode(children={"Tony": tony}) + vito = TreeNode(children={"Michael": michael}) + assert tony.root is vito + assert tony.parents == (michael, vito) + assert tony.ancestors == (vito, michael, tony) + + +class TestGetNodes: + def test_get_child(self): + steven: TreeNode = TreeNode() + sue = TreeNode(children={"Steven": steven}) + mary = TreeNode(children={"Sue": sue}) + john = TreeNode(children={"Mary": mary}) + + # get child + assert john._get_item("Mary") is mary + assert mary._get_item("Sue") is sue + + # no child exists + with pytest.raises(KeyError): + john._get_item("Kate") + + # get grandchild + assert john._get_item("Mary/Sue") is sue + + # get great-grandchild + assert john._get_item("Mary/Sue/Steven") is steven + + # get from middle of tree + assert mary._get_item("Sue/Steven") is steven + + def test_get_upwards(self): + sue: TreeNode = TreeNode() + kate: TreeNode = TreeNode() + mary = TreeNode(children={"Sue": sue, "Kate": kate}) + john = TreeNode(children={"Mary": mary}) + + assert sue._get_item("../") is mary + assert sue._get_item("../../") is john + + # relative path + assert sue._get_item("../Kate") is kate + + def test_get_from_root(self): + sue: TreeNode = TreeNode() + mary = TreeNode(children={"Sue": sue}) + john = TreeNode(children={"Mary": mary}) # noqa + + assert sue._get_item("/Mary") is mary + + +class TestSetNodes: + def test_set_child_node(self): + john: TreeNode = TreeNode() + mary: TreeNode = TreeNode() + john._set_item("Mary", mary) + + assert john.children["Mary"] is mary + assert isinstance(mary, TreeNode) + assert mary.children == {} + assert mary.parent is john + + def test_child_already_exists(self): + mary: TreeNode = TreeNode() + john: TreeNode = TreeNode(children={"Mary": mary}) + mary_2: TreeNode = TreeNode() + with pytest.raises(KeyError): + john._set_item("Mary", mary_2, allow_overwrite=False) + + def test_set_grandchild(self): + rose: TreeNode = TreeNode() + mary: TreeNode = TreeNode() + john: TreeNode = TreeNode() + + john._set_item("Mary", mary) + john._set_item("Mary/Rose", rose) + + assert john.children["Mary"] is mary + assert isinstance(mary, TreeNode) + assert "Rose" in mary.children + assert rose.parent is mary + + def test_create_intermediate_child(self): + john: TreeNode = TreeNode() + rose: TreeNode = TreeNode() + + # test intermediate children not allowed + with pytest.raises(KeyError, match="Could not reach"): + john._set_item(path="Mary/Rose", item=rose, new_nodes_along_path=False) + + # test intermediate children allowed + john._set_item("Mary/Rose", rose, new_nodes_along_path=True) + assert "Mary" in john.children + mary = john.children["Mary"] + assert isinstance(mary, TreeNode) + assert mary.children == {"Rose": rose} + assert rose.parent == mary + assert rose.parent == mary + + def test_overwrite_child(self): + john: TreeNode = TreeNode() + mary: TreeNode = TreeNode() + john._set_item("Mary", mary) + + # test overwriting not allowed + marys_evil_twin: TreeNode = TreeNode() + with pytest.raises(KeyError, match="Already a node object"): + john._set_item("Mary", marys_evil_twin, allow_overwrite=False) + assert john.children["Mary"] is mary + assert marys_evil_twin.parent is None + + # test overwriting allowed + marys_evil_twin = TreeNode() + john._set_item("Mary", marys_evil_twin, allow_overwrite=True) + assert john.children["Mary"] is marys_evil_twin + assert marys_evil_twin.parent is john + + +class TestPruning: + def test_del_child(self): + john: TreeNode = TreeNode() + mary: TreeNode = TreeNode() + john._set_item("Mary", mary) + + del john["Mary"] + assert "Mary" not in john.children + assert mary.parent is None + + with pytest.raises(KeyError): + del john["Mary"] + + +def create_test_tree() -> tuple[NamedNode, NamedNode]: + # a + # ├── b + # │ ├── d + # │ └── e + # │ ├── f + # │ └── g + # └── c + # └── h + # └── i + a: NamedNode = NamedNode(name="a") + b: NamedNode = NamedNode() + c: NamedNode = NamedNode() + d: NamedNode = NamedNode() + e: NamedNode = NamedNode() + f: NamedNode = NamedNode() + g: NamedNode = NamedNode() + h: NamedNode = NamedNode() + i: NamedNode = NamedNode() + + a.children = {"b": b, "c": c} + b.children = {"d": d, "e": e} + e.children = {"f": f, "g": g} + c.children = {"h": h} + h.children = {"i": i} + + return a, f + + +class TestIterators: + + def test_levelorderiter(self): + root, _ = create_test_tree() + result: list[str | None] = [ + node.name for node in cast(Iterator[NamedNode], LevelOrderIter(root)) + ] + expected = [ + "a", # root Node is unnamed + "b", + "c", + "d", + "e", + "h", + "f", + "g", + "i", + ] + assert result == expected + + +class TestAncestry: + + def test_parents(self): + _, leaf_f = create_test_tree() + expected = ["e", "b", "a"] + assert [node.name for node in leaf_f.parents] == expected + + def test_lineage(self): + _, leaf_f = create_test_tree() + expected = ["f", "e", "b", "a"] + assert [node.name for node in leaf_f.lineage] == expected + + def test_ancestors(self): + _, leaf_f = create_test_tree() + ancestors = leaf_f.ancestors + expected = ["a", "b", "e", "f"] + for node, expected_name in zip(ancestors, expected): + assert node.name == expected_name + + def test_subtree(self): + root, _ = create_test_tree() + subtree = root.subtree + expected = [ + "a", + "b", + "c", + "d", + "e", + "h", + "f", + "g", + "i", + ] + for node, expected_name in zip(subtree, expected): + assert node.name == expected_name + + def test_descendants(self): + root, _ = create_test_tree() + descendants = root.descendants + expected = [ + "b", + "c", + "d", + "e", + "h", + "f", + "g", + "i", + ] + for node, expected_name in zip(descendants, expected): + assert node.name == expected_name + + def test_leaves(self): + tree, _ = create_test_tree() + leaves = tree.leaves + expected = [ + "d", + "f", + "g", + "i", + ] + for node, expected_name in zip(leaves, expected): + assert node.name == expected_name + + def test_levels(self): + a, f = create_test_tree() + + assert a.level == 0 + assert f.level == 3 + + assert a.depth == 3 + assert f.depth == 3 + + assert a.width == 1 + assert f.width == 3 + + +class TestRenderTree: + def test_render_nodetree(self): + sam: NamedNode = NamedNode() + ben: NamedNode = NamedNode() + mary: NamedNode = NamedNode(children={"Sam": sam, "Ben": ben}) + kate: NamedNode = NamedNode() + john: NamedNode = NamedNode(children={"Mary": mary, "Kate": kate}) + expected_nodes = [ + "NamedNode()", + "\tNamedNode('Mary')", + "\t\tNamedNode('Sam')", + "\t\tNamedNode('Ben')", + "\tNamedNode('Kate')", + ] + expected_str = "NamedNode('Mary')" + john_repr = john.__repr__() + mary_str = mary.__str__() + + assert mary_str == expected_str + + john_nodes = john_repr.splitlines() + assert len(john_nodes) == len(expected_nodes) + for expected_node, repr_node in zip(expected_nodes, john_nodes): + assert expected_node == repr_node + + +def test_nodepath(): + path = NodePath("/Mary") + assert path.root == "/" + assert path.stem == "Mary" diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_tutorial.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_tutorial.py new file mode 100644 index 0000000..9d59219 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_tutorial.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +import pytest + +from xarray import DataArray, tutorial +from xarray.tests import assert_identical, network + + +@network +class TestLoadDataset: + @pytest.fixture(autouse=True) + def setUp(self): + self.testfile = "tiny" + + def test_download_from_github(self, tmp_path) -> None: + cache_dir = tmp_path / tutorial._default_cache_dir_name + ds = tutorial.open_dataset(self.testfile, cache_dir=cache_dir).load() + tiny = DataArray(range(5), name="tiny").to_dataset() + assert_identical(ds, tiny) + + def test_download_from_github_load_without_cache( + self, tmp_path, monkeypatch + ) -> None: + cache_dir = tmp_path / tutorial._default_cache_dir_name + + ds_nocache = tutorial.open_dataset( + self.testfile, cache=False, cache_dir=cache_dir + ).load() + ds_cache = tutorial.open_dataset(self.testfile, cache_dir=cache_dir).load() + assert_identical(ds_cache, ds_nocache) diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_typed_ops.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_typed_ops.py new file mode 100644 index 0000000..1d4ef89 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_typed_ops.py @@ -0,0 +1,246 @@ +import numpy as np + +from xarray import DataArray, Dataset, Variable + + +def test_variable_typed_ops() -> None: + """Tests for type checking of typed_ops on Variable""" + + var = Variable(dims=["t"], data=[1, 2, 3]) + + def _test(var: Variable) -> None: + # mypy checks the input type + assert isinstance(var, Variable) + + _int: int = 1 + _list = [1, 2, 3] + _ndarray = np.array([1, 2, 3]) + + # __add__ as an example of binary ops + _test(var + _int) + _test(var + _list) + _test(var + _ndarray) + _test(var + var) + + # __radd__ as an example of reflexive binary ops + _test(_int + var) + _test(_list + var) + _test(_ndarray + var) # type: ignore[arg-type] # numpy problem + + # __eq__ as an example of cmp ops + _test(var == _int) + _test(var == _list) + _test(var == _ndarray) + _test(_int == var) # type: ignore[arg-type] # typeshed problem + _test(_list == var) # type: ignore[arg-type] # typeshed problem + _test(_ndarray == var) + + # __lt__ as another example of cmp ops + _test(var < _int) + _test(var < _list) + _test(var < _ndarray) + _test(_int > var) + _test(_list > var) + _test(_ndarray > var) # type: ignore[arg-type] # numpy problem + + # __iadd__ as an example of inplace binary ops + var += _int + var += _list + var += _ndarray + + # __neg__ as an example of unary ops + _test(-var) + + +def test_dataarray_typed_ops() -> None: + """Tests for type checking of typed_ops on DataArray""" + + da = DataArray([1, 2, 3], dims=["t"]) + + def _test(da: DataArray) -> None: + # mypy checks the input type + assert isinstance(da, DataArray) + + _int: int = 1 + _list = [1, 2, 3] + _ndarray = np.array([1, 2, 3]) + _var = Variable(dims=["t"], data=[1, 2, 3]) + + # __add__ as an example of binary ops + _test(da + _int) + _test(da + _list) + _test(da + _ndarray) + _test(da + _var) + _test(da + da) + + # __radd__ as an example of reflexive binary ops + _test(_int + da) + _test(_list + da) + _test(_ndarray + da) # type: ignore[arg-type] # numpy problem + _test(_var + da) + + # __eq__ as an example of cmp ops + _test(da == _int) + _test(da == _list) + _test(da == _ndarray) + _test(da == _var) + _test(_int == da) # type: ignore[arg-type] # typeshed problem + _test(_list == da) # type: ignore[arg-type] # typeshed problem + _test(_ndarray == da) + _test(_var == da) + + # __lt__ as another example of cmp ops + _test(da < _int) + _test(da < _list) + _test(da < _ndarray) + _test(da < _var) + _test(_int > da) + _test(_list > da) + _test(_ndarray > da) # type: ignore[arg-type] # numpy problem + _test(_var > da) + + # __iadd__ as an example of inplace binary ops + da += _int + da += _list + da += _ndarray + da += _var + + # __neg__ as an example of unary ops + _test(-da) + + +def test_dataset_typed_ops() -> None: + """Tests for type checking of typed_ops on Dataset""" + + ds = Dataset({"a": ("t", [1, 2, 3])}) + + def _test(ds: Dataset) -> None: + # mypy checks the input type + assert isinstance(ds, Dataset) + + _int: int = 1 + _list = [1, 2, 3] + _ndarray = np.array([1, 2, 3]) + _var = Variable(dims=["t"], data=[1, 2, 3]) + _da = DataArray([1, 2, 3], dims=["t"]) + + # __add__ as an example of binary ops + _test(ds + _int) + _test(ds + _list) + _test(ds + _ndarray) + _test(ds + _var) + _test(ds + _da) + _test(ds + ds) + + # __radd__ as an example of reflexive binary ops + _test(_int + ds) + _test(_list + ds) + _test(_ndarray + ds) + _test(_var + ds) + _test(_da + ds) + + # __eq__ as an example of cmp ops + _test(ds == _int) + _test(ds == _list) + _test(ds == _ndarray) + _test(ds == _var) + _test(ds == _da) + _test(_int == ds) # type: ignore[arg-type] # typeshed problem + _test(_list == ds) # type: ignore[arg-type] # typeshed problem + _test(_ndarray == ds) + _test(_var == ds) + _test(_da == ds) + + # __lt__ as another example of cmp ops + _test(ds < _int) + _test(ds < _list) + _test(ds < _ndarray) + _test(ds < _var) + _test(ds < _da) + _test(_int > ds) + _test(_list > ds) + _test(_ndarray > ds) # type: ignore[arg-type] # numpy problem + _test(_var > ds) + _test(_da > ds) + + # __iadd__ as an example of inplace binary ops + ds += _int + ds += _list + ds += _ndarray + ds += _var + ds += _da + + # __neg__ as an example of unary ops + _test(-ds) + + +def test_dataarray_groupy_typed_ops() -> None: + """Tests for type checking of typed_ops on DataArrayGroupBy""" + + da = DataArray([1, 2, 3], coords={"x": ("t", [1, 2, 2])}, dims=["t"]) + grp = da.groupby("x") + + def _testda(da: DataArray) -> None: + # mypy checks the input type + assert isinstance(da, DataArray) + + def _testds(ds: Dataset) -> None: + # mypy checks the input type + assert isinstance(ds, Dataset) + + _da = DataArray([5, 6], coords={"x": [1, 2]}, dims="x") + _ds = _da.to_dataset(name="a") + + # __add__ as an example of binary ops + _testda(grp + _da) + _testds(grp + _ds) + + # __radd__ as an example of reflexive binary ops + _testda(_da + grp) + _testds(_ds + grp) + + # __eq__ as an example of cmp ops + _testda(grp == _da) + _testda(_da == grp) + _testds(grp == _ds) + _testds(_ds == grp) + + # __lt__ as another example of cmp ops + _testda(grp < _da) + _testda(_da > grp) + _testds(grp < _ds) + _testds(_ds > grp) + + +def test_dataset_groupy_typed_ops() -> None: + """Tests for type checking of typed_ops on DatasetGroupBy""" + + ds = Dataset({"a": ("t", [1, 2, 3])}, coords={"x": ("t", [1, 2, 2])}) + grp = ds.groupby("x") + + def _test(ds: Dataset) -> None: + # mypy checks the input type + assert isinstance(ds, Dataset) + + _da = DataArray([5, 6], coords={"x": [1, 2]}, dims="x") + _ds = _da.to_dataset(name="a") + + # __add__ as an example of binary ops + _test(grp + _da) + _test(grp + _ds) + + # __radd__ as an example of reflexive binary ops + _test(_da + grp) + _test(_ds + grp) + + # __eq__ as an example of cmp ops + _test(grp == _da) + _test(_da == grp) + _test(grp == _ds) + _test(_ds == grp) + + # __lt__ as another example of cmp ops + _test(grp < _da) + _test(_da > grp) + _test(grp < _ds) + _test(_ds > grp) diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_ufuncs.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_ufuncs.py new file mode 100644 index 0000000..6b4c3f3 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_ufuncs.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +import numpy as np +import pytest + +import xarray as xr +from xarray.tests import assert_allclose, assert_array_equal, mock +from xarray.tests import assert_identical as assert_identical_ + + +def assert_identical(a, b): + assert type(a) is type(b) or float(a) == float(b) + if isinstance(a, (xr.DataArray, xr.Dataset, xr.Variable)): + assert_identical_(a, b) + else: + assert_array_equal(a, b) + + +@pytest.mark.parametrize( + "a", + [ + xr.Variable(["x"], [0, 0]), + xr.DataArray([0, 0], dims="x"), + xr.Dataset({"y": ("x", [0, 0])}), + ], +) +def test_unary(a): + assert_allclose(a + 1, np.cos(a)) + + +def test_binary(): + args = [ + 0, + np.zeros(2), + xr.Variable(["x"], [0, 0]), + xr.DataArray([0, 0], dims="x"), + xr.Dataset({"y": ("x", [0, 0])}), + ] + for n, t1 in enumerate(args): + for t2 in args[n:]: + assert_identical(t2 + 1, np.maximum(t1, t2 + 1)) + assert_identical(t2 + 1, np.maximum(t2, t1 + 1)) + assert_identical(t2 + 1, np.maximum(t1 + 1, t2)) + assert_identical(t2 + 1, np.maximum(t2 + 1, t1)) + + +def test_binary_out(): + args = [ + 1, + np.ones(2), + xr.Variable(["x"], [1, 1]), + xr.DataArray([1, 1], dims="x"), + xr.Dataset({"y": ("x", [1, 1])}), + ] + for arg in args: + actual_mantissa, actual_exponent = np.frexp(arg) + assert_identical(actual_mantissa, 0.5 * arg) + assert_identical(actual_exponent, arg) + + +def test_groupby(): + ds = xr.Dataset({"a": ("x", [0, 0, 0])}, {"c": ("x", [0, 0, 1])}) + ds_grouped = ds.groupby("c") + group_mean = ds_grouped.mean("x") + arr_grouped = ds["a"].groupby("c") + + assert_identical(ds, np.maximum(ds_grouped, group_mean)) + assert_identical(ds, np.maximum(group_mean, ds_grouped)) + + assert_identical(ds, np.maximum(arr_grouped, group_mean)) + assert_identical(ds, np.maximum(group_mean, arr_grouped)) + + assert_identical(ds, np.maximum(ds_grouped, group_mean["a"])) + assert_identical(ds, np.maximum(group_mean["a"], ds_grouped)) + + assert_identical(ds.a, np.maximum(arr_grouped, group_mean.a)) + assert_identical(ds.a, np.maximum(group_mean.a, arr_grouped)) + + with pytest.raises(ValueError, match=r"mismatched lengths for dimension"): + np.maximum(ds.a.variable, ds_grouped) + + +def test_alignment(): + ds1 = xr.Dataset({"a": ("x", [1, 2])}, {"x": [0, 1]}) + ds2 = xr.Dataset({"a": ("x", [2, 3]), "b": 4}, {"x": [1, 2]}) + + actual = np.add(ds1, ds2) + expected = xr.Dataset({"a": ("x", [4])}, {"x": [1]}) + assert_identical_(actual, expected) + + with xr.set_options(arithmetic_join="outer"): + actual = np.add(ds1, ds2) + expected = xr.Dataset( + {"a": ("x", [np.nan, 4, np.nan]), "b": np.nan}, coords={"x": [0, 1, 2]} + ) + assert_identical_(actual, expected) + + +def test_kwargs(): + x = xr.DataArray(0) + result = np.add(x, 1, dtype=np.float64) + assert result.dtype == np.float64 + + +def test_xarray_defers_to_unrecognized_type(): + class Other: + def __array_ufunc__(self, *args, **kwargs): + return "other" + + xarray_obj = xr.DataArray([1, 2, 3]) + other = Other() + assert np.maximum(xarray_obj, other) == "other" + assert np.sin(xarray_obj, out=other) == "other" + + +def test_xarray_handles_dask(): + da = pytest.importorskip("dask.array") + x = xr.DataArray(np.ones((2, 2)), dims=["x", "y"]) + y = da.ones((2, 2), chunks=(2, 2)) + result = np.add(x, y) + assert result.chunks == ((2,), (2,)) + assert isinstance(result, xr.DataArray) + + +def test_dask_defers_to_xarray(): + da = pytest.importorskip("dask.array") + x = xr.DataArray(np.ones((2, 2)), dims=["x", "y"]) + y = da.ones((2, 2), chunks=(2, 2)) + result = np.add(y, x) + assert result.chunks == ((2,), (2,)) + assert isinstance(result, xr.DataArray) + + +def test_gufunc_methods(): + xarray_obj = xr.DataArray([1, 2, 3]) + with pytest.raises(NotImplementedError, match=r"reduce method"): + np.add.reduce(xarray_obj, 1) + + +def test_out(): + xarray_obj = xr.DataArray([1, 2, 3]) + + # xarray out arguments should raise + with pytest.raises(NotImplementedError, match=r"`out` argument"): + np.add(xarray_obj, 1, out=xarray_obj) + + # but non-xarray should be OK + other = np.zeros((3,)) + np.add(other, xarray_obj, out=other) + assert_identical(other, np.array([1, 2, 3])) + + +def test_gufuncs(): + xarray_obj = xr.DataArray([1, 2, 3]) + fake_gufunc = mock.Mock(signature="(n)->()", autospec=np.sin) + with pytest.raises(NotImplementedError, match=r"generalized ufuncs"): + xarray_obj.__array_ufunc__(fake_gufunc, "__call__", xarray_obj) diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_units.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_units.py new file mode 100644 index 0000000..0e8fbe9 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_units.py @@ -0,0 +1,5746 @@ +from __future__ import annotations + +import functools +import operator + +import numpy as np +import pytest + +import xarray as xr +from xarray.core import dtypes, duck_array_ops +from xarray.tests import ( + assert_allclose, + assert_duckarray_allclose, + assert_equal, + assert_identical, + requires_dask, + requires_matplotlib, + requires_numbagg, +) +from xarray.tests.test_plot import PlotTestCase +from xarray.tests.test_variable import _PAD_XR_NP_ARGS + +try: + import matplotlib.pyplot as plt +except ImportError: + pass + + +pint = pytest.importorskip("pint") +DimensionalityError = pint.errors.DimensionalityError + + +# make sure scalars are converted to 0d arrays so quantities can +# always be treated like ndarrays +unit_registry = pint.UnitRegistry(force_ndarray_like=True) +Quantity = unit_registry.Quantity +no_unit_values = ("none", None) + + +pytestmark = [ + pytest.mark.filterwarnings("error::pint.UnitStrippedWarning"), +] + + +def is_compatible(unit1, unit2): + def dimensionality(obj): + if isinstance(obj, (unit_registry.Quantity, unit_registry.Unit)): + unit_like = obj + else: + unit_like = unit_registry.dimensionless + + return unit_like.dimensionality + + return dimensionality(unit1) == dimensionality(unit2) + + +def compatible_mappings(first, second): + return { + key: is_compatible(unit1, unit2) + for key, (unit1, unit2) in zip_mappings(first, second) + } + + +def merge_mappings(base, *mappings): + result = base.copy() + for m in mappings: + result.update(m) + + return result + + +def zip_mappings(*mappings): + for key in set(mappings[0]).intersection(*mappings[1:]): + yield key, tuple(m[key] for m in mappings) + + +def array_extract_units(obj): + if isinstance(obj, (xr.Variable, xr.DataArray, xr.Dataset)): + obj = obj.data + + try: + return obj.units + except AttributeError: + return None + + +def array_strip_units(array): + try: + return array.magnitude + except AttributeError: + return array + + +def array_attach_units(data, unit): + if isinstance(data, Quantity) and data.units != unit: + raise ValueError(f"cannot attach unit {unit} to quantity {data}") + + if unit in no_unit_values or (isinstance(unit, int) and unit == 1): + return data + + quantity = unit_registry.Quantity(data, unit) + return quantity + + +def extract_units(obj): + if isinstance(obj, xr.Dataset): + vars_units = { + name: array_extract_units(value) for name, value in obj.data_vars.items() + } + coords_units = { + name: array_extract_units(value) for name, value in obj.coords.items() + } + + units = {**vars_units, **coords_units} + elif isinstance(obj, xr.DataArray): + vars_units = {obj.name: array_extract_units(obj)} + coords_units = { + name: array_extract_units(value) for name, value in obj.coords.items() + } + + units = {**vars_units, **coords_units} + elif isinstance(obj, xr.Variable): + vars_units = {None: array_extract_units(obj.data)} + + units = {**vars_units} + elif isinstance(obj, Quantity): + vars_units = {None: array_extract_units(obj)} + + units = {**vars_units} + else: + units = {} + + return units + + +def strip_units(obj): + if isinstance(obj, xr.Dataset): + data_vars = { + strip_units(name): strip_units(value) + for name, value in obj.data_vars.items() + } + coords = { + strip_units(name): strip_units(value) for name, value in obj.coords.items() + } + + new_obj = xr.Dataset(data_vars=data_vars, coords=coords) + elif isinstance(obj, xr.DataArray): + data = array_strip_units(obj.variable._data) + coords = { + strip_units(name): ( + (value.dims, array_strip_units(value.variable._data)) + if isinstance(value.data, Quantity) + else value # to preserve multiindexes + ) + for name, value in obj.coords.items() + } + + new_obj = xr.DataArray( + name=strip_units(obj.name), data=data, coords=coords, dims=obj.dims + ) + elif isinstance(obj, xr.Variable): + data = array_strip_units(obj.data) + new_obj = obj.copy(data=data) + elif isinstance(obj, unit_registry.Quantity): + new_obj = obj.magnitude + elif isinstance(obj, (list, tuple)): + return type(obj)(strip_units(elem) for elem in obj) + else: + new_obj = obj + + return new_obj + + +def attach_units(obj, units): + if not isinstance(obj, (xr.DataArray, xr.Dataset, xr.Variable)): + units = units.get("data", None) or units.get(None, None) or 1 + return array_attach_units(obj, units) + + if isinstance(obj, xr.Dataset): + data_vars = { + name: attach_units(value, units) for name, value in obj.data_vars.items() + } + + coords = { + name: attach_units(value, units) for name, value in obj.coords.items() + } + + new_obj = xr.Dataset(data_vars=data_vars, coords=coords, attrs=obj.attrs) + elif isinstance(obj, xr.DataArray): + # try the array name, "data" and None, then fall back to dimensionless + data_units = units.get(obj.name, None) or units.get(None, None) or 1 + + data = array_attach_units(obj.data, data_units) + + coords = { + name: ( + (value.dims, array_attach_units(value.data, units.get(name) or 1)) + if name in units + else (value.dims, value.data) + ) + for name, value in obj.coords.items() + } + dims = obj.dims + attrs = obj.attrs + + new_obj = xr.DataArray( + name=obj.name, data=data, coords=coords, attrs=attrs, dims=dims + ) + else: + data_units = units.get("data", None) or units.get(None, None) or 1 + + data = array_attach_units(obj.data, data_units) + new_obj = obj.copy(data=data) + + return new_obj + + +def convert_units(obj, to): + # preprocess + to = { + key: None if not isinstance(value, unit_registry.Unit) else value + for key, value in to.items() + } + if isinstance(obj, xr.Dataset): + data_vars = { + name: convert_units(array.variable, {None: to.get(name)}) + for name, array in obj.data_vars.items() + } + coords = { + name: convert_units(array.variable, {None: to.get(name)}) + for name, array in obj.coords.items() + } + + new_obj = xr.Dataset(data_vars=data_vars, coords=coords, attrs=obj.attrs) + elif isinstance(obj, xr.DataArray): + name = obj.name + + new_units = ( + to.get(name, None) or to.get("data", None) or to.get(None, None) or None + ) + data = convert_units(obj.variable, {None: new_units}) + + coords = { + name: (array.dims, convert_units(array.variable, {None: to.get(name)})) + for name, array in obj.coords.items() + if name != obj.name + } + + new_obj = xr.DataArray( + name=name, data=data, coords=coords, attrs=obj.attrs, dims=obj.dims + ) + elif isinstance(obj, xr.Variable): + new_data = convert_units(obj.data, to) + new_obj = obj.copy(data=new_data) + elif isinstance(obj, unit_registry.Quantity): + units = to.get(None) + new_obj = obj.to(units) if units is not None else obj + else: + new_obj = obj + + return new_obj + + +def assert_units_equal(a, b): + __tracebackhide__ = True + assert extract_units(a) == extract_units(b) + + +@pytest.fixture(params=[np.dtype(float), np.dtype(int)], ids=str) +def dtype(request): + return request.param + + +def merge_args(default_args, new_args): + from itertools import zip_longest + + fill_value = object() + return [ + second if second is not fill_value else first + for first, second in zip_longest(default_args, new_args, fillvalue=fill_value) + ] + + +class method: + """wrapper class to help with passing methods via parametrize + + This is works a bit similar to using `partial(Class.method, arg, kwarg)` + """ + + def __init__(self, name, *args, fallback_func=None, **kwargs): + self.name = name + self.fallback = fallback_func + self.args = args + self.kwargs = kwargs + + def __call__(self, obj, *args, **kwargs): + from functools import partial + + all_args = merge_args(self.args, args) + all_kwargs = {**self.kwargs, **kwargs} + + from xarray.core.groupby import GroupBy + + xarray_classes = ( + xr.Variable, + xr.DataArray, + xr.Dataset, + GroupBy, + ) + + if not isinstance(obj, xarray_classes): + # remove typical xarray args like "dim" + exclude_kwargs = ("dim", "dims") + # TODO: figure out a way to replace dim / dims with axis + all_kwargs = { + key: value + for key, value in all_kwargs.items() + if key not in exclude_kwargs + } + if self.fallback is not None: + func = partial(self.fallback, obj) + else: + func = getattr(obj, self.name, None) + + if func is None or not callable(func): + # fall back to module level numpy functions + numpy_func = getattr(np, self.name) + func = partial(numpy_func, obj) + else: + func = getattr(obj, self.name) + + return func(*all_args, **all_kwargs) + + def __repr__(self): + return f"method_{self.name}" + + +class function: + """wrapper class for numpy functions + + Same as method, but the name is used for referencing numpy functions + """ + + def __init__(self, name_or_function, *args, function_label=None, **kwargs): + if callable(name_or_function): + self.name = ( + function_label + if function_label is not None + else name_or_function.__name__ + ) + self.func = name_or_function + else: + self.name = name_or_function if function_label is None else function_label + self.func = getattr(np, name_or_function) + if self.func is None: + raise AttributeError( + f"module 'numpy' has no attribute named '{self.name}'" + ) + + self.args = args + self.kwargs = kwargs + + def __call__(self, *args, **kwargs): + all_args = merge_args(self.args, args) + all_kwargs = {**self.kwargs, **kwargs} + + return self.func(*all_args, **all_kwargs) + + def __repr__(self): + return f"function_{self.name}" + + +@pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.skip(reason="indexes don't support units") + ), + "coords", + ), +) +def test_apply_ufunc_dataarray(variant, dtype): + variants = { + "data": (unit_registry.m, 1, 1), + "dims": (1, unit_registry.m, 1), + "coords": (1, 1, unit_registry.m), + } + data_unit, dim_unit, coord_unit = variants.get(variant) + func = functools.partial( + xr.apply_ufunc, np.mean, input_core_dims=[["x"]], kwargs={"axis": -1} + ) + + array = np.linspace(0, 10, 20).astype(dtype) * data_unit + x = np.arange(20) * dim_unit + u = np.linspace(-1, 1, 20) * coord_unit + data_array = xr.DataArray(data=array, dims="x", coords={"x": x, "u": ("x", u)}) + + expected = attach_units(func(strip_units(data_array)), extract_units(data_array)) + actual = func(data_array) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + +@pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.skip(reason="indexes don't support units") + ), + "coords", + ), +) +def test_apply_ufunc_dataset(variant, dtype): + variants = { + "data": (unit_registry.m, 1, 1), + "dims": (1, unit_registry.m, 1), + "coords": (1, 1, unit_registry.s), + } + data_unit, dim_unit, coord_unit = variants.get(variant) + + func = functools.partial( + xr.apply_ufunc, np.mean, input_core_dims=[["x"]], kwargs={"axis": -1} + ) + + array1 = np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * data_unit + array2 = np.linspace(0, 10, 5).astype(dtype) * data_unit + + x = np.arange(5) * dim_unit + y = np.arange(10) * dim_unit + + u = np.linspace(-1, 1, 10) * coord_unit + + ds = xr.Dataset( + data_vars={"a": (("x", "y"), array1), "b": ("x", array2)}, + coords={"x": x, "y": y, "u": ("y", u)}, + ) + + expected = attach_units(func(strip_units(ds)), extract_units(ds)) + actual = func(ds) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + +@pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.mm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ids=repr, +) +@pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.skip(reason="indexes don't support units") + ), + "coords", + ), +) +@pytest.mark.parametrize("value", (10, dtypes.NA)) +def test_align_dataarray(value, variant, unit, error, dtype): + if variant == "coords" and ( + value != dtypes.NA or isinstance(unit, unit_registry.Unit) + ): + pytest.xfail( + reason=( + "fill_value is used for both data variables and coords. " + "See https://github.com/pydata/xarray/issues/4165" + ) + ) + + fill_value = dtypes.get_fill_value(dtype) if value == dtypes.NA else value + + original_unit = unit_registry.m + + variants = { + "data": ((original_unit, unit), (1, 1), (1, 1)), + "dims": ((1, 1), (original_unit, unit), (1, 1)), + "coords": ((1, 1), (1, 1), (original_unit, unit)), + } + ( + (data_unit1, data_unit2), + (dim_unit1, dim_unit2), + (coord_unit1, coord_unit2), + ) = variants.get(variant) + + array1 = np.linspace(0, 10, 2 * 5).reshape(2, 5).astype(dtype) * data_unit1 + array2 = np.linspace(0, 8, 2 * 5).reshape(2, 5).astype(dtype) * data_unit2 + + x = np.arange(2) * dim_unit1 + y1 = np.arange(5) * dim_unit1 + y2 = np.arange(2, 7) * dim_unit2 + + u1 = np.array([3, 5, 7, 8, 9]) * coord_unit1 + u2 = np.array([7, 8, 9, 11, 13]) * coord_unit2 + + coords1 = {"x": x, "y": y1} + coords2 = {"x": x, "y": y2} + if variant == "coords": + coords1["y_a"] = ("y", u1) + coords2["y_a"] = ("y", u2) + + data_array1 = xr.DataArray(data=array1, coords=coords1, dims=("x", "y")) + data_array2 = xr.DataArray(data=array2, coords=coords2, dims=("x", "y")) + + fill_value = fill_value * data_unit2 + func = function(xr.align, join="outer", fill_value=fill_value) + if error is not None and (value != dtypes.NA or isinstance(fill_value, Quantity)): + with pytest.raises(error): + func(data_array1, data_array2) + + return + + stripped_kwargs = { + key: strip_units( + convert_units(value, {None: data_unit1 if data_unit2 != 1 else None}) + ) + for key, value in func.kwargs.items() + } + units_a = extract_units(data_array1) + units_b = extract_units(data_array2) + expected_a, expected_b = func( + strip_units(data_array1), + strip_units(convert_units(data_array2, units_a)), + **stripped_kwargs, + ) + expected_a = attach_units(expected_a, units_a) + if isinstance(array2, Quantity): + expected_b = convert_units(attach_units(expected_b, units_a), units_b) + else: + expected_b = attach_units(expected_b, units_b) + + actual_a, actual_b = func(data_array1, data_array2) + + assert_units_equal(expected_a, actual_a) + assert_allclose(expected_a, actual_a) + assert_units_equal(expected_b, actual_b) + assert_allclose(expected_b, actual_b) + + +@pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.mm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ids=repr, +) +@pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.skip(reason="indexes don't support units") + ), + "coords", + ), +) +@pytest.mark.parametrize("value", (10, dtypes.NA)) +def test_align_dataset(value, unit, variant, error, dtype): + if variant == "coords" and ( + value != dtypes.NA or isinstance(unit, unit_registry.Unit) + ): + pytest.xfail( + reason=( + "fill_value is used for both data variables and coords. " + "See https://github.com/pydata/xarray/issues/4165" + ) + ) + + fill_value = dtypes.get_fill_value(dtype) if value == dtypes.NA else value + + original_unit = unit_registry.m + + variants = { + "data": ((original_unit, unit), (1, 1), (1, 1)), + "dims": ((1, 1), (original_unit, unit), (1, 1)), + "coords": ((1, 1), (1, 1), (original_unit, unit)), + } + ( + (data_unit1, data_unit2), + (dim_unit1, dim_unit2), + (coord_unit1, coord_unit2), + ) = variants.get(variant) + + array1 = np.linspace(0, 10, 2 * 5).reshape(2, 5).astype(dtype) * data_unit1 + array2 = np.linspace(0, 10, 2 * 5).reshape(2, 5).astype(dtype) * data_unit2 + + x = np.arange(2) * dim_unit1 + y1 = np.arange(5) * dim_unit1 + y2 = np.arange(2, 7) * dim_unit2 + + u1 = np.array([3, 5, 7, 8, 9]) * coord_unit1 + u2 = np.array([7, 8, 9, 11, 13]) * coord_unit2 + + coords1 = {"x": x, "y": y1} + coords2 = {"x": x, "y": y2} + if variant == "coords": + coords1["u"] = ("y", u1) + coords2["u"] = ("y", u2) + + ds1 = xr.Dataset(data_vars={"a": (("x", "y"), array1)}, coords=coords1) + ds2 = xr.Dataset(data_vars={"a": (("x", "y"), array2)}, coords=coords2) + + fill_value = fill_value * data_unit2 + func = function(xr.align, join="outer", fill_value=fill_value) + if error is not None and (value != dtypes.NA or isinstance(fill_value, Quantity)): + with pytest.raises(error): + func(ds1, ds2) + + return + + stripped_kwargs = { + key: strip_units( + convert_units(value, {None: data_unit1 if data_unit2 != 1 else None}) + ) + for key, value in func.kwargs.items() + } + units_a = extract_units(ds1) + units_b = extract_units(ds2) + expected_a, expected_b = func( + strip_units(ds1), + strip_units(convert_units(ds2, units_a)), + **stripped_kwargs, + ) + expected_a = attach_units(expected_a, units_a) + if isinstance(array2, Quantity): + expected_b = convert_units(attach_units(expected_b, units_a), units_b) + else: + expected_b = attach_units(expected_b, units_b) + + actual_a, actual_b = func(ds1, ds2) + + assert_units_equal(expected_a, actual_a) + assert_allclose(expected_a, actual_a) + assert_units_equal(expected_b, actual_b) + assert_allclose(expected_b, actual_b) + + +def test_broadcast_dataarray(dtype): + # uses align internally so more thorough tests are not needed + array1 = np.linspace(0, 10, 2) * unit_registry.Pa + array2 = np.linspace(0, 10, 3) * unit_registry.Pa + + a = xr.DataArray(data=array1, dims="x") + b = xr.DataArray(data=array2, dims="y") + + units_a = extract_units(a) + units_b = extract_units(b) + expected_a, expected_b = xr.broadcast(strip_units(a), strip_units(b)) + expected_a = attach_units(expected_a, units_a) + expected_b = convert_units(attach_units(expected_b, units_a), units_b) + + actual_a, actual_b = xr.broadcast(a, b) + + assert_units_equal(expected_a, actual_a) + assert_identical(expected_a, actual_a) + assert_units_equal(expected_b, actual_b) + assert_identical(expected_b, actual_b) + + +def test_broadcast_dataset(dtype): + # uses align internally so more thorough tests are not needed + array1 = np.linspace(0, 10, 2) * unit_registry.Pa + array2 = np.linspace(0, 10, 3) * unit_registry.Pa + + x1 = np.arange(2) + y1 = np.arange(3) + + x2 = np.arange(2, 4) + y2 = np.arange(3, 6) + + ds = xr.Dataset( + data_vars={"a": ("x", array1), "b": ("y", array2)}, coords={"x": x1, "y": y1} + ) + other = xr.Dataset( + data_vars={ + "a": ("x", array1.to(unit_registry.hPa)), + "b": ("y", array2.to(unit_registry.hPa)), + }, + coords={"x": x2, "y": y2}, + ) + + units_a = extract_units(ds) + units_b = extract_units(other) + expected_a, expected_b = xr.broadcast(strip_units(ds), strip_units(other)) + expected_a = attach_units(expected_a, units_a) + expected_b = attach_units(expected_b, units_b) + + actual_a, actual_b = xr.broadcast(ds, other) + + assert_units_equal(expected_a, actual_a) + assert_identical(expected_a, actual_a) + assert_units_equal(expected_b, actual_b) + assert_identical(expected_b, actual_b) + + +@pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.mm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ids=repr, +) +@pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.skip(reason="indexes don't support units") + ), + "coords", + ), +) +def test_combine_by_coords(variant, unit, error, dtype): + original_unit = unit_registry.m + + variants = { + "data": ((original_unit, unit), (1, 1), (1, 1)), + "dims": ((1, 1), (original_unit, unit), (1, 1)), + "coords": ((1, 1), (1, 1), (original_unit, unit)), + } + ( + (data_unit1, data_unit2), + (dim_unit1, dim_unit2), + (coord_unit1, coord_unit2), + ) = variants.get(variant) + + array1 = np.zeros(shape=(2, 3), dtype=dtype) * data_unit1 + array2 = np.zeros(shape=(2, 3), dtype=dtype) * data_unit1 + x = np.arange(1, 4) * 10 * dim_unit1 + y = np.arange(2) * dim_unit1 + u = np.arange(3) * coord_unit1 + + other_array1 = np.ones_like(array1) * data_unit2 + other_array2 = np.ones_like(array2) * data_unit2 + other_x = np.arange(1, 4) * 10 * dim_unit2 + other_y = np.arange(2, 4) * dim_unit2 + other_u = np.arange(3, 6) * coord_unit2 + + ds = xr.Dataset( + data_vars={"a": (("y", "x"), array1), "b": (("y", "x"), array2)}, + coords={"x": x, "y": y, "u": ("x", u)}, + ) + other = xr.Dataset( + data_vars={"a": (("y", "x"), other_array1), "b": (("y", "x"), other_array2)}, + coords={"x": other_x, "y": other_y, "u": ("x", other_u)}, + ) + + if error is not None: + with pytest.raises(error): + xr.combine_by_coords([ds, other]) + + return + + units = extract_units(ds) + expected = attach_units( + xr.combine_by_coords( + [strip_units(ds), strip_units(convert_units(other, units))] + ), + units, + ) + actual = xr.combine_by_coords([ds, other]) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + +@pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.mm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ids=repr, +) +@pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.skip(reason="indexes don't support units") + ), + "coords", + ), +) +def test_combine_nested(variant, unit, error, dtype): + original_unit = unit_registry.m + + variants = { + "data": ((original_unit, unit), (1, 1), (1, 1)), + "dims": ((1, 1), (original_unit, unit), (1, 1)), + "coords": ((1, 1), (1, 1), (original_unit, unit)), + } + ( + (data_unit1, data_unit2), + (dim_unit1, dim_unit2), + (coord_unit1, coord_unit2), + ) = variants.get(variant) + + array1 = np.zeros(shape=(2, 3), dtype=dtype) * data_unit1 + array2 = np.zeros(shape=(2, 3), dtype=dtype) * data_unit1 + + x = np.arange(1, 4) * 10 * dim_unit1 + y = np.arange(2) * dim_unit1 + z = np.arange(3) * coord_unit1 + + ds1 = xr.Dataset( + data_vars={"a": (("y", "x"), array1), "b": (("y", "x"), array2)}, + coords={"x": x, "y": y, "z": ("x", z)}, + ) + ds2 = xr.Dataset( + data_vars={ + "a": (("y", "x"), np.ones_like(array1) * data_unit2), + "b": (("y", "x"), np.ones_like(array2) * data_unit2), + }, + coords={ + "x": np.arange(3) * dim_unit2, + "y": np.arange(2, 4) * dim_unit2, + "z": ("x", np.arange(-3, 0) * coord_unit2), + }, + ) + ds3 = xr.Dataset( + data_vars={ + "a": (("y", "x"), np.full_like(array1, fill_value=np.nan) * data_unit2), + "b": (("y", "x"), np.full_like(array2, fill_value=np.nan) * data_unit2), + }, + coords={ + "x": np.arange(3, 6) * dim_unit2, + "y": np.arange(4, 6) * dim_unit2, + "z": ("x", np.arange(3, 6) * coord_unit2), + }, + ) + ds4 = xr.Dataset( + data_vars={ + "a": (("y", "x"), -1 * np.ones_like(array1) * data_unit2), + "b": (("y", "x"), -1 * np.ones_like(array2) * data_unit2), + }, + coords={ + "x": np.arange(6, 9) * dim_unit2, + "y": np.arange(6, 8) * dim_unit2, + "z": ("x", np.arange(6, 9) * coord_unit2), + }, + ) + + func = function(xr.combine_nested, concat_dim=["x", "y"]) + if error is not None: + with pytest.raises(error): + func([[ds1, ds2], [ds3, ds4]]) + + return + + units = extract_units(ds1) + convert_and_strip = lambda ds: strip_units(convert_units(ds, units)) + expected = attach_units( + func( + [ + [strip_units(ds1), convert_and_strip(ds2)], + [convert_and_strip(ds3), convert_and_strip(ds4)], + ] + ), + units, + ) + actual = func([[ds1, ds2], [ds3, ds4]]) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + +@pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.mm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ids=repr, +) +@pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.skip(reason="indexes don't support units") + ), + "coords", + ), +) +def test_concat_dataarray(variant, unit, error, dtype): + original_unit = unit_registry.m + + variants = { + "data": ((original_unit, unit), (1, 1), (1, 1)), + "dims": ((1, 1), (original_unit, unit), (1, 1)), + "coords": ((1, 1), (1, 1), (original_unit, unit)), + } + ( + (data_unit1, data_unit2), + (dim_unit1, dim_unit2), + (coord_unit1, coord_unit2), + ) = variants.get(variant) + + array1 = np.linspace(0, 5, 10).astype(dtype) * data_unit1 + array2 = np.linspace(-5, 0, 5).astype(dtype) * data_unit2 + + x1 = np.arange(5, 15) * dim_unit1 + x2 = np.arange(5) * dim_unit2 + + u1 = np.linspace(1, 2, 10).astype(dtype) * coord_unit1 + u2 = np.linspace(0, 1, 5).astype(dtype) * coord_unit2 + + arr1 = xr.DataArray(data=array1, coords={"x": x1, "u": ("x", u1)}, dims="x") + arr2 = xr.DataArray(data=array2, coords={"x": x2, "u": ("x", u2)}, dims="x") + + if error is not None: + with pytest.raises(error): + xr.concat([arr1, arr2], dim="x") + + return + + units = extract_units(arr1) + expected = attach_units( + xr.concat( + [strip_units(arr1), strip_units(convert_units(arr2, units))], dim="x" + ), + units, + ) + actual = xr.concat([arr1, arr2], dim="x") + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + +@pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.mm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ids=repr, +) +@pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.skip(reason="indexes don't support units") + ), + "coords", + ), +) +def test_concat_dataset(variant, unit, error, dtype): + original_unit = unit_registry.m + + variants = { + "data": ((original_unit, unit), (1, 1), (1, 1)), + "dims": ((1, 1), (original_unit, unit), (1, 1)), + "coords": ((1, 1), (1, 1), (original_unit, unit)), + } + ( + (data_unit1, data_unit2), + (dim_unit1, dim_unit2), + (coord_unit1, coord_unit2), + ) = variants.get(variant) + + array1 = np.linspace(0, 5, 10).astype(dtype) * data_unit1 + array2 = np.linspace(-5, 0, 5).astype(dtype) * data_unit2 + + x1 = np.arange(5, 15) * dim_unit1 + x2 = np.arange(5) * dim_unit2 + + u1 = np.linspace(1, 2, 10).astype(dtype) * coord_unit1 + u2 = np.linspace(0, 1, 5).astype(dtype) * coord_unit2 + + ds1 = xr.Dataset(data_vars={"a": ("x", array1)}, coords={"x": x1, "u": ("x", u1)}) + ds2 = xr.Dataset(data_vars={"a": ("x", array2)}, coords={"x": x2, "u": ("x", u2)}) + + if error is not None: + with pytest.raises(error): + xr.concat([ds1, ds2], dim="x") + + return + + units = extract_units(ds1) + expected = attach_units( + xr.concat([strip_units(ds1), strip_units(convert_units(ds2, units))], dim="x"), + units, + ) + actual = xr.concat([ds1, ds2], dim="x") + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + +@pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.mm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ids=repr, +) +@pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.skip(reason="indexes don't support units") + ), + "coords", + ), +) +def test_merge_dataarray(variant, unit, error, dtype): + original_unit = unit_registry.m + + variants = { + "data": ((original_unit, unit), (1, 1), (1, 1)), + "dims": ((1, 1), (original_unit, unit), (1, 1)), + "coords": ((1, 1), (1, 1), (original_unit, unit)), + } + ( + (data_unit1, data_unit2), + (dim_unit1, dim_unit2), + (coord_unit1, coord_unit2), + ) = variants.get(variant) + + array1 = np.linspace(0, 1, 2 * 3).reshape(2, 3).astype(dtype) * data_unit1 + x1 = np.arange(2) * dim_unit1 + y1 = np.arange(3) * dim_unit1 + u1 = np.linspace(10, 20, 2) * coord_unit1 + v1 = np.linspace(10, 20, 3) * coord_unit1 + + array2 = np.linspace(1, 2, 2 * 4).reshape(2, 4).astype(dtype) * data_unit2 + x2 = np.arange(2, 4) * dim_unit2 + z2 = np.arange(4) * dim_unit1 + u2 = np.linspace(20, 30, 2) * coord_unit2 + w2 = np.linspace(10, 20, 4) * coord_unit1 + + array3 = np.linspace(0, 2, 3 * 4).reshape(3, 4).astype(dtype) * data_unit2 + y3 = np.arange(3, 6) * dim_unit2 + z3 = np.arange(4, 8) * dim_unit2 + v3 = np.linspace(10, 20, 3) * coord_unit2 + w3 = np.linspace(10, 20, 4) * coord_unit2 + + arr1 = xr.DataArray( + name="a", + data=array1, + coords={"x": x1, "y": y1, "u": ("x", u1), "v": ("y", v1)}, + dims=("x", "y"), + ) + arr2 = xr.DataArray( + name="a", + data=array2, + coords={"x": x2, "z": z2, "u": ("x", u2), "w": ("z", w2)}, + dims=("x", "z"), + ) + arr3 = xr.DataArray( + name="a", + data=array3, + coords={"y": y3, "z": z3, "v": ("y", v3), "w": ("z", w3)}, + dims=("y", "z"), + ) + + if error is not None: + with pytest.raises(error): + xr.merge([arr1, arr2, arr3]) + + return + + units = { + "a": data_unit1, + "u": coord_unit1, + "v": coord_unit1, + "w": coord_unit1, + "x": dim_unit1, + "y": dim_unit1, + "z": dim_unit1, + } + convert_and_strip = lambda arr: strip_units(convert_units(arr, units)) + + expected = attach_units( + xr.merge( + [convert_and_strip(arr1), convert_and_strip(arr2), convert_and_strip(arr3)] + ), + units, + ) + + actual = xr.merge([arr1, arr2, arr3]) + + assert_units_equal(expected, actual) + assert_allclose(expected, actual) + + +@pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.mm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ids=repr, +) +@pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.skip(reason="indexes don't support units") + ), + "coords", + ), +) +def test_merge_dataset(variant, unit, error, dtype): + original_unit = unit_registry.m + + variants = { + "data": ((original_unit, unit), (1, 1), (1, 1)), + "dims": ((1, 1), (original_unit, unit), (1, 1)), + "coords": ((1, 1), (1, 1), (original_unit, unit)), + } + ( + (data_unit1, data_unit2), + (dim_unit1, dim_unit2), + (coord_unit1, coord_unit2), + ) = variants.get(variant) + + array1 = np.zeros(shape=(2, 3), dtype=dtype) * data_unit1 + array2 = np.zeros(shape=(2, 3), dtype=dtype) * data_unit1 + + x = np.arange(11, 14) * dim_unit1 + y = np.arange(2) * dim_unit1 + u = np.arange(3) * coord_unit1 + + ds1 = xr.Dataset( + data_vars={"a": (("y", "x"), array1), "b": (("y", "x"), array2)}, + coords={"x": x, "y": y, "u": ("x", u)}, + ) + ds2 = xr.Dataset( + data_vars={ + "a": (("y", "x"), np.ones_like(array1) * data_unit2), + "b": (("y", "x"), np.ones_like(array2) * data_unit2), + }, + coords={ + "x": np.arange(3) * dim_unit2, + "y": np.arange(2, 4) * dim_unit2, + "u": ("x", np.arange(-3, 0) * coord_unit2), + }, + ) + ds3 = xr.Dataset( + data_vars={ + "a": (("y", "x"), np.full_like(array1, np.nan) * data_unit2), + "b": (("y", "x"), np.full_like(array2, np.nan) * data_unit2), + }, + coords={ + "x": np.arange(3, 6) * dim_unit2, + "y": np.arange(4, 6) * dim_unit2, + "u": ("x", np.arange(3, 6) * coord_unit2), + }, + ) + + func = function(xr.merge) + if error is not None: + with pytest.raises(error): + func([ds1, ds2, ds3]) + + return + + units = extract_units(ds1) + convert_and_strip = lambda ds: strip_units(convert_units(ds, units)) + expected = attach_units( + func([convert_and_strip(ds1), convert_and_strip(ds2), convert_and_strip(ds3)]), + units, + ) + actual = func([ds1, ds2, ds3]) + + assert_units_equal(expected, actual) + assert_allclose(expected, actual) + + +@pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.skip(reason="indexes don't support units") + ), + "coords", + ), +) +@pytest.mark.parametrize("func", (xr.zeros_like, xr.ones_like)) +def test_replication_dataarray(func, variant, dtype): + unit = unit_registry.m + + variants = { + "data": (unit, 1, 1), + "dims": (1, unit, 1), + "coords": (1, 1, unit), + } + data_unit, dim_unit, coord_unit = variants.get(variant) + + array = np.linspace(0, 10, 20).astype(dtype) * data_unit + x = np.arange(20) * dim_unit + u = np.linspace(0, 1, 20) * coord_unit + + data_array = xr.DataArray(data=array, dims="x", coords={"x": x, "u": ("x", u)}) + units = extract_units(data_array) + units.pop(data_array.name) + + expected = attach_units(func(strip_units(data_array)), units) + actual = func(data_array) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + +@pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.skip(reason="indexes don't support units") + ), + "coords", + ), +) +@pytest.mark.parametrize("func", (xr.zeros_like, xr.ones_like)) +def test_replication_dataset(func, variant, dtype): + unit = unit_registry.m + + variants = { + "data": ((unit_registry.m, unit_registry.Pa), 1, 1), + "dims": ((1, 1), unit, 1), + "coords": ((1, 1), 1, unit), + } + (data_unit1, data_unit2), dim_unit, coord_unit = variants.get(variant) + + array1 = np.linspace(0, 10, 20).astype(dtype) * data_unit1 + array2 = np.linspace(5, 10, 10).astype(dtype) * data_unit2 + x = np.arange(20).astype(dtype) * dim_unit + y = np.arange(10).astype(dtype) * dim_unit + u = np.linspace(0, 1, 10) * coord_unit + + ds = xr.Dataset( + data_vars={"a": ("x", array1), "b": ("y", array2)}, + coords={"x": x, "y": y, "u": ("y", u)}, + ) + units = { + name: unit + for name, unit in extract_units(ds).items() + if name not in ds.data_vars + } + + expected = attach_units(func(strip_units(ds)), units) + + actual = func(ds) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + +@pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.skip(reason="indexes don't support units") + ), + pytest.param( + "coords", + marks=pytest.mark.xfail(reason="can't copy quantity into non-quantity"), + ), + ), +) +def test_replication_full_like_dataarray(variant, dtype): + # since full_like will strip units and then use the units of the + # fill value, we don't need to try multiple units + unit = unit_registry.m + + variants = { + "data": (unit, 1, 1), + "dims": (1, unit, 1), + "coords": (1, 1, unit), + } + data_unit, dim_unit, coord_unit = variants.get(variant) + + array = np.linspace(0, 5, 10) * data_unit + x = np.arange(10) * dim_unit + u = np.linspace(0, 1, 10) * coord_unit + data_array = xr.DataArray(data=array, dims="x", coords={"x": x, "u": ("x", u)}) + + fill_value = -1 * unit_registry.degK + + units = extract_units(data_array) + units[data_array.name] = fill_value.units + expected = attach_units( + xr.full_like(strip_units(data_array), fill_value=strip_units(fill_value)), units + ) + actual = xr.full_like(data_array, fill_value=fill_value) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + +@pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.skip(reason="indexes don't support units") + ), + pytest.param( + "coords", + marks=pytest.mark.xfail(reason="can't copy quantity into non-quantity"), + ), + ), +) +def test_replication_full_like_dataset(variant, dtype): + unit = unit_registry.m + + variants = { + "data": ((unit_registry.s, unit_registry.Pa), 1, 1), + "dims": ((1, 1), unit, 1), + "coords": ((1, 1), 1, unit), + } + (data_unit1, data_unit2), dim_unit, coord_unit = variants.get(variant) + + array1 = np.linspace(0, 10, 20).astype(dtype) * data_unit1 + array2 = np.linspace(5, 10, 10).astype(dtype) * data_unit2 + x = np.arange(20).astype(dtype) * dim_unit + y = np.arange(10).astype(dtype) * dim_unit + + u = np.linspace(0, 1, 10) * coord_unit + + ds = xr.Dataset( + data_vars={"a": ("x", array1), "b": ("y", array2)}, + coords={"x": x, "y": y, "u": ("y", u)}, + ) + + fill_value = -1 * unit_registry.degK + + units = { + **extract_units(ds), + **{name: unit_registry.degK for name in ds.data_vars}, + } + expected = attach_units( + xr.full_like(strip_units(ds), fill_value=strip_units(fill_value)), units + ) + actual = xr.full_like(ds, fill_value=fill_value) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + +@pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.mm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ids=repr, +) +@pytest.mark.parametrize("fill_value", (np.nan, 10.2)) +def test_where_dataarray(fill_value, unit, error, dtype): + array = np.linspace(0, 5, 10).astype(dtype) * unit_registry.m + + x = xr.DataArray(data=array, dims="x") + cond = x < 5 * unit_registry.m + fill_value = fill_value * unit + + if error is not None and not ( + np.isnan(fill_value) and not isinstance(fill_value, Quantity) + ): + with pytest.raises(error): + xr.where(cond, x, fill_value) + + return + + expected = attach_units( + xr.where( + cond, + strip_units(x), + strip_units(convert_units(fill_value, {None: unit_registry.m})), + ), + extract_units(x), + ) + actual = xr.where(cond, x, fill_value) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + +@pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.mm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ids=repr, +) +@pytest.mark.parametrize("fill_value", (np.nan, 10.2)) +def test_where_dataset(fill_value, unit, error, dtype): + array1 = np.linspace(0, 5, 10).astype(dtype) * unit_registry.m + array2 = np.linspace(-5, 0, 10).astype(dtype) * unit_registry.m + + ds = xr.Dataset(data_vars={"a": ("x", array1), "b": ("x", array2)}) + cond = array1 < 2 * unit_registry.m + fill_value = fill_value * unit + + if error is not None and not ( + np.isnan(fill_value) and not isinstance(fill_value, Quantity) + ): + with pytest.raises(error): + xr.where(cond, ds, fill_value) + + return + + expected = attach_units( + xr.where( + cond, + strip_units(ds), + strip_units(convert_units(fill_value, {None: unit_registry.m})), + ), + extract_units(ds), + ) + actual = xr.where(cond, ds, fill_value) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + +def test_dot_dataarray(dtype): + array1 = ( + np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) + * unit_registry.m + / unit_registry.s + ) + array2 = ( + np.linspace(10, 20, 10 * 20).reshape(10, 20).astype(dtype) * unit_registry.s + ) + + data_array = xr.DataArray(data=array1, dims=("x", "y")) + other = xr.DataArray(data=array2, dims=("y", "z")) + + with xr.set_options(use_opt_einsum=False): + expected = attach_units( + xr.dot(strip_units(data_array), strip_units(other)), {None: unit_registry.m} + ) + actual = xr.dot(data_array, other) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + +class TestVariable: + @pytest.mark.parametrize( + "func", + ( + method("all"), + method("any"), + method("argmax", dim="x"), + method("argmin", dim="x"), + method("argsort"), + method("cumprod"), + method("cumsum"), + method("max"), + method("mean"), + method("median"), + method("min"), + method("prod"), + method("std"), + method("sum"), + method("var"), + ), + ids=repr, + ) + def test_aggregation(self, func, dtype): + array = np.linspace(0, 1, 10).astype(dtype) * ( + unit_registry.m if func.name != "cumprod" else unit_registry.dimensionless + ) + variable = xr.Variable("x", array) + + numpy_kwargs = func.kwargs.copy() + if "dim" in func.kwargs: + numpy_kwargs["axis"] = variable.get_axis_num(numpy_kwargs.pop("dim")) + + units = extract_units(func(array, **numpy_kwargs)) + expected = attach_units(func(strip_units(variable)), units) + actual = func(variable) + + assert_units_equal(expected, actual) + assert_allclose(expected, actual) + + def test_aggregate_complex(self): + variable = xr.Variable("x", [1, 2j, np.nan] * unit_registry.m) + expected = xr.Variable((), (0.5 + 1j) * unit_registry.m) + actual = variable.mean() + + assert_units_equal(expected, actual) + assert_allclose(expected, actual) + + @pytest.mark.parametrize( + "func", + ( + method("astype", np.float32), + method("conj"), + method("conjugate"), + method("clip", min=2, max=7), + ), + ids=repr, + ) + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_numpy_methods(self, func, unit, error, dtype): + array = np.linspace(0, 1, 10).astype(dtype) * unit_registry.m + variable = xr.Variable("x", array) + + args = [ + item * unit if isinstance(item, (int, float, list)) else item + for item in func.args + ] + kwargs = { + key: value * unit if isinstance(value, (int, float, list)) else value + for key, value in func.kwargs.items() + } + + if error is not None and func.name in ("searchsorted", "clip"): + with pytest.raises(error): + func(variable, *args, **kwargs) + + return + + converted_args = [ + strip_units(convert_units(item, {None: unit_registry.m})) for item in args + ] + converted_kwargs = { + key: strip_units(convert_units(value, {None: unit_registry.m})) + for key, value in kwargs.items() + } + + units = extract_units(func(array, *args, **kwargs)) + expected = attach_units( + func(strip_units(variable), *converted_args, **converted_kwargs), units + ) + actual = func(variable, *args, **kwargs) + + assert_units_equal(expected, actual) + assert_allclose(expected, actual) + + @pytest.mark.parametrize( + "func", (method("item", 5), method("searchsorted", 5)), ids=repr + ) + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_raw_numpy_methods(self, func, unit, error, dtype): + array = np.linspace(0, 1, 10).astype(dtype) * unit_registry.m + variable = xr.Variable("x", array) + + args = [ + ( + item * unit + if isinstance(item, (int, float, list)) and func.name != "item" + else item + ) + for item in func.args + ] + kwargs = { + key: ( + value * unit + if isinstance(value, (int, float, list)) and func.name != "item" + else value + ) + for key, value in func.kwargs.items() + } + + if error is not None and func.name != "item": + with pytest.raises(error): + func(variable, *args, **kwargs) + + return + + converted_args = [ + ( + strip_units(convert_units(item, {None: unit_registry.m})) + if func.name != "item" + else item + ) + for item in args + ] + converted_kwargs = { + key: ( + strip_units(convert_units(value, {None: unit_registry.m})) + if func.name != "item" + else value + ) + for key, value in kwargs.items() + } + + units = extract_units(func(array, *args, **kwargs)) + expected = attach_units( + func(strip_units(variable), *converted_args, **converted_kwargs), units + ) + actual = func(variable, *args, **kwargs) + + assert_units_equal(expected, actual) + assert_duckarray_allclose(expected, actual) + + @pytest.mark.parametrize( + "func", (method("isnull"), method("notnull"), method("count")), ids=repr + ) + def test_missing_value_detection(self, func): + array = ( + np.array( + [ + [1.4, 2.3, np.nan, 7.2], + [np.nan, 9.7, np.nan, np.nan], + [2.1, np.nan, np.nan, 4.6], + [9.9, np.nan, 7.2, 9.1], + ] + ) + * unit_registry.degK + ) + variable = xr.Variable(("x", "y"), array) + + expected = func(strip_units(variable)) + actual = func(variable) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_missing_value_fillna(self, unit, error): + value = 10 + array = ( + np.array( + [ + [1.4, 2.3, np.nan, 7.2], + [np.nan, 9.7, np.nan, np.nan], + [2.1, np.nan, np.nan, 4.6], + [9.9, np.nan, 7.2, 9.1], + ] + ) + * unit_registry.m + ) + variable = xr.Variable(("x", "y"), array) + + fill_value = value * unit + + if error is not None: + with pytest.raises(error): + variable.fillna(value=fill_value) + + return + + expected = attach_units( + strip_units(variable).fillna( + value=fill_value.to(unit_registry.m).magnitude + ), + extract_units(variable), + ) + actual = variable.fillna(value=fill_value) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + @pytest.mark.parametrize( + "unit", + ( + pytest.param(1, id="no_unit"), + pytest.param(unit_registry.dimensionless, id="dimensionless"), + pytest.param(unit_registry.s, id="incompatible_unit"), + pytest.param( + unit_registry.cm, + id="compatible_unit", + ), + pytest.param(unit_registry.m, id="identical_unit"), + ), + ) + @pytest.mark.parametrize( + "convert_data", + ( + pytest.param(False, id="no_conversion"), + pytest.param(True, id="with_conversion"), + ), + ) + @pytest.mark.parametrize( + "func", + ( + method("equals"), + pytest.param( + method("identical"), + marks=pytest.mark.skip(reason="behavior of identical is undecided"), + ), + ), + ids=repr, + ) + def test_comparisons(self, func, unit, convert_data, dtype): + array = np.linspace(0, 1, 9).astype(dtype) + quantity1 = array * unit_registry.m + variable = xr.Variable("x", quantity1) + + if convert_data and is_compatible(unit_registry.m, unit): + quantity2 = convert_units(array * unit_registry.m, {None: unit}) + else: + quantity2 = array * unit + other = xr.Variable("x", quantity2) + + expected = func( + strip_units(variable), + strip_units( + convert_units(other, extract_units(variable)) + if is_compatible(unit_registry.m, unit) + else other + ), + ) + if func.name == "identical": + expected &= extract_units(variable) == extract_units(other) + else: + expected &= all( + compatible_mappings( + extract_units(variable), extract_units(other) + ).values() + ) + + actual = func(variable, other) + + assert expected == actual + + @pytest.mark.parametrize( + "unit", + ( + pytest.param(1, id="no_unit"), + pytest.param(unit_registry.dimensionless, id="dimensionless"), + pytest.param(unit_registry.s, id="incompatible_unit"), + pytest.param(unit_registry.cm, id="compatible_unit"), + pytest.param(unit_registry.m, id="identical_unit"), + ), + ) + def test_broadcast_equals(self, unit, dtype): + base_unit = unit_registry.m + left_array = np.ones(shape=(2, 2), dtype=dtype) * base_unit + value = ( + (1 * base_unit).to(unit).magnitude if is_compatible(unit, base_unit) else 1 + ) + right_array = np.full(shape=(2,), fill_value=value, dtype=dtype) * unit + + left = xr.Variable(("x", "y"), left_array) + right = xr.Variable("x", right_array) + + units = { + **extract_units(left), + **({} if is_compatible(unit, base_unit) else {None: None}), + } + expected = strip_units(left).broadcast_equals( + strip_units(convert_units(right, units)) + ) & is_compatible(unit, base_unit) + actual = left.broadcast_equals(right) + + assert expected == actual + + @pytest.mark.parametrize("dask", [False, pytest.param(True, marks=[requires_dask])]) + @pytest.mark.parametrize( + ["variable", "indexers"], + ( + pytest.param( + xr.Variable("x", np.linspace(0, 5, 10)), + {"x": 4}, + id="single value-single indexer", + ), + pytest.param( + xr.Variable("x", np.linspace(0, 5, 10)), + {"x": [5, 2, 9, 1]}, + id="multiple values-single indexer", + ), + pytest.param( + xr.Variable(("x", "y"), np.linspace(0, 5, 20).reshape(4, 5)), + {"x": 1, "y": 4}, + id="single value-multiple indexers", + ), + pytest.param( + xr.Variable(("x", "y"), np.linspace(0, 5, 20).reshape(4, 5)), + {"x": [0, 1, 2], "y": [0, 2, 4]}, + id="multiple values-multiple indexers", + ), + ), + ) + def test_isel(self, variable, indexers, dask, dtype): + if dask: + variable = variable.chunk({dim: 2 for dim in variable.dims}) + quantified = xr.Variable( + variable.dims, variable.data.astype(dtype) * unit_registry.s + ) + + expected = attach_units( + strip_units(quantified).isel(indexers), extract_units(quantified) + ) + actual = quantified.isel(indexers) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + @pytest.mark.parametrize( + "func", + ( + function(lambda x, *_: +x, function_label="unary_plus"), + function(lambda x, *_: -x, function_label="unary_minus"), + function(lambda x, *_: abs(x), function_label="absolute"), + function(lambda x, y: x + y, function_label="sum"), + function(lambda x, y: y + x, function_label="commutative_sum"), + function(lambda x, y: x * y, function_label="product"), + function(lambda x, y: y * x, function_label="commutative_product"), + ), + ids=repr, + ) + def test_1d_math(self, func, unit, error, dtype): + base_unit = unit_registry.m + array = np.arange(5).astype(dtype) * base_unit + variable = xr.Variable("x", array) + + values = np.ones(5) + y = values * unit + + if error is not None and func.name in ("sum", "commutative_sum"): + with pytest.raises(error): + func(variable, y) + + return + + units = extract_units(func(array, y)) + if all(compatible_mappings(units, extract_units(y)).values()): + converted_y = convert_units(y, units) + else: + converted_y = y + + if all(compatible_mappings(units, extract_units(variable)).values()): + converted_variable = convert_units(variable, units) + else: + converted_variable = variable + + expected = attach_units( + func(strip_units(converted_variable), strip_units(converted_y)), units + ) + actual = func(variable, y) + + assert_units_equal(expected, actual) + assert_allclose(expected, actual) + + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + @pytest.mark.parametrize( + "func", (method("where"), method("_getitem_with_mask")), ids=repr + ) + def test_masking(self, func, unit, error, dtype): + base_unit = unit_registry.m + array = np.linspace(0, 5, 10).astype(dtype) * base_unit + variable = xr.Variable("x", array) + cond = np.array([True, False] * 5) + + other = -1 * unit + + if error is not None: + with pytest.raises(error): + func(variable, cond, other) + + return + + expected = attach_units( + func( + strip_units(variable), + cond, + strip_units( + convert_units( + other, + ( + {None: base_unit} + if is_compatible(base_unit, unit) + else {None: None} + ), + ) + ), + ), + extract_units(variable), + ) + actual = func(variable, cond, other) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + @pytest.mark.parametrize("dim", ("x", "y", "z", "t", "all")) + def test_squeeze(self, dim, dtype): + shape = (2, 1, 3, 1, 1, 2) + names = list("abcdef") + dim_lengths = dict(zip(names, shape)) + array = np.ones(shape=shape) * unit_registry.m + variable = xr.Variable(names, array) + + kwargs = {"dim": dim} if dim != "all" and dim_lengths.get(dim, 0) == 1 else {} + expected = attach_units( + strip_units(variable).squeeze(**kwargs), extract_units(variable) + ) + actual = variable.squeeze(**kwargs) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + @pytest.mark.parametrize("compute_backend", ["numbagg", None], indirect=True) + @pytest.mark.parametrize( + "func", + ( + method("coarsen", windows={"y": 2}, func=np.mean), + method("quantile", q=[0.25, 0.75]), + pytest.param( + method("rank", dim="x"), + marks=pytest.mark.skip(reason="rank not implemented for non-ndarray"), + ), + method("roll", {"x": 2}), + pytest.param( + method("rolling_window", "x", 3, "window"), + marks=pytest.mark.xfail(reason="converts to ndarray"), + ), + method("reduce", np.std, "x"), + method("round", 2), + method("shift", {"x": -2}), + method("transpose", "y", "x"), + ), + ids=repr, + ) + def test_computation(self, func, dtype, compute_backend): + base_unit = unit_registry.m + array = np.linspace(0, 5, 5 * 10).reshape(5, 10).astype(dtype) * base_unit + variable = xr.Variable(("x", "y"), array) + + expected = attach_units(func(strip_units(variable)), extract_units(variable)) + + actual = func(variable) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_searchsorted(self, unit, error, dtype): + base_unit = unit_registry.m + array = np.linspace(0, 5, 10).astype(dtype) * base_unit + variable = xr.Variable("x", array) + + value = 0 * unit + + if error is not None: + with pytest.raises(error): + variable.searchsorted(value) + + return + + expected = strip_units(variable).searchsorted( + strip_units(convert_units(value, {None: base_unit})) + ) + + actual = variable.searchsorted(value) + + assert_units_equal(expected, actual) + np.testing.assert_allclose(expected, actual) + + def test_stack(self, dtype): + array = np.linspace(0, 5, 3 * 10).reshape(3, 10).astype(dtype) * unit_registry.m + variable = xr.Variable(("x", "y"), array) + + expected = attach_units( + strip_units(variable).stack(z=("x", "y")), extract_units(variable) + ) + actual = variable.stack(z=("x", "y")) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + def test_unstack(self, dtype): + array = np.linspace(0, 5, 3 * 10).astype(dtype) * unit_registry.m + variable = xr.Variable("z", array) + + expected = attach_units( + strip_units(variable).unstack(z={"x": 3, "y": 10}), extract_units(variable) + ) + actual = variable.unstack(z={"x": 3, "y": 10}) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_concat(self, unit, error, dtype): + array1 = ( + np.linspace(0, 5, 9 * 10).reshape(3, 6, 5).astype(dtype) * unit_registry.m + ) + array2 = np.linspace(5, 10, 10 * 3).reshape(3, 2, 5).astype(dtype) * unit + + variable = xr.Variable(("x", "y", "z"), array1) + other = xr.Variable(("x", "y", "z"), array2) + + if error is not None: + with pytest.raises(error): + xr.Variable.concat([variable, other], dim="y") + + return + + units = extract_units(variable) + expected = attach_units( + xr.Variable.concat( + [strip_units(variable), strip_units(convert_units(other, units))], + dim="y", + ), + units, + ) + actual = xr.Variable.concat([variable, other], dim="y") + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + def test_set_dims(self, dtype): + array = np.linspace(0, 5, 3 * 10).reshape(3, 10).astype(dtype) * unit_registry.m + variable = xr.Variable(("x", "y"), array) + + dims = {"z": 6, "x": 3, "a": 1, "b": 4, "y": 10} + expected = attach_units( + strip_units(variable).set_dims(dims), extract_units(variable) + ) + actual = variable.set_dims(dims) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + def test_copy(self, dtype): + array = np.linspace(0, 5, 10).astype(dtype) * unit_registry.m + other = np.arange(10).astype(dtype) * unit_registry.s + + variable = xr.Variable("x", array) + expected = attach_units( + strip_units(variable).copy(data=strip_units(other)), extract_units(other) + ) + actual = variable.copy(data=other) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + @pytest.mark.parametrize( + "unit", + ( + pytest.param(1, id="no_unit"), + pytest.param(unit_registry.dimensionless, id="dimensionless"), + pytest.param(unit_registry.s, id="incompatible_unit"), + pytest.param(unit_registry.cm, id="compatible_unit"), + pytest.param(unit_registry.m, id="identical_unit"), + ), + ) + def test_no_conflicts(self, unit, dtype): + base_unit = unit_registry.m + array1 = ( + np.array( + [ + [6.3, 0.3, 0.45], + [np.nan, 0.3, 0.3], + [3.7, np.nan, 0.2], + [9.43, 0.3, 0.7], + ] + ) + * base_unit + ) + array2 = np.array([np.nan, 0.3, np.nan]) * unit + + variable = xr.Variable(("x", "y"), array1) + other = xr.Variable("y", array2) + + expected = strip_units(variable).no_conflicts( + strip_units( + convert_units( + other, {None: base_unit if is_compatible(base_unit, unit) else None} + ) + ) + ) & is_compatible(base_unit, unit) + actual = variable.no_conflicts(other) + + assert expected == actual + + @pytest.mark.parametrize( + "mode", + [ + "constant", + "mean", + "median", + "reflect", + "edge", + "linear_ramp", + "maximum", + "minimum", + "symmetric", + "wrap", + ], + ) + @pytest.mark.parametrize("xr_arg, np_arg", _PAD_XR_NP_ARGS) + def test_pad(self, mode, xr_arg, np_arg): + data = np.arange(4 * 3 * 2).reshape(4, 3, 2) * unit_registry.m + v = xr.Variable(["x", "y", "z"], data) + + expected = attach_units( + strip_units(v).pad(mode=mode, **xr_arg), + extract_units(v), + ) + actual = v.pad(mode=mode, **xr_arg) + + assert_units_equal(expected, actual) + assert_equal(actual, expected) + + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_pad_unit_constant_value(self, unit, error, dtype): + array = np.linspace(0, 5, 3 * 10).reshape(3, 10).astype(dtype) * unit_registry.m + variable = xr.Variable(("x", "y"), array) + + fill_value = -100 * unit + + func = method("pad", mode="constant", x=(2, 3), y=(1, 4)) + if error is not None: + with pytest.raises(error): + func(variable, constant_values=fill_value) + + return + + units = extract_units(variable) + expected = attach_units( + func( + strip_units(variable), + constant_values=strip_units(convert_units(fill_value, units)), + ), + units, + ) + actual = func(variable, constant_values=fill_value) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + +class TestDataArray: + @pytest.mark.parametrize( + "variant", + ( + pytest.param( + "with_dims", + marks=pytest.mark.skip(reason="indexes don't support units"), + ), + "with_coords", + "without_coords", + ), + ) + def test_init(self, variant, dtype): + array = np.linspace(1, 2, 10, dtype=dtype) * unit_registry.m + + x = np.arange(len(array)) * unit_registry.s + y = x.to(unit_registry.ms) + + variants = { + "with_dims": {"x": x}, + "with_coords": {"y": ("x", y)}, + "without_coords": {}, + } + + kwargs = {"data": array, "dims": "x", "coords": variants.get(variant)} + data_array = xr.DataArray(**kwargs) + + assert isinstance(data_array.data, Quantity) + assert all( + { + name: isinstance(coord.data, Quantity) + for name, coord in data_array.coords.items() + }.values() + ) + + @pytest.mark.parametrize( + "func", (pytest.param(str, id="str"), pytest.param(repr, id="repr")) + ) + @pytest.mark.parametrize( + "variant", + ( + pytest.param( + "with_dims", + marks=pytest.mark.skip(reason="indexes don't support units"), + ), + pytest.param("with_coords"), + pytest.param("without_coords"), + ), + ) + def test_repr(self, func, variant, dtype): + array = np.linspace(1, 2, 10, dtype=dtype) * unit_registry.m + x = np.arange(len(array)) * unit_registry.s + y = x.to(unit_registry.ms) + + variants = { + "with_dims": {"x": x}, + "with_coords": {"y": ("x", y)}, + "without_coords": {}, + } + + kwargs = {"data": array, "dims": "x", "coords": variants.get(variant)} + data_array = xr.DataArray(**kwargs) + + # FIXME: this just checks that the repr does not raise + # warnings or errors, but does not check the result + func(data_array) + + @pytest.mark.parametrize( + "func", + ( + function("all"), + function("any"), + pytest.param( + function("argmax"), + marks=pytest.mark.skip( + reason="calling np.argmax as a function on xarray objects is not " + "supported" + ), + ), + pytest.param( + function("argmin"), + marks=pytest.mark.skip( + reason="calling np.argmin as a function on xarray objects is not " + "supported" + ), + ), + function("max"), + function("mean"), + pytest.param( + function("median"), + marks=pytest.mark.skip( + reason="median does not work with dataarrays yet" + ), + ), + function("min"), + function("prod"), + function("sum"), + function("std"), + function("var"), + function("cumsum"), + function("cumprod"), + method("all"), + method("any"), + method("argmax", dim="x"), + method("argmin", dim="x"), + method("max"), + method("mean"), + method("median"), + method("min"), + method("prod"), + method("sum"), + method("std"), + method("var"), + method("cumsum"), + method("cumprod"), + ), + ids=repr, + ) + def test_aggregation(self, func, dtype): + array = np.arange(10).astype(dtype) * ( + unit_registry.m if func.name != "cumprod" else unit_registry.dimensionless + ) + data_array = xr.DataArray(data=array, dims="x") + + numpy_kwargs = func.kwargs.copy() + if "dim" in numpy_kwargs: + numpy_kwargs["axis"] = data_array.get_axis_num(numpy_kwargs.pop("dim")) + + # units differ based on the applied function, so we need to + # first compute the units + units = extract_units(func(array)) + expected = attach_units(func(strip_units(data_array)), units) + actual = func(data_array) + + assert_units_equal(expected, actual) + assert_allclose(expected, actual) + + @pytest.mark.parametrize( + "func", + ( + pytest.param(operator.neg, id="negate"), + pytest.param(abs, id="absolute"), + pytest.param(np.round, id="round"), + ), + ) + def test_unary_operations(self, func, dtype): + array = np.arange(10).astype(dtype) * unit_registry.m + data_array = xr.DataArray(data=array) + + units = extract_units(func(array)) + expected = attach_units(func(strip_units(data_array)), units) + actual = func(data_array) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + @pytest.mark.parametrize( + "func", + ( + pytest.param(lambda x: 2 * x, id="multiply"), + pytest.param(lambda x: x + x, id="add"), + pytest.param(lambda x: x[0] + x, id="add scalar"), + pytest.param(lambda x: x.T @ x, id="matrix multiply"), + ), + ) + def test_binary_operations(self, func, dtype): + array = np.arange(10).astype(dtype) * unit_registry.m + data_array = xr.DataArray(data=array) + + units = extract_units(func(array)) + with xr.set_options(use_opt_einsum=False): + expected = attach_units(func(strip_units(data_array)), units) + actual = func(data_array) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + @pytest.mark.parametrize( + "comparison", + ( + pytest.param(operator.lt, id="less_than"), + pytest.param(operator.ge, id="greater_equal"), + pytest.param(operator.eq, id="equal"), + ), + ) + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, ValueError, id="without_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.mm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_comparison_operations(self, comparison, unit, error, dtype): + array = ( + np.array([10.1, 5.2, 6.5, 8.0, 21.3, 7.1, 1.3]).astype(dtype) + * unit_registry.m + ) + data_array = xr.DataArray(data=array) + + value = 8 + to_compare_with = value * unit + + # incompatible units are all not equal + if error is not None and comparison is not operator.eq: + with pytest.raises(error): + comparison(array, to_compare_with) + + with pytest.raises(error): + comparison(data_array, to_compare_with) + + return + + actual = comparison(data_array, to_compare_with) + + expected_units = {None: unit_registry.m if array.check(unit) else None} + expected = array.check(unit) & comparison( + strip_units(data_array), + strip_units(convert_units(to_compare_with, expected_units)), + ) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + @pytest.mark.parametrize( + "units,error", + ( + pytest.param(unit_registry.dimensionless, None, id="dimensionless"), + pytest.param(unit_registry.m, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.degree, None, id="compatible_unit"), + ), + ) + def test_univariate_ufunc(self, units, error, dtype): + array = np.arange(10).astype(dtype) * units + data_array = xr.DataArray(data=array) + + func = function("sin") + + if error is not None: + with pytest.raises(error): + np.sin(data_array) + + return + + expected = attach_units( + func(strip_units(convert_units(data_array, {None: unit_registry.radians}))), + {None: unit_registry.dimensionless}, + ) + actual = func(data_array) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="without_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param( + unit_registry.mm, + None, + id="compatible_unit", + marks=pytest.mark.xfail(reason="pint converts to the wrong units"), + ), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_bivariate_ufunc(self, unit, error, dtype): + original_unit = unit_registry.m + array = np.arange(10).astype(dtype) * original_unit + data_array = xr.DataArray(data=array) + + if error is not None: + with pytest.raises(error): + np.maximum(data_array, 1 * unit) + + return + + expected_units = {None: original_unit} + expected = attach_units( + np.maximum( + strip_units(data_array), + strip_units(convert_units(1 * unit, expected_units)), + ), + expected_units, + ) + + actual = np.maximum(data_array, 1 * unit) + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + actual = np.maximum(1 * unit, data_array) + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + @pytest.mark.parametrize("property", ("T", "imag", "real")) + def test_numpy_properties(self, property, dtype): + array = ( + np.arange(5 * 10).astype(dtype) + + 1j * np.linspace(-1, 0, 5 * 10).astype(dtype) + ).reshape(5, 10) * unit_registry.s + + data_array = xr.DataArray(data=array, dims=("x", "y")) + + expected = attach_units( + getattr(strip_units(data_array), property), extract_units(data_array) + ) + actual = getattr(data_array, property) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + @pytest.mark.parametrize( + "func", + (method("conj"), method("argsort"), method("conjugate"), method("round")), + ids=repr, + ) + def test_numpy_methods(self, func, dtype): + array = np.arange(10).astype(dtype) * unit_registry.m + data_array = xr.DataArray(data=array, dims="x") + + units = extract_units(func(array)) + expected = attach_units(strip_units(data_array), units) + actual = func(data_array) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + def test_item(self, dtype): + array = np.arange(10).astype(dtype) * unit_registry.m + data_array = xr.DataArray(data=array) + + func = method("item", 2) + + expected = func(strip_units(data_array)) * unit_registry.m + actual = func(data_array) + + assert_duckarray_allclose(expected, actual) + + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + @pytest.mark.parametrize( + "func", + ( + method("searchsorted", 5), + pytest.param( + function("searchsorted", 5), + marks=pytest.mark.xfail( + reason="xarray does not implement __array_function__" + ), + ), + ), + ids=repr, + ) + def test_searchsorted(self, func, unit, error, dtype): + array = np.arange(10).astype(dtype) * unit_registry.m + data_array = xr.DataArray(data=array) + + scalar_types = (int, float) + args = list(value * unit for value in func.args) + kwargs = { + key: (value * unit if isinstance(value, scalar_types) else value) + for key, value in func.kwargs.items() + } + + if error is not None: + with pytest.raises(error): + func(data_array, *args, **kwargs) + + return + + units = extract_units(data_array) + expected_units = extract_units(func(array, *args, **kwargs)) + stripped_args = [strip_units(convert_units(value, units)) for value in args] + stripped_kwargs = { + key: strip_units(convert_units(value, units)) + for key, value in kwargs.items() + } + expected = attach_units( + func(strip_units(data_array), *stripped_args, **stripped_kwargs), + expected_units, + ) + actual = func(data_array, *args, **kwargs) + + assert_units_equal(expected, actual) + np.testing.assert_allclose(expected, actual) + + @pytest.mark.parametrize( + "func", + ( + method("clip", min=3, max=8), + pytest.param( + function("clip", a_min=3, a_max=8), + marks=pytest.mark.xfail( + reason="xarray does not implement __array_function__" + ), + ), + ), + ids=repr, + ) + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_numpy_methods_with_args(self, func, unit, error, dtype): + array = np.arange(10).astype(dtype) * unit_registry.m + data_array = xr.DataArray(data=array) + + scalar_types = (int, float) + args = list(value * unit for value in func.args) + kwargs = { + key: (value * unit if isinstance(value, scalar_types) else value) + for key, value in func.kwargs.items() + } + if error is not None: + with pytest.raises(error): + func(data_array, *args, **kwargs) + + return + + units = extract_units(data_array) + expected_units = extract_units(func(array, *args, **kwargs)) + stripped_args = [strip_units(convert_units(value, units)) for value in args] + stripped_kwargs = { + key: strip_units(convert_units(value, units)) + for key, value in kwargs.items() + } + expected = attach_units( + func(strip_units(data_array), *stripped_args, **stripped_kwargs), + expected_units, + ) + actual = func(data_array, *args, **kwargs) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + @pytest.mark.parametrize( + "func", (method("isnull"), method("notnull"), method("count")), ids=repr + ) + def test_missing_value_detection(self, func, dtype): + array = ( + np.array( + [ + [1.4, 2.3, np.nan, 7.2], + [np.nan, 9.7, np.nan, np.nan], + [2.1, np.nan, np.nan, 4.6], + [9.9, np.nan, 7.2, 9.1], + ] + ) + * unit_registry.degK + ) + data_array = xr.DataArray(data=array) + + expected = func(strip_units(data_array)) + actual = func(data_array) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + @pytest.mark.xfail(reason="ffill and bfill lose units in data") + @pytest.mark.parametrize("func", (method("ffill"), method("bfill")), ids=repr) + def test_missing_value_filling(self, func, dtype): + array = ( + np.array([1.4, np.nan, 2.3, np.nan, np.nan, 9.1]).astype(dtype) + * unit_registry.degK + ) + x = np.arange(len(array)) + data_array = xr.DataArray(data=array, coords={"x": x}, dims="x") + + expected = attach_units( + func(strip_units(data_array), dim="x"), extract_units(data_array) + ) + actual = func(data_array, dim="x") + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + @pytest.mark.parametrize( + "fill_value", + ( + pytest.param(-1, id="python_scalar"), + pytest.param(np.array(-1), id="numpy_scalar"), + pytest.param(np.array([-1]), id="numpy_array"), + ), + ) + def test_fillna(self, fill_value, unit, error, dtype): + original_unit = unit_registry.m + array = ( + np.array([1.4, np.nan, 2.3, np.nan, np.nan, 9.1]).astype(dtype) + * original_unit + ) + data_array = xr.DataArray(data=array) + + func = method("fillna") + + value = fill_value * unit + if error is not None: + with pytest.raises(error): + func(data_array, value=value) + + return + + units = extract_units(data_array) + expected = attach_units( + func( + strip_units(data_array), value=strip_units(convert_units(value, units)) + ), + units, + ) + actual = func(data_array, value=value) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + def test_dropna(self, dtype): + array = ( + np.array([1.4, np.nan, 2.3, np.nan, np.nan, 9.1]).astype(dtype) + * unit_registry.m + ) + x = np.arange(len(array)) + data_array = xr.DataArray(data=array, coords={"x": x}, dims=["x"]) + + units = extract_units(data_array) + expected = attach_units(strip_units(data_array).dropna(dim="x"), units) + actual = data_array.dropna(dim="x") + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + @pytest.mark.parametrize( + "unit", + ( + pytest.param(1, id="no_unit"), + pytest.param(unit_registry.dimensionless, id="dimensionless"), + pytest.param(unit_registry.s, id="incompatible_unit"), + pytest.param(unit_registry.cm, id="compatible_unit"), + pytest.param(unit_registry.m, id="identical_unit"), + ), + ) + def test_isin(self, unit, dtype): + array = ( + np.array([1.4, np.nan, 2.3, np.nan, np.nan, 9.1]).astype(dtype) + * unit_registry.m + ) + data_array = xr.DataArray(data=array, dims="x") + + raw_values = np.array([1.4, np.nan, 2.3]).astype(dtype) + values = raw_values * unit + + units = {None: unit_registry.m if array.check(unit) else None} + expected = strip_units(data_array).isin( + strip_units(convert_units(values, units)) + ) & array.check(unit) + actual = data_array.isin(values) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + @pytest.mark.parametrize( + "variant", ("masking", "replacing_scalar", "replacing_array", "dropping") + ) + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_where(self, variant, unit, error, dtype): + original_unit = unit_registry.m + array = np.linspace(0, 1, 10).astype(dtype) * original_unit + + data_array = xr.DataArray(data=array) + + condition = data_array < 0.5 * original_unit + other = np.linspace(-2, -1, 10).astype(dtype) * unit + variant_kwargs = { + "masking": {"cond": condition}, + "replacing_scalar": {"cond": condition, "other": -1 * unit}, + "replacing_array": {"cond": condition, "other": other}, + "dropping": {"cond": condition, "drop": True}, + } + kwargs = variant_kwargs.get(variant) + kwargs_without_units = { + key: strip_units( + convert_units( + value, {None: original_unit if array.check(unit) else None} + ) + ) + for key, value in kwargs.items() + } + + if variant not in ("masking", "dropping") and error is not None: + with pytest.raises(error): + data_array.where(**kwargs) + + return + + expected = attach_units( + strip_units(data_array).where(**kwargs_without_units), + extract_units(data_array), + ) + actual = data_array.where(**kwargs) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + @pytest.mark.xfail(reason="uses numpy.vectorize") + def test_interpolate_na(self): + array = ( + np.array([-1.03, 0.1, 1.4, np.nan, 2.3, np.nan, np.nan, 9.1]) + * unit_registry.m + ) + x = np.arange(len(array)) + data_array = xr.DataArray(data=array, coords={"x": x}, dims="x") + + units = extract_units(data_array) + expected = attach_units(strip_units(data_array).interpolate_na(dim="x"), units) + actual = data_array.interpolate_na(dim="x") + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param( + unit_registry.cm, + None, + id="compatible_unit", + ), + pytest.param( + unit_registry.m, + None, + id="identical_unit", + ), + ), + ) + def test_combine_first(self, unit, error, dtype): + array = np.zeros(shape=(2, 2), dtype=dtype) * unit_registry.m + other_array = np.ones_like(array) * unit + + data_array = xr.DataArray( + data=array, coords={"x": ["a", "b"], "y": [-1, 0]}, dims=["x", "y"] + ) + other = xr.DataArray( + data=other_array, coords={"x": ["b", "c"], "y": [0, 1]}, dims=["x", "y"] + ) + + if error is not None: + with pytest.raises(error): + data_array.combine_first(other) + + return + + units = extract_units(data_array) + expected = attach_units( + strip_units(data_array).combine_first( + strip_units(convert_units(other, units)) + ), + units, + ) + actual = data_array.combine_first(other) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + @pytest.mark.parametrize( + "unit", + ( + pytest.param(1, id="no_unit"), + pytest.param(unit_registry.dimensionless, id="dimensionless"), + pytest.param(unit_registry.s, id="incompatible_unit"), + pytest.param(unit_registry.cm, id="compatible_unit"), + pytest.param(unit_registry.m, id="identical_unit"), + ), + ) + @pytest.mark.parametrize( + "variation", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.skip(reason="indexes don't support units") + ), + "coords", + ), + ) + @pytest.mark.parametrize( + "func", + ( + method("equals"), + pytest.param( + method("identical"), + marks=pytest.mark.skip(reason="the behavior of identical is undecided"), + ), + ), + ids=repr, + ) + def test_comparisons(self, func, variation, unit, dtype): + def is_compatible(a, b): + a = a if a is not None else 1 + b = b if b is not None else 1 + quantity = np.arange(5) * a + + return a == b or quantity.check(b) + + data = np.linspace(0, 5, 10).astype(dtype) + coord = np.arange(len(data)).astype(dtype) + + base_unit = unit_registry.m + array = data * (base_unit if variation == "data" else 1) + x = coord * (base_unit if variation == "dims" else 1) + y = coord * (base_unit if variation == "coords" else 1) + + variations = { + "data": (unit, 1, 1), + "dims": (1, unit, 1), + "coords": (1, 1, unit), + } + data_unit, dim_unit, coord_unit = variations.get(variation) + + data_array = xr.DataArray(data=array, coords={"x": x, "y": ("x", y)}, dims="x") + + other = attach_units( + strip_units(data_array), {None: data_unit, "x": dim_unit, "y": coord_unit} + ) + + units = extract_units(data_array) + other_units = extract_units(other) + + equal_arrays = all( + is_compatible(units[name], other_units[name]) for name in units.keys() + ) and ( + strip_units(data_array).equals( + strip_units(convert_units(other, extract_units(data_array))) + ) + ) + equal_units = units == other_units + expected = equal_arrays and (func.name != "identical" or equal_units) + + actual = func(data_array, other) + + assert expected == actual + + @pytest.mark.parametrize( + "unit", + ( + pytest.param(1, id="no_unit"), + pytest.param(unit_registry.dimensionless, id="dimensionless"), + pytest.param(unit_registry.s, id="incompatible_unit"), + pytest.param(unit_registry.cm, id="compatible_unit"), + pytest.param(unit_registry.m, id="identical_unit"), + ), + ) + @pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.skip(reason="indexes don't support units") + ), + "coords", + ), + ) + def test_broadcast_like(self, variant, unit, dtype): + original_unit = unit_registry.m + + variants = { + "data": ((original_unit, unit), (1, 1), (1, 1)), + "dims": ((1, 1), (original_unit, unit), (1, 1)), + "coords": ((1, 1), (1, 1), (original_unit, unit)), + } + ( + (data_unit1, data_unit2), + (dim_unit1, dim_unit2), + (coord_unit1, coord_unit2), + ) = variants.get(variant) + + array1 = np.linspace(1, 2, 2 * 1).reshape(2, 1).astype(dtype) * data_unit1 + array2 = np.linspace(0, 1, 2 * 3).reshape(2, 3).astype(dtype) * data_unit2 + + x1 = np.arange(2) * dim_unit1 + x2 = np.arange(2) * dim_unit2 + y1 = np.array([0]) * dim_unit1 + y2 = np.arange(3) * dim_unit2 + + u1 = np.linspace(0, 1, 2) * coord_unit1 + u2 = np.linspace(0, 1, 2) * coord_unit2 + + arr1 = xr.DataArray( + data=array1, coords={"x": x1, "y": y1, "u": ("x", u1)}, dims=("x", "y") + ) + arr2 = xr.DataArray( + data=array2, coords={"x": x2, "y": y2, "u": ("x", u2)}, dims=("x", "y") + ) + + expected = attach_units( + strip_units(arr1).broadcast_like(strip_units(arr2)), extract_units(arr1) + ) + actual = arr1.broadcast_like(arr2) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + @pytest.mark.parametrize( + "unit", + ( + pytest.param(1, id="no_unit"), + pytest.param(unit_registry.dimensionless, id="dimensionless"), + pytest.param(unit_registry.s, id="incompatible_unit"), + pytest.param(unit_registry.cm, id="compatible_unit"), + pytest.param(unit_registry.m, id="identical_unit"), + ), + ) + def test_broadcast_equals(self, unit, dtype): + left_array = np.ones(shape=(2, 2), dtype=dtype) * unit_registry.m + right_array = np.ones(shape=(2,), dtype=dtype) * unit + + left = xr.DataArray(data=left_array, dims=("x", "y")) + right = xr.DataArray(data=right_array, dims="x") + + units = { + **extract_units(left), + **({} if left_array.check(unit) else {None: None}), + } + expected = strip_units(left).broadcast_equals( + strip_units(convert_units(right, units)) + ) & left_array.check(unit) + actual = left.broadcast_equals(right) + + assert expected == actual + + def test_pad(self, dtype): + array = np.linspace(0, 5, 10).astype(dtype) * unit_registry.m + + data_array = xr.DataArray(data=array, dims="x") + units = extract_units(data_array) + + expected = attach_units(strip_units(data_array).pad(x=(2, 3)), units) + actual = data_array.pad(x=(2, 3)) + + assert_units_equal(expected, actual) + assert_equal(expected, actual) + + @pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.skip(reason="indexes don't support units") + ), + "coords", + ), + ) + @pytest.mark.parametrize( + "func", + ( + method("pipe", lambda da: da * 10), + method("assign_coords", w=("y", np.arange(10) * unit_registry.mm)), + method("assign_attrs", attr1="value"), + method("rename", u="v"), + pytest.param( + method("swap_dims", {"x": "u"}), + marks=pytest.mark.skip(reason="indexes don't support units"), + ), + pytest.param( + method( + "expand_dims", + dim={"z": np.linspace(10, 20, 12) * unit_registry.s}, + axis=1, + ), + marks=pytest.mark.skip(reason="indexes don't support units"), + ), + method("drop_vars", "x"), + method("reset_coords", names="u"), + method("copy"), + method("astype", np.float32), + ), + ids=repr, + ) + def test_content_manipulation(self, func, variant, dtype): + unit = unit_registry.m + + variants = { + "data": (unit, 1, 1), + "dims": (1, unit, 1), + "coords": (1, 1, unit), + } + data_unit, dim_unit, coord_unit = variants.get(variant) + + quantity = np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * data_unit + x = np.arange(quantity.shape[0]) * dim_unit + y = np.arange(quantity.shape[1]) * dim_unit + u = np.linspace(0, 1, quantity.shape[0]) * coord_unit + + data_array = xr.DataArray( + name="a", + data=quantity, + coords={"x": x, "u": ("x", u), "y": y}, + dims=("x", "y"), + ) + + stripped_kwargs = { + key: array_strip_units(value) for key, value in func.kwargs.items() + } + units = extract_units(data_array) + units["u"] = getattr(u, "units", None) + units["v"] = getattr(u, "units", None) + + expected = attach_units(func(strip_units(data_array), **stripped_kwargs), units) + actual = func(data_array) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + @pytest.mark.parametrize( + "unit", + ( + pytest.param(1, id="no_unit"), + pytest.param(unit_registry.dimensionless, id="dimensionless"), + pytest.param(unit_registry.degK, id="with_unit"), + ), + ) + def test_copy(self, unit, dtype): + quantity = np.linspace(0, 10, 20, dtype=dtype) * unit_registry.pascal + new_data = np.arange(20) + + data_array = xr.DataArray(data=quantity, dims="x") + + expected = attach_units( + strip_units(data_array).copy(data=new_data), {None: unit} + ) + + actual = data_array.copy(data=new_data * unit) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + @pytest.mark.parametrize( + "indices", + ( + pytest.param(4, id="single index"), + pytest.param([5, 2, 9, 1], id="multiple indices"), + ), + ) + def test_isel(self, indices, dtype): + # TODO: maybe test for units in indexes? + array = np.arange(10).astype(dtype) * unit_registry.s + + data_array = xr.DataArray(data=array, dims="x") + + expected = attach_units( + strip_units(data_array).isel(x=indices), extract_units(data_array) + ) + actual = data_array.isel(x=indices) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + @pytest.mark.skip(reason="indexes don't support units") + @pytest.mark.parametrize( + "raw_values", + ( + pytest.param(10, id="single_value"), + pytest.param([10, 5, 13], id="list_of_values"), + pytest.param(np.array([9, 3, 7, 12]), id="array_of_values"), + ), + ) + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, KeyError, id="no_units"), + pytest.param(unit_registry.dimensionless, KeyError, id="dimensionless"), + pytest.param(unit_registry.degree, KeyError, id="incompatible_unit"), + pytest.param(unit_registry.dm, KeyError, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_sel(self, raw_values, unit, error, dtype): + array = np.linspace(5, 10, 20).astype(dtype) * unit_registry.m + x = np.arange(len(array)) * unit_registry.m + data_array = xr.DataArray(data=array, coords={"x": x}, dims="x") + + values = raw_values * unit + + if error is not None and not ( + isinstance(raw_values, (int, float)) and x.check(unit) + ): + with pytest.raises(error): + data_array.sel(x=values) + + return + + expected = attach_units( + strip_units(data_array).sel( + x=strip_units(convert_units(values, {None: array.units})) + ), + extract_units(data_array), + ) + actual = data_array.sel(x=values) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + @pytest.mark.skip(reason="indexes don't support units") + @pytest.mark.parametrize( + "raw_values", + ( + pytest.param(10, id="single_value"), + pytest.param([10, 5, 13], id="list_of_values"), + pytest.param(np.array([9, 3, 7, 12]), id="array_of_values"), + ), + ) + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, KeyError, id="no_units"), + pytest.param(unit_registry.dimensionless, KeyError, id="dimensionless"), + pytest.param(unit_registry.degree, KeyError, id="incompatible_unit"), + pytest.param(unit_registry.dm, KeyError, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_loc(self, raw_values, unit, error, dtype): + array = np.linspace(5, 10, 20).astype(dtype) * unit_registry.m + x = np.arange(len(array)) * unit_registry.m + data_array = xr.DataArray(data=array, coords={"x": x}, dims="x") + + values = raw_values * unit + + if error is not None and not ( + isinstance(raw_values, (int, float)) and x.check(unit) + ): + with pytest.raises(error): + data_array.loc[{"x": values}] + + return + + expected = attach_units( + strip_units(data_array).loc[ + {"x": strip_units(convert_units(values, {None: array.units}))} + ], + extract_units(data_array), + ) + actual = data_array.loc[{"x": values}] + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + @pytest.mark.skip(reason="indexes don't support units") + @pytest.mark.parametrize( + "raw_values", + ( + pytest.param(10, id="single_value"), + pytest.param([10, 5, 13], id="list_of_values"), + pytest.param(np.array([9, 3, 7, 12]), id="array_of_values"), + ), + ) + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, KeyError, id="no_units"), + pytest.param(unit_registry.dimensionless, KeyError, id="dimensionless"), + pytest.param(unit_registry.degree, KeyError, id="incompatible_unit"), + pytest.param(unit_registry.dm, KeyError, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_drop_sel(self, raw_values, unit, error, dtype): + array = np.linspace(5, 10, 20).astype(dtype) * unit_registry.m + x = np.arange(len(array)) * unit_registry.m + data_array = xr.DataArray(data=array, coords={"x": x}, dims="x") + + values = raw_values * unit + + if error is not None and not ( + isinstance(raw_values, (int, float)) and x.check(unit) + ): + with pytest.raises(error): + data_array.drop_sel(x=values) + + return + + expected = attach_units( + strip_units(data_array).drop_sel( + x=strip_units(convert_units(values, {None: x.units})) + ), + extract_units(data_array), + ) + actual = data_array.drop_sel(x=values) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + @pytest.mark.parametrize("dim", ("x", "y", "z", "t", "all")) + @pytest.mark.parametrize( + "shape", + ( + pytest.param((10, 20), id="nothing_squeezable"), + pytest.param((10, 20, 1), id="last_dimension_squeezable"), + pytest.param((10, 1, 20), id="middle_dimension_squeezable"), + pytest.param((1, 10, 20), id="first_dimension_squeezable"), + pytest.param((1, 10, 1, 20), id="first_and_last_dimension_squeezable"), + ), + ) + def test_squeeze(self, shape, dim, dtype): + names = "xyzt" + dim_lengths = dict(zip(names, shape)) + names = "xyzt" + array = np.arange(10 * 20).astype(dtype).reshape(shape) * unit_registry.J + data_array = xr.DataArray(data=array, dims=tuple(names[: len(shape)])) + + kwargs = {"dim": dim} if dim != "all" and dim_lengths.get(dim, 0) == 1 else {} + + expected = attach_units( + strip_units(data_array).squeeze(**kwargs), extract_units(data_array) + ) + actual = data_array.squeeze(**kwargs) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + @pytest.mark.parametrize( + "func", + (method("head", x=7, y=3), method("tail", x=7, y=3), method("thin", x=7, y=3)), + ids=repr, + ) + def test_head_tail_thin(self, func, dtype): + # TODO: works like isel. Maybe also test units in indexes? + array = np.linspace(1, 2, 10 * 5).reshape(10, 5) * unit_registry.degK + + data_array = xr.DataArray(data=array, dims=("x", "y")) + + expected = attach_units( + func(strip_units(data_array)), extract_units(data_array) + ) + actual = func(data_array) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + @pytest.mark.parametrize("variant", ("data", "coords")) + @pytest.mark.parametrize( + "func", + ( + pytest.param( + method("interp"), marks=pytest.mark.xfail(reason="uses scipy") + ), + method("reindex"), + ), + ids=repr, + ) + def test_interp_reindex(self, variant, func, dtype): + variants = { + "data": (unit_registry.m, 1), + "coords": (1, unit_registry.m), + } + data_unit, coord_unit = variants.get(variant) + + array = np.linspace(1, 2, 10).astype(dtype) * data_unit + y = np.arange(10) * coord_unit + + x = np.arange(10) + new_x = np.arange(10) + 0.5 + data_array = xr.DataArray(array, coords={"x": x, "y": ("x", y)}, dims="x") + + units = extract_units(data_array) + expected = attach_units(func(strip_units(data_array), x=new_x), units) + actual = func(data_array, x=new_x) + + assert_units_equal(expected, actual) + assert_allclose(expected, actual) + + @pytest.mark.skip(reason="indexes don't support units") + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + @pytest.mark.parametrize( + "func", + (method("interp"), method("reindex")), + ids=repr, + ) + def test_interp_reindex_indexing(self, func, unit, error, dtype): + array = np.linspace(1, 2, 10).astype(dtype) + x = np.arange(10) * unit_registry.m + new_x = (np.arange(10) + 0.5) * unit + data_array = xr.DataArray(array, coords={"x": x}, dims="x") + + if error is not None: + with pytest.raises(error): + func(data_array, x=new_x) + + return + + units = extract_units(data_array) + expected = attach_units( + func( + strip_units(data_array), + x=strip_units(convert_units(new_x, {None: unit_registry.m})), + ), + units, + ) + actual = func(data_array, x=new_x) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + @pytest.mark.parametrize("variant", ("data", "coords")) + @pytest.mark.parametrize( + "func", + ( + pytest.param( + method("interp_like"), marks=pytest.mark.xfail(reason="uses scipy") + ), + method("reindex_like"), + ), + ids=repr, + ) + def test_interp_reindex_like(self, variant, func, dtype): + variants = { + "data": (unit_registry.m, 1), + "coords": (1, unit_registry.m), + } + data_unit, coord_unit = variants.get(variant) + + array = np.linspace(1, 2, 10).astype(dtype) * data_unit + coord = np.arange(10) * coord_unit + + x = np.arange(10) + new_x = np.arange(-2, 2) + 0.5 + data_array = xr.DataArray(array, coords={"x": x, "y": ("x", coord)}, dims="x") + other = xr.DataArray(np.empty_like(new_x), coords={"x": new_x}, dims="x") + + units = extract_units(data_array) + expected = attach_units(func(strip_units(data_array), other), units) + actual = func(data_array, other) + + assert_units_equal(expected, actual) + assert_allclose(expected, actual) + + @pytest.mark.skip(reason="indexes don't support units") + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + @pytest.mark.parametrize( + "func", + (method("interp_like"), method("reindex_like")), + ids=repr, + ) + def test_interp_reindex_like_indexing(self, func, unit, error, dtype): + array = np.linspace(1, 2, 10).astype(dtype) + x = np.arange(10) * unit_registry.m + new_x = (np.arange(-2, 2) + 0.5) * unit + + data_array = xr.DataArray(array, coords={"x": x}, dims="x") + other = xr.DataArray(np.empty_like(new_x), {"x": new_x}, dims="x") + + if error is not None: + with pytest.raises(error): + func(data_array, other) + + return + + units = extract_units(data_array) + expected = attach_units( + func( + strip_units(data_array), + strip_units(convert_units(other, {None: unit_registry.m})), + ), + units, + ) + actual = func(data_array, other) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + @pytest.mark.parametrize( + "func", + (method("unstack"), method("reset_index", "z"), method("reorder_levels")), + ids=repr, + ) + def test_stacking_stacked(self, func, dtype): + array = ( + np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * unit_registry.m + ) + x = np.arange(array.shape[0]) + y = np.arange(array.shape[1]) + + data_array = xr.DataArray( + name="data", data=array, coords={"x": x, "y": y}, dims=("x", "y") + ) + stacked = data_array.stack(z=("x", "y")) + + expected = attach_units(func(strip_units(stacked)), {"data": unit_registry.m}) + actual = func(stacked) + + assert_units_equal(expected, actual) + if func.name == "reset_index": + assert_identical(expected, actual, check_default_indexes=False) + else: + assert_identical(expected, actual) + + @pytest.mark.skip(reason="indexes don't support units") + def test_to_unstacked_dataset(self, dtype): + array = ( + np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) + * unit_registry.pascal + ) + x = np.arange(array.shape[0]) * unit_registry.m + y = np.arange(array.shape[1]) * unit_registry.s + + data_array = xr.DataArray( + data=array, coords={"x": x, "y": y}, dims=("x", "y") + ).stack(z=("x", "y")) + + func = method("to_unstacked_dataset", dim="z") + + expected = attach_units( + func(strip_units(data_array)), + {"y": y.units, **dict(zip(x.magnitude, [array.units] * len(y)))}, + ).rename({elem.magnitude: elem for elem in x}) + actual = func(data_array) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + @pytest.mark.parametrize( + "func", + ( + method("transpose", "y", "x", "z"), + method("stack", a=("x", "y")), + method("set_index", x="x2"), + method("shift", x=2), + pytest.param( + method("rank", dim="x"), + marks=pytest.mark.skip(reason="rank not implemented for non-ndarray"), + ), + method("roll", x=2, roll_coords=False), + method("sortby", "x2"), + ), + ids=repr, + ) + def test_stacking_reordering(self, func, dtype): + array = ( + np.linspace(0, 10, 2 * 5 * 10).reshape(2, 5, 10).astype(dtype) + * unit_registry.m + ) + x = np.arange(array.shape[0]) + y = np.arange(array.shape[1]) + z = np.arange(array.shape[2]) + x2 = np.linspace(0, 1, array.shape[0])[::-1] + + data_array = xr.DataArray( + name="data", + data=array, + coords={"x": x, "y": y, "z": z, "x2": ("x", x2)}, + dims=("x", "y", "z"), + ) + + expected = attach_units(func(strip_units(data_array)), {None: unit_registry.m}) + actual = func(data_array) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + @pytest.mark.parametrize( + "variant", + ( + pytest.param( + "dims", marks=pytest.mark.skip(reason="indexes don't support units") + ), + "coords", + ), + ) + @pytest.mark.parametrize( + "func", + ( + method("differentiate", fallback_func=np.gradient), + method("integrate", fallback_func=duck_array_ops.cumulative_trapezoid), + method("cumulative_integrate", fallback_func=duck_array_ops.trapz), + ), + ids=repr, + ) + def test_differentiate_integrate(self, func, variant, dtype): + data_unit = unit_registry.m + unit = unit_registry.s + + variants = { + "dims": ("x", unit, 1), + "coords": ("u", 1, unit), + } + coord, dim_unit, coord_unit = variants.get(variant) + + array = np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * data_unit + + x = np.arange(array.shape[0]) * dim_unit + y = np.arange(array.shape[1]) * dim_unit + + u = np.linspace(0, 1, array.shape[0]) * coord_unit + + data_array = xr.DataArray( + data=array, coords={"x": x, "y": y, "u": ("x", u)}, dims=("x", "y") + ) + # we want to make sure the output unit is correct + units = extract_units(data_array) + units.update( + extract_units( + func( + data_array.data, + getattr(data_array, coord).data, + axis=0, + ) + ) + ) + + expected = attach_units( + func(strip_units(data_array), coord=strip_units(coord)), + units, + ) + actual = func(data_array, coord=coord) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + @pytest.mark.parametrize("compute_backend", ["numbagg", None], indirect=True) + @pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.skip(reason="indexes don't support units") + ), + "coords", + ), + ) + @pytest.mark.parametrize( + "func", + ( + method("diff", dim="x"), + method("quantile", q=[0.25, 0.75]), + method("reduce", func=np.sum, dim="x"), + pytest.param(lambda x: x.dot(x), id="method_dot"), + ), + ids=repr, + ) + def test_computation(self, func, variant, dtype, compute_backend): + unit = unit_registry.m + + variants = { + "data": (unit, 1, 1), + "dims": (1, unit, 1), + "coords": (1, 1, unit), + } + data_unit, dim_unit, coord_unit = variants.get(variant) + + array = np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * data_unit + + x = np.arange(array.shape[0]) * dim_unit + y = np.arange(array.shape[1]) * dim_unit + + u = np.linspace(0, 1, array.shape[0]) * coord_unit + + data_array = xr.DataArray( + data=array, coords={"x": x, "y": y, "u": ("x", u)}, dims=("x", "y") + ) + + # we want to make sure the output unit is correct + units = extract_units(data_array) + if not isinstance(func, (function, method)): + units.update(extract_units(func(array.reshape(-1)))) + + with xr.set_options(use_opt_einsum=False): + expected = attach_units(func(strip_units(data_array)), units) + actual = func(data_array) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + @pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.skip(reason="indexes don't support units") + ), + "coords", + ), + ) + @pytest.mark.parametrize( + "func", + ( + method("groupby", "x"), + method("groupby_bins", "y", bins=4), + method("coarsen", y=2), + method("rolling", y=3), + pytest.param(method("rolling_exp", y=3), marks=requires_numbagg), + method("weighted", xr.DataArray(data=np.linspace(0, 1, 10), dims="y")), + ), + ids=repr, + ) + def test_computation_objects(self, func, variant, dtype): + if variant == "data": + if func.name == "rolling_exp": + pytest.xfail(reason="numbagg functions are not supported by pint") + elif func.name == "rolling": + pytest.xfail( + reason="numpy.lib.stride_tricks.as_strided converts to ndarray" + ) + + unit = unit_registry.m + + variants = { + "data": (unit, 1, 1), + "dims": (1, unit, 1), + "coords": (1, 1, unit), + } + data_unit, dim_unit, coord_unit = variants.get(variant) + + array = np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * data_unit + + x = np.array([0, 0, 1, 2, 2]) * dim_unit + y = np.arange(array.shape[1]) * 3 * dim_unit + + u = np.linspace(0, 1, 5) * coord_unit + + data_array = xr.DataArray( + data=array, coords={"x": x, "y": y, "u": ("x", u)}, dims=("x", "y") + ) + units = extract_units(data_array) + + expected = attach_units(func(strip_units(data_array)).mean(), units) + actual = func(data_array).mean() + + assert_units_equal(expected, actual) + assert_allclose(expected, actual) + + def test_resample(self, dtype): + array = np.linspace(0, 5, 10).astype(dtype) * unit_registry.m + + time = xr.date_range("10-09-2010", periods=len(array), freq="YE") + data_array = xr.DataArray(data=array, coords={"time": time}, dims="time") + units = extract_units(data_array) + + func = method("resample", time="6ME") + + expected = attach_units(func(strip_units(data_array)).mean(), units) + actual = func(data_array).mean() + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + @pytest.mark.parametrize("compute_backend", ["numbagg", None], indirect=True) + @pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.skip(reason="indexes don't support units") + ), + "coords", + ), + ) + @pytest.mark.parametrize( + "func", + ( + method("assign_coords", z=("x", np.arange(5) * unit_registry.s)), + method("first"), + method("last"), + method("quantile", q=[0.25, 0.5, 0.75], dim="x"), + ), + ids=repr, + ) + def test_grouped_operations(self, func, variant, dtype, compute_backend): + unit = unit_registry.m + + variants = { + "data": (unit, 1, 1), + "dims": (1, unit, 1), + "coords": (1, 1, unit), + } + data_unit, dim_unit, coord_unit = variants.get(variant) + array = np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * data_unit + + x = np.arange(array.shape[0]) * dim_unit + y = np.arange(array.shape[1]) * 3 * dim_unit + + u = np.linspace(0, 1, array.shape[0]) * coord_unit + + data_array = xr.DataArray( + data=array, coords={"x": x, "y": y, "u": ("x", u)}, dims=("x", "y") + ) + units = {**extract_units(data_array), **{"z": unit_registry.s, "q": None}} + + stripped_kwargs = { + key: ( + strip_units(value) + if not isinstance(value, tuple) + else tuple(strip_units(elem) for elem in value) + ) + for key, value in func.kwargs.items() + } + expected = attach_units( + func( + strip_units(data_array).groupby("y", squeeze=False), **stripped_kwargs + ), + units, + ) + actual = func(data_array.groupby("y", squeeze=False)) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + + +class TestDataset: + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, xr.MergeError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, xr.MergeError, id="dimensionless" + ), + pytest.param(unit_registry.s, xr.MergeError, id="incompatible_unit"), + pytest.param(unit_registry.mm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="same_unit"), + ), + ) + @pytest.mark.parametrize( + "shared", + ( + "nothing", + pytest.param( + "dims", marks=pytest.mark.skip(reason="indexes don't support units") + ), + "coords", + ), + ) + def test_init(self, shared, unit, error, dtype): + original_unit = unit_registry.m + scaled_unit = unit_registry.mm + + a = np.linspace(0, 1, 10).astype(dtype) * unit_registry.Pa + b = np.linspace(-1, 0, 10).astype(dtype) * unit_registry.degK + + values_a = np.arange(a.shape[0]) + dim_a = values_a * original_unit + coord_a = dim_a.to(scaled_unit) + + values_b = np.arange(b.shape[0]) + dim_b = values_b * unit + coord_b = ( + dim_b.to(scaled_unit) + if unit_registry.is_compatible_with(dim_b, scaled_unit) + and unit != scaled_unit + else dim_b * 1000 + ) + + variants = { + "nothing": ({}, {}), + "dims": ({"x": dim_a}, {"x": dim_b}), + "coords": ( + {"x": values_a, "y": ("x", coord_a)}, + {"x": values_b, "y": ("x", coord_b)}, + ), + } + coords_a, coords_b = variants.get(shared) + + dims_a, dims_b = ("x", "y") if shared == "nothing" else ("x", "x") + + a = xr.DataArray(data=a, coords=coords_a, dims=dims_a) + b = xr.DataArray(data=b, coords=coords_b, dims=dims_b) + + if error is not None and shared != "nothing": + with pytest.raises(error): + xr.Dataset(data_vars={"a": a, "b": b}) + + return + + actual = xr.Dataset(data_vars={"a": a, "b": b}) + + units = merge_mappings( + extract_units(a.rename("a")), extract_units(b.rename("b")) + ) + expected = attach_units( + xr.Dataset(data_vars={"a": strip_units(a), "b": strip_units(b)}), units + ) + + assert_units_equal(expected, actual) + assert_equal(expected, actual) + + @pytest.mark.parametrize( + "func", (pytest.param(str, id="str"), pytest.param(repr, id="repr")) + ) + @pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", + marks=pytest.mark.skip(reason="indexes don't support units"), + ), + "coords", + ), + ) + def test_repr(self, func, variant, dtype): + unit1, unit2 = ( + (unit_registry.Pa, unit_registry.degK) if variant == "data" else (1, 1) + ) + + array1 = np.linspace(1, 2, 10, dtype=dtype) * unit1 + array2 = np.linspace(0, 1, 10, dtype=dtype) * unit2 + + x = np.arange(len(array1)) * unit_registry.s + y = x.to(unit_registry.ms) + + variants = { + "dims": {"x": x}, + "coords": {"y": ("x", y)}, + "data": {}, + } + + ds = xr.Dataset( + data_vars={"a": ("x", array1), "b": ("x", array2)}, + coords=variants.get(variant), + ) + + # FIXME: this just checks that the repr does not raise + # warnings or errors, but does not check the result + func(ds) + + @pytest.mark.parametrize( + "func", + ( + method("all"), + method("any"), + method("argmax", dim="x"), + method("argmin", dim="x"), + method("max"), + method("min"), + method("mean"), + method("median"), + method("sum"), + method("prod"), + method("std"), + method("var"), + method("cumsum"), + method("cumprod"), + ), + ids=repr, + ) + def test_aggregation(self, func, dtype): + unit_a, unit_b = ( + (unit_registry.Pa, unit_registry.degK) + if func.name != "cumprod" + else (unit_registry.dimensionless, unit_registry.dimensionless) + ) + + a = np.linspace(0, 1, 10).astype(dtype) * unit_a + b = np.linspace(-1, 0, 10).astype(dtype) * unit_b + + ds = xr.Dataset({"a": ("x", a), "b": ("x", b)}) + + if "dim" in func.kwargs: + numpy_kwargs = func.kwargs.copy() + dim = numpy_kwargs.pop("dim") + + axis_a = ds.a.get_axis_num(dim) + axis_b = ds.b.get_axis_num(dim) + + numpy_kwargs_a = numpy_kwargs.copy() + numpy_kwargs_a["axis"] = axis_a + numpy_kwargs_b = numpy_kwargs.copy() + numpy_kwargs_b["axis"] = axis_b + else: + numpy_kwargs_a = {} + numpy_kwargs_b = {} + + units_a = array_extract_units(func(a, **numpy_kwargs_a)) + units_b = array_extract_units(func(b, **numpy_kwargs_b)) + units = {"a": units_a, "b": units_b} + + actual = func(ds) + expected = attach_units(func(strip_units(ds)), units) + + assert_units_equal(expected, actual) + assert_allclose(expected, actual) + + @pytest.mark.parametrize("property", ("imag", "real")) + def test_numpy_properties(self, property, dtype): + a = np.linspace(0, 1, 10) * unit_registry.Pa + b = np.linspace(-1, 0, 15) * unit_registry.degK + ds = xr.Dataset({"a": ("x", a), "b": ("y", b)}) + units = extract_units(ds) + + actual = getattr(ds, property) + expected = attach_units(getattr(strip_units(ds), property), units) + + assert_units_equal(expected, actual) + assert_equal(expected, actual) + + @pytest.mark.parametrize( + "func", + ( + method("astype", float), + method("conj"), + method("argsort"), + method("conjugate"), + method("round"), + ), + ids=repr, + ) + def test_numpy_methods(self, func, dtype): + a = np.linspace(1, -1, 10) * unit_registry.Pa + b = np.linspace(-1, 1, 15) * unit_registry.degK + ds = xr.Dataset({"a": ("x", a), "b": ("y", b)}) + + units_a = array_extract_units(func(a)) + units_b = array_extract_units(func(b)) + units = {"a": units_a, "b": units_b} + + actual = func(ds) + expected = attach_units(func(strip_units(ds)), units) + + assert_units_equal(expected, actual) + assert_equal(expected, actual) + + @pytest.mark.parametrize("func", (method("clip", min=3, max=8),), ids=repr) + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_numpy_methods_with_args(self, func, unit, error, dtype): + data_unit = unit_registry.m + a = np.linspace(0, 10, 15) * unit_registry.m + b = np.linspace(-2, 12, 20) * unit_registry.m + ds = xr.Dataset({"a": ("x", a), "b": ("y", b)}) + units = extract_units(ds) + + kwargs = { + key: array_attach_units(value, unit) for key, value in func.kwargs.items() + } + + if error is not None: + with pytest.raises(error): + func(ds, **kwargs) + + return + + stripped_kwargs = { + key: strip_units(convert_units(value, {None: data_unit})) + for key, value in kwargs.items() + } + + actual = func(ds, **kwargs) + expected = attach_units(func(strip_units(ds), **stripped_kwargs), units) + + assert_units_equal(expected, actual) + assert_equal(expected, actual) + + @pytest.mark.parametrize( + "func", (method("isnull"), method("notnull"), method("count")), ids=repr + ) + def test_missing_value_detection(self, func, dtype): + array1 = ( + np.array( + [ + [1.4, 2.3, np.nan, 7.2], + [np.nan, 9.7, np.nan, np.nan], + [2.1, np.nan, np.nan, 4.6], + [9.9, np.nan, 7.2, 9.1], + ] + ) + * unit_registry.degK + ) + array2 = ( + np.array( + [ + [np.nan, 5.7, 12.0, 7.2], + [np.nan, 12.4, np.nan, 4.2], + [9.8, np.nan, 4.6, 1.4], + [7.2, np.nan, 6.3, np.nan], + [8.4, 3.9, np.nan, np.nan], + ] + ) + * unit_registry.Pa + ) + + ds = xr.Dataset({"a": (("x", "y"), array1), "b": (("z", "x"), array2)}) + + expected = func(strip_units(ds)) + actual = func(ds) + + assert_units_equal(expected, actual) + assert_equal(expected, actual) + + @pytest.mark.xfail(reason="ffill and bfill lose the unit") + @pytest.mark.parametrize("func", (method("ffill"), method("bfill")), ids=repr) + def test_missing_value_filling(self, func, dtype): + array1 = ( + np.array([1.4, np.nan, 2.3, np.nan, np.nan, 9.1]).astype(dtype) + * unit_registry.degK + ) + array2 = ( + np.array([4.3, 9.8, 7.5, np.nan, 8.2, np.nan]).astype(dtype) + * unit_registry.Pa + ) + + ds = xr.Dataset({"a": ("x", array1), "b": ("y", array2)}) + units = extract_units(ds) + + expected = attach_units(func(strip_units(ds), dim="x"), units) + actual = func(ds, dim="x") + + assert_units_equal(expected, actual) + assert_equal(expected, actual) + + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param( + unit_registry.cm, + None, + id="compatible_unit", + ), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + @pytest.mark.parametrize( + "fill_value", + ( + pytest.param(-1, id="python_scalar"), + pytest.param(np.array(-1), id="numpy_scalar"), + pytest.param(np.array([-1]), id="numpy_array"), + ), + ) + def test_fillna(self, fill_value, unit, error, dtype): + array1 = ( + np.array([1.4, np.nan, 2.3, np.nan, np.nan, 9.1]).astype(dtype) + * unit_registry.m + ) + array2 = ( + np.array([4.3, 9.8, 7.5, np.nan, 8.2, np.nan]).astype(dtype) + * unit_registry.m + ) + ds = xr.Dataset({"a": ("x", array1), "b": ("x", array2)}) + value = fill_value * unit + units = extract_units(ds) + + if error is not None: + with pytest.raises(error): + ds.fillna(value=value) + + return + + actual = ds.fillna(value=value) + expected = attach_units( + strip_units(ds).fillna( + value=strip_units(convert_units(value, {None: unit_registry.m})) + ), + units, + ) + + assert_units_equal(expected, actual) + assert_equal(expected, actual) + + def test_dropna(self, dtype): + array1 = ( + np.array([1.4, np.nan, 2.3, np.nan, np.nan, 9.1]).astype(dtype) + * unit_registry.degK + ) + array2 = ( + np.array([4.3, 9.8, 7.5, np.nan, 8.2, np.nan]).astype(dtype) + * unit_registry.Pa + ) + ds = xr.Dataset({"a": ("x", array1), "b": ("x", array2)}) + units = extract_units(ds) + + expected = attach_units(strip_units(ds).dropna(dim="x"), units) + actual = ds.dropna(dim="x") + + assert_units_equal(expected, actual) + assert_equal(expected, actual) + + @pytest.mark.parametrize( + "unit", + ( + pytest.param(1, id="no_unit"), + pytest.param(unit_registry.dimensionless, id="dimensionless"), + pytest.param(unit_registry.s, id="incompatible_unit"), + pytest.param(unit_registry.cm, id="compatible_unit"), + pytest.param(unit_registry.m, id="same_unit"), + ), + ) + def test_isin(self, unit, dtype): + array1 = ( + np.array([1.4, np.nan, 2.3, np.nan, np.nan, 9.1]).astype(dtype) + * unit_registry.m + ) + array2 = ( + np.array([4.3, 9.8, 7.5, np.nan, 8.2, np.nan]).astype(dtype) + * unit_registry.m + ) + ds = xr.Dataset({"a": ("x", array1), "b": ("x", array2)}) + + raw_values = np.array([1.4, np.nan, 2.3]).astype(dtype) + values = raw_values * unit + + converted_values = ( + convert_units(values, {None: unit_registry.m}) + if is_compatible(unit, unit_registry.m) + else values + ) + + expected = strip_units(ds).isin(strip_units(converted_values)) + # TODO: use `unit_registry.is_compatible_with(unit, unit_registry.m)` instead. + # Needs `pint>=0.12.1`, though, so we probably should wait until that is released. + if not is_compatible(unit, unit_registry.m): + expected.a[:] = False + expected.b[:] = False + + actual = ds.isin(values) + + assert_units_equal(expected, actual) + assert_equal(expected, actual) + + @pytest.mark.parametrize( + "variant", ("masking", "replacing_scalar", "replacing_array", "dropping") + ) + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="same_unit"), + ), + ) + def test_where(self, variant, unit, error, dtype): + original_unit = unit_registry.m + array1 = np.linspace(0, 1, 10).astype(dtype) * original_unit + array2 = np.linspace(-1, 0, 10).astype(dtype) * original_unit + + ds = xr.Dataset({"a": ("x", array1), "b": ("x", array2)}) + units = extract_units(ds) + + condition = ds < 0.5 * original_unit + other = np.linspace(-2, -1, 10).astype(dtype) * unit + variant_kwargs = { + "masking": {"cond": condition}, + "replacing_scalar": {"cond": condition, "other": -1 * unit}, + "replacing_array": {"cond": condition, "other": other}, + "dropping": {"cond": condition, "drop": True}, + } + kwargs = variant_kwargs.get(variant) + if variant not in ("masking", "dropping") and error is not None: + with pytest.raises(error): + ds.where(**kwargs) + + return + + kwargs_without_units = { + key: strip_units(convert_units(value, {None: original_unit})) + for key, value in kwargs.items() + } + + expected = attach_units( + strip_units(ds).where(**kwargs_without_units), + units, + ) + actual = ds.where(**kwargs) + + assert_units_equal(expected, actual) + assert_equal(expected, actual) + + @pytest.mark.xfail(reason="interpolate_na uses numpy.vectorize") + def test_interpolate_na(self, dtype): + array1 = ( + np.array([1.4, np.nan, 2.3, np.nan, np.nan, 9.1]).astype(dtype) + * unit_registry.degK + ) + array2 = ( + np.array([4.3, 9.8, 7.5, np.nan, 8.2, np.nan]).astype(dtype) + * unit_registry.Pa + ) + ds = xr.Dataset({"a": ("x", array1), "b": ("x", array2)}) + units = extract_units(ds) + + expected = attach_units( + strip_units(ds).interpolate_na(dim="x"), + units, + ) + actual = ds.interpolate_na(dim="x") + + assert_units_equal(expected, actual) + assert_equal(expected, actual) + + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="same_unit"), + ), + ) + @pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", + marks=pytest.mark.skip(reason="indexes don't support units"), + ), + ), + ) + def test_combine_first(self, variant, unit, error, dtype): + variants = { + "data": (unit_registry.m, unit, 1, 1), + "dims": (1, 1, unit_registry.m, unit), + } + data_unit, other_data_unit, dims_unit, other_dims_unit = variants.get(variant) + + array1 = ( + np.array([1.4, np.nan, 2.3, np.nan, np.nan, 9.1]).astype(dtype) * data_unit + ) + array2 = ( + np.array([4.3, 9.8, 7.5, np.nan, 8.2, np.nan]).astype(dtype) * data_unit + ) + x = np.arange(len(array1)) * dims_unit + ds = xr.Dataset( + data_vars={"a": ("x", array1), "b": ("x", array2)}, + coords={"x": x}, + ) + units = extract_units(ds) + + other_array1 = np.ones_like(array1) * other_data_unit + other_array2 = np.full_like(array2, fill_value=-1) * other_data_unit + other_x = (np.arange(array1.shape[0]) + 5) * other_dims_unit + other = xr.Dataset( + data_vars={"a": ("x", other_array1), "b": ("x", other_array2)}, + coords={"x": other_x}, + ) + + if error is not None: + with pytest.raises(error): + ds.combine_first(other) + + return + + expected = attach_units( + strip_units(ds).combine_first(strip_units(convert_units(other, units))), + units, + ) + actual = ds.combine_first(other) + + assert_units_equal(expected, actual) + assert_equal(expected, actual) + + @pytest.mark.parametrize( + "unit", + ( + pytest.param(1, id="no_unit"), + pytest.param(unit_registry.dimensionless, id="dimensionless"), + pytest.param(unit_registry.s, id="incompatible_unit"), + pytest.param(unit_registry.cm, id="compatible_unit"), + pytest.param(unit_registry.m, id="identical_unit"), + ), + ) + @pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.skip(reason="indexes don't support units") + ), + "coords", + ), + ) + @pytest.mark.parametrize( + "func", + ( + method("equals"), + pytest.param( + method("identical"), + marks=pytest.mark.skip("behaviour of identical is unclear"), + ), + ), + ids=repr, + ) + def test_comparisons(self, func, variant, unit, dtype): + array1 = np.linspace(0, 5, 10).astype(dtype) + array2 = np.linspace(-5, 0, 10).astype(dtype) + + coord = np.arange(len(array1)).astype(dtype) + + variants = { + "data": (unit_registry.m, 1, 1), + "dims": (1, unit_registry.m, 1), + "coords": (1, 1, unit_registry.m), + } + data_unit, dim_unit, coord_unit = variants.get(variant) + + a = array1 * data_unit + b = array2 * data_unit + x = coord * dim_unit + y = coord * coord_unit + + ds = xr.Dataset( + data_vars={"a": ("x", a), "b": ("x", b)}, + coords={"x": x, "y": ("x", y)}, + ) + units = extract_units(ds) + + other_variants = { + "data": (unit, 1, 1), + "dims": (1, unit, 1), + "coords": (1, 1, unit), + } + other_data_unit, other_dim_unit, other_coord_unit = other_variants.get(variant) + + other_units = { + "a": other_data_unit, + "b": other_data_unit, + "x": other_dim_unit, + "y": other_coord_unit, + } + + to_convert = { + key: unit if is_compatible(unit, reference) else None + for key, (unit, reference) in zip_mappings(units, other_units) + } + # convert units where possible, then attach all units to the converted dataset + other = attach_units(strip_units(convert_units(ds, to_convert)), other_units) + other_units = extract_units(other) + + # make sure all units are compatible and only then try to + # convert and compare values + equal_ds = all( + is_compatible(unit, other_unit) + for _, (unit, other_unit) in zip_mappings(units, other_units) + ) and (strip_units(ds).equals(strip_units(convert_units(other, units)))) + equal_units = units == other_units + expected = equal_ds and (func.name != "identical" or equal_units) + + actual = func(ds, other) + + assert expected == actual + + # TODO: eventually use another decorator / wrapper function that + # applies a filter to the parametrize combinations: + # we only need a single test for data + @pytest.mark.parametrize( + "unit", + ( + pytest.param(1, id="no_unit"), + pytest.param(unit_registry.dimensionless, id="dimensionless"), + pytest.param(unit_registry.s, id="incompatible_unit"), + pytest.param(unit_registry.cm, id="compatible_unit"), + pytest.param(unit_registry.m, id="identical_unit"), + ), + ) + @pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", + marks=pytest.mark.skip(reason="indexes don't support units"), + ), + ), + ) + def test_broadcast_like(self, variant, unit, dtype): + variants = { + "data": ((unit_registry.m, unit), (1, 1)), + "dims": ((1, 1), (unit_registry.m, unit)), + } + (data_unit1, data_unit2), (dim_unit1, dim_unit2) = variants.get(variant) + + array1 = np.linspace(1, 2, 2 * 1).reshape(2, 1).astype(dtype) * data_unit1 + array2 = np.linspace(0, 1, 2 * 3).reshape(2, 3).astype(dtype) * data_unit2 + + x1 = np.arange(2) * dim_unit1 + x2 = np.arange(2) * dim_unit2 + y1 = np.array([0]) * dim_unit1 + y2 = np.arange(3) * dim_unit2 + + ds1 = xr.Dataset( + data_vars={"a": (("x", "y"), array1)}, coords={"x": x1, "y": y1} + ) + ds2 = xr.Dataset( + data_vars={"a": (("x", "y"), array2)}, coords={"x": x2, "y": y2} + ) + + expected = attach_units( + strip_units(ds1).broadcast_like(strip_units(ds2)), extract_units(ds1) + ) + actual = ds1.broadcast_like(ds2) + + assert_units_equal(expected, actual) + assert_equal(expected, actual) + + @pytest.mark.parametrize( + "unit", + ( + pytest.param(1, id="no_unit"), + pytest.param(unit_registry.dimensionless, id="dimensionless"), + pytest.param(unit_registry.s, id="incompatible_unit"), + pytest.param(unit_registry.cm, id="compatible_unit"), + pytest.param(unit_registry.m, id="identical_unit"), + ), + ) + def test_broadcast_equals(self, unit, dtype): + # TODO: does this use indexes? + left_array1 = np.ones(shape=(2, 3), dtype=dtype) * unit_registry.m + left_array2 = np.zeros(shape=(3, 6), dtype=dtype) * unit_registry.m + + right_array1 = np.ones(shape=(2,)) * unit + right_array2 = np.zeros(shape=(3,)) * unit + + left = xr.Dataset( + {"a": (("x", "y"), left_array1), "b": (("y", "z"), left_array2)}, + ) + right = xr.Dataset({"a": ("x", right_array1), "b": ("y", right_array2)}) + + units = merge_mappings( + extract_units(left), + {} if is_compatible(left_array1, unit) else {"a": None, "b": None}, + ) + expected = is_compatible(left_array1, unit) and strip_units( + left + ).broadcast_equals(strip_units(convert_units(right, units))) + actual = left.broadcast_equals(right) + + assert expected == actual + + def test_pad(self, dtype): + a = np.linspace(0, 5, 10).astype(dtype) * unit_registry.Pa + b = np.linspace(-5, 0, 10).astype(dtype) * unit_registry.degK + + ds = xr.Dataset({"a": ("x", a), "b": ("x", b)}) + units = extract_units(ds) + + expected = attach_units(strip_units(ds).pad(x=(2, 3)), units) + actual = ds.pad(x=(2, 3)) + + assert_units_equal(expected, actual) + assert_equal(expected, actual) + + @pytest.mark.parametrize( + "func", + (method("unstack"), method("reset_index", "v"), method("reorder_levels")), + ids=repr, + ) + @pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", + marks=pytest.mark.skip(reason="indexes don't support units"), + ), + ), + ) + def test_stacking_stacked(self, variant, func, dtype): + variants = { + "data": (unit_registry.m, 1), + "dims": (1, unit_registry.m), + } + data_unit, dim_unit = variants.get(variant) + + array1 = np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * data_unit + array2 = ( + np.linspace(-10, 0, 5 * 10 * 15).reshape(5, 10, 15).astype(dtype) + * data_unit + ) + + x = np.arange(array1.shape[0]) * dim_unit + y = np.arange(array1.shape[1]) * dim_unit + z = np.arange(array2.shape[2]) * dim_unit + + ds = xr.Dataset( + data_vars={"a": (("x", "y"), array1), "b": (("x", "y", "z"), array2)}, + coords={"x": x, "y": y, "z": z}, + ) + units = extract_units(ds) + + stacked = ds.stack(v=("x", "y")) + + expected = attach_units(func(strip_units(stacked)), units) + actual = func(stacked) + + assert_units_equal(expected, actual) + if func.name == "reset_index": + assert_equal(expected, actual, check_default_indexes=False) + else: + assert_equal(expected, actual) + + @pytest.mark.xfail( + reason="stacked dimension's labels have to be hashable, but is a numpy.array" + ) + def test_to_stacked_array(self, dtype): + labels = range(5) * unit_registry.s + arrays = { + name: np.linspace(0, 1, 10).astype(dtype) * unit_registry.m + for name in labels + } + + ds = xr.Dataset({name: ("x", array) for name, array in arrays.items()}) + units = {None: unit_registry.m, "y": unit_registry.s} + + func = method("to_stacked_array", "z", variable_dim="y", sample_dims=["x"]) + + actual = func(ds).rename(None) + expected = attach_units( + func(strip_units(ds)).rename(None), + units, + ) + + assert_units_equal(expected, actual) + assert_equal(expected, actual) + + @pytest.mark.parametrize( + "func", + ( + method("transpose", "y", "x", "z1", "z2"), + method("stack", u=("x", "y")), + method("set_index", x="x2"), + method("shift", x=2), + pytest.param( + method("rank", dim="x"), + marks=pytest.mark.skip(reason="rank not implemented for non-ndarray"), + ), + method("roll", x=2, roll_coords=False), + method("sortby", "x2"), + ), + ids=repr, + ) + def test_stacking_reordering(self, func, dtype): + array1 = ( + np.linspace(0, 10, 2 * 5 * 10).reshape(2, 5, 10).astype(dtype) + * unit_registry.Pa + ) + array2 = ( + np.linspace(0, 10, 2 * 5 * 15).reshape(2, 5, 15).astype(dtype) + * unit_registry.degK + ) + + x = np.arange(array1.shape[0]) + y = np.arange(array1.shape[1]) + z1 = np.arange(array1.shape[2]) + z2 = np.arange(array2.shape[2]) + + x2 = np.linspace(0, 1, array1.shape[0])[::-1] + + ds = xr.Dataset( + data_vars={ + "a": (("x", "y", "z1"), array1), + "b": (("x", "y", "z2"), array2), + }, + coords={"x": x, "y": y, "z1": z1, "z2": z2, "x2": ("x", x2)}, + ) + units = extract_units(ds) + + expected = attach_units(func(strip_units(ds)), units) + actual = func(ds) + + assert_units_equal(expected, actual) + assert_equal(expected, actual) + + @pytest.mark.parametrize( + "indices", + ( + pytest.param(4, id="single index"), + pytest.param([5, 2, 9, 1], id="multiple indices"), + ), + ) + def test_isel(self, indices, dtype): + array1 = np.arange(10).astype(dtype) * unit_registry.s + array2 = np.linspace(0, 1, 10).astype(dtype) * unit_registry.Pa + + ds = xr.Dataset(data_vars={"a": ("x", array1), "b": ("x", array2)}) + units = extract_units(ds) + + expected = attach_units(strip_units(ds).isel(x=indices), units) + actual = ds.isel(x=indices) + + assert_units_equal(expected, actual) + assert_equal(expected, actual) + + @pytest.mark.skip(reason="indexes don't support units") + @pytest.mark.parametrize( + "raw_values", + ( + pytest.param(10, id="single_value"), + pytest.param([10, 5, 13], id="list_of_values"), + pytest.param(np.array([9, 3, 7, 12]), id="array_of_values"), + ), + ) + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, KeyError, id="no_units"), + pytest.param(unit_registry.dimensionless, KeyError, id="dimensionless"), + pytest.param(unit_registry.degree, KeyError, id="incompatible_unit"), + pytest.param(unit_registry.mm, KeyError, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_sel(self, raw_values, unit, error, dtype): + array1 = np.linspace(5, 10, 20).astype(dtype) * unit_registry.degK + array2 = np.linspace(0, 5, 20).astype(dtype) * unit_registry.Pa + x = np.arange(len(array1)) * unit_registry.m + + ds = xr.Dataset( + data_vars={ + "a": xr.DataArray(data=array1, dims="x"), + "b": xr.DataArray(data=array2, dims="x"), + }, + coords={"x": x}, + ) + + values = raw_values * unit + + # TODO: if we choose dm as compatible unit, single value keys + # can be found. Should we check that? + if error is not None: + with pytest.raises(error): + ds.sel(x=values) + + return + + expected = attach_units( + strip_units(ds).sel( + x=strip_units(convert_units(values, {None: unit_registry.m})) + ), + extract_units(ds), + ) + actual = ds.sel(x=values) + + assert_units_equal(expected, actual) + assert_equal(expected, actual) + + @pytest.mark.skip(reason="indexes don't support units") + @pytest.mark.parametrize( + "raw_values", + ( + pytest.param(10, id="single_value"), + pytest.param([10, 5, 13], id="list_of_values"), + pytest.param(np.array([9, 3, 7, 12]), id="array_of_values"), + ), + ) + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, KeyError, id="no_units"), + pytest.param(unit_registry.dimensionless, KeyError, id="dimensionless"), + pytest.param(unit_registry.degree, KeyError, id="incompatible_unit"), + pytest.param(unit_registry.mm, KeyError, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_drop_sel(self, raw_values, unit, error, dtype): + array1 = np.linspace(5, 10, 20).astype(dtype) * unit_registry.degK + array2 = np.linspace(0, 5, 20).astype(dtype) * unit_registry.Pa + x = np.arange(len(array1)) * unit_registry.m + + ds = xr.Dataset( + data_vars={ + "a": xr.DataArray(data=array1, dims="x"), + "b": xr.DataArray(data=array2, dims="x"), + }, + coords={"x": x}, + ) + + values = raw_values * unit + + # TODO: if we choose dm as compatible unit, single value keys + # can be found. Should we check that? + if error is not None: + with pytest.raises(error): + ds.drop_sel(x=values) + + return + + expected = attach_units( + strip_units(ds).drop_sel( + x=strip_units(convert_units(values, {None: unit_registry.m})) + ), + extract_units(ds), + ) + actual = ds.drop_sel(x=values) + + assert_units_equal(expected, actual) + assert_equal(expected, actual) + + @pytest.mark.skip(reason="indexes don't support units") + @pytest.mark.parametrize( + "raw_values", + ( + pytest.param(10, id="single_value"), + pytest.param([10, 5, 13], id="list_of_values"), + pytest.param(np.array([9, 3, 7, 12]), id="array_of_values"), + ), + ) + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, KeyError, id="no_units"), + pytest.param(unit_registry.dimensionless, KeyError, id="dimensionless"), + pytest.param(unit_registry.degree, KeyError, id="incompatible_unit"), + pytest.param(unit_registry.mm, KeyError, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_loc(self, raw_values, unit, error, dtype): + array1 = np.linspace(5, 10, 20).astype(dtype) * unit_registry.degK + array2 = np.linspace(0, 5, 20).astype(dtype) * unit_registry.Pa + x = np.arange(len(array1)) * unit_registry.m + + ds = xr.Dataset( + data_vars={ + "a": xr.DataArray(data=array1, dims="x"), + "b": xr.DataArray(data=array2, dims="x"), + }, + coords={"x": x}, + ) + + values = raw_values * unit + + # TODO: if we choose dm as compatible unit, single value keys + # can be found. Should we check that? + if error is not None: + with pytest.raises(error): + ds.loc[{"x": values}] + + return + + expected = attach_units( + strip_units(ds).loc[ + {"x": strip_units(convert_units(values, {None: unit_registry.m}))} + ], + extract_units(ds), + ) + actual = ds.loc[{"x": values}] + + assert_units_equal(expected, actual) + assert_equal(expected, actual) + + @pytest.mark.parametrize( + "func", + ( + method("head", x=7, y=3, z=6), + method("tail", x=7, y=3, z=6), + method("thin", x=7, y=3, z=6), + ), + ids=repr, + ) + @pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.skip(reason="indexes don't support units") + ), + "coords", + ), + ) + def test_head_tail_thin(self, func, variant, dtype): + variants = { + "data": ((unit_registry.degK, unit_registry.Pa), 1, 1), + "dims": ((1, 1), unit_registry.m, 1), + "coords": ((1, 1), 1, unit_registry.m), + } + (unit_a, unit_b), dim_unit, coord_unit = variants.get(variant) + + array1 = np.linspace(1, 2, 10 * 5).reshape(10, 5) * unit_a + array2 = np.linspace(1, 2, 10 * 8).reshape(10, 8) * unit_b + + coords = { + "x": np.arange(10) * dim_unit, + "y": np.arange(5) * dim_unit, + "z": np.arange(8) * dim_unit, + "u": ("x", np.linspace(0, 1, 10) * coord_unit), + "v": ("y", np.linspace(1, 2, 5) * coord_unit), + "w": ("z", np.linspace(-1, 0, 8) * coord_unit), + } + + ds = xr.Dataset( + data_vars={ + "a": xr.DataArray(data=array1, dims=("x", "y")), + "b": xr.DataArray(data=array2, dims=("x", "z")), + }, + coords=coords, + ) + + expected = attach_units(func(strip_units(ds)), extract_units(ds)) + actual = func(ds) + + assert_units_equal(expected, actual) + assert_equal(expected, actual) + + @pytest.mark.parametrize("dim", ("x", "y", "z", "t", "all")) + @pytest.mark.parametrize( + "shape", + ( + pytest.param((10, 20), id="nothing squeezable"), + pytest.param((10, 20, 1), id="last dimension squeezable"), + pytest.param((10, 1, 20), id="middle dimension squeezable"), + pytest.param((1, 10, 20), id="first dimension squeezable"), + pytest.param((1, 10, 1, 20), id="first and last dimension squeezable"), + ), + ) + def test_squeeze(self, shape, dim, dtype): + names = "xyzt" + dim_lengths = dict(zip(names, shape)) + array1 = ( + np.linspace(0, 1, 10 * 20).astype(dtype).reshape(shape) * unit_registry.degK + ) + array2 = ( + np.linspace(1, 2, 10 * 20).astype(dtype).reshape(shape) * unit_registry.Pa + ) + + ds = xr.Dataset( + data_vars={ + "a": (tuple(names[: len(shape)]), array1), + "b": (tuple(names[: len(shape)]), array2), + }, + ) + units = extract_units(ds) + + kwargs = {"dim": dim} if dim != "all" and dim_lengths.get(dim, 0) == 1 else {} + + expected = attach_units(strip_units(ds).squeeze(**kwargs), units) + + actual = ds.squeeze(**kwargs) + + assert_units_equal(expected, actual) + assert_equal(expected, actual) + + @pytest.mark.parametrize("variant", ("data", "coords")) + @pytest.mark.parametrize( + "func", + ( + pytest.param( + method("interp"), marks=pytest.mark.xfail(reason="uses scipy") + ), + method("reindex"), + ), + ids=repr, + ) + def test_interp_reindex(self, func, variant, dtype): + variants = { + "data": (unit_registry.m, 1), + "coords": (1, unit_registry.m), + } + data_unit, coord_unit = variants.get(variant) + + array1 = np.linspace(-1, 0, 10).astype(dtype) * data_unit + array2 = np.linspace(0, 1, 10).astype(dtype) * data_unit + + y = np.arange(10) * coord_unit + + x = np.arange(10) + new_x = np.arange(8) + 0.5 + + ds = xr.Dataset( + {"a": ("x", array1), "b": ("x", array2)}, coords={"x": x, "y": ("x", y)} + ) + units = extract_units(ds) + + expected = attach_units(func(strip_units(ds), x=new_x), units) + actual = func(ds, x=new_x) + + assert_units_equal(expected, actual) + assert_equal(expected, actual) + + @pytest.mark.skip(reason="indexes don't support units") + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + @pytest.mark.parametrize("func", (method("interp"), method("reindex")), ids=repr) + def test_interp_reindex_indexing(self, func, unit, error, dtype): + array1 = np.linspace(-1, 0, 10).astype(dtype) + array2 = np.linspace(0, 1, 10).astype(dtype) + + x = np.arange(10) * unit_registry.m + new_x = (np.arange(8) + 0.5) * unit + + ds = xr.Dataset({"a": ("x", array1), "b": ("x", array2)}, coords={"x": x}) + units = extract_units(ds) + + if error is not None: + with pytest.raises(error): + func(ds, x=new_x) + + return + + expected = attach_units(func(strip_units(ds), x=new_x), units) + actual = func(ds, x=new_x) + + assert_units_equal(expected, actual) + assert_equal(expected, actual) + + @pytest.mark.parametrize("variant", ("data", "coords")) + @pytest.mark.parametrize( + "func", + ( + pytest.param( + method("interp_like"), marks=pytest.mark.xfail(reason="uses scipy") + ), + method("reindex_like"), + ), + ids=repr, + ) + def test_interp_reindex_like(self, func, variant, dtype): + variants = { + "data": (unit_registry.m, 1), + "coords": (1, unit_registry.m), + } + data_unit, coord_unit = variants.get(variant) + + array1 = np.linspace(-1, 0, 10).astype(dtype) * data_unit + array2 = np.linspace(0, 1, 10).astype(dtype) * data_unit + + y = np.arange(10) * coord_unit + + x = np.arange(10) + new_x = np.arange(8) + 0.5 + + ds = xr.Dataset( + {"a": ("x", array1), "b": ("x", array2)}, coords={"x": x, "y": ("x", y)} + ) + units = extract_units(ds) + + other = xr.Dataset({"a": ("x", np.empty_like(new_x))}, coords={"x": new_x}) + + expected = attach_units(func(strip_units(ds), other), units) + actual = func(ds, other) + + assert_units_equal(expected, actual) + assert_equal(expected, actual) + + @pytest.mark.skip(reason="indexes don't support units") + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + @pytest.mark.parametrize( + "func", (method("interp_like"), method("reindex_like")), ids=repr + ) + def test_interp_reindex_like_indexing(self, func, unit, error, dtype): + array1 = np.linspace(-1, 0, 10).astype(dtype) + array2 = np.linspace(0, 1, 10).astype(dtype) + + x = np.arange(10) * unit_registry.m + new_x = (np.arange(8) + 0.5) * unit + + ds = xr.Dataset({"a": ("x", array1), "b": ("x", array2)}, coords={"x": x}) + units = extract_units(ds) + + other = xr.Dataset({"a": ("x", np.empty_like(new_x))}, coords={"x": new_x}) + + if error is not None: + with pytest.raises(error): + func(ds, other) + + return + + expected = attach_units(func(strip_units(ds), other), units) + actual = func(ds, other) + + assert_units_equal(expected, actual) + assert_equal(expected, actual) + + @pytest.mark.parametrize("compute_backend", ["numbagg", None], indirect=True) + @pytest.mark.parametrize( + "func", + ( + method("diff", dim="x"), + method("differentiate", coord="x"), + method("integrate", coord="x"), + method("quantile", q=[0.25, 0.75]), + method("reduce", func=np.sum, dim="x"), + method("map", np.fabs), + ), + ids=repr, + ) + @pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.skip(reason="indexes don't support units") + ), + "coords", + ), + ) + def test_computation(self, func, variant, dtype, compute_backend): + variants = { + "data": ((unit_registry.degK, unit_registry.Pa), 1, 1), + "dims": ((1, 1), unit_registry.m, 1), + "coords": ((1, 1), 1, unit_registry.m), + } + (unit1, unit2), dim_unit, coord_unit = variants.get(variant) + + array1 = np.linspace(-5, 5, 4 * 5).reshape(4, 5).astype(dtype) * unit1 + array2 = np.linspace(10, 20, 4 * 3).reshape(4, 3).astype(dtype) * unit2 + x = np.arange(4) * dim_unit + y = np.arange(5) * dim_unit + z = np.arange(3) * dim_unit + + ds = xr.Dataset( + data_vars={ + "a": xr.DataArray(data=array1, dims=("x", "y")), + "b": xr.DataArray(data=array2, dims=("x", "z")), + }, + coords={"x": x, "y": y, "z": z, "y2": ("y", np.arange(5) * coord_unit)}, + ) + + units = extract_units(ds) + + expected = attach_units(func(strip_units(ds)), units) + actual = func(ds) + + assert_units_equal(expected, actual) + assert_equal(expected, actual) + + @pytest.mark.parametrize( + "func", + ( + method("groupby", "x"), + method("groupby_bins", "x", bins=2), + method("coarsen", x=2), + pytest.param( + method("rolling", x=3), marks=pytest.mark.xfail(reason="strips units") + ), + pytest.param( + method("rolling_exp", x=3), + marks=pytest.mark.xfail( + reason="numbagg functions are not supported by pint" + ), + ), + method("weighted", xr.DataArray(data=np.linspace(0, 1, 5), dims="y")), + ), + ids=repr, + ) + @pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.skip(reason="indexes don't support units") + ), + "coords", + ), + ) + def test_computation_objects(self, func, variant, dtype): + variants = { + "data": ((unit_registry.degK, unit_registry.Pa), 1, 1), + "dims": ((1, 1), unit_registry.m, 1), + "coords": ((1, 1), 1, unit_registry.m), + } + (unit1, unit2), dim_unit, coord_unit = variants.get(variant) + + array1 = np.linspace(-5, 5, 4 * 5).reshape(4, 5).astype(dtype) * unit1 + array2 = np.linspace(10, 20, 4 * 3).reshape(4, 3).astype(dtype) * unit2 + x = np.arange(4) * dim_unit + y = np.arange(5) * dim_unit + z = np.arange(3) * dim_unit + + ds = xr.Dataset( + data_vars={"a": (("x", "y"), array1), "b": (("x", "z"), array2)}, + coords={"x": x, "y": y, "z": z, "y2": ("y", np.arange(5) * coord_unit)}, + ) + units = extract_units(ds) + + args = [] if func.name != "groupby" else ["y"] + # Doesn't work with flox because pint doesn't implement + # ufunc.reduceat or np.bincount + # kwargs = {"engine": "numpy"} if "groupby" in func.name else {} + kwargs = {} + expected = attach_units(func(strip_units(ds)).mean(*args, **kwargs), units) + actual = func(ds).mean(*args, **kwargs) + + assert_units_equal(expected, actual) + assert_allclose(expected, actual) + + @pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.skip(reason="indexes don't support units") + ), + "coords", + ), + ) + def test_resample(self, variant, dtype): + # TODO: move this to test_computation_objects + variants = { + "data": ((unit_registry.degK, unit_registry.Pa), 1, 1), + "dims": ((1, 1), unit_registry.m, 1), + "coords": ((1, 1), 1, unit_registry.m), + } + (unit1, unit2), dim_unit, coord_unit = variants.get(variant) + + array1 = np.linspace(-5, 5, 10 * 5).reshape(10, 5).astype(dtype) * unit1 + array2 = np.linspace(10, 20, 10 * 8).reshape(10, 8).astype(dtype) * unit2 + + t = xr.date_range("10-09-2010", periods=array1.shape[0], freq="YE") + y = np.arange(5) * dim_unit + z = np.arange(8) * dim_unit + + u = np.linspace(-1, 0, 5) * coord_unit + + ds = xr.Dataset( + data_vars={"a": (("time", "y"), array1), "b": (("time", "z"), array2)}, + coords={"time": t, "y": y, "z": z, "u": ("y", u)}, + ) + units = extract_units(ds) + + func = method("resample", time="6ME") + + expected = attach_units(func(strip_units(ds)).mean(), units) + actual = func(ds).mean() + + assert_units_equal(expected, actual) + assert_equal(expected, actual) + + @pytest.mark.parametrize("compute_backend", ["numbagg", None], indirect=True) + @pytest.mark.parametrize( + "func", + ( + method("assign", c=lambda ds: 10 * ds.b), + method("assign_coords", v=("x", np.arange(5) * unit_registry.s)), + method("first"), + method("last"), + method("quantile", q=[0.25, 0.5, 0.75], dim="x"), + ), + ids=repr, + ) + @pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.skip(reason="indexes don't support units") + ), + "coords", + ), + ) + def test_grouped_operations(self, func, variant, dtype, compute_backend): + variants = { + "data": ((unit_registry.degK, unit_registry.Pa), 1, 1), + "dims": ((1, 1), unit_registry.m, 1), + "coords": ((1, 1), 1, unit_registry.m), + } + (unit1, unit2), dim_unit, coord_unit = variants.get(variant) + + array1 = np.linspace(-5, 5, 5 * 4).reshape(5, 4).astype(dtype) * unit1 + array2 = np.linspace(10, 20, 5 * 4 * 3).reshape(5, 4, 3).astype(dtype) * unit2 + x = np.arange(5) * dim_unit + y = np.arange(4) * dim_unit + z = np.arange(3) * dim_unit + + u = np.linspace(-1, 0, 4) * coord_unit + + ds = xr.Dataset( + data_vars={"a": (("x", "y"), array1), "b": (("x", "y", "z"), array2)}, + coords={"x": x, "y": y, "z": z, "u": ("y", u)}, + ) + + assigned_units = {"c": unit2, "v": unit_registry.s} + units = merge_mappings(extract_units(ds), assigned_units) + + stripped_kwargs = { + name: strip_units(value) for name, value in func.kwargs.items() + } + expected = attach_units( + func(strip_units(ds).groupby("y", squeeze=False), **stripped_kwargs), units + ) + actual = func(ds.groupby("y", squeeze=False)) + + assert_units_equal(expected, actual) + assert_equal(expected, actual) + + @pytest.mark.parametrize( + "func", + ( + method("pipe", lambda ds: ds * 10), + method("assign", d=lambda ds: ds.b * 10), + method("assign_coords", y2=("y", np.arange(4) * unit_registry.mm)), + method("assign_attrs", attr1="value"), + method("rename", x2="x_mm"), + method("rename_vars", c="temperature"), + method("rename_dims", x="offset_x"), + method("swap_dims", {"x": "u"}), + pytest.param( + method( + "expand_dims", v=np.linspace(10, 20, 12) * unit_registry.s, axis=1 + ), + marks=pytest.mark.skip(reason="indexes don't support units"), + ), + method("drop_vars", "x"), + method("drop_dims", "z"), + method("set_coords", names="c"), + method("reset_coords", names="x2"), + method("copy"), + ), + ids=repr, + ) + @pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.skip(reason="indexes don't support units") + ), + "coords", + ), + ) + def test_content_manipulation(self, func, variant, dtype): + variants = { + "data": ( + (unit_registry.m**3, unit_registry.Pa, unit_registry.degK), + 1, + 1, + ), + "dims": ((1, 1, 1), unit_registry.m, 1), + "coords": ((1, 1, 1), 1, unit_registry.m), + } + (unit1, unit2, unit3), dim_unit, coord_unit = variants.get(variant) + + array1 = np.linspace(-5, 5, 5 * 4).reshape(5, 4).astype(dtype) * unit1 + array2 = np.linspace(10, 20, 5 * 4 * 3).reshape(5, 4, 3).astype(dtype) * unit2 + array3 = np.linspace(0, 10, 5).astype(dtype) * unit3 + + x = np.arange(5) * dim_unit + y = np.arange(4) * dim_unit + z = np.arange(3) * dim_unit + + x2 = np.linspace(-1, 0, 5) * coord_unit + + ds = xr.Dataset( + data_vars={ + "a": (("x", "y"), array1), + "b": (("x", "y", "z"), array2), + "c": ("x", array3), + }, + coords={"x": x, "y": y, "z": z, "x2": ("x", x2)}, + ) + + new_units = { + "y2": unit_registry.mm, + "x_mm": coord_unit, + "offset_x": unit_registry.m, + "d": unit2, + "temperature": unit3, + } + units = merge_mappings(extract_units(ds), new_units) + + stripped_kwargs = { + key: strip_units(value) for key, value in func.kwargs.items() + } + expected = attach_units(func(strip_units(ds), **stripped_kwargs), units) + actual = func(ds) + + assert_units_equal(expected, actual) + if func.name == "rename_dims": + assert_equal(expected, actual, check_default_indexes=False) + else: + assert_equal(expected, actual) + + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, xr.MergeError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, xr.MergeError, id="dimensionless" + ), + pytest.param(unit_registry.s, xr.MergeError, id="incompatible_unit"), + pytest.param(unit_registry.cm, xr.MergeError, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + @pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.skip(reason="indexes don't support units") + ), + "coords", + ), + ) + def test_merge(self, variant, unit, error, dtype): + left_variants = { + "data": (unit_registry.m, 1, 1), + "dims": (1, unit_registry.m, 1), + "coords": (1, 1, unit_registry.m), + } + + left_data_unit, left_dim_unit, left_coord_unit = left_variants.get(variant) + + right_variants = { + "data": (unit, 1, 1), + "dims": (1, unit, 1), + "coords": (1, 1, unit), + } + right_data_unit, right_dim_unit, right_coord_unit = right_variants.get(variant) + + left_array = np.arange(10).astype(dtype) * left_data_unit + right_array = np.arange(-5, 5).astype(dtype) * right_data_unit + + left_dim = np.arange(10, 20) * left_dim_unit + right_dim = np.arange(5, 15) * right_dim_unit + + left_coord = np.arange(-10, 0) * left_coord_unit + right_coord = np.arange(-15, -5) * right_coord_unit + + left = xr.Dataset( + data_vars={"a": ("x", left_array)}, + coords={"x": left_dim, "y": ("x", left_coord)}, + ) + right = xr.Dataset( + data_vars={"a": ("x", right_array)}, + coords={"x": right_dim, "y": ("x", right_coord)}, + ) + + units = extract_units(left) + + if error is not None: + with pytest.raises(error): + left.merge(right) + + return + + converted = convert_units(right, units) + expected = attach_units(strip_units(left).merge(strip_units(converted)), units) + actual = left.merge(right) + + assert_units_equal(expected, actual) + assert_equal(expected, actual) + + +@requires_dask +class TestPintWrappingDask: + def test_duck_array_ops(self): + import dask.array + + d = dask.array.array([1, 2, 3]) + q = unit_registry.Quantity(d, units="m") + da = xr.DataArray(q, dims="x") + + actual = da.mean().compute() + actual.name = None + expected = xr.DataArray(unit_registry.Quantity(np.array(2.0), units="m")) + + assert_units_equal(expected, actual) + # Don't use isinstance b/c we don't want to allow subclasses through + assert type(expected.data) == type(actual.data) # noqa + + +@requires_matplotlib +class TestPlots(PlotTestCase): + @pytest.mark.parametrize( + "coord_unit, coord_attrs", + [ + (1, {"units": "meter"}), + pytest.param( + unit_registry.m, + {}, + marks=pytest.mark.xfail(reason="indexes don't support units"), + ), + ], + ) + def test_units_in_line_plot_labels(self, coord_unit, coord_attrs): + arr = np.linspace(1, 10, 3) * unit_registry.Pa + coord_arr = np.linspace(1, 3, 3) * coord_unit + x_coord = xr.DataArray(coord_arr, dims="x", attrs=coord_attrs) + da = xr.DataArray(data=arr, dims="x", coords={"x": x_coord}, name="pressure") + + da.plot.line() + + ax = plt.gca() + assert ax.get_ylabel() == "pressure [pascal]" + assert ax.get_xlabel() == "x [meter]" + + @pytest.mark.parametrize( + "coord_unit, coord_attrs", + [ + (1, {"units": "meter"}), + pytest.param( + unit_registry.m, + {}, + marks=pytest.mark.xfail(reason="indexes don't support units"), + ), + ], + ) + def test_units_in_slice_line_plot_labels_sel(self, coord_unit, coord_attrs): + arr = xr.DataArray( + name="var_a", + data=np.array([[1, 2], [3, 4]]), + coords=dict( + a=("a", np.array([5, 6]) * coord_unit, coord_attrs), + b=("b", np.array([7, 8]) * coord_unit, coord_attrs), + ), + dims=("a", "b"), + ) + arr.sel(a=5).plot(marker="o") + + assert plt.gca().get_title() == "a = 5 [meter]" + + @pytest.mark.parametrize( + "coord_unit, coord_attrs", + [ + (1, {"units": "meter"}), + pytest.param( + unit_registry.m, + {}, + marks=pytest.mark.xfail(reason="pint.errors.UnitStrippedWarning"), + ), + ], + ) + def test_units_in_slice_line_plot_labels_isel(self, coord_unit, coord_attrs): + arr = xr.DataArray( + name="var_a", + data=np.array([[1, 2], [3, 4]]), + coords=dict( + a=("x", np.array([5, 6]) * coord_unit, coord_attrs), + b=("y", np.array([7, 8])), + ), + dims=("x", "y"), + ) + arr.isel(x=0).plot(marker="o") + assert plt.gca().get_title() == "a = 5 [meter]" + + def test_units_in_2d_plot_colorbar_label(self): + arr = np.ones((2, 3)) * unit_registry.Pa + da = xr.DataArray(data=arr, dims=["x", "y"], name="pressure") + + fig, (ax, cax) = plt.subplots(1, 2) + ax = da.plot.contourf(ax=ax, cbar_ax=cax, add_colorbar=True) + + assert cax.get_ylabel() == "pressure [pascal]" + + def test_units_facetgrid_plot_labels(self): + arr = np.ones((2, 3)) * unit_registry.Pa + da = xr.DataArray(data=arr, dims=["x", "y"], name="pressure") + + fig, (ax, cax) = plt.subplots(1, 2) + fgrid = da.plot.line(x="x", col="y") + + assert fgrid.axs[0, 0].get_ylabel() == "pressure [pascal]" + + def test_units_facetgrid_2d_imshow_plot_colorbar_labels(self): + arr = np.ones((2, 3, 4, 5)) * unit_registry.Pa + da = xr.DataArray(data=arr, dims=["x", "y", "z", "w"], name="pressure") + + da.plot.imshow(x="x", y="y", col="w") # no colorbar to check labels of + + def test_units_facetgrid_2d_contourf_plot_colorbar_labels(self): + arr = np.ones((2, 3, 4)) * unit_registry.Pa + da = xr.DataArray(data=arr, dims=["x", "y", "z"], name="pressure") + + fig, (ax1, ax2, ax3, cax) = plt.subplots(1, 4) + fgrid = da.plot.contourf(x="x", y="y", col="z") + + assert fgrid.cbar.ax.get_ylabel() == "pressure [pascal]" diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_utils.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_utils.py new file mode 100644 index 0000000..50061c7 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_utils.py @@ -0,0 +1,358 @@ +from __future__ import annotations + +from collections.abc import Hashable + +import numpy as np +import pandas as pd +import pytest + +from xarray.core import duck_array_ops, utils +from xarray.core.utils import either_dict_or_kwargs, infix_dims, iterate_nested +from xarray.tests import assert_array_equal, requires_dask + + +class TestAlias: + def test(self): + def new_method(): + pass + + old_method = utils.alias(new_method, "old_method") + assert "deprecated" in old_method.__doc__ + with pytest.warns(Warning, match="deprecated"): + old_method() + + +@pytest.mark.parametrize( + ["a", "b", "expected"], + [ + [np.array(["a"]), np.array(["b"]), np.array(["a", "b"])], + [np.array([1], dtype="int64"), np.array([2], dtype="int64"), pd.Index([1, 2])], + ], +) +def test_maybe_coerce_to_str(a, b, expected): + index = pd.Index(a).append(pd.Index(b)) + + actual = utils.maybe_coerce_to_str(index, [a, b]) + + assert_array_equal(expected, actual) + assert expected.dtype == actual.dtype + + +def test_maybe_coerce_to_str_minimal_str_dtype(): + a = np.array(["a", "a_long_string"]) + index = pd.Index(["a"]) + + actual = utils.maybe_coerce_to_str(index, [a]) + expected = np.array("a") + + assert_array_equal(expected, actual) + assert expected.dtype == actual.dtype + + +class TestArrayEquiv: + def test_0d(self): + # verify our work around for pd.isnull not working for 0-dimensional + # object arrays + assert duck_array_ops.array_equiv(0, np.array(0, dtype=object)) + assert duck_array_ops.array_equiv(np.nan, np.array(np.nan, dtype=object)) + assert not duck_array_ops.array_equiv(0, np.array(1, dtype=object)) + + +class TestDictionaries: + @pytest.fixture(autouse=True) + def setup(self): + self.x = {"a": "A", "b": "B"} + self.y = {"c": "C", "b": "B"} + self.z = {"a": "Z"} + + def test_equivalent(self): + assert utils.equivalent(0, 0) + assert utils.equivalent(np.nan, np.nan) + assert utils.equivalent(0, np.array(0.0)) + assert utils.equivalent([0], np.array([0])) + assert utils.equivalent(np.array([0]), [0]) + assert utils.equivalent(np.arange(3), 1.0 * np.arange(3)) + assert not utils.equivalent(0, np.zeros(3)) + + def test_safe(self): + # should not raise exception: + utils.update_safety_check(self.x, self.y) + + def test_unsafe(self): + with pytest.raises(ValueError): + utils.update_safety_check(self.x, self.z) + + def test_compat_dict_intersection(self): + assert {"b": "B"} == utils.compat_dict_intersection(self.x, self.y) + assert {} == utils.compat_dict_intersection(self.x, self.z) + + def test_compat_dict_union(self): + assert {"a": "A", "b": "B", "c": "C"} == utils.compat_dict_union(self.x, self.y) + with pytest.raises( + ValueError, + match=r"unsafe to merge dictionaries without " + "overriding values; conflicting key", + ): + utils.compat_dict_union(self.x, self.z) + + def test_dict_equiv(self): + x = {} + x["a"] = 3 + x["b"] = np.array([1, 2, 3]) + y = {} + y["b"] = np.array([1.0, 2.0, 3.0]) + y["a"] = 3 + assert utils.dict_equiv(x, y) # two nparrays are equal + y["b"] = [1, 2, 3] # np.array not the same as a list + assert utils.dict_equiv(x, y) # nparray == list + x["b"] = [1.0, 2.0, 3.0] + assert utils.dict_equiv(x, y) # list vs. list + x["c"] = None + assert not utils.dict_equiv(x, y) # new key in x + x["c"] = np.nan + y["c"] = np.nan + assert utils.dict_equiv(x, y) # as intended, nan is nan + x["c"] = np.inf + y["c"] = np.inf + assert utils.dict_equiv(x, y) # inf == inf + y = dict(y) + assert utils.dict_equiv(x, y) # different dictionary types are fine + y["b"] = 3 * np.arange(3) + assert not utils.dict_equiv(x, y) # not equal when arrays differ + + def test_frozen(self): + x = utils.Frozen(self.x) + with pytest.raises(TypeError): + x["foo"] = "bar" + with pytest.raises(TypeError): + del x["a"] + with pytest.raises(AttributeError): + x.update(self.y) + assert x.mapping == self.x + assert repr(x) in ( + "Frozen({'a': 'A', 'b': 'B'})", + "Frozen({'b': 'B', 'a': 'A'})", + ) + + +def test_repr_object(): + obj = utils.ReprObject("foo") + assert repr(obj) == "foo" + assert isinstance(obj, Hashable) + assert not isinstance(obj, str) + + +def test_repr_object_magic_methods(): + o1 = utils.ReprObject("foo") + o2 = utils.ReprObject("foo") + o3 = utils.ReprObject("bar") + o4 = "foo" + assert o1 == o2 + assert o1 != o3 + assert o1 != o4 + assert hash(o1) == hash(o2) + assert hash(o1) != hash(o3) + assert hash(o1) != hash(o4) + + +def test_is_remote_uri(): + assert utils.is_remote_uri("http://example.com") + assert utils.is_remote_uri("https://example.com") + assert not utils.is_remote_uri(" http://example.com") + assert not utils.is_remote_uri("example.nc") + + +class Test_is_uniform_and_sorted: + def test_sorted_uniform(self): + assert utils.is_uniform_spaced(np.arange(5)) + + def test_sorted_not_uniform(self): + assert not utils.is_uniform_spaced([-2, 1, 89]) + + def test_not_sorted_uniform(self): + assert not utils.is_uniform_spaced([1, -1, 3]) + + def test_not_sorted_not_uniform(self): + assert not utils.is_uniform_spaced([4, 1, 89]) + + def test_two_numbers(self): + assert utils.is_uniform_spaced([0, 1.7]) + + def test_relative_tolerance(self): + assert utils.is_uniform_spaced([0, 0.97, 2], rtol=0.1) + + +class Test_hashable: + def test_hashable(self): + for v in [False, 1, (2,), (3, 4), "four"]: + assert utils.hashable(v) + for v in [[5, 6], ["seven", "8"], {9: "ten"}]: + assert not utils.hashable(v) + + +@requires_dask +def test_dask_array_is_scalar(): + # regression test for GH1684 + import dask.array as da + + y = da.arange(8, chunks=4) + assert not utils.is_scalar(y) + + +def test_hidden_key_dict(): + hidden_key = "_hidden_key" + data = {"a": 1, "b": 2, hidden_key: 3} + data_expected = {"a": 1, "b": 2} + hkd = utils.HiddenKeyDict(data, [hidden_key]) + assert len(hkd) == 2 + assert hidden_key not in hkd + for k, v in data_expected.items(): + assert hkd[k] == v + with pytest.raises(KeyError): + hkd[hidden_key] + with pytest.raises(KeyError): + del hkd[hidden_key] + + +def test_either_dict_or_kwargs(): + result = either_dict_or_kwargs(dict(a=1), None, "foo") + expected = dict(a=1) + assert result == expected + + result = either_dict_or_kwargs(None, dict(a=1), "foo") + expected = dict(a=1) + assert result == expected + + with pytest.raises(ValueError, match=r"foo"): + result = either_dict_or_kwargs(dict(a=1), dict(a=1), "foo") + + +@pytest.mark.parametrize( + ["supplied", "all_", "expected"], + [ + (list("abc"), list("abc"), list("abc")), + (["a", ..., "c"], list("abc"), list("abc")), + (["a", ...], list("abc"), list("abc")), + (["c", ...], list("abc"), list("cab")), + ([..., "b"], list("abc"), list("acb")), + ([...], list("abc"), list("abc")), + ], +) +def test_infix_dims(supplied, all_, expected): + result = list(infix_dims(supplied, all_)) + assert result == expected + + +@pytest.mark.parametrize( + ["supplied", "all_"], [([..., ...], list("abc")), ([...], list("aac"))] +) +def test_infix_dims_errors(supplied, all_): + with pytest.raises(ValueError): + list(infix_dims(supplied, all_)) + + +@pytest.mark.parametrize( + ["dim", "expected"], + [ + pytest.param("a", ("a",), id="str"), + pytest.param(["a", "b"], ("a", "b"), id="list_of_str"), + pytest.param(["a", 1], ("a", 1), id="list_mixed"), + pytest.param(["a", ...], ("a", ...), id="list_with_ellipsis"), + pytest.param(("a", "b"), ("a", "b"), id="tuple_of_str"), + pytest.param(["a", ("b", "c")], ("a", ("b", "c")), id="list_with_tuple"), + pytest.param((("b", "c"),), (("b", "c"),), id="tuple_of_tuple"), + pytest.param({"a", 1}, tuple({"a", 1}), id="non_sequence_collection"), + pytest.param((), (), id="empty_tuple"), + pytest.param(set(), (), id="empty_collection"), + pytest.param(None, None, id="None"), + pytest.param(..., ..., id="ellipsis"), + ], +) +def test_parse_dims(dim, expected) -> None: + all_dims = ("a", "b", 1, ("b", "c")) # selection of different Hashables + actual = utils.parse_dims(dim, all_dims, replace_none=False) + assert actual == expected + + +def test_parse_dims_set() -> None: + all_dims = ("a", "b", 1, ("b", "c")) # selection of different Hashables + dim = {"a", 1} + actual = utils.parse_dims(dim, all_dims) + assert set(actual) == dim + + +@pytest.mark.parametrize( + "dim", [pytest.param(None, id="None"), pytest.param(..., id="ellipsis")] +) +def test_parse_dims_replace_none(dim: None | ellipsis) -> None: + all_dims = ("a", "b", 1, ("b", "c")) # selection of different Hashables + actual = utils.parse_dims(dim, all_dims, replace_none=True) + assert actual == all_dims + + +@pytest.mark.parametrize( + "dim", + [ + pytest.param("x", id="str_missing"), + pytest.param(["a", "x"], id="list_missing_one"), + pytest.param(["x", 2], id="list_missing_all"), + ], +) +def test_parse_dims_raises(dim) -> None: + all_dims = ("a", "b", 1, ("b", "c")) # selection of different Hashables + with pytest.raises(ValueError, match="'x'"): + utils.parse_dims(dim, all_dims, check_exists=True) + + +@pytest.mark.parametrize( + ["dim", "expected"], + [ + pytest.param("a", ("a",), id="str"), + pytest.param(["a", "b"], ("a", "b"), id="list"), + pytest.param([...], ("a", "b", "c"), id="list_only_ellipsis"), + pytest.param(["a", ...], ("a", "b", "c"), id="list_with_ellipsis"), + pytest.param(["a", ..., "b"], ("a", "c", "b"), id="list_with_middle_ellipsis"), + ], +) +def test_parse_ordered_dims(dim, expected) -> None: + all_dims = ("a", "b", "c") + actual = utils.parse_ordered_dims(dim, all_dims) + assert actual == expected + + +def test_parse_ordered_dims_raises() -> None: + all_dims = ("a", "b", "c") + + with pytest.raises(ValueError, match="'x' do not exist"): + utils.parse_ordered_dims("x", all_dims, check_exists=True) + + with pytest.raises(ValueError, match="repeated dims"): + utils.parse_ordered_dims(["a", ...], all_dims + ("a",)) + + with pytest.raises(ValueError, match="More than one ellipsis"): + utils.parse_ordered_dims(["a", ..., "b", ...], all_dims) + + +@pytest.mark.parametrize( + "nested_list, expected", + [ + ([], []), + ([1], [1]), + ([1, 2, 3], [1, 2, 3]), + ([[1]], [1]), + ([[1, 2], [3, 4]], [1, 2, 3, 4]), + ([[[1, 2, 3], [4]], [5, 6]], [1, 2, 3, 4, 5, 6]), + ], +) +def test_iterate_nested(nested_list, expected): + assert list(iterate_nested(nested_list)) == expected + + +def test_find_stack_level(): + assert utils.find_stack_level() == 1 + assert utils.find_stack_level(test_mode=True) == 2 + + def f(): + return utils.find_stack_level(test_mode=True) + + assert f() == 3 diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_variable.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_variable.py new file mode 100644 index 0000000..081bf09 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_variable.py @@ -0,0 +1,3062 @@ +from __future__ import annotations + +import warnings +from abc import ABC +from copy import copy, deepcopy +from datetime import datetime, timedelta +from textwrap import dedent +from typing import Generic + +import numpy as np +import pandas as pd +import pytest +import pytz + +from xarray import DataArray, Dataset, IndexVariable, Variable, set_options +from xarray.core import dtypes, duck_array_ops, indexing +from xarray.core.common import full_like, ones_like, zeros_like +from xarray.core.indexing import ( + BasicIndexer, + CopyOnWriteArray, + DaskIndexingAdapter, + LazilyIndexedArray, + MemoryCachedArray, + NumpyIndexingAdapter, + OuterIndexer, + PandasIndexingAdapter, + VectorizedIndexer, +) +from xarray.core.types import T_DuckArray +from xarray.core.utils import NDArrayMixin +from xarray.core.variable import as_compatible_data, as_variable +from xarray.namedarray.pycompat import array_type +from xarray.tests import ( + assert_allclose, + assert_array_equal, + assert_equal, + assert_identical, + assert_no_warnings, + has_pandas_3, + raise_if_dask_computes, + requires_bottleneck, + requires_cupy, + requires_dask, + requires_pint, + requires_sparse, + source_ndarray, +) +from xarray.tests.test_namedarray import NamedArraySubclassobjects + +dask_array_type = array_type("dask") + +_PAD_XR_NP_ARGS = [ + [{"x": (2, 1)}, ((2, 1), (0, 0), (0, 0))], + [{"x": 1}, ((1, 1), (0, 0), (0, 0))], + [{"y": (0, 3)}, ((0, 0), (0, 3), (0, 0))], + [{"x": (3, 1), "z": (2, 0)}, ((3, 1), (0, 0), (2, 0))], + [{"x": (3, 1), "z": 2}, ((3, 1), (0, 0), (2, 2))], +] + + +@pytest.fixture +def var(): + return Variable(dims=list("xyz"), data=np.random.rand(3, 4, 5)) + + +@pytest.mark.parametrize( + "data", + [ + np.array(["a", "bc", "def"], dtype=object), + np.array(["2019-01-01", "2019-01-02", "2019-01-03"], dtype="datetime64[ns]"), + ], +) +def test_as_compatible_data_writeable(data): + pd.set_option("mode.copy_on_write", True) + # GH8843, ensure writeable arrays for data_vars even with + # pandas copy-on-write mode + assert as_compatible_data(data).flags.writeable + pd.reset_option("mode.copy_on_write") + + +class VariableSubclassobjects(NamedArraySubclassobjects, ABC): + @pytest.fixture + def target(self, data): + data = 0.5 * np.arange(10).reshape(2, 5) + return Variable(["x", "y"], data) + + def test_getitem_dict(self): + v = self.cls(["x"], np.random.randn(5)) + actual = v[{"x": 0}] + expected = v[0] + assert_identical(expected, actual) + + def test_getitem_1d(self): + data = np.array([0, 1, 2]) + v = self.cls(["x"], data) + + v_new = v[dict(x=[0, 1])] + assert v_new.dims == ("x",) + assert_array_equal(v_new, data[[0, 1]]) + + v_new = v[dict(x=slice(None))] + assert v_new.dims == ("x",) + assert_array_equal(v_new, data) + + v_new = v[dict(x=Variable("a", [0, 1]))] + assert v_new.dims == ("a",) + assert_array_equal(v_new, data[[0, 1]]) + + v_new = v[dict(x=1)] + assert v_new.dims == () + assert_array_equal(v_new, data[1]) + + # tuple argument + v_new = v[slice(None)] + assert v_new.dims == ("x",) + assert_array_equal(v_new, data) + + def test_getitem_1d_fancy(self): + v = self.cls(["x"], [0, 1, 2]) + # 1d-variable should be indexable by multi-dimensional Variable + ind = Variable(("a", "b"), [[0, 1], [0, 1]]) + v_new = v[ind] + assert v_new.dims == ("a", "b") + expected = np.array(v._data)[([0, 1], [0, 1]), ...] + assert_array_equal(v_new, expected) + + # boolean indexing + ind = Variable(("x",), [True, False, True]) + v_new = v[ind] + assert_identical(v[[0, 2]], v_new) + v_new = v[[True, False, True]] + assert_identical(v[[0, 2]], v_new) + + with pytest.raises(IndexError, match=r"Boolean indexer should"): + ind = Variable(("a",), [True, False, True]) + v[ind] + + def test_getitem_with_mask(self): + v = self.cls(["x"], [0, 1, 2]) + assert_identical(v._getitem_with_mask(-1), Variable((), np.nan)) + assert_identical( + v._getitem_with_mask([0, -1, 1]), self.cls(["x"], [0, np.nan, 1]) + ) + assert_identical(v._getitem_with_mask(slice(2)), self.cls(["x"], [0, 1])) + assert_identical( + v._getitem_with_mask([0, -1, 1], fill_value=-99), + self.cls(["x"], [0, -99, 1]), + ) + + def test_getitem_with_mask_size_zero(self): + v = self.cls(["x"], []) + assert_identical(v._getitem_with_mask(-1), Variable((), np.nan)) + assert_identical( + v._getitem_with_mask([-1, -1, -1]), + self.cls(["x"], [np.nan, np.nan, np.nan]), + ) + + def test_getitem_with_mask_nd_indexer(self): + v = self.cls(["x"], [0, 1, 2]) + indexer = Variable(("x", "y"), [[0, -1], [-1, 2]]) + assert_identical(v._getitem_with_mask(indexer, fill_value=-1), indexer) + + def _assertIndexedLikeNDArray(self, variable, expected_value0, expected_dtype=None): + """Given a 1-dimensional variable, verify that the variable is indexed + like a numpy.ndarray. + """ + assert variable[0].shape == () + assert variable[0].ndim == 0 + assert variable[0].size == 1 + # test identity + assert variable.equals(variable.copy()) + assert variable.identical(variable.copy()) + # check value is equal for both ndarray and Variable + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "In the future, 'NAT == x'") + np.testing.assert_equal(variable.values[0], expected_value0) + np.testing.assert_equal(variable[0].values, expected_value0) + # check type or dtype is consistent for both ndarray and Variable + if expected_dtype is None: + # check output type instead of array dtype + assert type(variable.values[0]) == type(expected_value0) + assert type(variable[0].values) == type(expected_value0) + elif expected_dtype is not False: + assert variable.values[0].dtype == expected_dtype + assert variable[0].values.dtype == expected_dtype + + def test_index_0d_int(self): + for value, dtype in [(0, np.int_), (np.int32(0), np.int32)]: + x = self.cls(["x"], [value]) + self._assertIndexedLikeNDArray(x, value, dtype) + + def test_index_0d_float(self): + for value, dtype in [(0.5, float), (np.float32(0.5), np.float32)]: + x = self.cls(["x"], [value]) + self._assertIndexedLikeNDArray(x, value, dtype) + + def test_index_0d_string(self): + value = "foo" + dtype = np.dtype("U3") + x = self.cls(["x"], [value]) + self._assertIndexedLikeNDArray(x, value, dtype) + + @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") + def test_index_0d_datetime(self): + d = datetime(2000, 1, 1) + x = self.cls(["x"], [d]) + self._assertIndexedLikeNDArray(x, np.datetime64(d)) + + x = self.cls(["x"], [np.datetime64(d)]) + self._assertIndexedLikeNDArray(x, np.datetime64(d), "datetime64[ns]") + + x = self.cls(["x"], pd.DatetimeIndex([d])) + self._assertIndexedLikeNDArray(x, np.datetime64(d), "datetime64[ns]") + + @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") + def test_index_0d_timedelta64(self): + td = timedelta(hours=1) + + x = self.cls(["x"], [np.timedelta64(td)]) + self._assertIndexedLikeNDArray(x, np.timedelta64(td), "timedelta64[ns]") + + x = self.cls(["x"], pd.to_timedelta([td])) + self._assertIndexedLikeNDArray(x, np.timedelta64(td), "timedelta64[ns]") + + def test_index_0d_not_a_time(self): + d = np.datetime64("NaT", "ns") + x = self.cls(["x"], [d]) + self._assertIndexedLikeNDArray(x, d) + + def test_index_0d_object(self): + class HashableItemWrapper: + def __init__(self, item): + self.item = item + + def __eq__(self, other): + return self.item == other.item + + def __hash__(self): + return hash(self.item) + + def __repr__(self): + return f"{type(self).__name__}(item={self.item!r})" + + item = HashableItemWrapper((1, 2, 3)) + x = self.cls("x", [item]) + self._assertIndexedLikeNDArray(x, item, expected_dtype=False) + + def test_0d_object_array_with_list(self): + listarray = np.empty((1,), dtype=object) + listarray[0] = [1, 2, 3] + x = self.cls("x", listarray) + assert_array_equal(x.data, listarray) + assert_array_equal(x[0].data, listarray.squeeze()) + assert_array_equal(x.squeeze().data, listarray.squeeze()) + + @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") + def test_index_and_concat_datetime(self): + # regression test for #125 + date_range = pd.date_range("2011-09-01", periods=10) + for dates in [date_range, date_range.values, date_range.to_pydatetime()]: + expected = self.cls("t", dates) + for times in [ + [expected[i] for i in range(10)], + [expected[i : (i + 1)] for i in range(10)], + [expected[[i]] for i in range(10)], + ]: + actual = Variable.concat(times, "t") + assert expected.dtype == actual.dtype + assert_array_equal(expected, actual) + + def test_0d_time_data(self): + # regression test for #105 + x = self.cls("time", pd.date_range("2000-01-01", periods=5)) + expected = np.datetime64("2000-01-01", "ns") + assert x[0].values == expected + + @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") + def test_datetime64_conversion(self): + times = pd.date_range("2000-01-01", periods=3) + for values, preserve_source in [ + (times, True), + (times.values, True), + (times.values.astype("datetime64[s]"), False), + (times.to_pydatetime(), False), + ]: + v = self.cls(["t"], values) + assert v.dtype == np.dtype("datetime64[ns]") + assert_array_equal(v.values, times.values) + assert v.values.dtype == np.dtype("datetime64[ns]") + same_source = source_ndarray(v.values) is source_ndarray(values) + assert preserve_source == same_source + + @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") + def test_timedelta64_conversion(self): + times = pd.timedelta_range(start=0, periods=3) + for values, preserve_source in [ + (times, True), + (times.values, True), + (times.values.astype("timedelta64[s]"), False), + (times.to_pytimedelta(), False), + ]: + v = self.cls(["t"], values) + assert v.dtype == np.dtype("timedelta64[ns]") + assert_array_equal(v.values, times.values) + assert v.values.dtype == np.dtype("timedelta64[ns]") + same_source = source_ndarray(v.values) is source_ndarray(values) + assert preserve_source == same_source + + def test_object_conversion(self): + data = np.arange(5).astype(str).astype(object) + actual = self.cls("x", data) + assert actual.dtype == data.dtype + + @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") + def test_datetime64_valid_range(self): + data = np.datetime64("1250-01-01", "us") + pderror = pd.errors.OutOfBoundsDatetime + with pytest.raises(pderror, match=r"Out of bounds nanosecond"): + self.cls(["t"], [data]) + + @pytest.mark.xfail(reason="pandas issue 36615") + @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") + def test_timedelta64_valid_range(self): + data = np.timedelta64("200000", "D") + pderror = pd.errors.OutOfBoundsTimedelta + with pytest.raises(pderror, match=r"Out of bounds nanosecond"): + self.cls(["t"], [data]) + + def test_pandas_data(self): + v = self.cls(["x"], pd.Series([0, 1, 2], index=[3, 2, 1])) + assert_identical(v, v[[0, 1, 2]]) + v = self.cls(["x"], pd.Index([0, 1, 2])) + assert v[0].values == v.values[0] + + def test_pandas_period_index(self): + v = self.cls(["x"], pd.period_range(start="2000", periods=20, freq="D")) + v = v.load() # for dask-based Variable + assert v[0] == pd.Period("2000", freq="D") + assert "Period('2000-01-01', 'D')" in repr(v) + + @pytest.mark.parametrize("dtype", [float, int]) + def test_1d_math(self, dtype: np.typing.DTypeLike) -> None: + x = np.arange(5, dtype=dtype) + y = np.ones(5, dtype=dtype) + + # should we need `.to_base_variable()`? + # probably a break that `+v` changes type? + v = self.cls(["x"], x) + base_v = v.to_base_variable() + # unary ops + assert_identical(base_v, +v) + assert_identical(base_v, abs(v)) + assert_array_equal((-v).values, -x) + # binary ops with numbers + assert_identical(base_v, v + 0) + assert_identical(base_v, 0 + v) + assert_identical(base_v, v * 1) + if dtype is int: + assert_identical(base_v, v << 0) + assert_array_equal(v << 3, x << 3) + assert_array_equal(v >> 2, x >> 2) + # binary ops with numpy arrays + assert_array_equal((v * x).values, x**2) + assert_array_equal((x * v).values, x**2) + assert_array_equal(v - y, v - 1) + assert_array_equal(y - v, 1 - v) + if dtype is int: + assert_array_equal(v << x, x << x) + assert_array_equal(v >> x, x >> x) + # verify attributes are dropped + v2 = self.cls(["x"], x, {"units": "meters"}) + with set_options(keep_attrs=False): + assert_identical(base_v, +v2) + # binary ops with all variables + assert_array_equal(v + v, 2 * v) + w = self.cls(["x"], y, {"foo": "bar"}) + assert_identical(v + w, self.cls(["x"], x + y).to_base_variable()) + assert_array_equal((v * w).values, x * y) + + # something complicated + assert_array_equal((v**2 * w - 1 + x).values, x**2 * y - 1 + x) + # make sure dtype is preserved (for Index objects) + assert dtype == (+v).dtype + assert dtype == (+v).values.dtype + assert dtype == (0 + v).dtype + assert dtype == (0 + v).values.dtype + # check types of returned data + assert isinstance(+v, Variable) + assert not isinstance(+v, IndexVariable) + assert isinstance(0 + v, Variable) + assert not isinstance(0 + v, IndexVariable) + + def test_1d_reduce(self): + x = np.arange(5) + v = self.cls(["x"], x) + actual = v.sum() + expected = Variable((), 10) + assert_identical(expected, actual) + assert type(actual) is Variable + + def test_array_interface(self): + x = np.arange(5) + v = self.cls(["x"], x) + assert_array_equal(np.asarray(v), x) + # test patched in methods + assert_array_equal(v.astype(float), x.astype(float)) + # think this is a break, that argsort changes the type + assert_identical(v.argsort(), v.to_base_variable()) + assert_identical(v.clip(2, 3), self.cls("x", x.clip(2, 3)).to_base_variable()) + # test ufuncs + assert_identical(np.sin(v), self.cls(["x"], np.sin(x)).to_base_variable()) + assert isinstance(np.sin(v), Variable) + assert not isinstance(np.sin(v), IndexVariable) + + def example_1d_objects(self): + for data in [ + range(3), + 0.5 * np.arange(3), + 0.5 * np.arange(3, dtype=np.float32), + pd.date_range("2000-01-01", periods=3), + np.array(["a", "b", "c"], dtype=object), + ]: + yield (self.cls("x", data), data) + + def test___array__(self): + for v, data in self.example_1d_objects(): + assert_array_equal(v.values, np.asarray(data)) + assert_array_equal(np.asarray(v), np.asarray(data)) + assert v[0].values == np.asarray(data)[0] + assert np.asarray(v[0]) == np.asarray(data)[0] + + def test_equals_all_dtypes(self): + for v, _ in self.example_1d_objects(): + v2 = v.copy() + assert v.equals(v2) + assert v.identical(v2) + assert v.no_conflicts(v2) + assert v[0].equals(v2[0]) + assert v[0].identical(v2[0]) + assert v[0].no_conflicts(v2[0]) + assert v[:2].equals(v2[:2]) + assert v[:2].identical(v2[:2]) + assert v[:2].no_conflicts(v2[:2]) + + def test_eq_all_dtypes(self): + # ensure that we don't choke on comparisons for which numpy returns + # scalars + expected = Variable("x", 3 * [False]) + for v, _ in self.example_1d_objects(): + actual = "z" == v + assert_identical(expected, actual) + actual = ~("z" != v) + assert_identical(expected, actual) + + def test_encoding_preserved(self): + expected = self.cls("x", range(3), {"foo": 1}, {"bar": 2}) + for actual in [ + expected.T, + expected[...], + expected.squeeze(), + expected.isel(x=slice(None)), + expected.set_dims({"x": 3}), + expected.copy(deep=True), + expected.copy(deep=False), + ]: + assert_identical(expected.to_base_variable(), actual.to_base_variable()) + assert expected.encoding == actual.encoding + + def test_drop_encoding(self) -> None: + encoding1 = {"scale_factor": 1} + # encoding set via cls constructor + v1 = self.cls(["a"], [0, 1, 2], encoding=encoding1) + assert v1.encoding == encoding1 + v2 = v1.drop_encoding() + assert v1.encoding == encoding1 + assert v2.encoding == {} + + # encoding set via setter + encoding3 = {"scale_factor": 10} + v3 = self.cls(["a"], [0, 1, 2], encoding=encoding3) + assert v3.encoding == encoding3 + v4 = v3.drop_encoding() + assert v3.encoding == encoding3 + assert v4.encoding == {} + + def test_concat(self): + x = np.arange(5) + y = np.arange(5, 10) + v = self.cls(["a"], x) + w = self.cls(["a"], y) + assert_identical( + Variable(["b", "a"], np.array([x, y])), Variable.concat([v, w], "b") + ) + assert_identical( + Variable(["b", "a"], np.array([x, y])), Variable.concat((v, w), "b") + ) + assert_identical( + Variable(["b", "a"], np.array([x, y])), Variable.concat((v, w), "b") + ) + with pytest.raises(ValueError, match=r"Variable has dimensions"): + Variable.concat([v, Variable(["c"], y)], "b") + # test indexers + actual = Variable.concat( + [v, w], positions=[np.arange(0, 10, 2), np.arange(1, 10, 2)], dim="a" + ) + expected = Variable("a", np.array([x, y]).ravel(order="F")) + assert_identical(expected, actual) + # test concatenating along a dimension + v = Variable(["time", "x"], np.random.random((10, 8))) + assert_identical(v, Variable.concat([v[:5], v[5:]], "time")) + assert_identical(v, Variable.concat([v[:5], v[5:6], v[6:]], "time")) + assert_identical(v, Variable.concat([v[:1], v[1:]], "time")) + # test dimension order + assert_identical(v, Variable.concat([v[:, :5], v[:, 5:]], "x")) + with pytest.raises(ValueError, match=r"all input arrays must have"): + Variable.concat([v[:, 0], v[:, 1:]], "x") + + def test_concat_attrs(self): + # always keep attrs from first variable + v = self.cls("a", np.arange(5), {"foo": "bar"}) + w = self.cls("a", np.ones(5)) + expected = self.cls( + "a", np.concatenate([np.arange(5), np.ones(5)]) + ).to_base_variable() + expected.attrs["foo"] = "bar" + assert_identical(expected, Variable.concat([v, w], "a")) + + def test_concat_fixed_len_str(self): + # regression test for #217 + for kind in ["S", "U"]: + x = self.cls("animal", np.array(["horse"], dtype=kind)) + y = self.cls("animal", np.array(["aardvark"], dtype=kind)) + actual = Variable.concat([x, y], "animal") + expected = Variable("animal", np.array(["horse", "aardvark"], dtype=kind)) + assert_equal(expected, actual) + + def test_concat_number_strings(self): + # regression test for #305 + a = self.cls("x", ["0", "1", "2"]) + b = self.cls("x", ["3", "4"]) + actual = Variable.concat([a, b], dim="x") + expected = Variable("x", np.arange(5).astype(str)) + assert_identical(expected, actual) + assert actual.dtype.kind == expected.dtype.kind + + def test_concat_mixed_dtypes(self): + a = self.cls("x", [0, 1]) + b = self.cls("x", ["two"]) + actual = Variable.concat([a, b], dim="x") + expected = Variable("x", np.array([0, 1, "two"], dtype=object)) + assert_identical(expected, actual) + assert actual.dtype == object + + @pytest.mark.parametrize("deep", [True, False]) + @pytest.mark.parametrize("astype", [float, int, str]) + def test_copy(self, deep: bool, astype: type[object]) -> None: + v = self.cls("x", (0.5 * np.arange(10)).astype(astype), {"foo": "bar"}) + w = v.copy(deep=deep) + assert type(v) is type(w) + assert_identical(v, w) + assert v.dtype == w.dtype + if self.cls is Variable: + if deep: + assert source_ndarray(v.values) is not source_ndarray(w.values) + else: + assert source_ndarray(v.values) is source_ndarray(w.values) + assert_identical(v, copy(v)) + + def test_copy_deep_recursive(self) -> None: + # GH:issue:7111 + + # direct recursion + v = self.cls("x", [0, 1]) + v.attrs["other"] = v + + # TODO: cannot use assert_identical on recursive Vars yet... + # lets just ensure that deep copy works without RecursionError + v.copy(deep=True) + + # indirect recusrion + v2 = self.cls("y", [2, 3]) + v.attrs["other"] = v2 + v2.attrs["other"] = v + + # TODO: cannot use assert_identical on recursive Vars yet... + # lets just ensure that deep copy works without RecursionError + v.copy(deep=True) + v2.copy(deep=True) + + def test_copy_index(self): + midx = pd.MultiIndex.from_product( + [["a", "b"], [1, 2], [-1, -2]], names=("one", "two", "three") + ) + v = self.cls("x", midx) + for deep in [True, False]: + w = v.copy(deep=deep) + assert isinstance(w._data, PandasIndexingAdapter) + assert isinstance(w.to_index(), pd.MultiIndex) + assert_array_equal(v._data.array, w._data.array) + + def test_copy_with_data(self) -> None: + orig = Variable(("x", "y"), [[1.5, 2.0], [3.1, 4.3]], {"foo": "bar"}) + new_data = np.array([[2.5, 5.0], [7.1, 43]]) + actual = orig.copy(data=new_data) + expected = orig.copy() + expected.data = new_data + assert_identical(expected, actual) + + def test_copy_with_data_errors(self) -> None: + orig = Variable(("x", "y"), [[1.5, 2.0], [3.1, 4.3]], {"foo": "bar"}) + new_data = [2.5, 5.0] + with pytest.raises(ValueError, match=r"must match shape of object"): + orig.copy(data=new_data) # type: ignore[arg-type] + + def test_copy_index_with_data(self) -> None: + orig = IndexVariable("x", np.arange(5)) + new_data = np.arange(5, 10) + actual = orig.copy(data=new_data) + expected = IndexVariable("x", np.arange(5, 10)) + assert_identical(expected, actual) + + def test_copy_index_with_data_errors(self) -> None: + orig = IndexVariable("x", np.arange(5)) + new_data = np.arange(5, 20) + with pytest.raises(ValueError, match=r"must match shape of object"): + orig.copy(data=new_data) + with pytest.raises(ValueError, match=r"Cannot assign to the .data"): + orig.data = new_data + with pytest.raises(ValueError, match=r"Cannot assign to the .values"): + orig.values = new_data + + def test_replace(self): + var = Variable(("x", "y"), [[1.5, 2.0], [3.1, 4.3]], {"foo": "bar"}) + result = var._replace() + assert_identical(result, var) + + new_data = np.arange(4).reshape(2, 2) + result = var._replace(data=new_data) + assert_array_equal(result.data, new_data) + + def test_real_and_imag(self): + v = self.cls("x", np.arange(3) - 1j * np.arange(3), {"foo": "bar"}) + expected_re = self.cls("x", np.arange(3), {"foo": "bar"}) + assert_identical(v.real, expected_re) + + expected_im = self.cls("x", -np.arange(3), {"foo": "bar"}) + assert_identical(v.imag, expected_im) + + expected_abs = self.cls("x", np.sqrt(2 * np.arange(3) ** 2)).to_base_variable() + assert_allclose(abs(v), expected_abs) + + def test_aggregate_complex(self): + # should skip NaNs + v = self.cls("x", [1, 2j, np.nan]) + expected = Variable((), 0.5 + 1j) + assert_allclose(v.mean(), expected) + + def test_pandas_cateogrical_dtype(self): + data = pd.Categorical(np.arange(10, dtype="int64")) + v = self.cls("x", data) + print(v) # should not error + assert v.dtype == "int64" + + def test_pandas_datetime64_with_tz(self): + data = pd.date_range( + start="2000-01-01", + tz=pytz.timezone("America/New_York"), + periods=10, + freq="1h", + ) + v = self.cls("x", data) + print(v) # should not error + if "America/New_York" in str(data.dtype): + # pandas is new enough that it has datetime64 with timezone dtype + assert v.dtype == "object" + + def test_multiindex(self): + idx = pd.MultiIndex.from_product([list("abc"), [0, 1]]) + v = self.cls("x", idx) + assert_identical(Variable((), ("a", 0)), v[0]) + assert_identical(v, v[:]) + + def test_load(self): + array = self.cls("x", np.arange(5)) + orig_data = array._data + copied = array.copy(deep=True) + if array.chunks is None: + array.load() + assert type(array._data) is type(orig_data) + assert type(copied._data) is type(orig_data) + assert_identical(array, copied) + + def test_getitem_advanced(self): + v = self.cls(["x", "y"], [[0, 1, 2], [3, 4, 5]]) + v_data = v.compute().data + + # orthogonal indexing + v_new = v[([0, 1], [1, 0])] + assert v_new.dims == ("x", "y") + assert_array_equal(v_new, v_data[[0, 1]][:, [1, 0]]) + + v_new = v[[0, 1]] + assert v_new.dims == ("x", "y") + assert_array_equal(v_new, v_data[[0, 1]]) + + # with mixed arguments + ind = Variable(["a"], [0, 1]) + v_new = v[dict(x=[0, 1], y=ind)] + assert v_new.dims == ("x", "a") + assert_array_equal(v_new, v_data[[0, 1]][:, [0, 1]]) + + # boolean indexing + v_new = v[dict(x=[True, False], y=[False, True, False])] + assert v_new.dims == ("x", "y") + assert_array_equal(v_new, v_data[0][1]) + + # with scalar variable + ind = Variable((), 2) + v_new = v[dict(y=ind)] + expected = v[dict(y=2)] + assert_array_equal(v_new, expected) + + # with boolean variable with wrong shape + ind = np.array([True, False]) + with pytest.raises(IndexError, match=r"Boolean array size 2 is "): + v[Variable(("a", "b"), [[0, 1]]), ind] + + # boolean indexing with different dimension + ind = Variable(["a"], [True, False, False]) + with pytest.raises(IndexError, match=r"Boolean indexer should be"): + v[dict(y=ind)] + + def test_getitem_uint_1d(self): + # regression test for #1405 + v = self.cls(["x"], [0, 1, 2]) + v_data = v.compute().data + + v_new = v[np.array([0])] + assert_array_equal(v_new, v_data[0]) + v_new = v[np.array([0], dtype="uint64")] + assert_array_equal(v_new, v_data[0]) + + def test_getitem_uint(self): + # regression test for #1405 + v = self.cls(["x", "y"], [[0, 1, 2], [3, 4, 5]]) + v_data = v.compute().data + + v_new = v[np.array([0])] + assert_array_equal(v_new, v_data[[0], :]) + v_new = v[np.array([0], dtype="uint64")] + assert_array_equal(v_new, v_data[[0], :]) + + v_new = v[np.uint64(0)] + assert_array_equal(v_new, v_data[0, :]) + + def test_getitem_0d_array(self): + # make sure 0d-np.array can be used as an indexer + v = self.cls(["x"], [0, 1, 2]) + v_data = v.compute().data + + v_new = v[np.array([0])[0]] + assert_array_equal(v_new, v_data[0]) + + v_new = v[np.array(0)] + assert_array_equal(v_new, v_data[0]) + + v_new = v[Variable((), np.array(0))] + assert_array_equal(v_new, v_data[0]) + + def test_getitem_fancy(self): + v = self.cls(["x", "y"], [[0, 1, 2], [3, 4, 5]]) + v_data = v.compute().data + + ind = Variable(["a", "b"], [[0, 1, 1], [1, 1, 0]]) + v_new = v[ind] + assert v_new.dims == ("a", "b", "y") + assert_array_equal(v_new, v_data[[[0, 1, 1], [1, 1, 0]], :]) + + # It would be ok if indexed with the multi-dimensional array including + # the same name + ind = Variable(["x", "b"], [[0, 1, 1], [1, 1, 0]]) + v_new = v[ind] + assert v_new.dims == ("x", "b", "y") + assert_array_equal(v_new, v_data[[[0, 1, 1], [1, 1, 0]], :]) + + ind = Variable(["a", "b"], [[0, 1, 2], [2, 1, 0]]) + v_new = v[dict(y=ind)] + assert v_new.dims == ("x", "a", "b") + assert_array_equal(v_new, v_data[:, ([0, 1, 2], [2, 1, 0])]) + + ind = Variable(["a", "b"], [[0, 0], [1, 1]]) + v_new = v[dict(x=[1, 0], y=ind)] + assert v_new.dims == ("x", "a", "b") + assert_array_equal(v_new, v_data[[1, 0]][:, ind]) + + # along diagonal + ind = Variable(["a"], [0, 1]) + v_new = v[ind, ind] + assert v_new.dims == ("a",) + assert_array_equal(v_new, v_data[[0, 1], [0, 1]]) + + # with integer + ind = Variable(["a", "b"], [[0, 0], [1, 1]]) + v_new = v[dict(x=0, y=ind)] + assert v_new.dims == ("a", "b") + assert_array_equal(v_new[0], v_data[0][[0, 0]]) + assert_array_equal(v_new[1], v_data[0][[1, 1]]) + + # with slice + ind = Variable(["a", "b"], [[0, 0], [1, 1]]) + v_new = v[dict(x=slice(None), y=ind)] + assert v_new.dims == ("x", "a", "b") + assert_array_equal(v_new, v_data[:, [[0, 0], [1, 1]]]) + + ind = Variable(["a", "b"], [[0, 0], [1, 1]]) + v_new = v[dict(x=ind, y=slice(None))] + assert v_new.dims == ("a", "b", "y") + assert_array_equal(v_new, v_data[[[0, 0], [1, 1]], :]) + + ind = Variable(["a", "b"], [[0, 0], [1, 1]]) + v_new = v[dict(x=ind, y=slice(None, 1))] + assert v_new.dims == ("a", "b", "y") + assert_array_equal(v_new, v_data[[[0, 0], [1, 1]], slice(None, 1)]) + + # slice matches explicit dimension + ind = Variable(["y"], [0, 1]) + v_new = v[ind, :2] + assert v_new.dims == ("y",) + assert_array_equal(v_new, v_data[[0, 1], [0, 1]]) + + # with multiple slices + v = self.cls(["x", "y", "z"], [[[1, 2, 3], [4, 5, 6]]]) + ind = Variable(["a", "b"], [[0]]) + v_new = v[ind, :, :] + expected = Variable(["a", "b", "y", "z"], v.data[np.newaxis, ...]) + assert_identical(v_new, expected) + + v = Variable(["w", "x", "y", "z"], [[[[1, 2, 3], [4, 5, 6]]]]) + ind = Variable(["y"], [0]) + v_new = v[ind, :, 1:2, 2] + expected = Variable(["y", "x"], [[6]]) + assert_identical(v_new, expected) + + # slice and vector mixed indexing resulting in the same dimension + v = Variable(["x", "y", "z"], np.arange(60).reshape(3, 4, 5)) + ind = Variable(["x"], [0, 1, 2]) + v_new = v[:, ind] + expected = Variable(("x", "z"), np.zeros((3, 5))) + expected[0] = v.data[0, 0] + expected[1] = v.data[1, 1] + expected[2] = v.data[2, 2] + assert_identical(v_new, expected) + + v_new = v[:, ind.data] + assert v_new.shape == (3, 3, 5) + + def test_getitem_error(self): + v = self.cls(["x", "y"], [[0, 1, 2], [3, 4, 5]]) + + with pytest.raises(IndexError, match=r"labeled multi-"): + v[[[0, 1], [1, 2]]] + + ind_x = Variable(["a"], [0, 1, 1]) + ind_y = Variable(["a"], [0, 1]) + with pytest.raises(IndexError, match=r"Dimensions of indexers "): + v[ind_x, ind_y] + + ind = Variable(["a", "b"], [[True, False], [False, True]]) + with pytest.raises(IndexError, match=r"2-dimensional boolean"): + v[dict(x=ind)] + + v = Variable(["x", "y", "z"], np.arange(60).reshape(3, 4, 5)) + ind = Variable(["x"], [0, 1]) + with pytest.raises(IndexError, match=r"Dimensions of indexers mis"): + v[:, ind] + + @pytest.mark.parametrize( + "mode", + [ + "mean", + "median", + "reflect", + "edge", + "linear_ramp", + "maximum", + "minimum", + "symmetric", + "wrap", + ], + ) + @pytest.mark.parametrize("xr_arg, np_arg", _PAD_XR_NP_ARGS) + @pytest.mark.filterwarnings( + r"ignore:dask.array.pad.+? converts integers to floats." + ) + def test_pad(self, mode, xr_arg, np_arg): + data = np.arange(4 * 3 * 2).reshape(4, 3, 2) + v = self.cls(["x", "y", "z"], data) + + actual = v.pad(mode=mode, **xr_arg) + expected = np.pad(data, np_arg, mode=mode) + + assert_array_equal(actual, expected) + assert isinstance(actual._data, type(v._data)) + + @pytest.mark.parametrize("xr_arg, np_arg", _PAD_XR_NP_ARGS) + def test_pad_constant_values(self, xr_arg, np_arg): + data = np.arange(4 * 3 * 2).reshape(4, 3, 2) + v = self.cls(["x", "y", "z"], data) + + actual = v.pad(**xr_arg) + expected = np.pad( + np.array(duck_array_ops.astype(v.data, float)), + np_arg, + mode="constant", + constant_values=np.nan, + ) + assert_array_equal(actual, expected) + assert isinstance(actual._data, type(v._data)) + + # for the boolean array, we pad False + data = np.full_like(data, False, dtype=bool).reshape(4, 3, 2) + v = self.cls(["x", "y", "z"], data) + + actual = v.pad(mode="constant", constant_values=False, **xr_arg) + expected = np.pad( + np.array(v.data), np_arg, mode="constant", constant_values=False + ) + assert_array_equal(actual, expected) + + @pytest.mark.parametrize( + ["keep_attrs", "attrs", "expected"], + [ + pytest.param(None, {"a": 1, "b": 2}, {"a": 1, "b": 2}, id="default"), + pytest.param(False, {"a": 1, "b": 2}, {}, id="False"), + pytest.param(True, {"a": 1, "b": 2}, {"a": 1, "b": 2}, id="True"), + ], + ) + def test_pad_keep_attrs(self, keep_attrs, attrs, expected): + data = np.arange(10, dtype=float) + v = self.cls(["x"], data, attrs) + + keep_attrs_ = "default" if keep_attrs is None else keep_attrs + + with set_options(keep_attrs=keep_attrs_): + actual = v.pad({"x": (1, 1)}, mode="constant", constant_values=np.nan) + + assert actual.attrs == expected + + actual = v.pad( + {"x": (1, 1)}, + mode="constant", + constant_values=np.nan, + keep_attrs=keep_attrs, + ) + assert actual.attrs == expected + + @pytest.mark.parametrize("d, w", (("x", 3), ("y", 5))) + def test_rolling_window(self, d, w): + # Just a working test. See test_nputils for the algorithm validation + v = self.cls(["x", "y", "z"], np.arange(40 * 30 * 2).reshape(40, 30, 2)) + v_rolling = v.rolling_window(d, w, d + "_window") + assert v_rolling.dims == ("x", "y", "z", d + "_window") + assert v_rolling.shape == v.shape + (w,) + + v_rolling = v.rolling_window(d, w, d + "_window", center=True) + assert v_rolling.dims == ("x", "y", "z", d + "_window") + assert v_rolling.shape == v.shape + (w,) + + # dask and numpy result should be the same + v_loaded = v.load().rolling_window(d, w, d + "_window", center=True) + assert_array_equal(v_rolling, v_loaded) + + # numpy backend should not be over-written + if isinstance(v._data, np.ndarray): + with pytest.raises(ValueError): + v_loaded[0] = 1.0 + + def test_rolling_1d(self): + x = self.cls("x", np.array([1, 2, 3, 4], dtype=float)) + + kwargs = dict(dim="x", window=3, window_dim="xw") + actual = x.rolling_window(**kwargs, center=True, fill_value=np.nan) + expected = Variable( + ("x", "xw"), + np.array( + [[np.nan, 1, 2], [1, 2, 3], [2, 3, 4], [3, 4, np.nan]], dtype=float + ), + ) + assert_equal(actual, expected) + + actual = x.rolling_window(**kwargs, center=False, fill_value=0.0) + expected = self.cls( + ("x", "xw"), + np.array([[0, 0, 1], [0, 1, 2], [1, 2, 3], [2, 3, 4]], dtype=float), + ) + assert_equal(actual, expected) + + x = self.cls(("y", "x"), np.stack([x, x * 1.1])) + actual = x.rolling_window(**kwargs, center=False, fill_value=0.0) + expected = self.cls( + ("y", "x", "xw"), np.stack([expected.data, expected.data * 1.1], axis=0) + ) + assert_equal(actual, expected) + + @pytest.mark.parametrize("center", [[True, True], [False, False]]) + @pytest.mark.parametrize("dims", [("x", "y"), ("y", "z"), ("z", "x")]) + def test_nd_rolling(self, center, dims): + x = self.cls( + ("x", "y", "z"), + np.arange(7 * 6 * 8).reshape(7, 6, 8).astype(float), + ) + window = [3, 3] + actual = x.rolling_window( + dim=dims, + window=window, + window_dim=[f"{k}w" for k in dims], + center=center, + fill_value=np.nan, + ) + expected = x + for dim, win, cent in zip(dims, window, center): + expected = expected.rolling_window( + dim=dim, + window=win, + window_dim=f"{dim}w", + center=cent, + fill_value=np.nan, + ) + assert_equal(actual, expected) + + @pytest.mark.parametrize( + ("dim, window, window_dim, center"), + [ + ("x", [3, 3], "x_w", True), + ("x", 3, ("x_w", "x_w"), True), + ("x", 3, "x_w", [True, True]), + ], + ) + def test_rolling_window_errors(self, dim, window, window_dim, center): + x = self.cls( + ("x", "y", "z"), + np.arange(7 * 6 * 8).reshape(7, 6, 8).astype(float), + ) + with pytest.raises(ValueError): + x.rolling_window( + dim=dim, + window=window, + window_dim=window_dim, + center=center, + ) + + +class TestVariable(VariableSubclassobjects): + def cls(self, *args, **kwargs) -> Variable: + return Variable(*args, **kwargs) + + @pytest.fixture(autouse=True) + def setup(self): + self.d = np.random.random((10, 3)).astype(np.float64) + + def test_values(self): + v = Variable(["time", "x"], self.d) + assert_array_equal(v.values, self.d) + assert source_ndarray(v.values) is self.d + with pytest.raises(ValueError): + # wrong size + v.values = np.random.random(5) + d2 = np.random.random((10, 3)) + v.values = d2 + assert source_ndarray(v.values) is d2 + + def test_numpy_same_methods(self): + v = Variable([], np.float32(0.0)) + assert v.item() == 0 + assert type(v.item()) is float # noqa: E721 + + v = IndexVariable("x", np.arange(5)) + assert 2 == v.searchsorted(2) + + @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") + def test_datetime64_conversion_scalar(self): + expected = np.datetime64("2000-01-01", "ns") + for values in [ + np.datetime64("2000-01-01"), + pd.Timestamp("2000-01-01T00"), + datetime(2000, 1, 1), + ]: + v = Variable([], values) + assert v.dtype == np.dtype("datetime64[ns]") + assert v.values == expected + assert v.values.dtype == np.dtype("datetime64[ns]") + + @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") + def test_timedelta64_conversion_scalar(self): + expected = np.timedelta64(24 * 60 * 60 * 10**9, "ns") + for values in [ + np.timedelta64(1, "D"), + pd.Timedelta("1 day"), + timedelta(days=1), + ]: + v = Variable([], values) + assert v.dtype == np.dtype("timedelta64[ns]") + assert v.values == expected + assert v.values.dtype == np.dtype("timedelta64[ns]") + + def test_0d_str(self): + v = Variable([], "foo") + assert v.dtype == np.dtype("U3") + assert v.values == "foo" + + v = Variable([], np.bytes_("foo")) + assert v.dtype == np.dtype("S3") + assert v.values == "foo".encode("ascii") + + def test_0d_datetime(self): + v = Variable([], pd.Timestamp("2000-01-01")) + assert v.dtype == np.dtype("datetime64[ns]") + assert v.values == np.datetime64("2000-01-01", "ns") + + @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") + def test_0d_timedelta(self): + for td in [pd.to_timedelta("1s"), np.timedelta64(1, "s")]: + v = Variable([], td) + assert v.dtype == np.dtype("timedelta64[ns]") + assert v.values == np.timedelta64(10**9, "ns") + + def test_equals_and_identical(self): + d = np.random.rand(10, 3) + d[0, 0] = np.nan + v1 = Variable(("dim1", "dim2"), data=d, attrs={"att1": 3, "att2": [1, 2, 3]}) + v2 = Variable(("dim1", "dim2"), data=d, attrs={"att1": 3, "att2": [1, 2, 3]}) + assert v1.equals(v2) + assert v1.identical(v2) + + v3 = Variable(("dim1", "dim3"), data=d) + assert not v1.equals(v3) + + v4 = Variable(("dim1", "dim2"), data=d) + assert v1.equals(v4) + assert not v1.identical(v4) + + v5 = deepcopy(v1) + v5.values[:] = np.random.rand(10, 3) + assert not v1.equals(v5) + + assert not v1.equals(None) + assert not v1.equals(d) + + assert not v1.identical(None) + assert not v1.identical(d) + + def test_broadcast_equals(self): + v1 = Variable((), np.nan) + v2 = Variable(("x"), [np.nan, np.nan]) + assert v1.broadcast_equals(v2) + assert not v1.equals(v2) + assert not v1.identical(v2) + + v3 = Variable(("x"), [np.nan]) + assert v1.broadcast_equals(v3) + assert not v1.equals(v3) + assert not v1.identical(v3) + + assert not v1.broadcast_equals(None) + + v4 = Variable(("x"), [np.nan] * 3) + assert not v2.broadcast_equals(v4) + + def test_no_conflicts(self): + v1 = Variable(("x"), [1, 2, np.nan, np.nan]) + v2 = Variable(("x"), [np.nan, 2, 3, np.nan]) + assert v1.no_conflicts(v2) + assert not v1.equals(v2) + assert not v1.broadcast_equals(v2) + assert not v1.identical(v2) + + assert not v1.no_conflicts(None) + + v3 = Variable(("y"), [np.nan, 2, 3, np.nan]) + assert not v3.no_conflicts(v1) + + d = np.array([1, 2, np.nan, np.nan]) + assert not v1.no_conflicts(d) + assert not v2.no_conflicts(d) + + v4 = Variable(("w", "x"), [d]) + assert v1.no_conflicts(v4) + + def test_as_variable(self): + data = np.arange(10) + expected = Variable("x", data) + expected_extra = Variable( + "x", data, attrs={"myattr": "val"}, encoding={"scale_factor": 1} + ) + + assert_identical(expected, as_variable(expected)) + + ds = Dataset({"x": expected}) + var = as_variable(ds["x"]).to_base_variable() + assert_identical(expected, var) + assert not isinstance(ds["x"], Variable) + assert isinstance(as_variable(ds["x"]), Variable) + + xarray_tuple = ( + expected_extra.dims, + expected_extra.values, + expected_extra.attrs, + expected_extra.encoding, + ) + assert_identical(expected_extra, as_variable(xarray_tuple)) + + with pytest.raises(TypeError, match=r"tuple of form"): + as_variable(tuple(data)) + with pytest.raises(ValueError, match=r"tuple of form"): # GH1016 + as_variable(("five", "six", "seven")) + with pytest.raises(TypeError, match=r"without an explicit list of dimensions"): + as_variable(data) + + with pytest.warns(FutureWarning, match="IndexVariable"): + actual = as_variable(data, name="x") + assert_identical(expected.to_index_variable(), actual) + + actual = as_variable(0) + expected = Variable([], 0) + assert_identical(expected, actual) + + data = np.arange(9).reshape((3, 3)) + expected = Variable(("x", "y"), data) + with pytest.raises(ValueError, match=r"without explicit dimension names"): + as_variable(data, name="x") + + # name of nD variable matches dimension name + actual = as_variable(expected, name="x") + assert_identical(expected, actual) + + # test datetime, timedelta conversion + dt = np.array([datetime(1999, 1, 1) + timedelta(days=x) for x in range(10)]) + with pytest.warns(FutureWarning, match="IndexVariable"): + assert as_variable(dt, "time").dtype.kind == "M" + td = np.array([timedelta(days=x) for x in range(10)]) + with pytest.warns(FutureWarning, match="IndexVariable"): + assert as_variable(td, "time").dtype.kind == "m" + + with pytest.raises(TypeError): + as_variable(("x", DataArray([]))) + + def test_repr(self): + v = Variable(["time", "x"], [[1, 2, 3], [4, 5, 6]], {"foo": "bar"}) + v = v.astype(np.uint64) + expected = dedent( + """ + Size: 48B + array([[1, 2, 3], + [4, 5, 6]], dtype=uint64) + Attributes: + foo: bar + """ + ).strip() + assert expected == repr(v) + + def test_repr_lazy_data(self): + v = Variable("x", LazilyIndexedArray(np.arange(2e5))) + assert "200000 values with dtype" in repr(v) + assert isinstance(v._data, LazilyIndexedArray) + + def test_detect_indexer_type(self): + """Tests indexer type was correctly detected.""" + data = np.random.random((10, 11)) + v = Variable(["x", "y"], data) + + _, ind, _ = v._broadcast_indexes((0, 1)) + assert type(ind) == indexing.BasicIndexer + + _, ind, _ = v._broadcast_indexes((0, slice(0, 8, 2))) + assert type(ind) == indexing.BasicIndexer + + _, ind, _ = v._broadcast_indexes((0, [0, 1])) + assert type(ind) == indexing.OuterIndexer + + _, ind, _ = v._broadcast_indexes(([0, 1], 1)) + assert type(ind) == indexing.OuterIndexer + + _, ind, _ = v._broadcast_indexes(([0, 1], [1, 2])) + assert type(ind) == indexing.OuterIndexer + + _, ind, _ = v._broadcast_indexes(([0, 1], slice(0, 8, 2))) + assert type(ind) == indexing.OuterIndexer + + vind = Variable(("a",), [0, 1]) + _, ind, _ = v._broadcast_indexes((vind, slice(0, 8, 2))) + assert type(ind) == indexing.OuterIndexer + + vind = Variable(("y",), [0, 1]) + _, ind, _ = v._broadcast_indexes((vind, 3)) + assert type(ind) == indexing.OuterIndexer + + vind = Variable(("a",), [0, 1]) + _, ind, _ = v._broadcast_indexes((vind, vind)) + assert type(ind) == indexing.VectorizedIndexer + + vind = Variable(("a", "b"), [[0, 2], [1, 3]]) + _, ind, _ = v._broadcast_indexes((vind, 3)) + assert type(ind) == indexing.VectorizedIndexer + + def test_indexer_type(self): + # GH:issue:1688. Wrong indexer type induces NotImplementedError + data = np.random.random((10, 11)) + v = Variable(["x", "y"], data) + + def assert_indexer_type(key, object_type): + dims, index_tuple, new_order = v._broadcast_indexes(key) + assert isinstance(index_tuple, object_type) + + # should return BasicIndexer + assert_indexer_type((0, 1), BasicIndexer) + assert_indexer_type((0, slice(None, None)), BasicIndexer) + assert_indexer_type((Variable([], 3), slice(None, None)), BasicIndexer) + assert_indexer_type((Variable([], 3), (Variable([], 6))), BasicIndexer) + + # should return OuterIndexer + assert_indexer_type(([0, 1], 1), OuterIndexer) + assert_indexer_type(([0, 1], [1, 2]), OuterIndexer) + assert_indexer_type((Variable(("x"), [0, 1]), 1), OuterIndexer) + assert_indexer_type((Variable(("x"), [0, 1]), slice(None, None)), OuterIndexer) + assert_indexer_type( + (Variable(("x"), [0, 1]), Variable(("y"), [0, 1])), OuterIndexer + ) + + # should return VectorizedIndexer + assert_indexer_type((Variable(("y"), [0, 1]), [0, 1]), VectorizedIndexer) + assert_indexer_type( + (Variable(("z"), [0, 1]), Variable(("z"), [0, 1])), VectorizedIndexer + ) + assert_indexer_type( + ( + Variable(("a", "b"), [[0, 1], [1, 2]]), + Variable(("a", "b"), [[0, 1], [1, 2]]), + ), + VectorizedIndexer, + ) + + def test_items(self): + data = np.random.random((10, 11)) + v = Variable(["x", "y"], data) + # test slicing + assert_identical(v, v[:]) + assert_identical(v, v[...]) + assert_identical(Variable(["y"], data[0]), v[0]) + assert_identical(Variable(["x"], data[:, 0]), v[:, 0]) + assert_identical(Variable(["x", "y"], data[:3, :2]), v[:3, :2]) + # test array indexing + x = Variable(["x"], np.arange(10)) + y = Variable(["y"], np.arange(11)) + assert_identical(v, v[x.values]) + assert_identical(v, v[x]) + assert_identical(v[:3], v[x < 3]) + assert_identical(v[:, 3:], v[:, y >= 3]) + assert_identical(v[:3, 3:], v[x < 3, y >= 3]) + assert_identical(v[:3, :2], v[x[:3], y[:2]]) + assert_identical(v[:3, :2], v[range(3), range(2)]) + # test iteration + for n, item in enumerate(v): + assert_identical(Variable(["y"], data[n]), item) + with pytest.raises(TypeError, match=r"iteration over a 0-d"): + iter(Variable([], 0)) + # test setting + v.values[:] = 0 + assert np.all(v.values == 0) + # test orthogonal setting + v[range(10), range(11)] = 1 + assert_array_equal(v.values, np.ones((10, 11))) + + def test_getitem_basic(self): + v = self.cls(["x", "y"], [[0, 1, 2], [3, 4, 5]]) + + # int argument + v_new = v[0] + assert v_new.dims == ("y",) + assert_array_equal(v_new, v._data[0]) + + # slice argument + v_new = v[:2] + assert v_new.dims == ("x", "y") + assert_array_equal(v_new, v._data[:2]) + + # list arguments + v_new = v[[0]] + assert v_new.dims == ("x", "y") + assert_array_equal(v_new, v._data[[0]]) + + v_new = v[[]] + assert v_new.dims == ("x", "y") + assert_array_equal(v_new, v._data[[]]) + + # dict arguments + v_new = v[dict(x=0)] + assert v_new.dims == ("y",) + assert_array_equal(v_new, v._data[0]) + + v_new = v[dict(x=0, y=slice(None))] + assert v_new.dims == ("y",) + assert_array_equal(v_new, v._data[0]) + + v_new = v[dict(x=0, y=1)] + assert v_new.dims == () + assert_array_equal(v_new, v._data[0, 1]) + + v_new = v[dict(y=1)] + assert v_new.dims == ("x",) + assert_array_equal(v_new, v._data[:, 1]) + + # tuple argument + v_new = v[(slice(None), 1)] + assert v_new.dims == ("x",) + assert_array_equal(v_new, v._data[:, 1]) + + # test that we obtain a modifiable view when taking a 0d slice + v_new = v[0, 0] + v_new[...] += 99 + assert_array_equal(v_new, v._data[0, 0]) + + def test_getitem_with_mask_2d_input(self): + v = Variable(("x", "y"), [[0, 1, 2], [3, 4, 5]]) + assert_identical( + v._getitem_with_mask(([-1, 0], [1, -1])), + Variable(("x", "y"), [[np.nan, np.nan], [1, np.nan]]), + ) + assert_identical(v._getitem_with_mask((slice(2), [0, 1, 2])), v) + + def test_isel(self): + v = Variable(["time", "x"], self.d) + assert_identical(v.isel(time=slice(None)), v) + assert_identical(v.isel(time=0), v[0]) + assert_identical(v.isel(time=slice(0, 3)), v[:3]) + assert_identical(v.isel(x=0), v[:, 0]) + assert_identical(v.isel(x=[0, 2]), v[:, [0, 2]]) + assert_identical(v.isel(time=[]), v[[]]) + with pytest.raises( + ValueError, + match=r"Dimensions {'not_a_dim'} do not exist. Expected one or more of " + r"\('time', 'x'\)", + ): + v.isel(not_a_dim=0) + with pytest.warns( + UserWarning, + match=r"Dimensions {'not_a_dim'} do not exist. Expected one or more of " + r"\('time', 'x'\)", + ): + v.isel(not_a_dim=0, missing_dims="warn") + assert_identical(v, v.isel(not_a_dim=0, missing_dims="ignore")) + + def test_index_0d_numpy_string(self): + # regression test to verify our work around for indexing 0d strings + v = Variable([], np.bytes_("asdf")) + assert_identical(v[()], v) + + v = Variable([], np.str_("asdf")) + assert_identical(v[()], v) + + def test_indexing_0d_unicode(self): + # regression test for GH568 + actual = Variable(("x"), ["tmax"])[0][()] + expected = Variable((), "tmax") + assert_identical(actual, expected) + + @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0]) + def test_shift(self, fill_value): + v = Variable("x", [1, 2, 3, 4, 5]) + + assert_identical(v, v.shift(x=0)) + assert v is not v.shift(x=0) + + expected = Variable("x", [np.nan, np.nan, 1, 2, 3]) + assert_identical(expected, v.shift(x=2)) + + if fill_value == dtypes.NA: + # if we supply the default, we expect the missing value for a + # float array + fill_value_exp = np.nan + else: + fill_value_exp = fill_value + + expected = Variable("x", [fill_value_exp, 1, 2, 3, 4]) + assert_identical(expected, v.shift(x=1, fill_value=fill_value)) + + expected = Variable("x", [2, 3, 4, 5, fill_value_exp]) + assert_identical(expected, v.shift(x=-1, fill_value=fill_value)) + + expected = Variable("x", [fill_value_exp] * 5) + assert_identical(expected, v.shift(x=5, fill_value=fill_value)) + assert_identical(expected, v.shift(x=6, fill_value=fill_value)) + + with pytest.raises(ValueError, match=r"dimension"): + v.shift(z=0) + + v = Variable("x", [1, 2, 3, 4, 5], {"foo": "bar"}) + assert_identical(v, v.shift(x=0)) + + expected = Variable("x", [fill_value_exp, 1, 2, 3, 4], {"foo": "bar"}) + assert_identical(expected, v.shift(x=1, fill_value=fill_value)) + + def test_shift2d(self): + v = Variable(("x", "y"), [[1, 2], [3, 4]]) + expected = Variable(("x", "y"), [[np.nan, np.nan], [np.nan, 1]]) + assert_identical(expected, v.shift(x=1, y=1)) + + def test_roll(self): + v = Variable("x", [1, 2, 3, 4, 5]) + + assert_identical(v, v.roll(x=0)) + assert v is not v.roll(x=0) + + expected = Variable("x", [5, 1, 2, 3, 4]) + assert_identical(expected, v.roll(x=1)) + assert_identical(expected, v.roll(x=-4)) + assert_identical(expected, v.roll(x=6)) + + expected = Variable("x", [4, 5, 1, 2, 3]) + assert_identical(expected, v.roll(x=2)) + assert_identical(expected, v.roll(x=-3)) + + with pytest.raises(ValueError, match=r"dimension"): + v.roll(z=0) + + def test_roll_consistency(self): + v = Variable(("x", "y"), np.random.randn(5, 6)) + + for axis, dim in [(0, "x"), (1, "y")]: + for shift in [-3, 0, 1, 7, 11]: + expected = np.roll(v.values, shift, axis=axis) + actual = v.roll(**{dim: shift}).values + assert_array_equal(expected, actual) + + def test_transpose(self): + v = Variable(["time", "x"], self.d) + v2 = Variable(["x", "time"], self.d.T) + assert_identical(v, v2.transpose()) + assert_identical(v.transpose(), v.T) + x = np.random.randn(2, 3, 4, 5) + w = Variable(["a", "b", "c", "d"], x) + w2 = Variable(["d", "b", "c", "a"], np.einsum("abcd->dbca", x)) + assert w2.shape == (5, 3, 4, 2) + assert_identical(w2, w.transpose("d", "b", "c", "a")) + assert_identical(w2, w.transpose("d", ..., "a")) + assert_identical(w2, w.transpose("d", "b", "c", ...)) + assert_identical(w2, w.transpose(..., "b", "c", "a")) + assert_identical(w, w2.transpose("a", "b", "c", "d")) + w3 = Variable(["b", "c", "d", "a"], np.einsum("abcd->bcda", x)) + assert_identical(w, w3.transpose("a", "b", "c", "d")) + + # test missing dimension, raise error + with pytest.raises(ValueError): + v.transpose(..., "not_a_dim") + + # test missing dimension, ignore error + actual = v.transpose(..., "not_a_dim", missing_dims="ignore") + expected_ell = v.transpose(...) + assert_identical(expected_ell, actual) + + # test missing dimension, raise warning + with pytest.warns(UserWarning): + v.transpose(..., "not_a_dim", missing_dims="warn") + assert_identical(expected_ell, actual) + + @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") + def test_transpose_0d(self): + for value in [ + 3.5, + ("a", 1), + np.datetime64("2000-01-01"), + np.timedelta64(1, "h"), + None, + object(), + ]: + variable = Variable([], value) + actual = variable.transpose() + assert_identical(actual, variable) + + def test_pandas_cateogrical_dtype(self): + data = pd.Categorical(np.arange(10, dtype="int64")) + v = self.cls("x", data) + print(v) # should not error + assert pd.api.types.is_extension_array_dtype(v.dtype) + + def test_pandas_cateogrical_no_chunk(self): + data = pd.Categorical(np.arange(10, dtype="int64")) + v = self.cls("x", data) + with pytest.raises( + ValueError, match=r".*was found to be a Pandas ExtensionArray.*" + ): + v.chunk((5,)) + + def test_squeeze(self): + v = Variable(["x", "y"], [[1]]) + assert_identical(Variable([], 1), v.squeeze()) + assert_identical(Variable(["y"], [1]), v.squeeze("x")) + assert_identical(Variable(["y"], [1]), v.squeeze(["x"])) + assert_identical(Variable(["x"], [1]), v.squeeze("y")) + assert_identical(Variable([], 1), v.squeeze(["x", "y"])) + + v = Variable(["x", "y"], [[1, 2]]) + assert_identical(Variable(["y"], [1, 2]), v.squeeze()) + assert_identical(Variable(["y"], [1, 2]), v.squeeze("x")) + with pytest.raises(ValueError, match=r"cannot select a dimension"): + v.squeeze("y") + + def test_get_axis_num(self): + v = Variable(["x", "y", "z"], np.random.randn(2, 3, 4)) + assert v.get_axis_num("x") == 0 + assert v.get_axis_num(["x"]) == (0,) + assert v.get_axis_num(["x", "y"]) == (0, 1) + assert v.get_axis_num(["z", "y", "x"]) == (2, 1, 0) + with pytest.raises(ValueError, match=r"not found in array dim"): + v.get_axis_num("foobar") + + def test_set_dims(self): + v = Variable(["x"], [0, 1]) + actual = v.set_dims(["x", "y"]) + expected = Variable(["x", "y"], [[0], [1]]) + assert_identical(actual, expected) + + actual = v.set_dims(["y", "x"]) + assert_identical(actual, expected.T) + + actual = v.set_dims({"x": 2, "y": 2}) + expected = Variable(["x", "y"], [[0, 0], [1, 1]]) + assert_identical(actual, expected) + + v = Variable(["foo"], [0, 1]) + actual = v.set_dims("foo") + expected = v + assert_identical(actual, expected) + + with pytest.raises(ValueError, match=r"must be a superset"): + v.set_dims(["z"]) + + def test_set_dims_object_dtype(self): + v = Variable([], ("a", 1)) + actual = v.set_dims(("x",), (3,)) + exp_values = np.empty((3,), dtype=object) + for i in range(3): + exp_values[i] = ("a", 1) + expected = Variable(["x"], exp_values) + assert_identical(actual, expected) + + def test_stack(self): + v = Variable(["x", "y"], [[0, 1], [2, 3]], {"foo": "bar"}) + actual = v.stack(z=("x", "y")) + expected = Variable("z", [0, 1, 2, 3], v.attrs) + assert_identical(actual, expected) + + actual = v.stack(z=("x",)) + expected = Variable(("y", "z"), v.data.T, v.attrs) + assert_identical(actual, expected) + + actual = v.stack(z=()) + assert_identical(actual, v) + + actual = v.stack(X=("x",), Y=("y",)).transpose("X", "Y") + expected = Variable(("X", "Y"), v.data, v.attrs) + assert_identical(actual, expected) + + def test_stack_errors(self): + v = Variable(["x", "y"], [[0, 1], [2, 3]], {"foo": "bar"}) + + with pytest.raises(ValueError, match=r"invalid existing dim"): + v.stack(z=("x1",)) + with pytest.raises(ValueError, match=r"cannot create a new dim"): + v.stack(x=("x",)) + + def test_unstack(self): + v = Variable("z", [0, 1, 2, 3], {"foo": "bar"}) + actual = v.unstack(z={"x": 2, "y": 2}) + expected = Variable(("x", "y"), [[0, 1], [2, 3]], v.attrs) + assert_identical(actual, expected) + + actual = v.unstack(z={"x": 4, "y": 1}) + expected = Variable(("x", "y"), [[0], [1], [2], [3]], v.attrs) + assert_identical(actual, expected) + + actual = v.unstack(z={"x": 4}) + expected = Variable("x", [0, 1, 2, 3], v.attrs) + assert_identical(actual, expected) + + def test_unstack_errors(self): + v = Variable("z", [0, 1, 2, 3]) + with pytest.raises(ValueError, match=r"invalid existing dim"): + v.unstack(foo={"x": 4}) + with pytest.raises(ValueError, match=r"cannot create a new dim"): + v.stack(z=("z",)) + with pytest.raises(ValueError, match=r"the product of the new dim"): + v.unstack(z={"x": 5}) + + def test_unstack_2d(self): + v = Variable(["x", "y"], [[0, 1], [2, 3]]) + actual = v.unstack(y={"z": 2}) + expected = Variable(["x", "z"], v.data) + assert_identical(actual, expected) + + actual = v.unstack(x={"z": 2}) + expected = Variable(["y", "z"], v.data.T) + assert_identical(actual, expected) + + def test_stack_unstack_consistency(self): + v = Variable(["x", "y"], [[0, 1], [2, 3]]) + actual = v.stack(z=("x", "y")).unstack(z={"x": 2, "y": 2}) + assert_identical(actual, v) + + @pytest.mark.filterwarnings("error::RuntimeWarning") + def test_unstack_without_missing(self): + v = Variable(["z"], [0, 1, 2, 3]) + expected = Variable(["x", "y"], [[0, 1], [2, 3]]) + + actual = v.unstack(z={"x": 2, "y": 2}) + + assert_identical(actual, expected) + + def test_broadcasting_math(self): + x = np.random.randn(2, 3) + v = Variable(["a", "b"], x) + # 1d to 2d broadcasting + assert_identical(v * v, Variable(["a", "b"], np.einsum("ab,ab->ab", x, x))) + assert_identical(v * v[0], Variable(["a", "b"], np.einsum("ab,b->ab", x, x[0]))) + assert_identical(v[0] * v, Variable(["b", "a"], np.einsum("b,ab->ba", x[0], x))) + assert_identical( + v[0] * v[:, 0], Variable(["b", "a"], np.einsum("b,a->ba", x[0], x[:, 0])) + ) + # higher dim broadcasting + y = np.random.randn(3, 4, 5) + w = Variable(["b", "c", "d"], y) + assert_identical( + v * w, Variable(["a", "b", "c", "d"], np.einsum("ab,bcd->abcd", x, y)) + ) + assert_identical( + w * v, Variable(["b", "c", "d", "a"], np.einsum("bcd,ab->bcda", y, x)) + ) + assert_identical( + v * w[0], Variable(["a", "b", "c", "d"], np.einsum("ab,cd->abcd", x, y[0])) + ) + + @pytest.mark.filterwarnings("ignore:Duplicate dimension names") + def test_broadcasting_failures(self): + a = Variable(["x"], np.arange(10)) + b = Variable(["x"], np.arange(5)) + c = Variable(["x", "x"], np.arange(100).reshape(10, 10)) + with pytest.raises(ValueError, match=r"mismatched lengths"): + a + b + with pytest.raises(ValueError, match=r"duplicate dimensions"): + a + c + + def test_inplace_math(self): + x = np.arange(5) + v = Variable(["x"], x) + v2 = v + v2 += 1 + assert v is v2 + # since we provided an ndarray for data, it is also modified in-place + assert source_ndarray(v.values) is x + assert_array_equal(v.values, np.arange(5) + 1) + + with pytest.raises(ValueError, match=r"dimensions cannot change"): + v += Variable("y", np.arange(5)) + + def test_inplace_math_error(self): + x = np.arange(5) + v = IndexVariable(["x"], x) + with pytest.raises( + TypeError, match=r"Values of an IndexVariable are immutable" + ): + v += 1 + + def test_reduce(self): + v = Variable(["x", "y"], self.d, {"ignored": "attributes"}) + assert_identical(v.reduce(np.std, "x"), Variable(["y"], self.d.std(axis=0))) + assert_identical(v.reduce(np.std, axis=0), v.reduce(np.std, dim="x")) + assert_identical( + v.reduce(np.std, ["y", "x"]), Variable([], self.d.std(axis=(0, 1))) + ) + assert_identical(v.reduce(np.std), Variable([], self.d.std())) + assert_identical( + v.reduce(np.mean, "x").reduce(np.std, "y"), + Variable([], self.d.mean(axis=0).std()), + ) + assert_allclose(v.mean("x"), v.reduce(np.mean, "x")) + + with pytest.raises(ValueError, match=r"cannot supply both"): + v.mean(dim="x", axis=0) + + @requires_bottleneck + @pytest.mark.parametrize("compute_backend", ["bottleneck"], indirect=True) + def test_reduce_use_bottleneck(self, monkeypatch, compute_backend): + def raise_if_called(*args, **kwargs): + raise RuntimeError("should not have been called") + + import bottleneck as bn + + monkeypatch.setattr(bn, "nanmin", raise_if_called) + + v = Variable("x", [0.0, np.nan, 1.0]) + with pytest.raises(RuntimeError, match="should not have been called"): + with set_options(use_bottleneck=True): + v.min() + + with set_options(use_bottleneck=False): + v.min() + + @pytest.mark.parametrize("skipna", [True, False, None]) + @pytest.mark.parametrize("q", [0.25, [0.50], [0.25, 0.75]]) + @pytest.mark.parametrize( + "axis, dim", zip([None, 0, [0], [0, 1]], [None, "x", ["x"], ["x", "y"]]) + ) + def test_quantile(self, q, axis, dim, skipna): + d = self.d.copy() + d[0, 0] = np.nan + + v = Variable(["x", "y"], d) + actual = v.quantile(q, dim=dim, skipna=skipna) + _percentile_func = np.nanpercentile if skipna in (True, None) else np.percentile + expected = _percentile_func(d, np.array(q) * 100, axis=axis) + np.testing.assert_allclose(actual.values, expected) + + @requires_dask + @pytest.mark.parametrize("q", [0.25, [0.50], [0.25, 0.75]]) + @pytest.mark.parametrize("axis, dim", [[1, "y"], [[1], ["y"]]]) + def test_quantile_dask(self, q, axis, dim): + v = Variable(["x", "y"], self.d).chunk({"x": 2}) + actual = v.quantile(q, dim=dim) + assert isinstance(actual.data, dask_array_type) + expected = np.nanpercentile(self.d, np.array(q) * 100, axis=axis) + np.testing.assert_allclose(actual.values, expected) + + @pytest.mark.parametrize("method", ["midpoint", "lower"]) + @pytest.mark.parametrize( + "use_dask", [pytest.param(True, marks=requires_dask), False] + ) + def test_quantile_method(self, method, use_dask) -> None: + v = Variable(["x", "y"], self.d) + if use_dask: + v = v.chunk({"x": 2}) + + q = np.array([0.25, 0.5, 0.75]) + actual = v.quantile(q, dim="y", method=method) + + expected = np.nanquantile(self.d, q, axis=1, method=method) + + if use_dask: + assert isinstance(actual.data, dask_array_type) + + np.testing.assert_allclose(actual.values, expected) + + @pytest.mark.parametrize("method", ["midpoint", "lower"]) + def test_quantile_interpolation_deprecation(self, method) -> None: + v = Variable(["x", "y"], self.d) + q = np.array([0.25, 0.5, 0.75]) + + with pytest.warns( + FutureWarning, + match="`interpolation` argument to quantile was renamed to `method`", + ): + actual = v.quantile(q, dim="y", interpolation=method) + + expected = v.quantile(q, dim="y", method=method) + + np.testing.assert_allclose(actual.values, expected.values) + + with warnings.catch_warnings(record=True): + with pytest.raises(TypeError, match="interpolation and method keywords"): + v.quantile(q, dim="y", interpolation=method, method=method) + + @requires_dask + def test_quantile_chunked_dim_error(self): + v = Variable(["x", "y"], self.d).chunk({"x": 2}) + + # this checks for ValueError in dask.array.apply_gufunc + with pytest.raises(ValueError, match=r"consists of multiple chunks"): + v.quantile(0.5, dim="x") + + @pytest.mark.parametrize("compute_backend", ["numbagg", None], indirect=True) + @pytest.mark.parametrize("q", [-0.1, 1.1, [2], [0.25, 2]]) + def test_quantile_out_of_bounds(self, q, compute_backend): + v = Variable(["x", "y"], self.d) + + # escape special characters + with pytest.raises( + ValueError, + match=r"(Q|q)uantiles must be in the range \[0, 1\]", + ): + v.quantile(q, dim="x") + + @requires_dask + @requires_bottleneck + def test_rank_dask(self): + # Instead of a single test here, we could parameterize the other tests for both + # arrays. But this is sufficient. + v = Variable( + ["x", "y"], [[30.0, 1.0, np.nan, 20.0, 4.0], [30.0, 1.0, np.nan, 20.0, 4.0]] + ).chunk(x=1) + expected = Variable( + ["x", "y"], [[4.0, 1.0, np.nan, 3.0, 2.0], [4.0, 1.0, np.nan, 3.0, 2.0]] + ) + assert_equal(v.rank("y").compute(), expected) + + with pytest.raises( + ValueError, match=r" with dask='parallelized' consists of multiple chunks" + ): + v.rank("x") + + def test_rank_use_bottleneck(self): + v = Variable(["x"], [3.0, 1.0, np.nan, 2.0, 4.0]) + with set_options(use_bottleneck=False): + with pytest.raises(RuntimeError): + v.rank("x") + + @requires_bottleneck + def test_rank(self): + import bottleneck as bn + + # floats + v = Variable(["x", "y"], [[3, 4, np.nan, 1]]) + expect_0 = bn.nanrankdata(v.data, axis=0) + expect_1 = bn.nanrankdata(v.data, axis=1) + np.testing.assert_allclose(v.rank("x").values, expect_0) + np.testing.assert_allclose(v.rank("y").values, expect_1) + # int + v = Variable(["x"], [3, 2, 1]) + expect = bn.rankdata(v.data, axis=0) + np.testing.assert_allclose(v.rank("x").values, expect) + # str + v = Variable(["x"], ["c", "b", "a"]) + expect = bn.rankdata(v.data, axis=0) + np.testing.assert_allclose(v.rank("x").values, expect) + # pct + v = Variable(["x"], [3.0, 1.0, np.nan, 2.0, 4.0]) + v_expect = Variable(["x"], [0.75, 0.25, np.nan, 0.5, 1.0]) + assert_equal(v.rank("x", pct=True), v_expect) + # invalid dim + with pytest.raises(ValueError): + # apply_ufunc error message isn't great here — `ValueError: tuple.index(x): x not in tuple` + v.rank("y") + + def test_big_endian_reduce(self): + # regression test for GH489 + data = np.ones(5, dtype=">f4") + v = Variable(["x"], data) + expected = Variable([], 5) + assert_identical(expected, v.sum()) + + def test_reduce_funcs(self): + v = Variable("x", np.array([1, np.nan, 2, 3])) + assert_identical(v.mean(), Variable([], 2)) + assert_identical(v.mean(skipna=True), Variable([], 2)) + assert_identical(v.mean(skipna=False), Variable([], np.nan)) + assert_identical(np.mean(v), Variable([], 2)) + + assert_identical(v.prod(), Variable([], 6)) + assert_identical(v.cumsum(axis=0), Variable("x", np.array([1, 1, 3, 6]))) + assert_identical(v.cumprod(axis=0), Variable("x", np.array([1, 1, 2, 6]))) + assert_identical(v.var(), Variable([], 2.0 / 3)) + assert_identical(v.median(), Variable([], 2)) + + v = Variable("x", [True, False, False]) + assert_identical(v.any(), Variable([], True)) + assert_identical(v.all(dim="x"), Variable([], False)) + + v = Variable("t", pd.date_range("2000-01-01", periods=3)) + assert v.argmax(skipna=True, dim="t") == 2 + + assert_identical(v.max(), Variable([], pd.Timestamp("2000-01-03"))) + + def test_reduce_keepdims(self): + v = Variable(["x", "y"], self.d) + + assert_identical( + v.mean(keepdims=True), Variable(v.dims, np.mean(self.d, keepdims=True)) + ) + assert_identical( + v.mean(dim="x", keepdims=True), + Variable(v.dims, np.mean(self.d, axis=0, keepdims=True)), + ) + assert_identical( + v.mean(dim="y", keepdims=True), + Variable(v.dims, np.mean(self.d, axis=1, keepdims=True)), + ) + assert_identical( + v.mean(dim=["y", "x"], keepdims=True), + Variable(v.dims, np.mean(self.d, axis=(1, 0), keepdims=True)), + ) + + v = Variable([], 1.0) + assert_identical( + v.mean(keepdims=True), Variable([], np.mean(v.data, keepdims=True)) + ) + + @requires_dask + def test_reduce_keepdims_dask(self): + import dask.array + + v = Variable(["x", "y"], self.d).chunk() + + actual = v.mean(keepdims=True) + assert isinstance(actual.data, dask.array.Array) + + expected = Variable(v.dims, np.mean(self.d, keepdims=True)) + assert_identical(actual, expected) + + actual = v.mean(dim="y", keepdims=True) + assert isinstance(actual.data, dask.array.Array) + + expected = Variable(v.dims, np.mean(self.d, axis=1, keepdims=True)) + assert_identical(actual, expected) + + def test_reduce_keep_attrs(self): + _attrs = {"units": "test", "long_name": "testing"} + + v = Variable(["x", "y"], self.d, _attrs) + + # Test dropped attrs + vm = v.mean() + assert len(vm.attrs) == 0 + assert vm.attrs == {} + + # Test kept attrs + vm = v.mean(keep_attrs=True) + assert len(vm.attrs) == len(_attrs) + assert vm.attrs == _attrs + + def test_binary_ops_keep_attrs(self): + _attrs = {"units": "test", "long_name": "testing"} + a = Variable(["x", "y"], np.random.randn(3, 3), _attrs) + b = Variable(["x", "y"], np.random.randn(3, 3), _attrs) + # Test dropped attrs + d = a - b # just one operation + assert d.attrs == {} + # Test kept attrs + with set_options(keep_attrs=True): + d = a - b + assert d.attrs == _attrs + + def test_count(self): + expected = Variable([], 3) + actual = Variable(["x"], [1, 2, 3, np.nan]).count() + assert_identical(expected, actual) + + v = Variable(["x"], np.array(["1", "2", "3", np.nan], dtype=object)) + actual = v.count() + assert_identical(expected, actual) + + actual = Variable(["x"], [True, False, True]).count() + assert_identical(expected, actual) + assert actual.dtype == int + + expected = Variable(["x"], [2, 3]) + actual = Variable(["x", "y"], [[1, 0, np.nan], [1, 1, 1]]).count("y") + assert_identical(expected, actual) + + def test_setitem(self): + v = Variable(["x", "y"], [[0, 3, 2], [3, 4, 5]]) + v[0, 1] = 1 + assert v[0, 1] == 1 + + v = Variable(["x", "y"], [[0, 3, 2], [3, 4, 5]]) + v[dict(x=[0, 1])] = 1 + assert_array_equal(v[[0, 1]], np.ones_like(v[[0, 1]])) + + # boolean indexing + v = Variable(["x", "y"], [[0, 3, 2], [3, 4, 5]]) + v[dict(x=[True, False])] = 1 + + assert_array_equal(v[0], np.ones_like(v[0])) + v = Variable(["x", "y"], [[0, 3, 2], [3, 4, 5]]) + v[dict(x=[True, False], y=[False, True, False])] = 1 + assert v[0, 1] == 1 + + def test_setitem_fancy(self): + # assignment which should work as np.ndarray does + def assert_assigned_2d(array, key_x, key_y, values): + expected = array.copy() + expected[key_x, key_y] = values + v = Variable(["x", "y"], array) + v[dict(x=key_x, y=key_y)] = values + assert_array_equal(expected, v) + + # 1d vectorized indexing + assert_assigned_2d( + np.random.randn(4, 3), + key_x=Variable(["a"], [0, 1]), + key_y=Variable(["a"], [0, 1]), + values=0, + ) + assert_assigned_2d( + np.random.randn(4, 3), + key_x=Variable(["a"], [0, 1]), + key_y=Variable(["a"], [0, 1]), + values=Variable((), 0), + ) + assert_assigned_2d( + np.random.randn(4, 3), + key_x=Variable(["a"], [0, 1]), + key_y=Variable(["a"], [0, 1]), + values=Variable(("a"), [3, 2]), + ) + assert_assigned_2d( + np.random.randn(4, 3), + key_x=slice(None), + key_y=Variable(["a"], [0, 1]), + values=Variable(("a"), [3, 2]), + ) + + # 2d-vectorized indexing + assert_assigned_2d( + np.random.randn(4, 3), + key_x=Variable(["a", "b"], [[0, 1]]), + key_y=Variable(["a", "b"], [[1, 0]]), + values=0, + ) + assert_assigned_2d( + np.random.randn(4, 3), + key_x=Variable(["a", "b"], [[0, 1]]), + key_y=Variable(["a", "b"], [[1, 0]]), + values=[0], + ) + assert_assigned_2d( + np.random.randn(5, 4), + key_x=Variable(["a", "b"], [[0, 1], [2, 3]]), + key_y=Variable(["a", "b"], [[1, 0], [3, 3]]), + values=[2, 3], + ) + + # vindex with slice + v = Variable(["x", "y", "z"], np.ones((4, 3, 2))) + ind = Variable(["a"], [0, 1]) + v[dict(x=ind, z=ind)] = 0 + expected = Variable(["x", "y", "z"], np.ones((4, 3, 2))) + expected[0, :, 0] = 0 + expected[1, :, 1] = 0 + assert_identical(expected, v) + + # dimension broadcast + v = Variable(["x", "y"], np.ones((3, 2))) + ind = Variable(["a", "b"], [[0, 1]]) + v[ind, :] = 0 + expected = Variable(["x", "y"], [[0, 0], [0, 0], [1, 1]]) + assert_identical(expected, v) + + with pytest.raises(ValueError, match=r"shape mismatch"): + v[ind, ind] = np.zeros((1, 2, 1)) + + v = Variable(["x", "y"], [[0, 3, 2], [3, 4, 5]]) + ind = Variable(["a"], [0, 1]) + v[dict(x=ind)] = Variable(["a", "y"], np.ones((2, 3), dtype=int) * 10) + assert_array_equal(v[0], np.ones_like(v[0]) * 10) + assert_array_equal(v[1], np.ones_like(v[1]) * 10) + assert v.dims == ("x", "y") # dimension should not change + + # increment + v = Variable(["x", "y"], np.arange(6).reshape(3, 2)) + ind = Variable(["a"], [0, 1]) + v[dict(x=ind)] += 1 + expected = Variable(["x", "y"], [[1, 2], [3, 4], [4, 5]]) + assert_identical(v, expected) + + ind = Variable(["a"], [0, 0]) + v[dict(x=ind)] += 1 + expected = Variable(["x", "y"], [[2, 3], [3, 4], [4, 5]]) + assert_identical(v, expected) + + def test_coarsen(self): + v = self.cls(["x"], [0, 1, 2, 3, 4]) + actual = v.coarsen({"x": 2}, boundary="pad", func="mean") + expected = self.cls(["x"], [0.5, 2.5, 4]) + assert_identical(actual, expected) + + actual = v.coarsen({"x": 2}, func="mean", boundary="pad", side="right") + expected = self.cls(["x"], [0, 1.5, 3.5]) + assert_identical(actual, expected) + + actual = v.coarsen({"x": 2}, func=np.mean, side="right", boundary="trim") + expected = self.cls(["x"], [1.5, 3.5]) + assert_identical(actual, expected) + + # working test + v = self.cls(["x", "y", "z"], np.arange(40 * 30 * 2).reshape(40, 30, 2)) + for windows, func, side, boundary in [ + ({"x": 2}, np.mean, "left", "trim"), + ({"x": 2}, np.median, {"x": "left"}, "pad"), + ({"x": 2, "y": 3}, np.max, "left", {"x": "pad", "y": "trim"}), + ]: + v.coarsen(windows, func, boundary, side) + + def test_coarsen_2d(self): + # 2d-mean should be the same with the successive 1d-mean + v = self.cls(["x", "y"], np.arange(6 * 12).reshape(6, 12)) + actual = v.coarsen({"x": 3, "y": 4}, func="mean") + expected = v.coarsen({"x": 3}, func="mean").coarsen({"y": 4}, func="mean") + assert_equal(actual, expected) + + v = self.cls(["x", "y"], np.arange(7 * 12).reshape(7, 12)) + actual = v.coarsen({"x": 3, "y": 4}, func="mean", boundary="trim") + expected = v.coarsen({"x": 3}, func="mean", boundary="trim").coarsen( + {"y": 4}, func="mean", boundary="trim" + ) + assert_equal(actual, expected) + + # if there is nan, the two should be different + v = self.cls(["x", "y"], 1.0 * np.arange(6 * 12).reshape(6, 12)) + v[2, 4] = np.nan + v[3, 5] = np.nan + actual = v.coarsen({"x": 3, "y": 4}, func="mean", boundary="trim") + expected = ( + v.coarsen({"x": 3}, func="sum", boundary="trim").coarsen( + {"y": 4}, func="sum", boundary="trim" + ) + / 12 + ) + assert not actual.equals(expected) + # adjusting the nan count + expected[0, 1] *= 12 / 11 + expected[1, 1] *= 12 / 11 + assert_allclose(actual, expected) + + v = self.cls(("x", "y"), np.arange(4 * 4, dtype=np.float32).reshape(4, 4)) + actual = v.coarsen(dict(x=2, y=2), func="count", boundary="exact") + expected = self.cls(("x", "y"), 4 * np.ones((2, 2))) + assert_equal(actual, expected) + + v[0, 0] = np.nan + v[-1, -1] = np.nan + expected[0, 0] = 3 + expected[-1, -1] = 3 + actual = v.coarsen(dict(x=2, y=2), func="count", boundary="exact") + assert_equal(actual, expected) + + actual = v.coarsen(dict(x=2, y=2), func="sum", boundary="exact", skipna=False) + expected = self.cls(("x", "y"), [[np.nan, 18], [42, np.nan]]) + assert_equal(actual, expected) + + actual = v.coarsen(dict(x=2, y=2), func="sum", boundary="exact", skipna=True) + expected = self.cls(("x", "y"), [[10, 18], [42, 35]]) + assert_equal(actual, expected) + + # perhaps @pytest.mark.parametrize("operation", [f for f in duck_array_ops]) + def test_coarsen_keep_attrs(self, operation="mean"): + _attrs = {"units": "test", "long_name": "testing"} + + test_func = getattr(duck_array_ops, operation, None) + + # Test dropped attrs + with set_options(keep_attrs=False): + new = Variable(["coord"], np.linspace(1, 10, 100), attrs=_attrs).coarsen( + windows={"coord": 1}, func=test_func, boundary="exact", side="left" + ) + assert new.attrs == {} + + # Test kept attrs + with set_options(keep_attrs=True): + new = Variable(["coord"], np.linspace(1, 10, 100), attrs=_attrs).coarsen( + windows={"coord": 1}, + func=test_func, + boundary="exact", + side="left", + ) + assert new.attrs == _attrs + + +@requires_dask +class TestVariableWithDask(VariableSubclassobjects): + def cls(self, *args, **kwargs) -> Variable: + return Variable(*args, **kwargs).chunk() + + def test_chunk(self): + unblocked = Variable(["dim_0", "dim_1"], np.ones((3, 4))) + assert unblocked.chunks is None + + blocked = unblocked.chunk() + assert blocked.chunks == ((3,), (4,)) + first_dask_name = blocked.data.name + + blocked = unblocked.chunk(chunks=((2, 1), (2, 2))) + assert blocked.chunks == ((2, 1), (2, 2)) + assert blocked.data.name != first_dask_name + + blocked = unblocked.chunk(chunks=(3, 3)) + assert blocked.chunks == ((3,), (3, 1)) + assert blocked.data.name != first_dask_name + + # name doesn't change when rechunking by same amount + # this fails if ReprObject doesn't have __dask_tokenize__ defined + assert unblocked.chunk(2).data.name == unblocked.chunk(2).data.name + + assert blocked.load().chunks is None + + # Check that kwargs are passed + import dask.array as da + + blocked = unblocked.chunk(name="testname_") + assert isinstance(blocked.data, da.Array) + assert "testname_" in blocked.data.name + + # test kwargs form of chunks + blocked = unblocked.chunk(dim_0=3, dim_1=3) + assert blocked.chunks == ((3,), (3, 1)) + assert blocked.data.name != first_dask_name + + @pytest.mark.xfail + def test_0d_object_array_with_list(self): + super().test_0d_object_array_with_list() + + @pytest.mark.xfail + def test_array_interface(self): + # dask array does not have `argsort` + super().test_array_interface() + + @pytest.mark.xfail + def test_copy_index(self): + super().test_copy_index() + + @pytest.mark.xfail + @pytest.mark.filterwarnings("ignore:elementwise comparison failed.*:FutureWarning") + def test_eq_all_dtypes(self): + super().test_eq_all_dtypes() + + def test_getitem_fancy(self): + super().test_getitem_fancy() + + def test_getitem_1d_fancy(self): + super().test_getitem_1d_fancy() + + def test_getitem_with_mask_nd_indexer(self): + import dask.array as da + + v = Variable(["x"], da.arange(3, chunks=3)) + indexer = Variable(("x", "y"), [[0, -1], [-1, 2]]) + assert_identical( + v._getitem_with_mask(indexer, fill_value=-1), + self.cls(("x", "y"), [[0, -1], [-1, 2]]), + ) + + @pytest.mark.parametrize("dim", ["x", "y"]) + @pytest.mark.parametrize("window", [3, 8, 11]) + @pytest.mark.parametrize("center", [True, False]) + def test_dask_rolling(self, dim, window, center): + import dask + import dask.array as da + + dask.config.set(scheduler="single-threaded") + + x = Variable(("x", "y"), np.array(np.random.randn(100, 40), dtype=float)) + dx = Variable(("x", "y"), da.from_array(x, chunks=[(6, 30, 30, 20, 14), 8])) + + expected = x.rolling_window( + dim, window, "window", center=center, fill_value=np.nan + ) + with raise_if_dask_computes(): + actual = dx.rolling_window( + dim, window, "window", center=center, fill_value=np.nan + ) + assert isinstance(actual.data, da.Array) + assert actual.shape == expected.shape + assert_equal(actual, expected) + + def test_multiindex(self): + super().test_multiindex() + + @pytest.mark.parametrize( + "mode", + [ + "mean", + pytest.param( + "median", + marks=pytest.mark.xfail(reason="median is not implemented by Dask"), + ), + pytest.param( + "reflect", marks=pytest.mark.xfail(reason="dask.array.pad bug") + ), + "edge", + "linear_ramp", + "maximum", + "minimum", + "symmetric", + "wrap", + ], + ) + @pytest.mark.parametrize("xr_arg, np_arg", _PAD_XR_NP_ARGS) + @pytest.mark.filterwarnings( + r"ignore:dask.array.pad.+? converts integers to floats." + ) + def test_pad(self, mode, xr_arg, np_arg): + super().test_pad(mode, xr_arg, np_arg) + + def test_pandas_cateogrical_dtype(self): + data = pd.Categorical(np.arange(10, dtype="int64")) + with pytest.raises(ValueError, match="was found to be a Pandas ExtensionArray"): + self.cls("x", data) + + +@requires_sparse +class TestVariableWithSparse: + # TODO inherit VariableSubclassobjects to cover more tests + + def test_as_sparse(self): + data = np.arange(12).reshape(3, 4) + var = Variable(("x", "y"), data)._as_sparse(fill_value=-1) + actual = var._to_dense() + assert_identical(var, actual) + + +class TestIndexVariable(VariableSubclassobjects): + def cls(self, *args, **kwargs) -> IndexVariable: + return IndexVariable(*args, **kwargs) + + def test_init(self): + with pytest.raises(ValueError, match=r"must be 1-dimensional"): + IndexVariable((), 0) + + def test_to_index(self): + data = 0.5 * np.arange(10) + v = IndexVariable(["time"], data, {"foo": "bar"}) + assert pd.Index(data, name="time").identical(v.to_index()) + + def test_to_index_multiindex_level(self): + midx = pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")) + ds = Dataset(coords={"x": midx}) + assert ds.one.variable.to_index().equals(midx.get_level_values("one")) + + def test_multiindex_default_level_names(self): + midx = pd.MultiIndex.from_product([["a", "b"], [1, 2]]) + v = IndexVariable(["x"], midx, {"foo": "bar"}) + assert v.to_index().names == ("x_level_0", "x_level_1") + + def test_data(self): + x = IndexVariable("x", np.arange(3.0)) + assert isinstance(x._data, PandasIndexingAdapter) + assert isinstance(x.data, np.ndarray) + assert float == x.dtype + assert_array_equal(np.arange(3), x) + assert float == x.values.dtype + with pytest.raises(TypeError, match=r"cannot be modified"): + x[:] = 0 + + def test_name(self): + coord = IndexVariable("x", [10.0]) + assert coord.name == "x" + + with pytest.raises(AttributeError): + coord.name = "y" + + def test_level_names(self): + midx = pd.MultiIndex.from_product( + [["a", "b"], [1, 2]], names=["level_1", "level_2"] + ) + x = IndexVariable("x", midx) + assert x.level_names == midx.names + + assert IndexVariable("y", [10.0]).level_names is None + + def test_get_level_variable(self): + midx = pd.MultiIndex.from_product( + [["a", "b"], [1, 2]], names=["level_1", "level_2"] + ) + x = IndexVariable("x", midx) + level_1 = IndexVariable("x", midx.get_level_values("level_1")) + assert_identical(x.get_level_variable("level_1"), level_1) + + with pytest.raises(ValueError, match=r"has no MultiIndex"): + IndexVariable("y", [10.0]).get_level_variable("level") + + def test_concat_periods(self): + periods = pd.period_range("2000-01-01", periods=10) + coords = [IndexVariable("t", periods[:5]), IndexVariable("t", periods[5:])] + expected = IndexVariable("t", periods) + actual = IndexVariable.concat(coords, dim="t") + assert_identical(actual, expected) + assert isinstance(actual.to_index(), pd.PeriodIndex) + + positions = [list(range(5)), list(range(5, 10))] + actual = IndexVariable.concat(coords, dim="t", positions=positions) + assert_identical(actual, expected) + assert isinstance(actual.to_index(), pd.PeriodIndex) + + def test_concat_multiindex(self): + idx = pd.MultiIndex.from_product([[0, 1, 2], ["a", "b"]]) + coords = [IndexVariable("x", idx[:2]), IndexVariable("x", idx[2:])] + expected = IndexVariable("x", idx) + actual = IndexVariable.concat(coords, dim="x") + assert_identical(actual, expected) + assert isinstance(actual.to_index(), pd.MultiIndex) + + @pytest.mark.parametrize("dtype", [str, bytes]) + def test_concat_str_dtype(self, dtype): + a = IndexVariable("x", np.array(["a"], dtype=dtype)) + b = IndexVariable("x", np.array(["b"], dtype=dtype)) + expected = IndexVariable("x", np.array(["a", "b"], dtype=dtype)) + + actual = IndexVariable.concat([a, b]) + assert actual.identical(expected) + assert np.issubdtype(actual.dtype, dtype) + + def test_datetime64(self): + # GH:1932 Make sure indexing keeps precision + t = np.array([1518418799999986560, 1518418799999996560], dtype="datetime64[ns]") + v = IndexVariable("t", t) + assert v[0].data == t[0] + + # These tests make use of multi-dimensional variables, which are not valid + # IndexVariable objects: + @pytest.mark.skip + def test_getitem_error(self): + super().test_getitem_error() + + @pytest.mark.skip + def test_getitem_advanced(self): + super().test_getitem_advanced() + + @pytest.mark.skip + def test_getitem_fancy(self): + super().test_getitem_fancy() + + @pytest.mark.skip + def test_getitem_uint(self): + super().test_getitem_fancy() + + @pytest.mark.skip + @pytest.mark.parametrize( + "mode", + [ + "mean", + "median", + "reflect", + "edge", + "linear_ramp", + "maximum", + "minimum", + "symmetric", + "wrap", + ], + ) + @pytest.mark.parametrize("xr_arg, np_arg", _PAD_XR_NP_ARGS) + def test_pad(self, mode, xr_arg, np_arg): + super().test_pad(mode, xr_arg, np_arg) + + @pytest.mark.skip + def test_pad_constant_values(self, xr_arg, np_arg): + super().test_pad_constant_values(xr_arg, np_arg) + + @pytest.mark.skip + def test_rolling_window(self): + super().test_rolling_window() + + @pytest.mark.skip + def test_rolling_1d(self): + super().test_rolling_1d() + + @pytest.mark.skip + def test_nd_rolling(self): + super().test_nd_rolling() + + @pytest.mark.skip + def test_rolling_window_errors(self): + super().test_rolling_window_errors() + + @pytest.mark.skip + def test_coarsen_2d(self): + super().test_coarsen_2d() + + def test_to_index_variable_copy(self) -> None: + # to_index_variable should return a copy + # https://github.com/pydata/xarray/issues/6931 + a = IndexVariable("x", ["a"]) + b = a.to_index_variable() + assert a is not b + b.dims = ("y",) + assert a.dims == ("x",) + + +class TestAsCompatibleData(Generic[T_DuckArray]): + def test_unchanged_types(self): + types = (np.asarray, PandasIndexingAdapter, LazilyIndexedArray) + for t in types: + for data in [ + np.arange(3), + pd.date_range("2000-01-01", periods=3), + pd.date_range("2000-01-01", periods=3).values, + ]: + x = t(data) + assert source_ndarray(x) is source_ndarray(as_compatible_data(x)) + + def test_converted_types(self): + for input_array in [[[0, 1, 2]], pd.DataFrame([[0, 1, 2]])]: + actual = as_compatible_data(input_array) + assert_array_equal(np.asarray(input_array), actual) + assert np.ndarray == type(actual) + assert np.asarray(input_array).dtype == actual.dtype + + def test_masked_array(self): + original = np.ma.MaskedArray(np.arange(5)) + expected = np.arange(5) + actual = as_compatible_data(original) + assert_array_equal(expected, actual) + assert np.dtype(int) == actual.dtype + + original = np.ma.MaskedArray(np.arange(5), mask=4 * [False] + [True]) + expected = np.arange(5.0) + expected[-1] = np.nan + actual = as_compatible_data(original) + assert_array_equal(expected, actual) + assert np.dtype(float) == actual.dtype + + original = np.ma.MaskedArray([1.0, 2.0], mask=[True, False]) + original.flags.writeable = False + expected = [np.nan, 2.0] + actual = as_compatible_data(original) + assert_array_equal(expected, actual) + assert np.dtype(float) == actual.dtype + + # GH2377 + actual = Variable(dims=tuple(), data=np.ma.masked) + expected = Variable(dims=tuple(), data=np.nan) + assert_array_equal(expected, actual) + assert actual.dtype == expected.dtype + + @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") + def test_datetime(self): + expected = np.datetime64("2000-01-01") + actual = as_compatible_data(expected) + assert expected == actual + assert np.ndarray == type(actual) + assert np.dtype("datetime64[ns]") == actual.dtype + + expected = np.array([np.datetime64("2000-01-01")]) + actual = as_compatible_data(expected) + assert np.asarray(expected) == actual + assert np.ndarray == type(actual) + assert np.dtype("datetime64[ns]") == actual.dtype + + expected = np.array([np.datetime64("2000-01-01", "ns")]) + actual = as_compatible_data(expected) + assert np.asarray(expected) == actual + assert np.ndarray == type(actual) + assert np.dtype("datetime64[ns]") == actual.dtype + assert expected is source_ndarray(np.asarray(actual)) + + expected = np.datetime64("2000-01-01", "ns") + actual = as_compatible_data(datetime(2000, 1, 1)) + assert np.asarray(expected) == actual + assert np.ndarray == type(actual) + assert np.dtype("datetime64[ns]") == actual.dtype + + def test_tz_datetime(self) -> None: + tz = pytz.timezone("America/New_York") + times_ns = pd.date_range("2000", periods=1, tz=tz) + + times_s = times_ns.astype(pd.DatetimeTZDtype("s", tz)) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + actual: T_DuckArray = as_compatible_data(times_s) + assert actual.array == times_s + assert actual.array.dtype == pd.DatetimeTZDtype("ns", tz) + + series = pd.Series(times_s) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + actual2: T_DuckArray = as_compatible_data(series) + + np.testing.assert_array_equal(actual2, series.values) + assert actual2.dtype == np.dtype("datetime64[ns]") + + def test_full_like(self) -> None: + # For more thorough tests, see test_variable.py + orig = Variable( + dims=("x", "y"), data=[[1.5, 2.0], [3.1, 4.3]], attrs={"foo": "bar"} + ) + + expect = orig.copy(deep=True) + expect.values = [[2.0, 2.0], [2.0, 2.0]] + assert_identical(expect, full_like(orig, 2)) + + # override dtype + expect.values = [[True, True], [True, True]] + assert expect.dtype == bool + assert_identical(expect, full_like(orig, True, dtype=bool)) + + # raise error on non-scalar fill_value + with pytest.raises(ValueError, match=r"must be scalar"): + full_like(orig, [1.0, 2.0]) + + with pytest.raises(ValueError, match="'dtype' cannot be dict-like"): + full_like(orig, True, dtype={"x": bool}) + + @requires_dask + def test_full_like_dask(self) -> None: + orig = Variable( + dims=("x", "y"), data=[[1.5, 2.0], [3.1, 4.3]], attrs={"foo": "bar"} + ).chunk(dict(x=(1, 1), y=(2,))) + + def check(actual, expect_dtype, expect_values): + assert actual.dtype == expect_dtype + assert actual.shape == orig.shape + assert actual.dims == orig.dims + assert actual.attrs == orig.attrs + assert actual.chunks == orig.chunks + assert_array_equal(actual.values, expect_values) + + check(full_like(orig, 2), orig.dtype, np.full_like(orig.values, 2)) + # override dtype + check( + full_like(orig, True, dtype=bool), + bool, + np.full_like(orig.values, True, dtype=bool), + ) + + # Check that there's no array stored inside dask + # (e.g. we didn't create a numpy array and then we chunked it!) + dsk = full_like(orig, 1).data.dask + for v in dsk.values(): + if isinstance(v, tuple): + for vi in v: + assert not isinstance(vi, np.ndarray) + else: + assert not isinstance(v, np.ndarray) + + def test_zeros_like(self) -> None: + orig = Variable( + dims=("x", "y"), data=[[1.5, 2.0], [3.1, 4.3]], attrs={"foo": "bar"} + ) + assert_identical(zeros_like(orig), full_like(orig, 0)) + assert_identical(zeros_like(orig, dtype=int), full_like(orig, 0, dtype=int)) + + def test_ones_like(self) -> None: + orig = Variable( + dims=("x", "y"), data=[[1.5, 2.0], [3.1, 4.3]], attrs={"foo": "bar"} + ) + assert_identical(ones_like(orig), full_like(orig, 1)) + assert_identical(ones_like(orig, dtype=int), full_like(orig, 1, dtype=int)) + + def test_unsupported_type(self): + # Non indexable type + class CustomArray(NDArrayMixin): + def __init__(self, array): + self.array = array + + class CustomIndexable(CustomArray, indexing.ExplicitlyIndexed): + pass + + # Type with data stored in values attribute + class CustomWithValuesAttr: + def __init__(self, array): + self.values = array + + array = CustomArray(np.arange(3)) + orig = Variable(dims=("x"), data=array, attrs={"foo": "bar"}) + assert isinstance(orig._data, np.ndarray) # should not be CustomArray + + array = CustomIndexable(np.arange(3)) + orig = Variable(dims=("x"), data=array, attrs={"foo": "bar"}) + assert isinstance(orig._data, CustomIndexable) + + array = CustomWithValuesAttr(np.arange(3)) + orig = Variable(dims=(), data=array) + assert isinstance(orig._data.item(), CustomWithValuesAttr) + + +def test_raise_no_warning_for_nan_in_binary_ops(): + with assert_no_warnings(): + Variable("x", [1, 2, np.nan]) > 0 + + +class TestBackendIndexing: + """Make sure all the array wrappers can be indexed.""" + + @pytest.fixture(autouse=True) + def setUp(self): + self.d = np.random.random((10, 3)).astype(np.float64) + + def check_orthogonal_indexing(self, v): + assert np.allclose(v.isel(x=[8, 3], y=[2, 1]), self.d[[8, 3]][:, [2, 1]]) + + def check_vectorized_indexing(self, v): + ind_x = Variable("z", [0, 2]) + ind_y = Variable("z", [2, 1]) + assert np.allclose(v.isel(x=ind_x, y=ind_y), self.d[ind_x, ind_y]) + + def test_NumpyIndexingAdapter(self): + v = Variable(dims=("x", "y"), data=NumpyIndexingAdapter(self.d)) + self.check_orthogonal_indexing(v) + self.check_vectorized_indexing(v) + # could not doubly wrapping + with pytest.raises(TypeError, match=r"NumpyIndexingAdapter only wraps "): + v = Variable( + dims=("x", "y"), data=NumpyIndexingAdapter(NumpyIndexingAdapter(self.d)) + ) + + def test_LazilyIndexedArray(self): + v = Variable(dims=("x", "y"), data=LazilyIndexedArray(self.d)) + self.check_orthogonal_indexing(v) + self.check_vectorized_indexing(v) + # doubly wrapping + v = Variable( + dims=("x", "y"), + data=LazilyIndexedArray(LazilyIndexedArray(self.d)), + ) + self.check_orthogonal_indexing(v) + # hierarchical wrapping + v = Variable( + dims=("x", "y"), data=LazilyIndexedArray(NumpyIndexingAdapter(self.d)) + ) + self.check_orthogonal_indexing(v) + + def test_CopyOnWriteArray(self): + v = Variable(dims=("x", "y"), data=CopyOnWriteArray(self.d)) + self.check_orthogonal_indexing(v) + self.check_vectorized_indexing(v) + # doubly wrapping + v = Variable(dims=("x", "y"), data=CopyOnWriteArray(LazilyIndexedArray(self.d))) + self.check_orthogonal_indexing(v) + self.check_vectorized_indexing(v) + + def test_MemoryCachedArray(self): + v = Variable(dims=("x", "y"), data=MemoryCachedArray(self.d)) + self.check_orthogonal_indexing(v) + self.check_vectorized_indexing(v) + # doubly wrapping + v = Variable(dims=("x", "y"), data=CopyOnWriteArray(MemoryCachedArray(self.d))) + self.check_orthogonal_indexing(v) + self.check_vectorized_indexing(v) + + @requires_dask + def test_DaskIndexingAdapter(self): + import dask.array as da + + da = da.asarray(self.d) + v = Variable(dims=("x", "y"), data=DaskIndexingAdapter(da)) + self.check_orthogonal_indexing(v) + self.check_vectorized_indexing(v) + # doubly wrapping + v = Variable(dims=("x", "y"), data=CopyOnWriteArray(DaskIndexingAdapter(da))) + self.check_orthogonal_indexing(v) + self.check_vectorized_indexing(v) + + +def test_clip(var): + # Copied from test_dataarray (would there be a way to combine the tests?) + result = var.clip(min=0.5) + assert result.min(...) >= 0.5 + + result = var.clip(max=0.5) + assert result.max(...) <= 0.5 + + result = var.clip(min=0.25, max=0.75) + assert result.min(...) >= 0.25 + assert result.max(...) <= 0.75 + + result = var.clip(min=var.mean("x"), max=var.mean("z")) + assert result.dims == var.dims + assert_array_equal( + result.data, + np.clip( + var.data, + var.mean("x").data[np.newaxis, :, :], + var.mean("z").data[:, :, np.newaxis], + ), + ) + + +@pytest.mark.parametrize("Var", [Variable, IndexVariable]) +class TestNumpyCoercion: + def test_from_numpy(self, Var): + v = Var("x", [1, 2, 3]) + + assert_identical(v.as_numpy(), v) + np.testing.assert_equal(v.to_numpy(), np.array([1, 2, 3])) + + @requires_dask + def test_from_dask(self, Var): + v = Var("x", [1, 2, 3]) + v_chunked = v.chunk(1) + + assert_identical(v_chunked.as_numpy(), v.compute()) + np.testing.assert_equal(v.to_numpy(), np.array([1, 2, 3])) + + @requires_pint + def test_from_pint(self, Var): + import pint + + arr = np.array([1, 2, 3]) + + # IndexVariable strips the unit + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=pint.UnitStrippedWarning) + v = Var("x", pint.Quantity(arr, units="m")) + + assert_identical(v.as_numpy(), Var("x", arr)) + np.testing.assert_equal(v.to_numpy(), arr) + + @requires_sparse + def test_from_sparse(self, Var): + if Var is IndexVariable: + pytest.skip("Can't have 2D IndexVariables") + + import sparse + + arr = np.diagflat([1, 2, 3]) + sparr = sparse.COO(coords=[[0, 1, 2], [0, 1, 2]], data=[1, 2, 3]) + v = Variable(["x", "y"], sparr) + + assert_identical(v.as_numpy(), Variable(["x", "y"], arr)) + np.testing.assert_equal(v.to_numpy(), arr) + + @requires_cupy + def test_from_cupy(self, Var): + if Var is IndexVariable: + pytest.skip("cupy in default indexes is not supported at the moment") + import cupy as cp + + arr = np.array([1, 2, 3]) + v = Var("x", cp.array(arr)) + + assert_identical(v.as_numpy(), Var("x", arr)) + np.testing.assert_equal(v.to_numpy(), arr) + + @requires_dask + @requires_pint + def test_from_pint_wrapping_dask(self, Var): + import dask + import pint + + arr = np.array([1, 2, 3]) + d = dask.array.from_array(np.array([1, 2, 3])) + + # IndexVariable strips the unit + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=pint.UnitStrippedWarning) + v = Var("x", pint.Quantity(d, units="m")) + + result = v.as_numpy() + assert_identical(result, Var("x", arr)) + np.testing.assert_equal(v.to_numpy(), arr) + + +@pytest.mark.parametrize( + ("values", "warns"), + [ + (np.datetime64("2000-01-01", "ns"), False), + (np.datetime64("2000-01-01", "s"), True), + (np.array([np.datetime64("2000-01-01", "ns")]), False), + (np.array([np.datetime64("2000-01-01", "s")]), True), + (pd.date_range("2000", periods=1), False), + (datetime(2000, 1, 1), has_pandas_3), + (np.array([datetime(2000, 1, 1)]), has_pandas_3), + (pd.date_range("2000", periods=1, tz=pytz.timezone("America/New_York")), False), + ( + pd.Series( + pd.date_range("2000", periods=1, tz=pytz.timezone("America/New_York")) + ), + False, + ), + ], + ids=lambda x: f"{x}", +) +def test_datetime_conversion_warning(values, warns) -> None: + dims = ["time"] if isinstance(values, (np.ndarray, pd.Index, pd.Series)) else [] + if warns: + with pytest.warns(UserWarning, match="non-nanosecond precision datetime"): + var = Variable(dims, values) + else: + with warnings.catch_warnings(): + warnings.simplefilter("error") + var = Variable(dims, values) + + if var.dtype.kind == "M": + assert var.dtype == np.dtype("datetime64[ns]") + else: + # The only case where a non-datetime64 dtype can occur currently is in + # the case that the variable is backed by a timezone-aware + # DatetimeIndex, and thus is hidden within the PandasIndexingAdapter class. + assert isinstance(var._data, PandasIndexingAdapter) + assert var._data.array.dtype == pd.DatetimeTZDtype( + "ns", pytz.timezone("America/New_York") + ) + + +def test_pandas_two_only_datetime_conversion_warnings() -> None: + # Note these tests rely on pandas features that are only present in pandas + # 2.0.0 and above, and so for now cannot be parametrized. + cases = [ + (pd.date_range("2000", periods=1), "datetime64[s]"), + (pd.Series(pd.date_range("2000", periods=1)), "datetime64[s]"), + ( + pd.date_range("2000", periods=1, tz=pytz.timezone("America/New_York")), + pd.DatetimeTZDtype("s", pytz.timezone("America/New_York")), + ), + ( + pd.Series( + pd.date_range("2000", periods=1, tz=pytz.timezone("America/New_York")) + ), + pd.DatetimeTZDtype("s", pytz.timezone("America/New_York")), + ), + ] + for data, dtype in cases: + with pytest.warns(UserWarning, match="non-nanosecond precision datetime"): + var = Variable(["time"], data.astype(dtype)) + + if var.dtype.kind == "M": + assert var.dtype == np.dtype("datetime64[ns]") + else: + # The only case where a non-datetime64 dtype can occur currently is in + # the case that the variable is backed by a timezone-aware + # DatetimeIndex, and thus is hidden within the PandasIndexingAdapter class. + assert isinstance(var._data, PandasIndexingAdapter) + assert var._data.array.dtype == pd.DatetimeTZDtype( + "ns", pytz.timezone("America/New_York") + ) + + +@pytest.mark.parametrize( + ("values", "warns"), + [ + (np.timedelta64(10, "ns"), False), + (np.timedelta64(10, "s"), True), + (np.array([np.timedelta64(10, "ns")]), False), + (np.array([np.timedelta64(10, "s")]), True), + (pd.timedelta_range("1", periods=1), False), + (timedelta(days=1), False), + (np.array([timedelta(days=1)]), False), + ], + ids=lambda x: f"{x}", +) +def test_timedelta_conversion_warning(values, warns) -> None: + dims = ["time"] if isinstance(values, (np.ndarray, pd.Index)) else [] + if warns: + with pytest.warns(UserWarning, match="non-nanosecond precision timedelta"): + var = Variable(dims, values) + else: + with warnings.catch_warnings(): + warnings.simplefilter("error") + var = Variable(dims, values) + + assert var.dtype == np.dtype("timedelta64[ns]") + + +def test_pandas_two_only_timedelta_conversion_warning() -> None: + # Note this test relies on a pandas feature that is only present in pandas + # 2.0.0 and above, and so for now cannot be parametrized. + data = pd.timedelta_range("1", periods=1).astype("timedelta64[s]") + with pytest.warns(UserWarning, match="non-nanosecond precision timedelta"): + var = Variable(["time"], data) + + assert var.dtype == np.dtype("timedelta64[ns]") + + +@pytest.mark.parametrize( + ("index", "dtype"), + [ + (pd.date_range("2000", periods=1), "datetime64"), + (pd.timedelta_range("1", periods=1), "timedelta64"), + ], + ids=lambda x: f"{x}", +) +def test_pandas_indexing_adapter_non_nanosecond_conversion(index, dtype) -> None: + data = PandasIndexingAdapter(index.astype(f"{dtype}[s]")) + with pytest.warns(UserWarning, match="non-nanosecond precision"): + var = Variable(["time"], data) + assert var.dtype == np.dtype(f"{dtype}[ns]") diff --git a/test/fixtures/whole_applications/xarray/xarray/tests/test_weighted.py b/test/fixtures/whole_applications/xarray/xarray/tests/test_weighted.py new file mode 100644 index 0000000..f3337d7 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tests/test_weighted.py @@ -0,0 +1,793 @@ +from __future__ import annotations + +from collections.abc import Iterable +from typing import Any + +import numpy as np +import pytest + +import xarray as xr +from xarray import DataArray, Dataset +from xarray.tests import ( + assert_allclose, + assert_equal, + raise_if_dask_computes, + requires_cftime, + requires_dask, +) + + +@pytest.mark.parametrize("as_dataset", (True, False)) +def test_weighted_non_DataArray_weights(as_dataset: bool) -> None: + data: DataArray | Dataset = DataArray([1, 2]) + if as_dataset: + data = data.to_dataset(name="data") + + with pytest.raises(ValueError, match=r"`weights` must be a DataArray"): + data.weighted([1, 2]) # type: ignore + + +@pytest.mark.parametrize("as_dataset", (True, False)) +@pytest.mark.parametrize("weights", ([np.nan, 2], [np.nan, np.nan])) +def test_weighted_weights_nan_raises(as_dataset: bool, weights: list[float]) -> None: + data: DataArray | Dataset = DataArray([1, 2]) + if as_dataset: + data = data.to_dataset(name="data") + + with pytest.raises(ValueError, match="`weights` cannot contain missing values."): + data.weighted(DataArray(weights)) + + +@requires_dask +@pytest.mark.parametrize("as_dataset", (True, False)) +@pytest.mark.parametrize("weights", ([np.nan, 2], [np.nan, np.nan])) +def test_weighted_weights_nan_raises_dask(as_dataset, weights): + data = DataArray([1, 2]).chunk({"dim_0": -1}) + if as_dataset: + data = data.to_dataset(name="data") + + weights = DataArray(weights).chunk({"dim_0": -1}) + + with raise_if_dask_computes(): + weighted = data.weighted(weights) + + with pytest.raises(ValueError, match="`weights` cannot contain missing values."): + weighted.sum().load() + + +@requires_cftime +@requires_dask +@pytest.mark.parametrize("time_chunks", (1, 5)) +@pytest.mark.parametrize("resample_spec", ("1YS", "5YS", "10YS")) +def test_weighted_lazy_resample(time_chunks, resample_spec): + # https://github.com/pydata/xarray/issues/4625 + + # simple customized weighted mean function + def mean_func(ds): + return ds.weighted(ds.weights).mean("time") + + # example dataset + t = xr.cftime_range(start="2000", periods=20, freq="1YS") + weights = xr.DataArray(np.random.rand(len(t)), dims=["time"], coords={"time": t}) + data = xr.DataArray( + np.random.rand(len(t)), dims=["time"], coords={"time": t, "weights": weights} + ) + ds = xr.Dataset({"data": data}).chunk({"time": time_chunks}) + + with raise_if_dask_computes(): + ds.resample(time=resample_spec).map(mean_func) + + +@pytest.mark.parametrize( + ("weights", "expected"), + (([1, 2], 3), ([2, 0], 2), ([0, 0], np.nan), ([-1, 1], np.nan)), +) +def test_weighted_sum_of_weights_no_nan(weights, expected): + da = DataArray([1, 2]) + weights = DataArray(weights) + result = da.weighted(weights).sum_of_weights() + + expected = DataArray(expected) + + assert_equal(expected, result) + + +@pytest.mark.parametrize( + ("weights", "expected"), + (([1, 2], 2), ([2, 0], np.nan), ([0, 0], np.nan), ([-1, 1], 1)), +) +def test_weighted_sum_of_weights_nan(weights, expected): + da = DataArray([np.nan, 2]) + weights = DataArray(weights) + result = da.weighted(weights).sum_of_weights() + + expected = DataArray(expected) + + assert_equal(expected, result) + + +def test_weighted_sum_of_weights_bool(): + # https://github.com/pydata/xarray/issues/4074 + + da = DataArray([1, 2]) + weights = DataArray([True, True]) + result = da.weighted(weights).sum_of_weights() + + expected = DataArray(2) + + assert_equal(expected, result) + + +@pytest.mark.parametrize("da", ([1.0, 2], [1, np.nan], [np.nan, np.nan])) +@pytest.mark.parametrize("factor", [0, 1, 3.14]) +@pytest.mark.parametrize("skipna", (True, False)) +def test_weighted_sum_equal_weights(da, factor, skipna): + # if all weights are 'f'; weighted sum is f times the ordinary sum + + da = DataArray(da) + weights = xr.full_like(da, factor) + + expected = da.sum(skipna=skipna) * factor + result = da.weighted(weights).sum(skipna=skipna) + + assert_equal(expected, result) + + +@pytest.mark.parametrize( + ("weights", "expected"), (([1, 2], 5), ([0, 2], 4), ([0, 0], 0)) +) +def test_weighted_sum_no_nan(weights, expected): + da = DataArray([1, 2]) + + weights = DataArray(weights) + result = da.weighted(weights).sum() + expected = DataArray(expected) + + assert_equal(expected, result) + + +@pytest.mark.parametrize( + ("weights", "expected"), (([1, 2], 4), ([0, 2], 4), ([1, 0], 0), ([0, 0], 0)) +) +@pytest.mark.parametrize("skipna", (True, False)) +def test_weighted_sum_nan(weights, expected, skipna): + da = DataArray([np.nan, 2]) + + weights = DataArray(weights) + result = da.weighted(weights).sum(skipna=skipna) + + if skipna: + expected = DataArray(expected) + else: + expected = DataArray(np.nan) + + assert_equal(expected, result) + + +@pytest.mark.filterwarnings("error") +@pytest.mark.parametrize("da", ([1.0, 2], [1, np.nan], [np.nan, np.nan])) +@pytest.mark.parametrize("skipna", (True, False)) +@pytest.mark.parametrize("factor", [1, 2, 3.14]) +def test_weighted_mean_equal_weights(da, skipna, factor): + # if all weights are equal (!= 0), should yield the same result as mean + + da = DataArray(da) + + # all weights as 1. + weights = xr.full_like(da, factor) + + expected = da.mean(skipna=skipna) + result = da.weighted(weights).mean(skipna=skipna) + + assert_equal(expected, result) + + +@pytest.mark.parametrize( + ("weights", "expected"), (([4, 6], 1.6), ([1, 0], 1.0), ([0, 0], np.nan)) +) +def test_weighted_mean_no_nan(weights, expected): + da = DataArray([1, 2]) + weights = DataArray(weights) + expected = DataArray(expected) + + result = da.weighted(weights).mean() + + assert_equal(expected, result) + + +@pytest.mark.parametrize( + ("weights", "expected"), + ( + ( + [0.25, 0.05, 0.15, 0.25, 0.15, 0.1, 0.05], + [1.554595, 2.463784, 3.000000, 3.518378], + ), + ( + [0.05, 0.05, 0.1, 0.15, 0.15, 0.25, 0.25], + [2.840000, 3.632973, 4.076216, 4.523243], + ), + ), +) +def test_weighted_quantile_no_nan(weights, expected): + # Expected values were calculated by running the reference implementation + # proposed in https://aakinshin.net/posts/weighted-quantiles/ + + da = DataArray([1, 1.9, 2.2, 3, 3.7, 4.1, 5]) + q = [0.2, 0.4, 0.6, 0.8] + weights = DataArray(weights) + + expected = DataArray(expected, coords={"quantile": q}) + result = da.weighted(weights).quantile(q) + + assert_allclose(expected, result) + + +def test_weighted_quantile_zero_weights(): + da = DataArray([0, 1, 2, 3]) + weights = DataArray([1, 0, 1, 0]) + q = 0.75 + + result = da.weighted(weights).quantile(q) + expected = DataArray([0, 2]).quantile(0.75) + + assert_allclose(expected, result) + + +def test_weighted_quantile_simple(): + # Check that weighted quantiles return the same value as numpy quantiles + da = DataArray([0, 1, 2, 3]) + w = DataArray([1, 0, 1, 0]) + + w_eps = DataArray([1, 0.0001, 1, 0.0001]) + q = 0.75 + + expected = DataArray(np.quantile([0, 2], q), coords={"quantile": q}) # 1.5 + + assert_equal(expected, da.weighted(w).quantile(q)) + assert_allclose(expected, da.weighted(w_eps).quantile(q), rtol=0.001) + + +@pytest.mark.parametrize("skipna", (True, False)) +def test_weighted_quantile_nan(skipna): + # Check skipna behavior + da = DataArray([0, 1, 2, 3, np.nan]) + w = DataArray([1, 0, 1, 0, 1]) + q = [0.5, 0.75] + + result = da.weighted(w).quantile(q, skipna=skipna) + + if skipna: + expected = DataArray(np.quantile([0, 2], q), coords={"quantile": q}) + else: + expected = DataArray(np.full(len(q), np.nan), coords={"quantile": q}) + + assert_allclose(expected, result) + + +@pytest.mark.parametrize( + "da", + ( + pytest.param([1, 1.9, 2.2, 3, 3.7, 4.1, 5], id="nonan"), + pytest.param([1, 1.9, 2.2, 3, 3.7, 4.1, np.nan], id="singlenan"), + pytest.param( + [np.nan, np.nan, np.nan], + id="allnan", + marks=pytest.mark.filterwarnings( + "ignore:All-NaN slice encountered:RuntimeWarning" + ), + ), + ), +) +@pytest.mark.parametrize("q", (0.5, (0.2, 0.8))) +@pytest.mark.parametrize("skipna", (True, False)) +@pytest.mark.parametrize("factor", [1, 3.14]) +def test_weighted_quantile_equal_weights( + da: list[float], q: float | tuple[float, ...], skipna: bool, factor: float +) -> None: + # if all weights are equal (!= 0), should yield the same result as quantile + + data = DataArray(da) + weights = xr.full_like(data, factor) + + expected = data.quantile(q, skipna=skipna) + result = data.weighted(weights).quantile(q, skipna=skipna) + + assert_allclose(expected, result) + + +@pytest.mark.skip(reason="`method` argument is not currently exposed") +@pytest.mark.parametrize( + "da", + ( + [1, 1.9, 2.2, 3, 3.7, 4.1, 5], + [1, 1.9, 2.2, 3, 3.7, 4.1, np.nan], + [np.nan, np.nan, np.nan], + ), +) +@pytest.mark.parametrize("q", (0.5, (0.2, 0.8))) +@pytest.mark.parametrize("skipna", (True, False)) +@pytest.mark.parametrize( + "method", + [ + "linear", + "interpolated_inverted_cdf", + "hazen", + "weibull", + "median_unbiased", + "normal_unbiased2", + ], +) +def test_weighted_quantile_equal_weights_all_methods(da, q, skipna, factor, method): + # If all weights are equal (!= 0), should yield the same result as numpy quantile + + da = DataArray(da) + weights = xr.full_like(da, 3.14) + + expected = da.quantile(q, skipna=skipna, method=method) + result = da.weighted(weights).quantile(q, skipna=skipna, method=method) + + assert_allclose(expected, result) + + +def test_weighted_quantile_bool(): + # https://github.com/pydata/xarray/issues/4074 + da = DataArray([1, 1]) + weights = DataArray([True, True]) + q = 0.5 + + expected = DataArray([1], coords={"quantile": [q]}).squeeze() + result = da.weighted(weights).quantile(q) + + assert_equal(expected, result) + + +@pytest.mark.parametrize("q", (-1, 1.1, (0.5, 1.1), ((0.2, 0.4), (0.6, 0.8)))) +def test_weighted_quantile_with_invalid_q(q): + da = DataArray([1, 1.9, 2.2, 3, 3.7, 4.1, 5]) + q = np.asarray(q) + weights = xr.ones_like(da) + + if q.ndim <= 1: + with pytest.raises(ValueError, match="q values must be between 0 and 1"): + da.weighted(weights).quantile(q) + else: + with pytest.raises(ValueError, match="q must be a scalar or 1d"): + da.weighted(weights).quantile(q) + + +@pytest.mark.parametrize( + ("weights", "expected"), (([4, 6], 2.0), ([1, 0], np.nan), ([0, 0], np.nan)) +) +@pytest.mark.parametrize("skipna", (True, False)) +def test_weighted_mean_nan(weights, expected, skipna): + da = DataArray([np.nan, 2]) + weights = DataArray(weights) + + if skipna: + expected = DataArray(expected) + else: + expected = DataArray(np.nan) + + result = da.weighted(weights).mean(skipna=skipna) + + assert_equal(expected, result) + + +def test_weighted_mean_bool(): + # https://github.com/pydata/xarray/issues/4074 + da = DataArray([1, 1]) + weights = DataArray([True, True]) + expected = DataArray(1) + + result = da.weighted(weights).mean() + + assert_equal(expected, result) + + +@pytest.mark.parametrize( + ("weights", "expected"), + (([1, 2], 2 / 3), ([2, 0], 0), ([0, 0], 0), ([-1, 1], 0)), +) +def test_weighted_sum_of_squares_no_nan(weights, expected): + da = DataArray([1, 2]) + weights = DataArray(weights) + result = da.weighted(weights).sum_of_squares() + + expected = DataArray(expected) + + assert_equal(expected, result) + + +@pytest.mark.parametrize( + ("weights", "expected"), + (([1, 2], 0), ([2, 0], 0), ([0, 0], 0), ([-1, 1], 0)), +) +def test_weighted_sum_of_squares_nan(weights, expected): + da = DataArray([np.nan, 2]) + weights = DataArray(weights) + result = da.weighted(weights).sum_of_squares() + + expected = DataArray(expected) + + assert_equal(expected, result) + + +@pytest.mark.filterwarnings("error") +@pytest.mark.parametrize("da", ([1.0, 2], [1, np.nan])) +@pytest.mark.parametrize("skipna", (True, False)) +@pytest.mark.parametrize("factor", [1, 2, 3.14]) +def test_weighted_var_equal_weights(da, skipna, factor): + # if all weights are equal (!= 0), should yield the same result as var + + da = DataArray(da) + + # all weights as 1. + weights = xr.full_like(da, factor) + + expected = da.var(skipna=skipna) + result = da.weighted(weights).var(skipna=skipna) + + assert_equal(expected, result) + + +@pytest.mark.parametrize( + ("weights", "expected"), (([4, 6], 0.24), ([1, 0], 0.0), ([0, 0], np.nan)) +) +def test_weighted_var_no_nan(weights, expected): + da = DataArray([1, 2]) + weights = DataArray(weights) + expected = DataArray(expected) + + result = da.weighted(weights).var() + + assert_equal(expected, result) + + +@pytest.mark.parametrize( + ("weights", "expected"), (([4, 6], 0), ([1, 0], np.nan), ([0, 0], np.nan)) +) +def test_weighted_var_nan(weights, expected): + da = DataArray([np.nan, 2]) + weights = DataArray(weights) + expected = DataArray(expected) + + result = da.weighted(weights).var() + + assert_equal(expected, result) + + +def test_weighted_var_bool(): + # https://github.com/pydata/xarray/issues/4074 + da = DataArray([1, 1]) + weights = DataArray([True, True]) + expected = DataArray(0) + + result = da.weighted(weights).var() + + assert_equal(expected, result) + + +@pytest.mark.filterwarnings("error") +@pytest.mark.parametrize("da", ([1.0, 2], [1, np.nan])) +@pytest.mark.parametrize("skipna", (True, False)) +@pytest.mark.parametrize("factor", [1, 2, 3.14]) +def test_weighted_std_equal_weights(da, skipna, factor): + # if all weights are equal (!= 0), should yield the same result as std + + da = DataArray(da) + + # all weights as 1. + weights = xr.full_like(da, factor) + + expected = da.std(skipna=skipna) + result = da.weighted(weights).std(skipna=skipna) + + assert_equal(expected, result) + + +@pytest.mark.parametrize( + ("weights", "expected"), (([4, 6], np.sqrt(0.24)), ([1, 0], 0.0), ([0, 0], np.nan)) +) +def test_weighted_std_no_nan(weights, expected): + da = DataArray([1, 2]) + weights = DataArray(weights) + expected = DataArray(expected) + + result = da.weighted(weights).std() + + assert_equal(expected, result) + + +@pytest.mark.parametrize( + ("weights", "expected"), (([4, 6], 0), ([1, 0], np.nan), ([0, 0], np.nan)) +) +def test_weighted_std_nan(weights, expected): + da = DataArray([np.nan, 2]) + weights = DataArray(weights) + expected = DataArray(expected) + + result = da.weighted(weights).std() + + assert_equal(expected, result) + + +def test_weighted_std_bool(): + # https://github.com/pydata/xarray/issues/4074 + da = DataArray([1, 1]) + weights = DataArray([True, True]) + expected = DataArray(0) + + result = da.weighted(weights).std() + + assert_equal(expected, result) + + +def expected_weighted(da, weights, dim, skipna, operation): + """ + Generate expected result using ``*`` and ``sum``. This is checked against + the result of da.weighted which uses ``dot`` + """ + + weighted_sum = (da * weights).sum(dim=dim, skipna=skipna) + + if operation == "sum": + return weighted_sum + + masked_weights = weights.where(da.notnull()) + sum_of_weights = masked_weights.sum(dim=dim, skipna=True) + valid_weights = sum_of_weights != 0 + sum_of_weights = sum_of_weights.where(valid_weights) + + if operation == "sum_of_weights": + return sum_of_weights + + weighted_mean = weighted_sum / sum_of_weights + + if operation == "mean": + return weighted_mean + + demeaned = da - weighted_mean + sum_of_squares = ((demeaned**2) * weights).sum(dim=dim, skipna=skipna) + + if operation == "sum_of_squares": + return sum_of_squares + + var = sum_of_squares / sum_of_weights + + if operation == "var": + return var + + if operation == "std": + return np.sqrt(var) + + +def check_weighted_operations(data, weights, dim, skipna): + # check sum of weights + result = data.weighted(weights).sum_of_weights(dim) + expected = expected_weighted(data, weights, dim, skipna, "sum_of_weights") + assert_allclose(expected, result) + + # check weighted sum + result = data.weighted(weights).sum(dim, skipna=skipna) + expected = expected_weighted(data, weights, dim, skipna, "sum") + assert_allclose(expected, result) + + # check weighted mean + result = data.weighted(weights).mean(dim, skipna=skipna) + expected = expected_weighted(data, weights, dim, skipna, "mean") + assert_allclose(expected, result) + + # check weighted sum of squares + result = data.weighted(weights).sum_of_squares(dim, skipna=skipna) + expected = expected_weighted(data, weights, dim, skipna, "sum_of_squares") + assert_allclose(expected, result) + + # check weighted var + result = data.weighted(weights).var(dim, skipna=skipna) + expected = expected_weighted(data, weights, dim, skipna, "var") + assert_allclose(expected, result) + + # check weighted std + result = data.weighted(weights).std(dim, skipna=skipna) + expected = expected_weighted(data, weights, dim, skipna, "std") + assert_allclose(expected, result) + + +@pytest.mark.parametrize("dim", ("a", "b", "c", ("a", "b"), ("a", "b", "c"), None)) +@pytest.mark.parametrize("add_nans", (True, False)) +@pytest.mark.parametrize("skipna", (None, True, False)) +@pytest.mark.filterwarnings("ignore:invalid value encountered in sqrt") +def test_weighted_operations_3D(dim, add_nans, skipna): + dims = ("a", "b", "c") + coords = dict(a=[0, 1, 2, 3], b=[0, 1, 2, 3], c=[0, 1, 2, 3]) + + weights = DataArray(np.random.randn(4, 4, 4), dims=dims, coords=coords) + + data = np.random.randn(4, 4, 4) + + # add approximately 25 % NaNs (https://stackoverflow.com/a/32182680/3010700) + if add_nans: + c = int(data.size * 0.25) + data.ravel()[np.random.choice(data.size, c, replace=False)] = np.nan + + data = DataArray(data, dims=dims, coords=coords) + + check_weighted_operations(data, weights, dim, skipna) + + data = data.to_dataset(name="data") + check_weighted_operations(data, weights, dim, skipna) + + +@pytest.mark.parametrize("dim", ("a", "b", "c", ("a", "b"), ("a", "b", "c"), None)) +@pytest.mark.parametrize("q", (0.5, (0.1, 0.9), (0.2, 0.4, 0.6, 0.8))) +@pytest.mark.parametrize("add_nans", (True, False)) +@pytest.mark.parametrize("skipna", (None, True, False)) +def test_weighted_quantile_3D(dim, q, add_nans, skipna): + dims = ("a", "b", "c") + coords = dict(a=[0, 1, 2], b=[0, 1, 2, 3], c=[0, 1, 2, 3, 4]) + + data = np.arange(60).reshape(3, 4, 5).astype(float) + + # add approximately 25 % NaNs (https://stackoverflow.com/a/32182680/3010700) + if add_nans: + c = int(data.size * 0.25) + data.ravel()[np.random.choice(data.size, c, replace=False)] = np.nan + + da = DataArray(data, dims=dims, coords=coords) + + # Weights are all ones, because we will compare against DataArray.quantile (non-weighted) + weights = xr.ones_like(da) + + result = da.weighted(weights).quantile(q, dim=dim, skipna=skipna) + expected = da.quantile(q, dim=dim, skipna=skipna) + + assert_allclose(expected, result) + + ds = da.to_dataset(name="data") + result2 = ds.weighted(weights).quantile(q, dim=dim, skipna=skipna) + + assert_allclose(expected, result2.data) + + +@pytest.mark.parametrize( + "coords_weights, coords_data, expected_value_at_weighted_quantile", + [ + ([0, 1, 2, 3], [1, 2, 3, 4], 2.5), # no weights for coord a == 4 + ([0, 1, 2, 3], [2, 3, 4, 5], 1.8), # no weights for coord a == 4 or 5 + ([2, 3, 4, 5], [0, 1, 2, 3], 3.8), # no weights for coord a == 0 or 1 + ], +) +def test_weighted_operations_nonequal_coords( + coords_weights: Iterable[Any], + coords_data: Iterable[Any], + expected_value_at_weighted_quantile: float, +) -> None: + """Check that weighted operations work with unequal coords. + + + Parameters + ---------- + coords_weights : Iterable[Any] + The coords for the weights. + coords_data : Iterable[Any] + The coords for the data. + expected_value_at_weighted_quantile : float + The expected value for the quantile of the weighted data. + """ + da_weights = DataArray( + [0.5, 1.0, 1.0, 2.0], dims=("a",), coords=dict(a=coords_weights) + ) + da_data = DataArray([1, 2, 3, 4], dims=("a",), coords=dict(a=coords_data)) + check_weighted_operations(da_data, da_weights, dim="a", skipna=None) + + quantile = 0.5 + da_actual = da_data.weighted(da_weights).quantile(quantile, dim="a") + da_expected = DataArray( + [expected_value_at_weighted_quantile], coords={"quantile": [quantile]} + ).squeeze() + assert_allclose(da_actual, da_expected) + + ds_data = da_data.to_dataset(name="data") + check_weighted_operations(ds_data, da_weights, dim="a", skipna=None) + + ds_actual = ds_data.weighted(da_weights).quantile(quantile, dim="a") + assert_allclose(ds_actual, da_expected.to_dataset(name="data")) + + +@pytest.mark.parametrize("shape_data", ((4,), (4, 4), (4, 4, 4))) +@pytest.mark.parametrize("shape_weights", ((4,), (4, 4), (4, 4, 4))) +@pytest.mark.parametrize("add_nans", (True, False)) +@pytest.mark.parametrize("skipna", (None, True, False)) +@pytest.mark.filterwarnings("ignore:invalid value encountered in sqrt") +def test_weighted_operations_different_shapes( + shape_data, shape_weights, add_nans, skipna +): + weights = DataArray(np.random.randn(*shape_weights)) + + data = np.random.randn(*shape_data) + + # add approximately 25 % NaNs + if add_nans: + c = int(data.size * 0.25) + data.ravel()[np.random.choice(data.size, c, replace=False)] = np.nan + + data = DataArray(data) + + check_weighted_operations(data, weights, "dim_0", skipna) + check_weighted_operations(data, weights, None, skipna) + + data = data.to_dataset(name="data") + check_weighted_operations(data, weights, "dim_0", skipna) + check_weighted_operations(data, weights, None, skipna) + + +@pytest.mark.parametrize( + "operation", + ("sum_of_weights", "sum", "mean", "sum_of_squares", "var", "std", "quantile"), +) +@pytest.mark.parametrize("as_dataset", (True, False)) +@pytest.mark.parametrize("keep_attrs", (True, False, None)) +def test_weighted_operations_keep_attr(operation, as_dataset, keep_attrs): + weights = DataArray(np.random.randn(2, 2), attrs=dict(attr="weights")) + data = DataArray(np.random.randn(2, 2)) + + if as_dataset: + data = data.to_dataset(name="data") + + data.attrs = dict(attr="weights") + + kwargs = {"keep_attrs": keep_attrs} + if operation == "quantile": + kwargs["q"] = 0.5 + + result = getattr(data.weighted(weights), operation)(**kwargs) + + if operation == "sum_of_weights": + assert result.attrs == (weights.attrs if keep_attrs else {}) + assert result.attrs == (weights.attrs if keep_attrs else {}) + else: + assert result.attrs == (weights.attrs if keep_attrs else {}) + assert result.attrs == (data.attrs if keep_attrs else {}) + + +@pytest.mark.parametrize( + "operation", + ("sum_of_weights", "sum", "mean", "sum_of_squares", "var", "std", "quantile"), +) +def test_weighted_operations_keep_attr_da_in_ds(operation): + # GH #3595 + + weights = DataArray(np.random.randn(2, 2)) + data = DataArray(np.random.randn(2, 2), attrs=dict(attr="data")) + data = data.to_dataset(name="a") + + kwargs = {"keep_attrs": True} + if operation == "quantile": + kwargs["q"] = 0.5 + + result = getattr(data.weighted(weights), operation)(**kwargs) + + assert data.a.attrs == result.a.attrs + + +@pytest.mark.parametrize("operation", ("sum_of_weights", "sum", "mean", "quantile")) +@pytest.mark.parametrize("as_dataset", (True, False)) +def test_weighted_bad_dim(operation, as_dataset): + data = DataArray(np.random.randn(2, 2)) + weights = xr.ones_like(data) + if as_dataset: + data = data.to_dataset(name="data") + + kwargs = {"dim": "bad_dim"} + if operation == "quantile": + kwargs["q"] = 0.5 + + with pytest.raises( + ValueError, + match=( + f"Dimensions \\('bad_dim',\\) not found in {data.__class__.__name__}Weighted " + # the order of (dim_0, dim_1) varies + "dimensions \\(('dim_0', 'dim_1'|'dim_1', 'dim_0')\\)" + ), + ): + getattr(data.weighted(weights), operation)(**kwargs) diff --git a/test/fixtures/whole_applications/xarray/xarray/tutorial.py b/test/fixtures/whole_applications/xarray/xarray/tutorial.py new file mode 100644 index 0000000..82bb394 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/tutorial.py @@ -0,0 +1,244 @@ +""" +Useful for: + +* users learning xarray +* building tutorials in the documentation. + +""" + +from __future__ import annotations + +import os +import pathlib +from typing import TYPE_CHECKING + +import numpy as np + +from xarray.backends.api import open_dataset as _open_dataset +from xarray.core.dataarray import DataArray +from xarray.core.dataset import Dataset + +if TYPE_CHECKING: + from xarray.backends.api import T_Engine + + +_default_cache_dir_name = "xarray_tutorial_data" +base_url = "https://github.com/pydata/xarray-data" +version = "master" + + +def _construct_cache_dir(path): + import pooch + + if isinstance(path, os.PathLike): + path = os.fspath(path) + elif path is None: + path = pooch.os_cache(_default_cache_dir_name) + + return path + + +external_urls = {} # type: dict +file_formats = { + "air_temperature": 3, + "air_temperature_gradient": 4, + "ASE_ice_velocity": 4, + "basin_mask": 4, + "ersstv5": 4, + "rasm": 3, + "ROMS_example": 4, + "tiny": 3, + "eraint_uvz": 3, +} + + +def _check_netcdf_engine_installed(name): + version = file_formats.get(name) + if version == 3: + try: + import scipy # noqa + except ImportError: + try: + import netCDF4 # noqa + except ImportError: + raise ImportError( + f"opening tutorial dataset {name} requires either scipy or " + "netCDF4 to be installed." + ) + if version == 4: + try: + import h5netcdf # noqa + except ImportError: + try: + import netCDF4 # noqa + except ImportError: + raise ImportError( + f"opening tutorial dataset {name} requires either h5netcdf " + "or netCDF4 to be installed." + ) + + +# idea borrowed from Seaborn +def open_dataset( + name: str, + cache: bool = True, + cache_dir: None | str | os.PathLike = None, + *, + engine: T_Engine = None, + **kws, +) -> Dataset: + """ + Open a dataset from the online repository (requires internet). + + If a local copy is found then always use that to avoid network traffic. + + Available datasets: + + * ``"air_temperature"``: NCEP reanalysis subset + * ``"air_temperature_gradient"``: NCEP reanalysis subset with approximate x,y gradients + * ``"basin_mask"``: Dataset with ocean basins marked using integers + * ``"ASE_ice_velocity"``: MEaSUREs InSAR-Based Ice Velocity of the Amundsen Sea Embayment, Antarctica, Version 1 + * ``"rasm"``: Output of the Regional Arctic System Model (RASM) + * ``"ROMS_example"``: Regional Ocean Model System (ROMS) output + * ``"tiny"``: small synthetic dataset with a 1D data variable + * ``"era5-2mt-2019-03-uk.grib"``: ERA5 temperature data over the UK + * ``"eraint_uvz"``: data from ERA-Interim reanalysis, monthly averages of upper level data + * ``"ersstv5"``: NOAA's Extended Reconstructed Sea Surface Temperature monthly averages + + Parameters + ---------- + name : str + Name of the file containing the dataset. + e.g. 'air_temperature' + cache_dir : path-like, optional + The directory in which to search for and write cached data. + cache : bool, optional + If True, then cache data locally for use on subsequent calls + **kws : dict, optional + Passed to xarray.open_dataset + + See Also + -------- + tutorial.load_dataset + open_dataset + load_dataset + """ + try: + import pooch + except ImportError as e: + raise ImportError( + "tutorial.open_dataset depends on pooch to download and manage datasets." + " To proceed please install pooch." + ) from e + + logger = pooch.get_logger() + logger.setLevel("WARNING") + + cache_dir = _construct_cache_dir(cache_dir) + if name in external_urls: + url = external_urls[name] + else: + path = pathlib.Path(name) + if not path.suffix: + # process the name + default_extension = ".nc" + if engine is None: + _check_netcdf_engine_installed(name) + path = path.with_suffix(default_extension) + elif path.suffix == ".grib": + if engine is None: + engine = "cfgrib" + try: + import cfgrib # noqa + except ImportError as e: + raise ImportError( + "Reading this tutorial dataset requires the cfgrib package." + ) from e + + url = f"{base_url}/raw/{version}/{path.name}" + + # retrieve the file + filepath = pooch.retrieve(url=url, known_hash=None, path=cache_dir) + ds = _open_dataset(filepath, engine=engine, **kws) + if not cache: + ds = ds.load() + pathlib.Path(filepath).unlink() + + return ds + + +def load_dataset(*args, **kwargs) -> Dataset: + """ + Open, load into memory, and close a dataset from the online repository + (requires internet). + + If a local copy is found then always use that to avoid network traffic. + + Available datasets: + + * ``"air_temperature"``: NCEP reanalysis subset + * ``"air_temperature_gradient"``: NCEP reanalysis subset with approximate x,y gradients + * ``"basin_mask"``: Dataset with ocean basins marked using integers + * ``"rasm"``: Output of the Regional Arctic System Model (RASM) + * ``"ROMS_example"``: Regional Ocean Model System (ROMS) output + * ``"tiny"``: small synthetic dataset with a 1D data variable + * ``"era5-2mt-2019-03-uk.grib"``: ERA5 temperature data over the UK + * ``"eraint_uvz"``: data from ERA-Interim reanalysis, monthly averages of upper level data + * ``"ersstv5"``: NOAA's Extended Reconstructed Sea Surface Temperature monthly averages + + Parameters + ---------- + name : str + Name of the file containing the dataset. + e.g. 'air_temperature' + cache_dir : path-like, optional + The directory in which to search for and write cached data. + cache : bool, optional + If True, then cache data locally for use on subsequent calls + **kws : dict, optional + Passed to xarray.open_dataset + + See Also + -------- + tutorial.open_dataset + open_dataset + load_dataset + """ + with open_dataset(*args, **kwargs) as ds: + return ds.load() + + +def scatter_example_dataset(*, seed: None | int = None) -> Dataset: + """ + Create an example dataset. + + Parameters + ---------- + seed : int, optional + Seed for the random number generation. + """ + rng = np.random.default_rng(seed) + A = DataArray( + np.zeros([3, 11, 4, 4]), + dims=["x", "y", "z", "w"], + coords={ + "x": np.arange(3), + "y": np.linspace(0, 1, 11), + "z": np.arange(4), + "w": 0.1 * rng.standard_normal(4), + }, + ) + B = 0.1 * A.x**2 + A.y**2.5 + 0.1 * A.z * A.w + A = -0.1 * A.x + A.y / (5 + A.z) + A.w + ds = Dataset({"A": A, "B": B}) + ds["w"] = ["one", "two", "three", "five"] + + ds.x.attrs["units"] = "xunits" + ds.y.attrs["units"] = "yunits" + ds.z.attrs["units"] = "zunits" + ds.w.attrs["units"] = "wunits" + + ds.A.attrs["units"] = "Aunits" + ds.B.attrs["units"] = "Bunits" + + return ds diff --git a/test/fixtures/whole_applications/xarray/xarray/util/__init__.py b/test/fixtures/whole_applications/xarray/xarray/util/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/fixtures/whole_applications/xarray/xarray/util/deprecation_helpers.py b/test/fixtures/whole_applications/xarray/xarray/util/deprecation_helpers.py new file mode 100644 index 0000000..3d52253 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/util/deprecation_helpers.py @@ -0,0 +1,144 @@ +# For reference, here is a copy of the scikit-learn copyright notice: + +# BSD 3-Clause License + +# Copyright (c) 2007-2021 The scikit-learn developers. +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE + + +import inspect +import warnings +from functools import wraps +from typing import Callable, TypeVar + +from xarray.core.utils import emit_user_level_warning + +T = TypeVar("T", bound=Callable) + +POSITIONAL_OR_KEYWORD = inspect.Parameter.POSITIONAL_OR_KEYWORD +KEYWORD_ONLY = inspect.Parameter.KEYWORD_ONLY +POSITIONAL_ONLY = inspect.Parameter.POSITIONAL_ONLY +EMPTY = inspect.Parameter.empty + + +def _deprecate_positional_args(version) -> Callable[[T], T]: + """Decorator for methods that issues warnings for positional arguments + + Using the keyword-only argument syntax in pep 3102, arguments after the + ``*`` will issue a warning when passed as a positional argument. + + Parameters + ---------- + version : str + version of the library when the positional arguments were deprecated + + Examples + -------- + Deprecate passing `b` as positional argument: + + def func(a, b=1): + pass + + @_deprecate_positional_args("v0.1.0") + def func(a, *, b=2): + pass + + func(1, 2) + + Notes + ----- + This function is adapted from scikit-learn under the terms of its license. See + licences/SCIKIT_LEARN_LICENSE + """ + + def _decorator(func): + signature = inspect.signature(func) + + pos_or_kw_args = [] + kwonly_args = [] + for name, param in signature.parameters.items(): + if param.kind in (POSITIONAL_OR_KEYWORD, POSITIONAL_ONLY): + pos_or_kw_args.append(name) + elif param.kind == KEYWORD_ONLY: + kwonly_args.append(name) + if param.default is EMPTY: + # IMHO `def f(a, *, b):` does not make sense -> disallow it + # if removing this constraint -> need to add these to kwargs as well + raise TypeError("Keyword-only param without default disallowed.") + + @wraps(func) + def inner(*args, **kwargs): + name = func.__name__ + n_extra_args = len(args) - len(pos_or_kw_args) + if n_extra_args > 0: + extra_args = ", ".join(kwonly_args[:n_extra_args]) + + warnings.warn( + f"Passing '{extra_args}' as positional argument(s) to {name} " + f"was deprecated in version {version} and will raise an error two " + "releases later. Please pass them as keyword arguments." + "", + FutureWarning, + stacklevel=2, + ) + + zip_args = zip(kwonly_args[:n_extra_args], args[-n_extra_args:]) + kwargs.update({name: arg for name, arg in zip_args}) + + return func(*args[:-n_extra_args], **kwargs) + + return func(*args, **kwargs) + + return inner + + return _decorator + + +def deprecate_dims(func: T, old_name="dims") -> T: + """ + For functions that previously took `dims` as a kwarg, and have now transitioned to + `dim`. This decorator will issue a warning if `dims` is passed while forwarding it + to `dim`. + """ + + @wraps(func) + def wrapper(*args, **kwargs): + if old_name in kwargs: + emit_user_level_warning( + f"The `{old_name}` argument has been renamed to `dim`, and will be removed " + "in the future. This renaming is taking place throughout xarray over the " + "next few releases.", + # Upgrade to `DeprecationWarning` in the future, when the renaming is complete. + PendingDeprecationWarning, + ) + kwargs["dim"] = kwargs.pop(old_name) + return func(*args, **kwargs) + + # We're quite confident we're just returning `T` from this function, so it's fine to ignore typing + # within the function. + return wrapper # type: ignore diff --git a/test/fixtures/whole_applications/xarray/xarray/util/generate_aggregations.py b/test/fixtures/whole_applications/xarray/xarray/util/generate_aggregations.py new file mode 100644 index 0000000..b59dc36 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/util/generate_aggregations.py @@ -0,0 +1,682 @@ +"""Generate module and stub file for arithmetic operators of various xarray classes. + +For internal xarray development use only. + +Usage: + python xarray/util/generate_aggregations.py + pytest --doctest-modules xarray/core/_aggregations.py --accept || true + pytest --doctest-modules xarray/core/_aggregations.py + +This requires [pytest-accept](https://github.com/max-sixty/pytest-accept). +The second run of pytest is deliberate, since the first will return an error +while replacing the doctests. + +""" + +import collections +import textwrap +from dataclasses import dataclass, field + +MODULE_PREAMBLE = '''\ +"""Mixin classes with reduction operations.""" + +# This file was generated using xarray.util.generate_aggregations. Do not edit manually. + +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, Callable + +from xarray.core import duck_array_ops +from xarray.core.options import OPTIONS +from xarray.core.types import Dims, Self +from xarray.core.utils import contains_only_chunked_or_numpy, module_available + +if TYPE_CHECKING: + from xarray.core.dataarray import DataArray + from xarray.core.dataset import Dataset + +flox_available = module_available("flox") +''' + +NAMED_ARRAY_MODULE_PREAMBLE = '''\ +"""Mixin classes with reduction operations.""" +# This file was generated using xarray.util.generate_aggregations. Do not edit manually. + +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any, Callable + +from xarray.core import duck_array_ops +from xarray.core.types import Dims, Self +''' + +AGGREGATIONS_PREAMBLE = """ + +class {obj}{cls}Aggregations: + __slots__ = () + + def reduce( + self, + func: Callable[..., Any], + dim: Dims = None, + *, + axis: int | Sequence[int] | None = None, + keep_attrs: bool | None = None, + keepdims: bool = False, + **kwargs: Any, + ) -> Self: + raise NotImplementedError()""" + +NAMED_ARRAY_AGGREGATIONS_PREAMBLE = """ + +class {obj}{cls}Aggregations: + __slots__ = () + + def reduce( + self, + func: Callable[..., Any], + dim: Dims = None, + *, + axis: int | Sequence[int] | None = None, + keepdims: bool = False, + **kwargs: Any, + ) -> Self: + raise NotImplementedError()""" + + +GROUPBY_PREAMBLE = """ + +class {obj}{cls}Aggregations: + _obj: {obj} + + def _reduce_without_squeeze_warn( + self, + func: Callable[..., Any], + dim: Dims = None, + *, + axis: int | Sequence[int] | None = None, + keep_attrs: bool | None = None, + keepdims: bool = False, + shortcut: bool = True, + **kwargs: Any, + ) -> {obj}: + raise NotImplementedError() + + def reduce( + self, + func: Callable[..., Any], + dim: Dims = None, + *, + axis: int | Sequence[int] | None = None, + keep_attrs: bool | None = None, + keepdims: bool = False, + **kwargs: Any, + ) -> {obj}: + raise NotImplementedError() + + def _flox_reduce( + self, + dim: Dims, + **kwargs: Any, + ) -> {obj}: + raise NotImplementedError()""" + +RESAMPLE_PREAMBLE = """ + +class {obj}{cls}Aggregations: + _obj: {obj} + + def _reduce_without_squeeze_warn( + self, + func: Callable[..., Any], + dim: Dims = None, + *, + axis: int | Sequence[int] | None = None, + keep_attrs: bool | None = None, + keepdims: bool = False, + shortcut: bool = True, + **kwargs: Any, + ) -> {obj}: + raise NotImplementedError() + + def reduce( + self, + func: Callable[..., Any], + dim: Dims = None, + *, + axis: int | Sequence[int] | None = None, + keep_attrs: bool | None = None, + keepdims: bool = False, + **kwargs: Any, + ) -> {obj}: + raise NotImplementedError() + + def _flox_reduce( + self, + dim: Dims, + **kwargs: Any, + ) -> {obj}: + raise NotImplementedError()""" + +TEMPLATE_REDUCTION_SIGNATURE = ''' + def {method}( + self, + dim: Dims = None,{kw_only}{extra_kwargs}{keep_attrs} + **kwargs: Any, + ) -> Self: + """ + Reduce this {obj}'s data by applying ``{method}`` along some dimension(s). + + Parameters + ----------''' + +TEMPLATE_REDUCTION_SIGNATURE_GROUPBY = ''' + def {method}( + self, + dim: Dims = None, + *,{extra_kwargs} + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> {obj}: + """ + Reduce this {obj}'s data by applying ``{method}`` along some dimension(s). + + Parameters + ----------''' + +TEMPLATE_RETURNS = """ + Returns + ------- + reduced : {obj} + New {obj} with ``{method}`` applied to its data and the + indicated dimension(s) removed""" + +TEMPLATE_SEE_ALSO = """ + See Also + -------- +{see_also_methods} + :ref:`{docref}` + User guide on {docref_description}.""" + +TEMPLATE_NOTES = """ + Notes + ----- +{notes}""" + +_DIM_DOCSTRING = """dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``{method}``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions.""" + +_DIM_DOCSTRING_GROUPBY = """dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``{method}``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If None, will reduce over the {cls} dimensions. + If "...", will reduce over all dimensions.""" + +_SKIPNA_DOCSTRING = """skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64).""" + +_MINCOUNT_DOCSTRING = """min_count : int or None, optional + The required number of valid values to perform the operation. If + fewer than min_count non-NA values are present the result will be + NA. Only used if skipna is set to True or defaults to True for the + array's dtype. Changed in version 0.17.0: if specified on an integer + array and skipna=True, the result will be a float array.""" + +_DDOF_DOCSTRING = """ddof : int, default: 0 + “Delta Degrees of Freedom”: the divisor used in the calculation is ``N - ddof``, + where ``N`` represents the number of elements.""" + +_KEEP_ATTRS_DOCSTRING = """keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes.""" + +_KWARGS_DOCSTRING = """**kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``{method}`` on this object's data. + These could include dask-specific kwargs like ``split_every``.""" + +_NUMERIC_ONLY_NOTES = "Non-numeric variables will be removed prior to reducing." + +_FLOX_NOTES_TEMPLATE = """Use the ``flox`` package to significantly speed up {kind} computations, +especially with dask arrays. Xarray will use flox by default if installed. +Pass flox-specific keyword arguments in ``**kwargs``. +See the `flox documentation `_ for more.""" +_FLOX_GROUPBY_NOTES = _FLOX_NOTES_TEMPLATE.format(kind="groupby") +_FLOX_RESAMPLE_NOTES = _FLOX_NOTES_TEMPLATE.format(kind="resampling") + +ExtraKwarg = collections.namedtuple("ExtraKwarg", "docs kwarg call example") +skipna = ExtraKwarg( + docs=_SKIPNA_DOCSTRING, + kwarg="skipna: bool | None = None,", + call="skipna=skipna,", + example="""\n + Use ``skipna`` to control whether NaNs are ignored. + + >>> {calculation}(skipna=False)""", +) +min_count = ExtraKwarg( + docs=_MINCOUNT_DOCSTRING, + kwarg="min_count: int | None = None,", + call="min_count=min_count,", + example="""\n + Specify ``min_count`` for finer control over when NaNs are ignored. + + >>> {calculation}(skipna=True, min_count=2)""", +) +ddof = ExtraKwarg( + docs=_DDOF_DOCSTRING, + kwarg="ddof: int = 0,", + call="ddof=ddof,", + example="""\n + Specify ``ddof=1`` for an unbiased estimate. + + >>> {calculation}(skipna=True, ddof=1)""", +) + + +@dataclass +class DataStructure: + name: str + create_example: str + example_var_name: str + numeric_only: bool = False + see_also_modules: tuple[str] = tuple + + +class Method: + def __init__( + self, + name, + bool_reduce=False, + extra_kwargs=tuple(), + numeric_only=False, + see_also_modules=("numpy", "dask.array"), + min_flox_version=None, + ): + self.name = name + self.extra_kwargs = extra_kwargs + self.numeric_only = numeric_only + self.see_also_modules = see_also_modules + self.min_flox_version = min_flox_version + if bool_reduce: + self.array_method = f"array_{name}" + self.np_example_array = """ + ... np.array([True, True, True, True, True, False], dtype=bool)""" + + else: + self.array_method = name + self.np_example_array = """ + ... np.array([1, 2, 3, 0, 2, np.nan])""" + + +@dataclass +class AggregationGenerator: + _dim_docstring = _DIM_DOCSTRING + _template_signature = TEMPLATE_REDUCTION_SIGNATURE + + cls: str + datastructure: DataStructure + methods: tuple[Method, ...] + docref: str + docref_description: str + example_call_preamble: str + definition_preamble: str + has_keep_attrs: bool = True + notes: str = "" + preamble: str = field(init=False) + + def __post_init__(self): + self.preamble = self.definition_preamble.format( + obj=self.datastructure.name, cls=self.cls + ) + + def generate_methods(self): + yield [self.preamble] + for method in self.methods: + yield self.generate_method(method) + + def generate_method(self, method): + has_kw_only = method.extra_kwargs or self.has_keep_attrs + + template_kwargs = dict( + obj=self.datastructure.name, + method=method.name, + keep_attrs=( + "\n keep_attrs: bool | None = None," + if self.has_keep_attrs + else "" + ), + kw_only="\n *," if has_kw_only else "", + ) + + if method.extra_kwargs: + extra_kwargs = "\n " + "\n ".join( + [kwarg.kwarg for kwarg in method.extra_kwargs if kwarg.kwarg] + ) + else: + extra_kwargs = "" + + yield self._template_signature.format( + **template_kwargs, + extra_kwargs=extra_kwargs, + ) + + for text in [ + self._dim_docstring.format(method=method.name, cls=self.cls), + *(kwarg.docs for kwarg in method.extra_kwargs if kwarg.docs), + _KEEP_ATTRS_DOCSTRING if self.has_keep_attrs else None, + _KWARGS_DOCSTRING.format(method=method.name), + ]: + if text: + yield textwrap.indent(text, 8 * " ") + + yield TEMPLATE_RETURNS.format(**template_kwargs) + + # we want Dataset.count to refer to DataArray.count + # but we also want DatasetGroupBy.count to refer to Dataset.count + # The generic aggregations have self.cls == '' + others = ( + self.datastructure.see_also_modules + if self.cls == "" + else (self.datastructure.name,) + ) + see_also_methods = "\n".join( + " " * 8 + f"{mod}.{method.name}" + for mod in (method.see_also_modules + others) + ) + # Fixes broken links mentioned in #8055 + yield TEMPLATE_SEE_ALSO.format( + **template_kwargs, + docref=self.docref, + docref_description=self.docref_description, + see_also_methods=see_also_methods, + ) + + notes = self.notes + if method.numeric_only: + if notes != "": + notes += "\n\n" + notes += _NUMERIC_ONLY_NOTES + + if notes != "": + yield TEMPLATE_NOTES.format(notes=textwrap.indent(notes, 8 * " ")) + + yield textwrap.indent(self.generate_example(method=method), "") + yield ' """' + + yield self.generate_code(method, self.has_keep_attrs) + + def generate_example(self, method): + created = self.datastructure.create_example.format( + example_array=method.np_example_array + ) + calculation = f"{self.datastructure.example_var_name}{self.example_call_preamble}.{method.name}" + if method.extra_kwargs: + extra_examples = "".join( + kwarg.example for kwarg in method.extra_kwargs if kwarg.example + ).format(calculation=calculation, method=method.name) + else: + extra_examples = "" + + return f""" + Examples + --------{created} + >>> {self.datastructure.example_var_name} + + >>> {calculation}(){extra_examples}""" + + +class GroupByAggregationGenerator(AggregationGenerator): + _dim_docstring = _DIM_DOCSTRING_GROUPBY + _template_signature = TEMPLATE_REDUCTION_SIGNATURE_GROUPBY + + def generate_code(self, method, has_keep_attrs): + extra_kwargs = [kwarg.call for kwarg in method.extra_kwargs if kwarg.call] + + if self.datastructure.numeric_only: + extra_kwargs.append(f"numeric_only={method.numeric_only},") + + # median isn't enabled yet, because it would break if a single group was present in multiple + # chunks. The non-flox code path will just rechunk every group to a single chunk and execute the median + method_is_not_flox_supported = method.name in ("median", "cumsum", "cumprod") + if method_is_not_flox_supported: + indent = 12 + else: + indent = 16 + + if extra_kwargs: + extra_kwargs = textwrap.indent("\n" + "\n".join(extra_kwargs), indent * " ") + else: + extra_kwargs = "" + + if method_is_not_flox_supported: + return f"""\ + return self._reduce_without_squeeze_warn( + duck_array_ops.{method.array_method}, + dim=dim,{extra_kwargs} + keep_attrs=keep_attrs, + **kwargs, + )""" + + min_version_check = f""" + and module_available("flox", minversion="{method.min_flox_version}")""" + + return ( + """\ + if ( + flox_available + and OPTIONS["use_flox"]""" + + (min_version_check if method.min_flox_version is not None else "") + + f""" + and contains_only_chunked_or_numpy(self._obj) + ): + return self._flox_reduce( + func="{method.name}", + dim=dim,{extra_kwargs} + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self._reduce_without_squeeze_warn( + duck_array_ops.{method.array_method}, + dim=dim,{extra_kwargs} + keep_attrs=keep_attrs, + **kwargs, + )""" + ) + + +class GenericAggregationGenerator(AggregationGenerator): + def generate_code(self, method, has_keep_attrs): + extra_kwargs = [kwarg.call for kwarg in method.extra_kwargs if kwarg.call] + + if self.datastructure.numeric_only: + extra_kwargs.append(f"numeric_only={method.numeric_only},") + + if extra_kwargs: + extra_kwargs = textwrap.indent("\n" + "\n".join(extra_kwargs), 12 * " ") + else: + extra_kwargs = "" + keep_attrs = ( + "\n" + 12 * " " + "keep_attrs=keep_attrs," if has_keep_attrs else "" + ) + return f"""\ + return self.reduce( + duck_array_ops.{method.array_method}, + dim=dim,{extra_kwargs}{keep_attrs} + **kwargs, + )""" + + +AGGREGATION_METHODS = ( + # Reductions: + Method("count", see_also_modules=("pandas.DataFrame", "dask.dataframe.DataFrame")), + Method("all", bool_reduce=True), + Method("any", bool_reduce=True), + Method("max", extra_kwargs=(skipna,)), + Method("min", extra_kwargs=(skipna,)), + Method("mean", extra_kwargs=(skipna,), numeric_only=True), + Method("prod", extra_kwargs=(skipna, min_count), numeric_only=True), + Method("sum", extra_kwargs=(skipna, min_count), numeric_only=True), + Method("std", extra_kwargs=(skipna, ddof), numeric_only=True), + Method("var", extra_kwargs=(skipna, ddof), numeric_only=True), + Method( + "median", extra_kwargs=(skipna,), numeric_only=True, min_flox_version="0.9.2" + ), + # Cumulatives: + Method("cumsum", extra_kwargs=(skipna,), numeric_only=True), + Method("cumprod", extra_kwargs=(skipna,), numeric_only=True), +) + + +DATASET_OBJECT = DataStructure( + name="Dataset", + create_example=""" + >>> da = xr.DataArray({example_array}, + ... dims="time", + ... coords=dict( + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ) + >>> ds = xr.Dataset(dict(da=da))""", + example_var_name="ds", + numeric_only=True, + see_also_modules=("DataArray",), +) +DATAARRAY_OBJECT = DataStructure( + name="DataArray", + create_example=""" + >>> da = xr.DataArray({example_array}, + ... dims="time", + ... coords=dict( + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... )""", + example_var_name="da", + numeric_only=False, + see_also_modules=("Dataset",), +) +DATASET_GENERATOR = GenericAggregationGenerator( + cls="", + datastructure=DATASET_OBJECT, + methods=AGGREGATION_METHODS, + docref="agg", + docref_description="reduction or aggregation operations", + example_call_preamble="", + definition_preamble=AGGREGATIONS_PREAMBLE, +) +DATAARRAY_GENERATOR = GenericAggregationGenerator( + cls="", + datastructure=DATAARRAY_OBJECT, + methods=AGGREGATION_METHODS, + docref="agg", + docref_description="reduction or aggregation operations", + example_call_preamble="", + definition_preamble=AGGREGATIONS_PREAMBLE, +) +DATAARRAY_GROUPBY_GENERATOR = GroupByAggregationGenerator( + cls="GroupBy", + datastructure=DATAARRAY_OBJECT, + methods=AGGREGATION_METHODS, + docref="groupby", + docref_description="groupby operations", + example_call_preamble='.groupby("labels")', + definition_preamble=GROUPBY_PREAMBLE, + notes=_FLOX_GROUPBY_NOTES, +) +DATAARRAY_RESAMPLE_GENERATOR = GroupByAggregationGenerator( + cls="Resample", + datastructure=DATAARRAY_OBJECT, + methods=AGGREGATION_METHODS, + docref="resampling", + docref_description="resampling operations", + example_call_preamble='.resample(time="3ME")', + definition_preamble=RESAMPLE_PREAMBLE, + notes=_FLOX_RESAMPLE_NOTES, +) +DATASET_GROUPBY_GENERATOR = GroupByAggregationGenerator( + cls="GroupBy", + datastructure=DATASET_OBJECT, + methods=AGGREGATION_METHODS, + docref="groupby", + docref_description="groupby operations", + example_call_preamble='.groupby("labels")', + definition_preamble=GROUPBY_PREAMBLE, + notes=_FLOX_GROUPBY_NOTES, +) +DATASET_RESAMPLE_GENERATOR = GroupByAggregationGenerator( + cls="Resample", + datastructure=DATASET_OBJECT, + methods=AGGREGATION_METHODS, + docref="resampling", + docref_description="resampling operations", + example_call_preamble='.resample(time="3ME")', + definition_preamble=RESAMPLE_PREAMBLE, + notes=_FLOX_RESAMPLE_NOTES, +) + +NAMED_ARRAY_OBJECT = DataStructure( + name="NamedArray", + create_example=""" + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x",{example_array}, + ... )""", + example_var_name="na", + numeric_only=False, + see_also_modules=("Dataset", "DataArray"), +) + +NAMED_ARRAY_GENERATOR = GenericAggregationGenerator( + cls="", + datastructure=NAMED_ARRAY_OBJECT, + methods=AGGREGATION_METHODS, + docref="agg", + docref_description="reduction or aggregation operations", + example_call_preamble="", + definition_preamble=NAMED_ARRAY_AGGREGATIONS_PREAMBLE, + has_keep_attrs=False, +) + + +def write_methods(filepath, generators, preamble): + with open(filepath, mode="w", encoding="utf-8") as f: + f.write(preamble) + for gen in generators: + for lines in gen.generate_methods(): + for line in lines: + f.write(line + "\n") + + +if __name__ == "__main__": + import os + from pathlib import Path + + p = Path(os.getcwd()) + write_methods( + filepath=p.parent / "xarray" / "xarray" / "core" / "_aggregations.py", + generators=[ + DATASET_GENERATOR, + DATAARRAY_GENERATOR, + DATASET_GROUPBY_GENERATOR, + DATASET_RESAMPLE_GENERATOR, + DATAARRAY_GROUPBY_GENERATOR, + DATAARRAY_RESAMPLE_GENERATOR, + ], + preamble=MODULE_PREAMBLE, + ) + write_methods( + filepath=p.parent / "xarray" / "xarray" / "namedarray" / "_aggregations.py", + generators=[NAMED_ARRAY_GENERATOR], + preamble=NAMED_ARRAY_MODULE_PREAMBLE, + ) + # filepath = p.parent / "core" / "_aggregations.py" # Run from script location diff --git a/test/fixtures/whole_applications/xarray/xarray/util/generate_ops.py b/test/fixtures/whole_applications/xarray/xarray/util/generate_ops.py new file mode 100644 index 0000000..ee4dd68 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/util/generate_ops.py @@ -0,0 +1,294 @@ +"""Generate module and stub file for arithmetic operators of various xarray classes. + +For internal xarray development use only. + +Usage: + python xarray/util/generate_ops.py > xarray/core/_typed_ops.py + +""" + +# Note: the comments in https://github.com/pydata/xarray/pull/4904 provide some +# background to some of the design choices made here. + +from __future__ import annotations + +from collections.abc import Iterator, Sequence +from typing import Optional + +BINOPS_EQNE = (("__eq__", "nputils.array_eq"), ("__ne__", "nputils.array_ne")) +BINOPS_CMP = ( + ("__lt__", "operator.lt"), + ("__le__", "operator.le"), + ("__gt__", "operator.gt"), + ("__ge__", "operator.ge"), +) +BINOPS_NUM = ( + ("__add__", "operator.add"), + ("__sub__", "operator.sub"), + ("__mul__", "operator.mul"), + ("__pow__", "operator.pow"), + ("__truediv__", "operator.truediv"), + ("__floordiv__", "operator.floordiv"), + ("__mod__", "operator.mod"), + ("__and__", "operator.and_"), + ("__xor__", "operator.xor"), + ("__or__", "operator.or_"), + ("__lshift__", "operator.lshift"), + ("__rshift__", "operator.rshift"), +) +BINOPS_REFLEXIVE = ( + ("__radd__", "operator.add"), + ("__rsub__", "operator.sub"), + ("__rmul__", "operator.mul"), + ("__rpow__", "operator.pow"), + ("__rtruediv__", "operator.truediv"), + ("__rfloordiv__", "operator.floordiv"), + ("__rmod__", "operator.mod"), + ("__rand__", "operator.and_"), + ("__rxor__", "operator.xor"), + ("__ror__", "operator.or_"), +) +BINOPS_INPLACE = ( + ("__iadd__", "operator.iadd"), + ("__isub__", "operator.isub"), + ("__imul__", "operator.imul"), + ("__ipow__", "operator.ipow"), + ("__itruediv__", "operator.itruediv"), + ("__ifloordiv__", "operator.ifloordiv"), + ("__imod__", "operator.imod"), + ("__iand__", "operator.iand"), + ("__ixor__", "operator.ixor"), + ("__ior__", "operator.ior"), + ("__ilshift__", "operator.ilshift"), + ("__irshift__", "operator.irshift"), +) +UNARY_OPS = ( + ("__neg__", "operator.neg"), + ("__pos__", "operator.pos"), + ("__abs__", "operator.abs"), + ("__invert__", "operator.invert"), +) +# round method and numpy/pandas unary methods which don't modify the data shape, +# so the result should still be wrapped in an Variable/DataArray/Dataset +OTHER_UNARY_METHODS = ( + ("round", "ops.round_"), + ("argsort", "ops.argsort"), + ("conj", "ops.conj"), + ("conjugate", "ops.conjugate"), +) + + +required_method_binary = """ + def _binary_op( + self, other: {other_type}, f: Callable, reflexive: bool = False + ) -> {return_type}: + raise NotImplementedError""" +template_binop = """ + def {method}(self, other: {other_type}) -> {return_type}:{type_ignore} + return self._binary_op(other, {func})""" +template_binop_overload = """ + @overload{overload_type_ignore} + def {method}(self, other: {overload_type}) -> {overload_type}: + ... + + @overload + def {method}(self, other: {other_type}) -> {return_type}: + ... + + def {method}(self, other: {other_type}) -> {return_type} | {overload_type}:{type_ignore} + return self._binary_op(other, {func})""" +template_reflexive = """ + def {method}(self, other: {other_type}) -> {return_type}: + return self._binary_op(other, {func}, reflexive=True)""" + +required_method_inplace = """ + def _inplace_binary_op(self, other: {other_type}, f: Callable) -> Self: + raise NotImplementedError""" +template_inplace = """ + def {method}(self, other: {other_type}) -> Self:{type_ignore} + return self._inplace_binary_op(other, {func})""" + +required_method_unary = """ + def _unary_op(self, f: Callable, *args: Any, **kwargs: Any) -> Self: + raise NotImplementedError""" +template_unary = """ + def {method}(self) -> Self: + return self._unary_op({func})""" +template_other_unary = """ + def {method}(self, *args: Any, **kwargs: Any) -> Self: + return self._unary_op({func}, *args, **kwargs)""" +unhashable = """ + # When __eq__ is defined but __hash__ is not, then an object is unhashable, + # and it should be declared as follows: + __hash__: None # type:ignore[assignment]""" + +# For some methods we override return type `bool` defined by base class `object`. +# We need to add "# type: ignore[override]" +# Keep an eye out for: +# https://discuss.python.org/t/make-type-hints-for-eq-of-primitives-less-strict/34240 +# The type ignores might not be necessary anymore at some point. +# +# We require a "hack" to tell type checkers that e.g. Variable + DataArray = DataArray +# In reality this returns NotImplementes, but this is not a valid type in python 3.9. +# Therefore, we return DataArray. In reality this would call DataArray.__add__(Variable) +# TODO: change once python 3.10 is the minimum. +# +# Mypy seems to require that __iadd__ and __add__ have the same signature. +# This requires some extra type: ignores[misc] in the inplace methods :/ + + +def _type_ignore(ignore: str) -> str: + return f" # type:ignore[{ignore}]" if ignore else "" + + +FuncType = Sequence[tuple[Optional[str], Optional[str]]] +OpsType = tuple[FuncType, str, dict[str, str]] + + +def binops( + other_type: str, return_type: str = "Self", type_ignore_eq: str = "override" +) -> list[OpsType]: + extras = {"other_type": other_type, "return_type": return_type} + return [ + ([(None, None)], required_method_binary, extras), + (BINOPS_NUM + BINOPS_CMP, template_binop, extras | {"type_ignore": ""}), + ( + BINOPS_EQNE, + template_binop, + extras | {"type_ignore": _type_ignore(type_ignore_eq)}, + ), + ([(None, None)], unhashable, extras), + (BINOPS_REFLEXIVE, template_reflexive, extras), + ] + + +def binops_overload( + other_type: str, + overload_type: str, + return_type: str = "Self", + type_ignore_eq: str = "override", +) -> list[OpsType]: + extras = {"other_type": other_type, "return_type": return_type} + return [ + ([(None, None)], required_method_binary, extras), + ( + BINOPS_NUM + BINOPS_CMP, + template_binop_overload, + extras + | { + "overload_type": overload_type, + "type_ignore": "", + "overload_type_ignore": "", + }, + ), + ( + BINOPS_EQNE, + template_binop_overload, + extras + | { + "overload_type": overload_type, + "type_ignore": "", + "overload_type_ignore": _type_ignore(type_ignore_eq), + }, + ), + ([(None, None)], unhashable, extras), + (BINOPS_REFLEXIVE, template_reflexive, extras), + ] + + +def inplace(other_type: str, type_ignore: str = "") -> list[OpsType]: + extras = {"other_type": other_type} + return [ + ([(None, None)], required_method_inplace, extras), + ( + BINOPS_INPLACE, + template_inplace, + extras | {"type_ignore": _type_ignore(type_ignore)}, + ), + ] + + +def unops() -> list[OpsType]: + return [ + ([(None, None)], required_method_unary, {}), + (UNARY_OPS, template_unary, {}), + (OTHER_UNARY_METHODS, template_other_unary, {}), + ] + + +ops_info = {} +ops_info["DatasetOpsMixin"] = ( + binops(other_type="DsCompatible") + inplace(other_type="DsCompatible") + unops() +) +ops_info["DataArrayOpsMixin"] = ( + binops(other_type="DaCompatible") + inplace(other_type="DaCompatible") + unops() +) +ops_info["VariableOpsMixin"] = ( + binops_overload(other_type="VarCompatible", overload_type="T_DataArray") + + inplace(other_type="VarCompatible", type_ignore="misc") + + unops() +) +ops_info["DatasetGroupByOpsMixin"] = binops( + other_type="GroupByCompatible", return_type="Dataset" +) +ops_info["DataArrayGroupByOpsMixin"] = binops( + other_type="T_Xarray", return_type="T_Xarray" +) + +MODULE_PREAMBLE = '''\ +"""Mixin classes with arithmetic operators.""" +# This file was generated using xarray.util.generate_ops. Do not edit manually. + +from __future__ import annotations + +import operator +from typing import TYPE_CHECKING, Any, Callable, overload + +from xarray.core import nputils, ops +from xarray.core.types import ( + DaCompatible, + DsCompatible, + GroupByCompatible, + Self, + T_DataArray, + T_Xarray, + VarCompatible, +) + +if TYPE_CHECKING: + from xarray.core.dataset import Dataset''' + + +CLASS_PREAMBLE = """{newline} +class {cls_name}: + __slots__ = ()""" + +COPY_DOCSTRING = """\ + {method}.__doc__ = {func}.__doc__""" + + +def render(ops_info: dict[str, list[OpsType]]) -> Iterator[str]: + """Render the module or stub file.""" + yield MODULE_PREAMBLE + + for cls_name, method_blocks in ops_info.items(): + yield CLASS_PREAMBLE.format(cls_name=cls_name, newline="\n") + yield from _render_classbody(method_blocks) + + +def _render_classbody(method_blocks: list[OpsType]) -> Iterator[str]: + for method_func_pairs, template, extra in method_blocks: + if template: + for method, func in method_func_pairs: + yield template.format(method=method, func=func, **extra) + + yield "" + for method_func_pairs, *_ in method_blocks: + for method, func in method_func_pairs: + if method and func: + yield COPY_DOCSTRING.format(method=method, func=func) + + +if __name__ == "__main__": + for line in render(ops_info): + print(line) diff --git a/test/fixtures/whole_applications/xarray/xarray/util/print_versions.py b/test/fixtures/whole_applications/xarray/xarray/util/print_versions.py new file mode 100755 index 0000000..0b2e2b0 --- /dev/null +++ b/test/fixtures/whole_applications/xarray/xarray/util/print_versions.py @@ -0,0 +1,163 @@ +"""Utility functions for printing version information.""" + +import importlib +import locale +import os +import platform +import struct +import subprocess +import sys + + +def get_sys_info(): + """Returns system information as a dict""" + + blob = [] + + # get full commit hash + commit = None + if os.path.isdir(".git") and os.path.isdir("xarray"): + try: + pipe = subprocess.Popen( + 'git log --format="%H" -n 1'.split(" "), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + so, _ = pipe.communicate() + except Exception: + pass + else: + if pipe.returncode == 0: + commit = so + try: + commit = so.decode("utf-8") + except ValueError: + pass + commit = commit.strip().strip('"') + + blob.append(("commit", commit)) + + try: + (sysname, _nodename, release, _version, machine, processor) = platform.uname() + blob.extend( + [ + ("python", sys.version), + ("python-bits", struct.calcsize("P") * 8), + ("OS", f"{sysname}"), + ("OS-release", f"{release}"), + # ("Version", f"{version}"), + ("machine", f"{machine}"), + ("processor", f"{processor}"), + ("byteorder", f"{sys.byteorder}"), + ("LC_ALL", f'{os.environ.get("LC_ALL", "None")}'), + ("LANG", f'{os.environ.get("LANG", "None")}'), + ("LOCALE", f"{locale.getlocale()}"), + ] + ) + except Exception: + pass + + return blob + + +def netcdf_and_hdf5_versions(): + libhdf5_version = None + libnetcdf_version = None + try: + import netCDF4 + + libhdf5_version = netCDF4.__hdf5libversion__ + libnetcdf_version = netCDF4.__netcdf4libversion__ + except ImportError: + try: + import h5py + + libhdf5_version = h5py.version.hdf5_version + except ImportError: + pass + return [("libhdf5", libhdf5_version), ("libnetcdf", libnetcdf_version)] + + +def show_versions(file=sys.stdout): + """print the versions of xarray and its dependencies + + Parameters + ---------- + file : file-like, optional + print to the given file-like object. Defaults to sys.stdout. + """ + sys_info = get_sys_info() + + try: + sys_info.extend(netcdf_and_hdf5_versions()) + except Exception as e: + print(f"Error collecting netcdf / hdf5 version: {e}") + + deps = [ + # (MODULE_NAME, f(mod) -> mod version) + ("xarray", lambda mod: mod.__version__), + ("pandas", lambda mod: mod.__version__), + ("numpy", lambda mod: mod.__version__), + ("scipy", lambda mod: mod.__version__), + # xarray optionals + ("netCDF4", lambda mod: mod.__version__), + ("pydap", lambda mod: mod.__version__), + ("h5netcdf", lambda mod: mod.__version__), + ("h5py", lambda mod: mod.__version__), + ("zarr", lambda mod: mod.__version__), + ("cftime", lambda mod: mod.__version__), + ("nc_time_axis", lambda mod: mod.__version__), + ("iris", lambda mod: mod.__version__), + ("bottleneck", lambda mod: mod.__version__), + ("dask", lambda mod: mod.__version__), + ("distributed", lambda mod: mod.__version__), + ("matplotlib", lambda mod: mod.__version__), + ("cartopy", lambda mod: mod.__version__), + ("seaborn", lambda mod: mod.__version__), + ("numbagg", lambda mod: mod.__version__), + ("fsspec", lambda mod: mod.__version__), + ("cupy", lambda mod: mod.__version__), + ("pint", lambda mod: mod.__version__), + ("sparse", lambda mod: mod.__version__), + ("flox", lambda mod: mod.__version__), + ("numpy_groupies", lambda mod: mod.__version__), + # xarray setup/test + ("setuptools", lambda mod: mod.__version__), + ("pip", lambda mod: mod.__version__), + ("conda", lambda mod: mod.__version__), + ("pytest", lambda mod: mod.__version__), + ("mypy", lambda mod: importlib.metadata.version(mod.__name__)), + # Misc. + ("IPython", lambda mod: mod.__version__), + ("sphinx", lambda mod: mod.__version__), + ] + + deps_blob = [] + for modname, ver_f in deps: + try: + if modname in sys.modules: + mod = sys.modules[modname] + else: + mod = importlib.import_module(modname) + except Exception: + deps_blob.append((modname, None)) + else: + try: + ver = ver_f(mod) + deps_blob.append((modname, ver)) + except Exception: + deps_blob.append((modname, "installed")) + + print("\nINSTALLED VERSIONS", file=file) + print("------------------", file=file) + + for k, stat in sys_info: + print(f"{k}: {stat}", file=file) + + print("", file=file) + for k, stat in deps_blob: + print(f"{k}: {stat}", file=file) + + +if __name__ == "__main__": + show_versions() diff --git a/test/test_cli.py b/test/test_cli.py index 278aae0..d341cfb 100644 --- a/test/test_cli.py +++ b/test/test_cli.py @@ -9,23 +9,24 @@ def test_cli_help(cli_runner): result = cli_runner.invoke(app, ["--help"], env={"NO_COLOR": "1", "TERM": "dumb"}) assert result.exit_code == 0 -def test_cli_call_symbol_table_with_json(cli_runner, project_root): +def test_cli_call_symbol_table_with_json(cli_runner, whole_applications__xarray): """Must be able to run the CLI with symbol table analysis.""" - output_dir = project_root.joinpath("test", ".output") + output_dir = whole_applications__xarray.joinpath("test", ".output") output_dir.mkdir(parents=True, exist_ok=True) result = cli_runner.invoke( app, [ "--input", - str(project_root), + str(whole_applications__xarray), "--output", str(output_dir), "--analysis-level", "1", + "--ray", "--no-codeql", "--cache-dir", - str(project_root.joinpath("test", ".cache")), - "--keep-cache", + str(whole_applications__xarray.joinpath("test", ".cache")), + "--clear-cache", "--format=json", ], env={"NO_COLOR": "1", "TERM": "dumb"}, @@ -37,3 +38,40 @@ def test_cli_call_symbol_table_with_json(cli_runner, project_root): assert isinstance(json_obj, dict), "JSON output should be a dictionary" assert "symbol_table" in json_obj.keys(), "Symbol table should be present in the output" assert len(json_obj["symbol_table"]) > 0, "Symbol table should not be empty" + + +def test_single_file(cli_runner, single_functionalities__stuff_nested_in_functions): + """Must be able to run the CLI with single file analysis using --file-name flag.""" + output_dir = single_functionalities__stuff_nested_in_functions.joinpath(".output") + output_dir.mkdir(parents=True, exist_ok=True) + + # Path to the specific test file + test_file = single_functionalities__stuff_nested_in_functions.joinpath("main.py") + + result = cli_runner.invoke( + app, + [ + "--input", + str(single_functionalities__stuff_nested_in_functions), + "--file-name", + str(test_file), + "--no-ray", + "--clear-cache", + "-vv", + "--skip-tests", + "--output", + str(output_dir), + "--eager", + "--format=json", + ], + env={"NO_COLOR": "1", "TERM": "dumb"}, + ) + + assert result.exit_code == 0, f"CLI command should succeed. Output: {result.output}" + assert Path(output_dir).joinpath("analysis.json").exists(), "Output JSON file should be created" + + # Load and validate the JSON output + json_obj = json.loads(Path(output_dir).joinpath("analysis.json").read_text()) + assert json_obj is not None, "JSON output should not be None" + assert isinstance(json_obj, dict), "JSON output should be a dictionary" + assert "symbol_table" in json_obj.keys(), "Symbol table should be present in the output" \ No newline at end of file