Skip to content

Commit

Permalink
Convert to xarray accessor
Browse files Browse the repository at this point in the history
- add new open_dataset() method
- override xr.Dataset.__getitem__
- minimum changes to pass tests
  • Loading branch information
khaeru committed Sep 27, 2016
1 parent 7484b72 commit ac8b3d2
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 107 deletions.
184 changes: 104 additions & 80 deletions gdx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,29 @@
unicode_literals)
from itertools import cycle
from logging import debug, info
# commented: for debugging
# # commented: for debugging
# import logging
# logging.basicConfig(level=logging.DEBUG)

import numpy
import numpy as np
import pandas
import xarray as xr

from .pycompat import install_aliases, filter, raise_from, range, super, zip
install_aliases()
from xarray.core.utils import is_dict_like, hashable

from .api import GDX, gdxcc, type_str, vartype_str
from .pycompat import install_aliases, filter, range, zip
install_aliases()


__version__ = '2'
__version__ = '4-dev'


__all__ = [
'File',
'open_dataset',
]


class File(xr.Dataset):
def open_dataset(filename, lazy=True, implicit=True, skip=set()):
"""Load the file at *filename* into memory.
If *lazy* is ``True`` (default), then the data for GDX Parameters is not
Expand All @@ -47,28 +47,61 @@ class File(xr.Dataset):
otherwise, loading ``foo`` as declared raises :py:class:`MemoryError`.
"""
# For the benefit of xr.Dataset.__getattr__
_api = None
_index = []
_state = {}
_alias = {}
_implicit = False

def __init__(self, filename='', lazy=True, implicit=True, skip=set()):
"""Constructor."""
super(File, self).__init__() # Invoke Dataset constructor

# load the GDX API
ds = xr.Dataset()
ds.gdx._initialize(filename, lazy, implicit, skip)

return ds


# Override xarray.Dataset.__getitem__ to add GDX lazy-loading
def _dataset_getitem(self, key):
"""DERP Access variables or coordinates this dataset as a
:py:class:`~xarray.DataArray`.
Indexing with a list of names will return a new ``Dataset`` object.
"""
if is_dict_like(key):
return self.isel(**key)

# GDX lazy-loading
self.gdx._lazy_load(key)

if hashable(key):
return self._construct_dataarray(key)
else:
return self._copy_listed(np.asarray(key))


xr.Dataset.__getitem__ = _dataset_getitem


@xr.register_dataset_accessor('gdx')
class GDXAccessor(object):
def __init__(self, xarray_obj):
self._obj = xarray_obj
self._initialized = False

def _lazy_load(self, key):
if not self._initialized:
return
keys = [key] if hashable(key) else key

for k in keys:
if k in self._state and isinstance(self._state[k], dict):
debug('Lazy-loading {}'.format(k))
self._load_symbol_data(k)

def _initialize(self, filename, lazy, implicit, skip):
self._api = GDX()
self._api.open_read(str(filename))

# Basic information about the GDX file
v, p = self._api.file_version()
sc, ec = self._api.system_info()
self.attrs['version'] = v.strip()
self.attrs['producer'] = p.strip()
self.attrs['symbol_count'] = sc
self.attrs['element_count'] = ec
self._obj.attrs['version'] = v.strip()
self._obj.attrs['producer'] = p.strip()
self._obj.attrs['symbol_count'] = sc
self._obj.attrs['element_count'] = ec

# Initialize private variables
self._index = [None for _ in range(sc + 1)]
Expand All @@ -87,6 +120,8 @@ def __init__(self, filename='', lazy=True, implicit=True, skip=set()):
if name not in skip:
self._load_symbol_data(name)

self._initialized = True

def _load_symbol(self, index):
"""Load the *index*-th Symbol in the GDX file."""
# Load basic information
Expand Down Expand Up @@ -129,11 +164,12 @@ def _load_symbol(self, index):
elif type_code == gdxcc.GMS_DT_ALIAS:
parent = desc.replace('Aliased with ', '')
self._alias[name] = parent
assert self[parent].attrs['_gdx_type_code'] == gdxcc.GMS_DT_SET
assert (self._obj[parent].attrs['_gdx_type_code'] ==
gdxcc.GMS_DT_SET)
# Duplicate the variable
self._variables[name] = self._variables[parent]
self._obj._variables[name] = self._obj._variables[parent]
self._state[name] = True
super(File, self).set_coords(name, inplace=True)
self._obj.set_coords(name, inplace=True)
return name, type_code

# The Symbol is either a Set, Parameter or Variable
Expand Down Expand Up @@ -233,7 +269,7 @@ def _infer_domain(self, name, domain, elements):
debug('guessing a better domain for {}: {}'.format(name, domain))

# Domain as a list of references to Variables in the File/xr.Dataset
domain_ = [self[d] for d in domain]
domain_ = [self._obj[d] for d in domain]

for i, d in enumerate(domain_): # Iterate over dimensions
e = set(elements[i])
Expand All @@ -246,13 +282,13 @@ def _infer_domain(self, name, domain, elements):
d = '_{}_{}'.format(name, i)
debug(('Constructing implicit set {} for dimension {} of {}\n'
' {} instead of {} elements')
.format(d, name, i, len(e), len(self['*'])))
self.coords[d] = elements[i]
d = self[d]
.format(d, name, i, len(e), len(self._obj['*'])))
self._obj.coords[d] = elements[i]
d = self._obj[d]
else:
# try to find a smaller domain for this dimension
# Iterate over every Set/Coordinate
for s in self.coords.values():
for s in self._obj.coords.values():
if s.ndim == 1 and set(s.values).issuperset(e) and \
len(s) < len(d):
d = s # Found a smaller Set; use this instead
Expand All @@ -272,19 +308,19 @@ def _infer_domain(self, name, domain, elements):

def _root_dim(self, dim):
"""Return the ultimate ancestor of the 1-D Set *dim*."""
parent = self[dim].dims[0]
parent = self._obj[dim].dims[0]
return dim if parent == dim else self._root_dim(parent)

def _empty(self, *dims, **kwargs):
"""Return an empty numpy.ndarray for a GAMS Set or Parameter."""
size = []
dtypes = []
for d in dims:
size.append(len(self[d]))
dtypes.append(self[d].dtype)
dtype = kwargs.pop('dtype', numpy.result_type(*dtypes))
size.append(len(self._obj[d]))
dtypes.append(self._obj[d].dtype)
dtype = kwargs.pop('dtype', np.result_type(*dtypes))
fv = kwargs.pop('fill_value')
return numpy.full(size, fill_value=fv, dtype=dtype)
return np.full(size, fill_value=fv, dtype=dtype)

def _add_symbol(self, name, dim, domain, attrs):
"""Add a xray.DataArray with the data from Symbol *name*."""
Expand All @@ -300,14 +336,13 @@ def _add_symbol(self, name, dim, domain, attrs):
kwargs = {} # Arguments to xr.Dataset.__setitem__()
if dim == 0:
# 0-D Variable or scalar Parameter
super(File, self).__setitem__(name, ([], data.popitem()[1],
gdx_attrs))
self._obj.__setitem__(name, ([], data.popitem()[1], gdx_attrs))
return
elif attrs['type_code'] == gdxcc.GMS_DT_SET: # GAMS Set
if dim == 1:
# One-dimensional Set
self.coords[name] = elements[0]
self.coords[name].attrs = gdx_attrs
self._obj.coords[name] = elements[0]
self._obj.coords[name].attrs = gdx_attrs
else:
# Multi-dimensional Sets are mappings indexed by other Sets;
# elements are either 'on'/True or 'off'/False
Expand All @@ -319,47 +354,48 @@ def _add_symbol(self, name, dim, domain, attrs):
dims = [self._root_dim(d) for d in domain]

# Update coords
self.coords.__setitem__(name, (dims, self._empty(*domain,
**kwargs),
gdx_attrs))
self._obj.coords.__setitem__(name, (dims, self._empty(*domain,
**kwargs), gdx_attrs))

# Store the elements
for k in data.keys():
self[name].loc[k] = k if dim == 1 else True
self._obj[name].loc[k] = k if dim == 1 else True
else: # 1+-dimensional GAMS Parameters
kwargs['dtype'] = float
kwargs['fill_value'] = numpy.nan
kwargs['fill_value'] = np.nan

dims = [self._root_dim(d) for d in domain] # Same as above

# Create an empty xr.DataArray; this ensures that the data
# read in below has the proper form and indices
super(File, self).__setitem__(name, (dims, self._empty(*domain,
**kwargs),
gdx_attrs))
self._obj.__setitem__(name, (dims, self._empty(*domain, **kwargs),
gdx_attrs))

# Fill in extra keys
longest = numpy.argmax(self[name].values.shape)
longest = np.argmax(self._obj[name].values.shape)
iters = []
for i, d in enumerate(dims):
if i == longest:
iters.append(self[d].to_index())
iters.append(self._obj[d].to_index())
else:
iters.append(cycle(self[d].to_index()))
data.update({k: numpy.nan for k in set(zip(*iters)) -
iters.append(cycle(self._obj[d].to_index()))
data.update({k: np.nan for k in set(zip(*iters)) -
set(data.keys())})

# Use pandas and xarray IO methods to convert data, a dict, to a
# xr.DataArray of the correct shape, then extract its values
tmp = pandas.Series(data)
tmp.index.names = dims
tmp = xr.DataArray.from_series(tmp).reindex_like(self[name])
self[name].values = tmp.values
tmp = xr.DataArray.from_series(tmp).reindex_like(self._obj[name])
self._obj[name].values = tmp.values

def dealias(self, name):
"""Identify the GDX Symbol that *name* refers to, and return the
corresponding :py:class:`xarray.DataArray`."""
return self[self._alias[name]] if name in self._alias else self[name]
if name in self._alias:
return self._obj[self._alias[name]]
else:
return self._obj[name]

def extract(self, name):
"""Extract the GAMS Symbol *name* from the dataset.
Expand All @@ -374,7 +410,7 @@ def extract(self, name):
dimensions), which does not make reference to the :class:`File`.
"""
# Copy the Symbol, triggering lazy-loading if needed
result = self[name].copy()
result = self._obj[name].copy()

# Declared dimensions of the Symbol, and their parents
try:
Expand All @@ -397,7 +433,7 @@ def extract(self, name):
# Dimension is indexed by 'p', but declared 'c'. First drop
# the elements which do not appear in the sub-Set c;, then
# rename 'p' to 'c'
drop = set(self[p].values) - set(self[c].values) - set('')
drop = set(self._obj[p].values) - set(self._obj[c].values)
result = result.drop(drop, dim=p).swap_dims({p: c})
# Add the old coord to the set of coords to drop
drop_coords.add(p)
Expand All @@ -412,14 +448,14 @@ def info(self, name):
attrs['type_str'], name, ','.join(attrs['domain']),
attrs['records'], attrs['description'])
else:
return repr(self[name])
return repr(self._obj[name])

def _loaded_and_cached(self, type_code):
"""Return a list of loaded and not-loaded Symbols of *type_code*."""
names = set()
for name, state in self._state.items():
if state is True:
tc = self._variables[name].attrs['_gdx_type_code']
tc = self._obj._variables[name].attrs['_gdx_type_code']
elif isinstance(state, dict):
tc = state['attrs']['type_code']
else: # pragma: no cover
Expand All @@ -437,19 +473,19 @@ def set(self, name, as_dict=False):
:func:`set()` returns the elements without these placeholders.
"""
assert self[name].attrs['_gdx_type_code'] == gdxcc.GMS_DT_SET, \
assert self._obj[name].attrs['_gdx_type_code'] == gdxcc.GMS_DT_SET, \
'Variable {} is not a GAMS Set'.format(name)
if len(self[name].dims) > 1:
return self[name]
if len(self._obj[name].dims) > 1:
return self._obj[name]
elif as_dict:
from collections import OrderedDict
result = OrderedDict()
parent = self[name].attrs['_gdx_domain'][0]
for label in self[parent].values:
result[label] = label in self[name].values
parent = self._obj[name].attrs['_gdx_domain'][0]
for label in self._obj[parent].values:
result._obj[label] = label in self[name].values
return result
else:
return list(self[name].values)
return list(self._obj[name].values)

def sets(self):
"""Return a list of all GDX Sets."""
Expand All @@ -462,16 +498,4 @@ def parameters(self):
def get_symbol_by_index(self, index):
"""Retrieve the GAMS Symbol from the *index*-th position of the
:class:`File`."""
return self[self._index[index]]

def __getitem__(self, key):
"""Set element access."""
try:
return super(File, self).__getitem__(key)
except KeyError as e:
if isinstance(self._state[key], dict):
debug('Lazy-loading {}'.format(key))
self._load_symbol_data(key)
return super(File, self).__getitem__(key)
else:
raise raise_from(KeyError(key), e)
return self._obj[self._index[index]]
3 changes: 1 addition & 2 deletions gdx/pycompat.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import sys

from builtins import filter, range, object, super, zip
from builtins import filter, range, object, zip
from future.standard_library import install_aliases
from future.utils import raise_from

PY3 = sys.version_info[0] >= 3

Expand Down
Loading

0 comments on commit ac8b3d2

Please sign in to comment.