Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion beets/autotag/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def current_metadata(items):
'mb_albumid', 'label', 'catalognum', 'country', 'media',
'albumdisambig']
for key in fields:
values = [getattr(item, key) for item in items if item]
values = [item.get(key) for item in items if item]
likelies[key], freq = plurality(values)
consensus[key] = (freq == len(values))

Expand Down
200 changes: 109 additions & 91 deletions beets/dbcore/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@

import beets
from beets.util.functemplate import Template
from .query import MatchQuery, build_sql
from .types import BASE_TYPE
from .query import MatchQuery, TrueQuery, build_sql
from .types import STRING


COLUMN_DEFAULT = object()


class FormattedMapping(collections.Mapping):
Expand All @@ -43,34 +46,31 @@ def __init__(self, model, for_path=False):
self.model = model
self.model_keys = model.keys(True)

def __getitem__(self, key):
if key in self.model_keys:
return self._get_formatted(self.model, key)
else:
raise KeyError(key)

def __iter__(self):
return iter(self.model_keys)

def __len__(self):
return len(self.model_keys)

def get(self, key, default=None):
if default is None:
default = self.model._type(key).format(None)
return super(FormattedMapping, self).get(key, default)
def __getitem__(self, key):
value = self.model[key]
if value is None:
raise KeyError(key)

def _get_formatted(self, model, key):
value = model._type(key).format(model.get(key))
if isinstance(value, bytes):
value = value.decode('utf8', 'ignore')
value = self.model._type(key).format(value)
return self._path_replace(value)

def get(self, key, default=COLUMN_DEFAULT):
value = self.model.get(key, default)
value = self.model._type(key).format(value)
return self._path_replace(value)

def _path_replace(self, value):
if self.for_path:
sep_repl = beets.config['path_sep_replace'].get(unicode)
for sep in (os.path.sep, os.path.altsep):
if sep:
value = value.replace(sep, sep_repl)

return value


Expand Down Expand Up @@ -115,11 +115,6 @@ class Model(object):
keys are field names and the values are `Type` objects.
"""

_bytes_keys = ()
"""Keys whose values should be stored as raw bytes blobs rather than
strings.
"""

_search_fields = ()
"""The fields that should be queried by default by unqualified query
terms.
Expand Down Expand Up @@ -152,29 +147,27 @@ def __init__(self, db=None, **values):
"""
self._db = db
self._dirty = set()
self._values_fixed = {}
self._values_flex = {}
self._values_fixed = {}
for key in self._fields:
self._values_fixed[key] = None

# Initial contents.
self.update(values)
self.clear_dirty()

@classmethod
def _awaken(cls, db=None, fixed_values=None, flex_values=None):
def _awaken(cls, db=None, fixed_values={}, flex_values={}):
"""Create an object with values drawn from the database.

This is a performance optimization: the checks involved with
ordinary construction are bypassed.
"""
obj = cls(db)
if fixed_values:
for key, value in fixed_values.items():
obj._values_fixed[key] = cls._fields[key].normalize(value)
if flex_values:
for key, value in flex_values.items():
if key in cls._types:
value = cls._types[key].normalize(value)
obj._values_flex[key] = value
for key, value in fixed_values.iteritems():
obj._values_fixed[key] = cls._from_sql(key, value)
for key, value in flex_values.iteritems():
obj._values_flex[key] = cls._from_sql(key, value)
return obj

def __repr__(self):
Expand All @@ -199,33 +192,52 @@ def _check_db(self, need_id=True):
if need_id and not self.id:
raise ValueError('{0} has no id'.format(type(self).__name__))

# Essential field accessors.

@classmethod
def _type(self, key):
"""Get the type of a field, a `Type` instance.

If the field has no explicit type, it is given the base `Type`,
which does no conversion.
"""
return self._fields.get(key) or self._types.get(key) or BASE_TYPE
return self._fields.get(key) or self._types.get(key) or STRING

@classmethod
def _from_sql(cls, key, value):
if value is not None:
return cls._type(key).from_sql(value)

@classmethod
def _to_sql(cls, key, value):
if value is not None:
return cls._type(key).to_sql(value)

@classmethod
def parse(cls, key, string):
if not isinstance(string, unicode):
raise TypeError(u'{0!r} must be a unicode object')
return cls._type(key).parse(string)

# Essential field accessors.

def __getitem__(self, key):
"""Get the value for a field. Raise a KeyError if the field is
not available.
"""Get the value for a field.

Raise a KeyError if flex field is not available.
"""
getters = self._getters()
if key in getters: # Computed.
if key in getters:
return getters[key](self)
elif key in self._fields: # Fixed.
return self._values_fixed.get(key)
elif key in self._values_flex: # Flexible.
elif key in self._values_fixed:
return self._values_fixed[key]
elif key in self._values_flex:
return self._values_flex[key]
else:
raise KeyError(key)

def __setitem__(self, key, value):
"""Assign the value for a field.

Uses `Type.normalize()` to convert values other than `None`.
"""
# Choose where to place the value.
if key in self._fields:
Expand All @@ -234,7 +246,8 @@ def __setitem__(self, key, value):
source = self._values_flex

# If the field has a type, filter the value.
value = self._type(key).normalize(value)
if value is not None:
value = self._type(key).normalize(value)

# Assign value and possibly mark as dirty.
old_value = source.get(key)
Expand All @@ -243,22 +256,25 @@ def __setitem__(self, key, value):
self._dirty.add(key)

def __delitem__(self, key):
"""Remove a flexible attribute from the model.
"""Remove a flexible attribute from the model and set fixed
values to `None`.
"""
if key in self._values_flex: # Flexible.
del self._values_flex[key]
self._dirty.add(key) # Mark for dropping on store.
elif key in self._fields: # Fixed.
self._values_fixed[key] = None
self._dirty.add(key) # Mark for dropping on store.
elif key in self._getters(): # Computed.
raise KeyError('computed field {0} cannot be deleted'.format(key))
elif key in self._fields: # Fixed.
raise KeyError('fixed field {0} cannot be deleted'.format(key))
else:
raise KeyError('no such field {0}'.format(key))

def keys(self, computed=False):
"""Get a list of available field names for this object. The
`computed` parameter controls whether computed (plugin-provided)
fields are included in the key list.
"""Get a list of available field names for this object.

The `computed` parameter controls whether computed
(plugin-provided) fields are included in the key list.
"""
base_keys = list(self._fields) + self._values_flex.keys()
if computed:
Expand All @@ -276,20 +292,32 @@ def update(self, values):

def items(self):
"""Iterate over (key, value) pairs that this object contains.

Computed fields are not included.
"""
for key in self:
yield key, self[key]

def get(self, key, default=None):
def get(self, key, default=COLUMN_DEFAULT):
"""Get the value for a given key or `default` if it does not
exist.
"""
if key in self:
return self[key]
else:
if default == COLUMN_DEFAULT:
default = self._type(key).default
try:
value = self[key]
if value is None:
return default
else:
return value
except KeyError:
return default

def set(self, key, string):
"""Parse a string as a value for the given key.
"""
self[key] = self.parse(key, string)

def __contains__(self, key):
"""Determine whether `key` is an attribute on this object.
"""
Expand Down Expand Up @@ -326,39 +354,45 @@ def __delattr__(self, key):

# Database interaction (CRUD methods).

def store(self):
def store(self, all=False):
"""Save the object's metadata into the library database.

By default, updates only dirty fields. If `all` is `True`,
updates all fields.
"""
self._check_db()

# Build assignments for query.
assignments = ''
assignments = []
subvars = []
for key in self._fields:
if key != 'id' and key in self._dirty:
self._dirty.remove(key)
assignments += key + '=?,'
value = self[key]
# Wrap path strings in buffers so they get stored
# "in the raw".
if key in self._bytes_keys and isinstance(value, str):
value = buffer(value)
for key, value in self._values_fixed.iteritems():
if key == 'id':
continue
if key in self._dirty or all:
if key in self._dirty:
self._dirty.remove(key)

# None values are stored as NULL
if value is not None:
value = self._to_sql(key, value)
subvars.append(value)
assignments = assignments[:-1] # Knock off last ,
assignments.append(key + '=?')

with self._db.transaction() as tx:
# Main table update.
if assignments:
query = 'UPDATE {0} SET {1} WHERE id=?'.format(
self._table, assignments
self._table, ','.join(assignments)
)
subvars.append(self.id)
tx.mutate(query, subvars)

# Modified/added flexible attributes.
for key, value in self._values_flex.items():
if key in self._dirty:
self._dirty.remove(key)
if key in self._dirty or all:
if key in self._dirty:
self._dirty.remove(key)
value = self._to_sql(key, value)
tx.mutate(
'INSERT INTO {0} '
'(entity_id, key, value) '
Expand All @@ -380,11 +414,9 @@ def load(self):
"""Refresh the object's metadata from the library database.
"""
self._check_db()
stored_obj = self._db._get(type(self), self.id)
assert stored_obj is not None, "object {0} not in DB".format(self.id)
self._values_fixed = {}
self._values_flex = {}
self.update(dict(stored_obj))
stored = self._db._get(type(self), self.id)
self._values_fixed = stored._values_fixed
self._values_flex = stored._values_flex
self.clear_dirty()

def remove(self):
Expand Down Expand Up @@ -413,18 +445,15 @@ def add(self, db=None):
self._db = db
self._check_db(False)

if 'added' in self._fields and not self['added']:
self.added = time.time()

with self._db.transaction() as tx:
new_id = tx.mutate(
'INSERT INTO {0} DEFAULT VALUES'.format(self._table)
)
self.id = new_id
self.added = time.time()

# Mark every non-null field as dirty and store.
for key in self:
if self[key] is not None:
self._dirty.add(key)
self.store()
self.store(all=True)

# Formatting and templating.

Expand All @@ -447,17 +476,6 @@ def evaluate_template(self, template, for_path=False):
return template.substitute(self.formatted(for_path),
self._template_funcs())

# Parsing.

@classmethod
def _parse(cls, key, string):
"""Parse a string as a value for the given key.
"""
if not isinstance(string, basestring):
raise TypeError("_parse() argument must be a string")

return cls._type(key).parse(string)


# Database controller and supporting interfaces.

Expand Down Expand Up @@ -745,7 +763,7 @@ def _make_attribute_table(self, flex_table):

# Querying.

def _fetch(self, model_cls, query, sort_order=None):
def _fetch(self, model_cls, query=TrueQuery(), sort_order=None):
"""Fetch the objects of type `model_cls` matching the given
query. The query may be given as a string, string sequence, a
Query object, or None (to fetch everything). If provided,
Expand Down
5 changes: 4 additions & 1 deletion beets/dbcore/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,10 @@ def __init__(self, field, pattern, fast=True):
self.rangemax = self._convert(parts[1])

def match(self, item):
value = getattr(item, self.field)
value = item.get(self.field, default=None)
if value is None:
return False

if isinstance(value, basestring):
value = self._convert(value)

Expand Down
Loading