Skip to content

Commit

Permalink
Start Dask branch
Browse files Browse the repository at this point in the history
  • Loading branch information
dmentipl committed Apr 13, 2020
1 parent 9a7851a commit 6566f8e
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 47 deletions.
9 changes: 5 additions & 4 deletions plonk/analysis/particles.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,8 +846,9 @@ def temperature(
if isinstance(cs, Quantity):
T = cs ** 2 / (gamma * specific_gas_constant)
else:
T = (cs * snap.units['velocity']) ** 2 / (gamma * specific_gas_constant)
T = (
cs ** 2
* (snap.units['velocity'] ** 2 / (gamma * specific_gas_constant)).magnitude
)

if isinstance(cs, Quantity):
return T
return T.magnitude
return T
6 changes: 4 additions & 2 deletions plonk/snap/readers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
_data_sources = ('Phantom',)


def load_snap(filename: Union[str, Path], data_source: str = 'Phantom') -> Snap:
def load_snap(
filename: Union[str, Path], data_source: str = 'Phantom', dask_arrays: bool = False
) -> Snap:
"""Load a snapshot from file.
Parameters
Expand All @@ -30,5 +32,5 @@ def load_snap(filename: Union[str, Path], data_source: str = 'Phantom') -> Snap:
)

if data_source == 'Phantom':
return read_phantom(filename)
return read_phantom(filename, dask_arrays)
raise ValueError('Cannot load snapshot')
82 changes: 52 additions & 30 deletions plonk/snap/readers/phantom.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path
from typing import Callable, Dict, List, Union

import dask.array as da
import h5py
import numpy as np
from numpy import ndarray
Expand Down Expand Up @@ -58,7 +59,9 @@
}


def generate_snap_from_file(filename: Union[str, Path]) -> Snap:
def generate_snap_from_file(
filename: Union[str, Path], dask_arrays: bool = False
) -> Snap:
"""Generate a Snap object from a Phantom HDF5 file.
Parameters
Expand All @@ -80,6 +83,7 @@ def generate_snap_from_file(filename: Union[str, Path]) -> Snap:
snap.data_source = 'Phantom'
snap.file_path = file_path
snap._file_pointer = file_handle
snap.dask_arrays = dask_arrays

header = {key: val[()] for key, val in file_handle['header'].items()}
snap.properties, snap.units = _header_to_properties(header)
Expand All @@ -90,13 +94,16 @@ def generate_snap_from_file(filename: Union[str, Path]) -> Snap:
array_registry = _populate_particle_array_registry(
arrays=arrays,
name_map=_particle_array_name_map,
dask_arrays=dask_arrays,
ndustsmall=ndustsmall,
ndustlarge=ndustlarge,
)
snap._array_registry.update(array_registry)

if header['nptmass'] > 0:
sink_registry = _populate_sink_array_registry(name_map=_sink_array_name_map)
sink_registry = _populate_sink_array_registry(
name_map=_sink_array_name_map, dask_arrays=dask_arrays
)
snap._sink_registry.update(sink_registry)

return snap
Expand Down Expand Up @@ -153,6 +160,7 @@ def _header_to_properties(header: dict):
def _populate_particle_array_registry(
arrays: List[str],
name_map: Dict[str, str],
dask_arrays: bool = False,
ndustsmall: int = 0,
ndustlarge: int = 0,
):
Expand All @@ -161,8 +169,8 @@ def _populate_particle_array_registry(

# Always read itype, xyz, h
array_registry['type'] = _particle_type
array_registry['position'] = _get_dataset('xyz', 'particles')
array_registry['smoothing_length'] = _get_dataset('h', 'particles')
array_registry['position'] = _get_dataset('xyz', 'particles', dask_arrays)
array_registry['smoothing_length'] = _get_dataset('h', 'particles', dask_arrays)
arrays.remove('itype')
arrays.remove('xyz')
arrays.remove('h')
Expand All @@ -174,15 +182,16 @@ def _populate_particle_array_registry(

elif ndustlarge > 0:
# Read dust type if there are dust particles
array_registry['sub_type'] = _dust_particle_type
array_registry['sub_type'] = _dust_sub_type
array_registry['stopping_time'] = _stopping_time
array_registry['dust_to_gas_ratio'] = _dust_to_gas_ratio
arrays.remove('dustfrac')
arrays.remove('tstop')

# Read arrays if available
for name_on_file, name in name_map.items():
if name_on_file in arrays:
array_registry[name] = _get_dataset(name_on_file, 'particles')
array_registry[name] = _get_dataset(name_on_file, 'particles', dask_arrays)
arrays.remove(name_on_file)

# Derived arrays not stored on file
Expand All @@ -193,56 +202,66 @@ def _populate_particle_array_registry(

# Read *any* extra arrays
for array in arrays:
array_registry[array] = _get_dataset(array, 'particles')
array_registry[array] = _get_dataset(array, 'particles', dask_arrays)

return array_registry


def _populate_sink_array_registry(name_map: Dict[str, str]):
def _populate_sink_array_registry(name_map: Dict[str, str], dask_arrays: bool = False):

sink_registry = dict()

for name_on_file, name in name_map.items():
sink_registry[name] = _get_dataset(name_on_file, 'sinks')
sink_registry[name] = _get_dataset(name_on_file, 'sinks', dask_arrays)

return sink_registry


def _get_dataset(dataset: str, group: str) -> Callable:
def func(snap: Snap) -> ndarray:
return snap._file_pointer[f'{group}/{dataset}'][()]
def _get_dataset(dataset: str, group: str, dask_arrays: bool = False) -> Callable:
if dask_arrays:

def array(_snap: Snap) -> ndarray:
return da.from_array(_snap._file_pointer[f'{group}/{dataset}'])

else:

def array(_snap: Snap) -> ndarray:
return _snap._file_pointer[f'{group}/{dataset}'][()]

return func
return array


def _particle_type(snap: Snap) -> ndarray:
idust = _get_dataset('idust', 'header')(snap)
particle_type = np.abs(_get_dataset('itype', 'particles')(snap))
particle_type = np.abs(_get_dataset('itype', 'particles', snap.dask_arrays)(snap))
particle_type[particle_type >= idust] = 2
return particle_type


def _dust_particle_type(snap: Snap) -> ndarray:
def _dust_sub_type(snap: Snap) -> ndarray:
idust = _get_dataset('idust', 'header')(snap)
particle_type = np.abs(_get_dataset('itype', 'particles')(snap))
sub_type = np.zeros(particle_type.shape, dtype=np.int8)
sub_type[particle_type >= idust] = particle_type[particle_type >= idust] - idust
particle_type = np.abs(_get_dataset('itype', 'particles', snap.dask_arrays)(snap))
if snap.dask_arrays:
sub_type = da.zeros(particle_type.shape, dtype=np.int8)
else:
sub_type = np.zeros(particle_type.shape, dtype=np.int8)
sub_type = particle_type - idust
sub_type[sub_type < 0] = 0
return sub_type


def _mass(snap: Snap) -> ndarray:
massoftype = _get_dataset('massoftype', 'header')(snap)
particle_type = _get_dataset('itype', 'particles')(snap)
massoftype = _get_dataset('massoftype', 'header', snap.dask_arrays)(snap)
particle_type = _get_dataset('itype', 'particles', snap.dask_arrays)(snap)
return massoftype[particle_type - 1]


def _density(snap: Snap) -> ndarray:
m = _mass(snap)
h = _get_dataset('h', 'particles')(snap)
h = _get_dataset('h', 'particles', snap.dask_arrays)(snap)
hfact = snap.properties['smoothing_length_factor']
rho = np.zeros(m.shape)
h = np.abs(h)
rho[h > 0] = m[h > 0] * (hfact / h[h > 0]) ** 3
rho = m * (hfact / h) ** 3
return rho


Expand All @@ -257,7 +276,7 @@ def _pressure(snap: Snap) -> ndarray:
return K * rho ** (gamma - 1)
if ieos == 3:
q = snap.properties['sound_speed_index']
pos: ndarray = _get_dataset('xyz', 'particles')(snap)
pos: ndarray = _get_dataset('xyz', 'particles', snap.dask_arrays)(snap)
r_squared = pos[:, 0] ** 2 + pos[:, 1] ** 2 + pos[:, 2] ** 2
return K * rho * r_squared ** (-q)
raise ValueError('Cannot determine equation of state')
Expand All @@ -268,26 +287,29 @@ def _sound_speed(snap: Snap) -> ndarray:
gamma = snap.properties['adiabatic_index']
rho = _density(snap)
P = _pressure(snap)
cs = np.zeros(P.shape)
if ieos in (1, 3):
cs[rho > 0] = np.sqrt(P[rho > 0] / rho[rho > 0])
cs = np.sqrt(P / rho)
elif ieos == 2:
cs[rho > 0] = np.sqrt(gamma * P[rho > 0] / rho[rho > 0])
cs = np.sqrt(gamma * P / rho)
else:
raise ValueError('Cannot determine equation of state')
if isinstance(cs, ndarray):
cs[cs == np.inf] = 0
elif isinstance(cs, da.Array):
cs = da.nan_to_num(cs)
return cs


def _stopping_time(snap: Snap) -> ndarray:
stopping_time = _get_dataset('tstop', 'particles')(snap)
stopping_time = _get_dataset('tstop', 'particles', snap.dask_arrays)(snap)
stopping_time[stopping_time == _bignumber] = np.inf
return stopping_time


def _dust_fraction(snap: Snap) -> ndarray:
if snap.properties['dust_method'] != 'dust/gas mixture':
raise ValueError('Dust fraction only available for "dust/gas mixture"')
dust_fraction = _get_dataset('dustfrac', 'particles')(snap)
dust_fraction = _get_dataset('dustfrac', 'particles', snap.dask_arrays)(snap)
return dust_fraction


Expand All @@ -296,5 +318,5 @@ def _dust_to_gas_ratio(snap: Snap) -> ndarray:
raise ValueError(
'Dust fraction only available for "dust as separate sets of particles"'
)
dust_to_gas_ratio = _get_dataset('dustfrac', 'particles')(snap)
dust_to_gas_ratio = _get_dataset('dustfrac', 'particles', snap.dask_arrays)(snap)
return dust_to_gas_ratio
32 changes: 25 additions & 7 deletions plonk/snap/snap.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union, cast

import dask.array as da
import h5py
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -284,6 +285,7 @@ def __init__(self):
self._extra_quantities = False
self._neighbours = None
self._tree = None
self.dask_arrays = False

def close_file(self):
"""Close access to underlying file."""
Expand Down Expand Up @@ -472,10 +474,16 @@ def physical_units(self) -> Snap:
raise ValueError(
'Physical units already set: snap.unset(units=True) to unset.'
)
for arr in self.loaded_arrays():
self._arrays[arr] = self._arrays[arr] * self.get_array_unit(arr)
for arr in self.loaded_arrays(sinks=True):
self._sinks[arr] = self._sinks[arr] * self.get_array_unit(arr)
if self.dask_arrays:
for arr in self.loaded_arrays():
del self._arrays[arr]
for arr in self.loaded_arrays(sinks=True):
del self._sinks[arr]
else:
for arr in self.loaded_arrays():
self._arrays[arr] = self._arrays[arr] * self.get_array_unit(arr)
for arr in self.loaded_arrays(sinks=True):
self._sinks[arr] = self._sinks[arr] * self.get_array_unit(arr)
self._physical_units = True

return self
Expand Down Expand Up @@ -504,7 +512,7 @@ def rotate(self, rotation: Union[ndarray, Rotation]) -> Snap:
>>> rot = rot * np.pi / 3 * np.linalg.norm(rot)
>>> snap.rotate(rot)
"""
if isinstance(rotation, ndarray):
if isinstance(rotation, (list, tuple, ndarray)):
rotation = Rotation.from_rotvec(rotation)
for arr in self._vector_arrays:
if arr in self.loaded_arrays():
Expand Down Expand Up @@ -578,14 +586,20 @@ def particle_indices(
is True, return a single array.
"""
if particle_type == 'dust' and not squeeze_subtype:
return [
ind = [
np.flatnonzero(
(self['type'] == self.particle_type['dust'])
& (self['sub_type'] == idx)
)
for idx in range(self.num_dust_species)
]
return np.flatnonzero(self['type'] == self.particle_type[particle_type])
if isinstance(ind[0], da.Array):
return [_ind.compute() for _ind in ind]
return ind
ind = np.flatnonzero(self['type'] == self.particle_type[particle_type])
if isinstance(ind, da.Array):
return ind.compute()
return ind

def set_kernel(self, kernel: str):
"""Set kernel.
Expand Down Expand Up @@ -1037,6 +1051,10 @@ def __delitem__(self, name):
"""Delete an array from memory."""
del self._arrays[name]

def _ipython_key_completions_(self):
"""IPython tab completion for __getitem__."""
return self.available_arrays()

def __len__(self):
"""Length as number of particles."""
return self.num_particles
Expand Down
8 changes: 6 additions & 2 deletions plonk/utils/math.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""NumPy functions with Pint support."""

import dask.array as da
import numpy as np

from .. import Quantity
Expand All @@ -21,8 +22,11 @@ def cross(x, y, **kwargs):
ndarray
The cross product of x and y.
"""
if isinstance(x, Quantity):
return np.cross(x.magnitude, y.magnitude, **kwargs) * x.units * y.units
if isinstance(x, da.Array):
result_x = x[:, 1] * y[:, 2] - x[:, 2] * y[:, 1]
result_y = x[:, 2] * y[:, 0] - x[:, 0] * y[:, 2]
result_z = x[:, 0] * y[:, 1] - x[:, 1] * y[:, 0]
return da.stack([result_x, result_y, result_z]).T
return np.cross(x, y, **kwargs)


Expand Down
8 changes: 6 additions & 2 deletions plonk/visualize/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,13 @@ def get_extent_from_percentile(
tuple
The extent of the box as (xmin, xmax, ymin, ymax).
"""
if snap.dask_arrays:
_x, _y = snap[x].compute(), snap[y].compute()
else:
_x, _y = snap[x], snap[y]
pl, pr = (100 - percentile) / 2, percentile + (100 - percentile) / 2
xlim = np.percentile(snap[x], [pl, pr])
ylim = np.percentile(snap[y], [pl, pr])
xlim = np.percentile(_x, [pl, pr])
ylim = np.percentile(_y, [pl, pr])

if x_center_on is not None:
xlim += x_center_on - xlim.mean()
Expand Down

0 comments on commit 6566f8e

Please sign in to comment.