Skip to content

Commit

Permalink
Added tests and development for concateanation
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielelanaro committed Aug 17, 2015
1 parent 671d7dc commit 2e46b9e
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 18 deletions.
98 changes: 85 additions & 13 deletions chemlab/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
"""
import numpy as np
from pandas.hashtable import Int64HashTable
import collections
from itertools import islice

# BASE CLASS
class EntityProperty(object):
Expand Down Expand Up @@ -54,17 +56,20 @@ class InstanceProperty(object):

class InstanceArray(InstanceProperty):

def empty(self, size):
def empty(self, size, inplace=True):
# If it is its own dimension, we need shape 1
if size == 0:
self.value = None
value = None
else:
if size > 0:
shape = (size,) + self.shape if self.shape else (size,)
self.value = np.zeros(shape, dtype=self.dtype)
else:
print 'Setting', self.name, 'to', None
self.value = None
shape = (size,) + self.shape if self.shape else (size,)
value = np.zeros(shape, dtype=self.dtype)

if inplace:
self.value = value
else:
obj = self.copy()
obj.value = value
return obj

def copy(self):
raise NotImplementedError()
Expand Down Expand Up @@ -124,14 +129,14 @@ def __init__(self, name, shape=None, dtype=None, dim=None, alias=None):

def copy(self):
obj = type(self)(self.name, self.shape, self.dtype, self.dim, self.alias)
obj.value = self.value.copy()
obj.value = self.value
return obj



class InstanceRelation(InstanceArray):

def __init__(self, name, map=None, dim=None, shape=None, alias=None):
def __init__(self, name, map=None, index=None, dim=None, shape=None, alias=None):
if not isinstance(dim, str):
raise ValueError('dim parameter is required and should be a string.')

Expand All @@ -141,28 +146,36 @@ def __init__(self, name, map=None, dim=None, shape=None, alias=None):
if not isinstance(name, str):
raise ValueError('name parameter should be a string')

if not isinstance(index, (list, np.ndarray)):
raise ValueError('index parameter should be an array-like object')

if shape is not None and not isinstance(shape, tuple):
raise ValueError('shape parameter should be a tuple')


self.name = name
self.map = map
self.index = index
self.dtype = 'int'
self.dim = dim
self.alias = alias
self.shape = shape
self.value = None

def copy(self):
obj = type(self)(self.name, self.map, self.dim, self.shape, self.alias)
obj = type(self)(self.name, self.map, self.index, self.dim, self.shape, self.alias)
obj.value = self.value
return obj

def remap(self, from_map, to_map, inplace=True):
if not isinstance(from_map, (list, np.ndarray)) or not isinstance(to_map, list):
raise ValueError('from_map and to_map should be either lists or arrays')

if self.value is None:
return #nothing to remap
# Nothing to remap
if inplace:
return
else:
return self.copy()

# Remap columns
hashtable = Int64HashTable()
Expand All @@ -187,6 +200,7 @@ def __init__(self, name, dtype=None, shape=None, alias=None):
self.alias = alias
self.shape = shape
self.value = None
self.empty()

def empty(self):
if self.shape is None:
Expand Down Expand Up @@ -274,6 +288,60 @@ def _from_entities(self, entities, newdim):
entity_dimensions=[e.dimensions for e in entities],
final_dimensions=self.dimensions)



def concatenate_relations(relations):
tpl = relations[0]

rel = tpl.copy()
rel.index = range(sum(len(r.index) for r in relations))

arrays = []
iterindex = iter(rel.index)
for r in relations:
# For a molecule e.index['atom'] = [0, 1, 2]
from_map = r.index
# we remap this to [3, 4, 5]
to_map = consume(iterindex, len(r.index))
if r.size == 0:
continue
arrays.append(r.remap(from_map, to_map, inplace=False).value)

if len(arrays) == 0:
rel.value = None
else:
rel.value = np.concatenate(arrays, axis=0)

return rel

def concatenate_attributes(attributes):
'''Concatenate InstanceAttribute to return a bigger one.'''
# We get a template
tpl = attributes[0]
attr = InstanceAttribute(tpl.name, tpl.shape,
tpl.dtype, tpl.dim, tpl.alias)
# Special case, not a single array has size bigger than 0
if all(a.size == 0 for a in attributes):
return attr
else:
attr.value = np.concatenate([a.value for a in attributes if a.size > 0], axis=0)
return attr

def concatenate_fields(fields, dim):
'Create an INstanceAttribute from a list of InstnaceFields'
if len(fields) == 0:
raise ValueError('fields cannot be an empty list')

if len(set((f.name, f.shape, f.dtype) for f in fields)) != 1:
raise ValueError('fields should have homogeneous name, shape and dtype')
tpl = fields[0]
attr = InstanceAttribute(tpl.name, shape=tpl.shape, dtype=tpl.dtype,
dim=dim, alias=tpl.alias)

attr.value = np.array([f.value for f in fields], dtype=tpl.dtype)
return attr


#TODO: move the utilities
def merge_dicts(*dict_args):
'''
Expand All @@ -284,3 +352,7 @@ def merge_dicts(*dict_args):
for dictionary in dict_args:
result.update(dictionary)
return result

def consume(iterator, n):
"Advance the iterator n-steps ahead. If n is none, consume entirely."
return list(islice(iterator, 0, n))
68 changes: 63 additions & 5 deletions tests/test_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from chemlab.core.base import (Attribute, InstanceAttribute,
Field, InstanceField,
Relation, InstanceRelation,
ChemicalEntity)
ChemicalEntity,
concatenate_fields,
concatenate_relations,
concatenate_attributes)
from nose.tools import eq_, ok_, assert_raises
import numpy as np
from .testtools import assert_npequal
Expand Down Expand Up @@ -47,8 +50,9 @@ def test_instance_attribute():


def test_instance_relation():
irel = InstanceRelation('bonds', map='atoms', dim='bonds', shape=(2,))
irel = InstanceRelation('bonds', map='atoms', index=range(3), dim='bonds', shape=(2,))
eq_(irel.size, 0)
assert_npequal(irel.index, [0, 1, 2])

# initialize as empty
irel.empty(2)
Expand All @@ -68,13 +72,67 @@ def test_instance_relation():

def test_instance_field():
ifield = InstanceField('mass')
ifield.empty()
eq_(ifield.value, 0.0)

ifield = InstanceField('r_array', shape=(3,))
ifield.empty()
assert_npequal(ifield.value, [0.0, 0.0, 0.0])

ifield = InstanceField('symbol', dtype='str')
ifield.empty()
eq_(ifield.value, '')

def test_concatenate_fields():
# Uninitialized fields
attr = concatenate_fields([InstanceField('mass'),
InstanceField('mass'),
InstanceField('mass')], 'atom')
eq_(attr.dim, 'atom')
eq_(attr.size, 3)
assert_npequal(attr.value, [0.0, 0.0, 0.0])

# Non uniform fields
assert_raises(ValueError, concatenate_fields, [InstanceField('mass'),
InstanceField('impostor'),
InstanceField('mass')],
'atom')

# Shape parameter
r_array = InstanceField('r_array', shape=(3,), dtype='f')
r_array.value = [0, 1, 2]

attr = concatenate_fields([r_array, r_array, r_array], 'atom')
assert_npequal(attr.value, [[0, 1, 2],[0, 1, 2],[0, 1, 2]])

box = InstanceField('box', shape=(3, 3), dtype='f')
box.value = np.eye(3)
attr = concatenate_fields([box, box, box], 'atom')
assert_npequal(attr.value, [np.eye(3),np.eye(3),np.eye(3)])

def test_concatenate_attributes():
a1 = InstanceAttribute('type_array', dim='atom', dtype='str')
newattr = concatenate_attributes([a1, a1, a1])
eq_(newattr.size, 0)

a2 = a1.empty(10, inplace=False)
newattr = concatenate_attributes([a1, a2, a2])
eq_(newattr.size, 20)

# Shape parameter
r_array = InstanceAttribute('r_array', shape=(3,), dtype='f', dim='atom')
r_array.value = [[0, 1, 2]]
newattr = concatenate_attributes([r_array, r_array, r_array])
eq_(newattr.size, 3)
assert_npequal(newattr.value, [[0, 1, 2], [0, 1, 2], [0, 1, 2]])


def test_concatenate_relations():
a1 = InstanceRelation('bonds', map='atom', index=range(3), shape=(2,), dim='bond')
newattr = concatenate_relations([a1, a1, a1])
eq_(newattr.size, 0)

a2 = a1.empty(2, inplace=False)
a3 = a2.copy()

newattr = concatenate_relations([a1, a2, a3])
eq_(newattr.size, 4)
assert_npequal(newattr.value, [[3, 3], [3, 3],
[6, 6], [6, 6]])

0 comments on commit 2e46b9e

Please sign in to comment.