Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for postgres ArrayField #472

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 5 additions & 1 deletion import_export/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from django.db.models.related import RelatedObject
ForeignObjectRel = RelatedObject
else:
from django.contrib.postgres.fields import ArrayField
from django.db.models.fields.related import ForeignObjectRel
RelatedObject = None

Expand Down Expand Up @@ -706,7 +707,7 @@ def widget_from_django_field(cls, f, default=widgets.Widget):
Django type.
"""
result = default
internal_type = f.get_internal_type()
internal_type = f.get_internal_type() if callable(getattr(f, "get_internal_type", None)) else ""
if internal_type in ('ManyToManyField', ):
result = functools.partial(widgets.ManyToManyWidget,
model=f.rel.to)
Expand All @@ -729,6 +730,9 @@ def widget_from_django_field(cls, f, default=widgets.Widget):
result = widgets.IntegerWidget
elif internal_type in ('BooleanField', 'NullBooleanField'):
result = widgets.BooleanWidget
elif VERSION >= (1, 8):
if type(f) == ArrayField:
return widgets.SimpleArrayWidget
return result

@classmethod
Expand Down
16 changes: 15 additions & 1 deletion import_export/widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from decimal import Decimal
from datetime import datetime
from django.utils import datetime_safe, timezone
from django.utils import datetime_safe, timezone, six
from django.utils.encoding import smart_text
from django.conf import settings

Expand Down Expand Up @@ -227,6 +227,20 @@ def render(self, value):
return value.strftime(self.formats[0])


class SimpleArrayWidget(Widget):
def __init__(self, separator=None):
if separator is None:
separator = ','
self.separator = separator
super(SimpleArrayWidget, self).__init__()

def clean(self, value):
return value.split(self.separator) if value else []

def render(self, value):
return self.separator.join(six.text_type(v) for v in value)


class ForeignKeyWidget(Widget):
"""
Widget for a ``ForeignKey`` field which looks up a related model using
Expand Down
36 changes: 36 additions & 0 deletions tests/core/migrations/0004_bookwithchapters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.9.5 on 2016-06-09 10:26
from __future__ import unicode_literals

import django.contrib.postgres.fields
from django.db import migrations, models


class PostgresOnlyCreateModel(migrations.CreateModel):
def database_forwards(self, app_label, schema_editor, from_state, to_state):
if schema_editor.connection.vendor.startswith("postgres"):
super(PostgresOnlyCreateModel, self).database_forwards(app_label, schema_editor, from_state, to_state)

def database_backwards(self, app_label, schema_editor, from_state, to_state):
if schema_editor.connection.vendor.startswith("postgres"):
super(PostgresOnlyCreateModel, self).database_backwards(app_label, schema_editor, from_state, to_state)


class Migration(migrations.Migration):

dependencies = [
('core', '0003_withfloatfield'),
]

operations = [
PostgresOnlyCreateModel(
name='BookWithChapters',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('name', models.CharField(max_length=100, verbose_name='Book name')),
('chapters',
django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=100), default=list,
size=None)),
],
),
]
28 changes: 28 additions & 0 deletions tests/core/tests/resources_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from decimal import Decimal
from unittest import skip, skipUnless

from django import VERSION
from django.conf import settings
from django.contrib.auth.models import User
from django.db import IntegrityError
Expand Down Expand Up @@ -715,6 +716,33 @@ def test_create_object_after_importing_dataset_with_id(self):
self.fail('IntegrityError was raised.')


if VERSION >= (1, 8) and 'postgresql' in settings.DATABASES['default']['ENGINE']:
from django.contrib.postgres.fields import ArrayField
from django.db import models

class BookWithChapters(models.Model):
name = models.CharField('Book name', max_length=100)
chapters = ArrayField(models.CharField(max_length=100), default=list)

class ArrayFieldTest(TestCase):
fixtures = []

def setUp(self):
pass

def test_arrayfield(self):
dataset_headers = ["id", "name", "chapters"]
chapters = ["Introduction", "Middle Chapter", "Ending"]
dataset_row = ["1", "Book With Chapters", ",".join(chapters)]
dataset = tablib.Dataset(headers=dataset_headers)
dataset.append(dataset_row)
book_with_chapters_resource = resources.modelresource_factory(model=BookWithChapters)()
result = book_with_chapters_resource.import_data(dataset, dry_run=False)
self.assertFalse(result.has_errors())
book_with_chapters = list(BookWithChapters.objects.all())[0]
self.assertListEqual(book_with_chapters.chapters, chapters)


class ManyRelatedManagerDiffTest(TestCase):
fixtures = ["category"]

Expand Down