Skip to content

Commit

Permalink
Add support for foreignkey relations
Browse files Browse the repository at this point in the history
  • Loading branch information
bhch committed Oct 12, 2021
1 parent cf4484d commit 788363b
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 9 deletions.
7 changes: 6 additions & 1 deletion tornadmin/backends/forms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from wtforms import Form, fields
from wtforms.fields.html5 import DateField
from tornadmin.backends.widgets import Select


class BaseModelForm(Form):
Expand Down Expand Up @@ -27,4 +28,8 @@ class NullDateField(DateField):
def process_data(self, value):
super().process_data(value)
if self.data == '':
self.data = None
self.data = None


class SelectField(fields.SelectField):
widget = Select()
17 changes: 16 additions & 1 deletion tornadmin/backends/tortoise/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def __init__(self, model, **kwargs):
self.app_slug = slugify(self.app)

def get_list_headers(self):
# :TODO: cache this function
headers = []
for header in self.list_headers:
if isinstance(header, tuple) or isinstance(header, list):
Expand All @@ -30,7 +31,21 @@ async def get_list(self, request_handler, page_num):

paginator = Paginator(queryset, per_page=self.items_per_page, count=count)
page = paginator.get_page(page_num)
return (await page.objects.order_by('-id'), page)
page_queryset = page.objects

# fetch related fields which are also shown on list page table
related_fields = []
for header in self.get_list_headers():
field = header[0]
if field in self.model._meta.fk_fields:
related_fields.append(field)

if related_fields:
page_list = await page_queryset.prefetch_related(*related_fields)
else:
page_list = await page_queryset

return (page_list, page)

async def get_object(self, request_handler, id):
return await self.model.get(id=id)
Expand Down
49 changes: 44 additions & 5 deletions tornadmin/backends/tortoise/forms.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,28 @@
from wtforms import validators
from wtforms import fields
from tornadmin.backends.forms import BaseModelForm, NullDateField, NullDateTimeField
from tornadmin.backends.forms import (
BaseModelForm, NullDateField, NullDateTimeField, SelectField,
)
from tornadmin.utils.text import get_display_name


class ModelForm(BaseModelForm):
pass
async def set_field_choices(self, request_handler):
"""Sets choices foreignkey and manytomany fields"""
for field_name in self.Meta.model._meta.fk_fields:
form_field = getattr(self, '%s_id' % field_name)
choices = await self._get_field_choices(request_handler, field_name)
form_field.choices = choices

async def _get_field_choices(self, request_handler, field_name):
"""Returns choices for the given field name."""

if hasattr(self, 'get_%s_choices' % field_name):
return await getattr(self, 'get_%s_choices' % field_name)(request_handler)

model_field = self.Meta.model._meta.fields_map[field_name]
objects = await model_field.related_model.all()
return [(obj.id, str(obj)) for obj in objects]


TORTOISE_TO_WTF_MAP = {
Expand All @@ -18,10 +35,14 @@ class ModelForm(BaseModelForm):
'IntField': fields.IntegerField,
'SmallIntField': fields.IntegerField,
'TextField': fields.TextAreaField,
'ForeignKeyField': SelectField,
}


def tortoise_to_wtf(tortoise_field):
def tortoise_to_wtf(tortoise_field, is_fk=False):
if is_fk:
return TORTOISE_TO_WTF_MAP['ForeignKeyField']

return TORTOISE_TO_WTF_MAP.get(
type(tortoise_field).__name__,
fields.StringField
Expand All @@ -30,6 +51,10 @@ def tortoise_to_wtf(tortoise_field):

def modelform_factory(admin, model):
fields = {}
fk_id_fields = []

for field_name in model._meta.fk_fields:
fk_id_fields.append('%s_id' % field_name)

for field_name, model_field in model._meta.fields_map.items():

Expand All @@ -43,6 +68,12 @@ def modelform_factory(admin, model):
if getattr(model_field, 'auto_now_add', False):
continue

if field_name in fk_id_fields:
continue

if field_name in model._meta.backward_fk_fields:
continue

name = get_display_name(field_name)

validators_list = []
Expand All @@ -58,16 +89,24 @@ def modelform_factory(admin, model):
if field_name in admin.readonly_fields:
attrs['readonly'] = True

form_field = tortoise_to_wtf(model_field)
is_fk = field_name in model._meta.fk_fields

form_field = tortoise_to_wtf(model_field, is_fk=is_fk)

if is_fk:
# For foreignkeys, we'll render a select input
# with the "_id" appended to the name
field_name = '%s_id' % field_name

fields[field_name] = form_field(
name,
validators_list,
render_kw=attrs
render_kw=attrs,
)

fields['_fields'] = list(fields.keys())

form = type('%sForm' % model.__name__, (ModelForm,), fields)
form.Meta.model = model

return form
23 changes: 23 additions & 0 deletions tornadmin/backends/widgets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from wtforms import widgets
from wtforms.widgets import html_params
from markupsafe import escape, Markup


class Select(widgets.Select):
"""Overrides WTForm's Select widget to insert placeholder options"""

def __call__(self, field, **kwargs):
kwargs.setdefault('id', field.id)
if self.multiple:
kwargs['multiple'] = True
if 'required' not in kwargs and 'required' in getattr(field, 'flags', []):
kwargs['required'] = True
html = ['<select %s>' % html_params(name=field.name, **kwargs)]

html.append(self.render_option('', 'Select...', field.data in [None, ''], disabled=True))
html.append(self.render_option('', '----------', False, disabled=True))

for val, label, selected in field.iter_choices():
html.append(self.render_option(val, label, selected))
html.append('</select>')
return Markup(''.join(html))
10 changes: 8 additions & 2 deletions tornadmin/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,13 +185,15 @@ async def get(self, app_slug, model_slug):


class CreateHandler(BaseHandler):
def get(self, app_slug, model_slug):
async def get(self, app_slug, model_slug):
admin = self.admin_site.get_registered(app_slug, model_slug)
form_class = admin.get_form(self)
form = form_class()
await form.set_field_choices(self)
namespace = {
'obj': None,
'admin': admin,
'form': form_class(),
'form': form,
}
self.render('create.html', **namespace)

Expand All @@ -209,6 +211,8 @@ async def post(self, app_slug, model_slug,):

form = form_class(data=data)

await form.set_field_choices(self)

if form.validate():
obj = await admin.save_model(self, form)
self.redirect('admin:detail', app_slug, model_slug, obj.id)
Expand All @@ -231,6 +235,7 @@ async def get(self, app_slug, model_slug, id):
data = await admin.get_form_data(obj)

form = form_class(data=data)
await form.set_field_choices(self)

namespace = {
'obj': obj,
Expand All @@ -254,6 +259,7 @@ async def post(self, app_slug, model_slug, id):
data[field_name] = self.get_body_argument(field_name, None)

form = form_class(data=data)
await form.set_field_choices(self)

if form.validate():
await admin.save_model(self, form, obj)
Expand Down

0 comments on commit 788363b

Please sign in to comment.