Skip to content

Commit

Permalink
when converting inserts, wrap fields in proxy to handle setting values
Browse files Browse the repository at this point in the history
  • Loading branch information
aburgel committed Mar 19, 2012
1 parent fe7d046 commit 657331c
Showing 1 changed file with 79 additions and 31 deletions.
110 changes: 79 additions & 31 deletions dbindexer/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,66 @@

OR = 'OR'

def get_target_value(field, start_model, field_chain, pk):
fields = field_chain.split('__')
foreign_key = start_model._meta.get_field(fields[0])

if not foreign_key.rel:
# field isn't a related one, so return the value itself
return pk

target_model = foreign_key.rel.to

foreignkey = target_model.objects.all().get(pk=pk)
for value in fields[1:-1]:
foreignkey = getattr(foreignkey, value)

if isinstance(foreignkey._meta.get_field(fields[-1]), models.ForeignKey):
return getattr(foreignkey, '%s_id' % fields[-1])
else:
return getattr(foreignkey, fields[-1])

class FieldProxy(object):
def __init__(self, target, lookup, column, follow_target=False):
object.__setattr__(self, '_target', target)
object.__setattr__(self, '_lookup', lookup)
object.__setattr__(self, '_column', column)
object.__setattr__(self, '_follow_target', follow_target)

def _proxy_get_db_prep_save(self, value, connection):
target = object.__getattribute__(self, '_target')
lookup = object.__getattribute__(self, '_lookup')
follow_target = object.__getattribute__(self, '_follow_target')

if follow_target:
real_value = target.get_db_prep_save(value, connection)

real_value = get_target_value(target, lookup.model, lookup.field_name, real_value)
else:
real_value = target.get_db_prep_save(value, connection)

return lookup.convert_value(real_value)

def __getattr__(self, name):
if name == 'column':
return object.__getattribute__(self, '_column')
if name == 'get_db_prep_save':
return self._proxy_get_db_prep_save

target = object.__getattribute__(self, '_target')
return getattr(target, name)

def __setattr__(self, name, value):
target = object.__getattribute__(self, '_target')
setattr(target, name, value)

def __repr__(self):
target = object.__getattribute__(self, '_target')
lookup = object.__getattribute__(self, '_lookup')
column = object.__getattribute__(self, '_column')
return "<FieldProxy: %s %s %s>" % (target, lookup, column)


# TODO: optimize code
class BaseResolver(object):
def __init__(self):
Expand Down Expand Up @@ -60,9 +120,13 @@ def _convert_insert_query(self, query, lookup):
if position is None:
return

value = self.get_value(lookup.model, lookup.field_name, query)
value = lookup.convert_value(value)
query.values[position] = (self.get_index(lookup), value)
field = self.get_field(lookup.model, lookup.field_name, query)
query.fields[position] = self.create_proxy_field(field, lookup, query.fields[position].column)

def create_proxy_field(self, field, lookup, column):
if isinstance(field, FieldProxy):
return field
return FieldProxy(field, lookup, column)

def convert_filters(self, query):
self._convert_filters(query, query.where)
Expand Down Expand Up @@ -113,12 +177,12 @@ def get_field_to_index(self, model, field_name):
except:
return None

def get_value(self, model, field_name, query):
def get_field(self, model, field_name, query):
field_to_index = self.get_field_to_index(model, field_name)
for query_field, value in query.values[:]:
for query_field in query.fields[:]:
if field_to_index == query_field:
return value
raise FieldDoesNotExist('Cannot find field in query.')
return query_field
raise FieldDoesNotExist('Cannot find field %s in query.' % field_name)

def add_column_to_name(self, model, field_name):
column_name = model._meta.get_field(field_name).column
Expand All @@ -128,7 +192,7 @@ def get_index(self, lookup):
return self.index_map[lookup]

def get_query_position(self, query, lookup):
for index, (field, query_value) in enumerate(query.values[:]):
for index, field in enumerate(query.fields[:]):
if field is self.get_index(lookup):
return index
return None
Expand Down Expand Up @@ -222,13 +286,10 @@ def get_field_to_index(self, model, field_name):
return super(ConstantFieldJOINResolver, self).get_field_to_index(model,
field_name)

def get_value(self, model, field_name, query):
value = super(ConstantFieldJOINResolver, self).get_value(model,
def get_field(self, model, field_name, query):
return super(ConstantFieldJOINResolver, self).get_field(model,
field_name.split('__')[0],
query)
if value is not None:
value = self.get_target_value(model, field_name, value)
return value

def get_field_chain(self, query, constraint):
if constraint.field is None:
Expand All @@ -237,31 +298,18 @@ def get_field_chain(self, query, constraint):
column_index = self.get_column_index(query, constraint)
return self.column_to_name.get(column_index)

def create_proxy_field(self, field, lookup, column):
if isinstance(field, FieldProxy):
return field
return FieldProxy(field, lookup, column, follow_target=True)

def get_model_chain(self, model, field_chain):
model_chain = [model, ]
for value in field_chain.split('__')[:-1]:
model = model._meta.get_field(value).rel.to
model_chain.append(model)
return model_chain

def get_target_value(self, start_model, field_chain, pk):
fields = field_chain.split('__')
foreign_key = start_model._meta.get_field(fields[0])

if not foreign_key.rel:
# field isn't a related one, so return the value itself
return pk

target_model = foreign_key.rel.to
foreignkey = target_model.objects.all().get(pk=pk)
for value in fields[1:-1]:
foreignkey = getattr(foreignkey, value)

if isinstance(foreignkey._meta.get_field(fields[-1]), models.ForeignKey):
return getattr(foreignkey, '%s_id' % fields[-1])
else:
return getattr(foreignkey, fields[-1])

def add_column_to_name(self, model, field_name):
model_chain = self.get_model_chain(model, field_name)
column_chain = ''
Expand Down

0 comments on commit 657331c

Please sign in to comment.