Skip to content

Commit

Permalink
BUG: Fix closing files manually
Browse files Browse the repository at this point in the history
  • Loading branch information
snowman2 committed Nov 11, 2022
1 parent 49f2815 commit adb8d3b
Show file tree
Hide file tree
Showing 7 changed files with 219 additions and 180 deletions.
9 changes: 7 additions & 2 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,32 +26,37 @@ jobs:
docker_tests:
needs: linting
runs-on: ubuntu-latest
name: Docker | python=${{ matrix.python-version }} | rasterio${{ matrix.rasterio-version }} | scipy ${{ matrix.run-with-scipy }}
container: osgeo/gdal:ubuntu-full-3.4.2
name: Docker | GDAL=${{ matrix.gdal-version }} | python=${{ matrix.python-version }} | rasterio${{ matrix.rasterio-version }} | scipy ${{ matrix.run-with-scipy }}
container: osgeo/gdal:ubuntu-full-${{ matrix.gdal-version }}
strategy:
fail-fast: false
matrix:
python-version: ['3.8', '3.9', '3.10']
rasterio-version: ['']
xarray-version: ['']
run-with-scipy: ['YES']
gdal-version: ['3.5.3']
include:
- python-version: '3.8'
rasterio-version: ''
xarray-version: '==0.17'
run-with-scipy: 'YES'
gdal-version: '3.4.3'
- python-version: '3.8'
rasterio-version: '==1.1'
xarray-version: ''
run-with-scipy: 'YES'
gdal-version: '3.4.3'
- python-version: '3.8'
rasterio-version: '==1.2.1'
xarray-version: ''
run-with-scipy: 'YES'
gdal-version: '3.5.3'
- python-version: '3.9'
rasterio-version: ''
xarray-version: ''
run-with-scipy: 'NO'
gdal-version: '3.5.3'
steps:
- uses: actions/checkout@v3

Expand Down
1 change: 1 addition & 0 deletions docs/history.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ History

Latest
------
- BUG: Fix closing files manually (pull #607)

0.13.0
-------
Expand Down
71 changes: 46 additions & 25 deletions rioxarray/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import re
import threading
import warnings
from collections import defaultdict
from typing import Any, Dict, Hashable, Iterable, List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -737,6 +738,29 @@ def _pop_global_netcdf_attrs_from_vars(dataset_to_clean: Dataset) -> Dataset:
return dataset_to_clean


def _subdataset_groups_to_dataset(
dim_groups: Dict[Hashable, DataArray], global_tags: Dict
) -> Union[Dataset, List[Dataset]]:
if dim_groups:
dataset = []
for dim_group in dim_groups.values():
dataset_group = _pop_global_netcdf_attrs_from_vars(
Dataset(dim_group, attrs=global_tags)
)

def _ds_close():
# pylint: disable=cell-var-from-loop
for data_var in dim_group.values():
data_var.close()

dataset_group.set_close(_ds_close)
dataset.append(dataset_group)
if isinstance(dataset, list) and len(dataset) == 1:
dataset = dataset.pop()
else:
dataset = Dataset(attrs=global_tags)


def _load_subdatasets(
riods: RasterioReader,
group: Optional[Union[str, List[str], Tuple[str, ...]]],
Expand All @@ -754,8 +778,7 @@ def _load_subdatasets(
"""
Load in rasterio subdatasets
"""
global_tags = _parse_tags(riods.tags())
dim_groups = {}
dim_groups = defaultdict(dict)
subdataset_filter = None
if any((group, variable)):
subdataset_filter = build_subdataset_filter(group, variable)
Expand All @@ -777,23 +800,10 @@ def _load_subdatasets(
decode_timedelta=decode_timedelta,
**open_kwargs,
)
if shape not in dim_groups:
dim_groups[shape] = {rioda.name: rioda}
else:
dim_groups[shape][rioda.name] = rioda

if len(dim_groups) > 1:
dataset: Union[Dataset, List[Dataset]] = [
_pop_global_netcdf_attrs_from_vars(Dataset(dim_group, attrs=global_tags))
for dim_group in dim_groups.values()
]
elif not dim_groups:
dataset = Dataset(attrs=global_tags)
else:
dataset = _pop_global_netcdf_attrs_from_vars(
Dataset(list(dim_groups.values())[0], attrs=global_tags)
)
return dataset
dim_groups[shape][rioda.name] = rioda
return _subdataset_groups_to_dataset(
dim_groups=dim_groups, global_tags=_parse_tags(riods.tags())
)


def _load_bands_as_variables(
Expand Down Expand Up @@ -836,7 +846,14 @@ def _load_bands_as_variables(
.squeeze() # type: ignore
.drop("band") # type: ignore
)
return Dataset(data_vars, attrs=global_tags)
dataset = Dataset(data_vars, attrs=global_tags)

def _ds_close():
for data_var in data_vars.values():
data_var.close()

dataset.set_close(_ds_close)
return dataset


def _prepare_dask(
Expand Down Expand Up @@ -1070,7 +1087,7 @@ def open_rasterio(
captured_warnings = rio_warnings.copy()

if band_as_variable:
return _load_bands_as_variables(
result = _load_bands_as_variables(
riods=riods,
parse_coordinates=parse_coordinates,
chunks=chunks,
Expand All @@ -1082,6 +1099,8 @@ def open_rasterio(
decode_timedelta=decode_timedelta,
**open_kwargs,
)
manager.close()
return result

# raise the NotGeoreferencedWarning if applicable
for rio_warning in captured_warnings:
Expand All @@ -1092,7 +1111,7 @@ def open_rasterio(

# open the subdatasets if they exist
if riods.subdatasets:
return _load_subdatasets(
result = _load_subdatasets(
riods=riods,
group=group,
variable=variable,
Expand All @@ -1106,6 +1125,8 @@ def open_rasterio(
decode_timedelta=decode_timedelta,
**open_kwargs,
)
manager.close()
return result

if vrt_params is not None:
riods = WarpedVRT(riods, **vrt_params)
Expand Down Expand Up @@ -1204,9 +1225,6 @@ def open_rasterio(
if chunks is not None:
result = _prepare_dask(result, riods, filename, chunks)

# Make the file closeable
result.set_close(manager.close)
result.rio._manager = manager
# add file path to encoding
result.encoding["source"] = riods.name
result.encoding["rasterio_dtype"] = str(riods.dtypes[0])
Expand All @@ -1224,4 +1242,7 @@ def open_rasterio(
for attr, value in result.attrs.items()
if not attr.startswith(f"{result.name}#")
}
# Make the file closeable
result.set_close(manager.close)
result.rio._manager = manager
return result
4 changes: 3 additions & 1 deletion rioxarray/xarray_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ def open_dataset(
**open_kwargs,
)
if isinstance(rds, xr.DataArray):
rds = rds.to_dataset()
dataset = rds.to_dataset()
dataset.set_close(rds._close)
rds = dataset
if not isinstance(rds, xr.Dataset):
raise RioXarrayError(
"Multiple resolution sets found. "
Expand Down
3 changes: 3 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
import os

import pytest
import xarray
from numpy.testing import assert_almost_equal, assert_array_equal
from packaging import version

import rioxarray
from rioxarray.raster_array import UNWANTED_RIO_ATTRS

xarray.set_options(warn_for_unclosed_files=True)

TEST_DATA_DIR = os.path.join(os.path.dirname(__file__), "test_data")
TEST_INPUT_DATA_DIR = os.path.join(TEST_DATA_DIR, "input")
TEST_COMPARE_DATA_DIR = os.path.join(TEST_DATA_DIR, "compare")
Expand Down
Loading

0 comments on commit adb8d3b

Please sign in to comment.