Skip to content

Commit

Permalink
Add Adapter.set_segments() and use it when refreshig the segment
Browse files Browse the repository at this point in the history
This makes the internal API a bit more consistent
  • Loading branch information
mvantellingen committed May 31, 2017
1 parent decfc88 commit f2aa887
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 13 deletions.
18 changes: 11 additions & 7 deletions src/wagtail_personalisation/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ def setup(self):
"""Prepare the adapter for segment storage."""
return None

def get_all_segments(self):
def get_segments(self):
"""Return the segments stored in the adapter storage."""
return None

def get_segment(self):
def get_segment_by_id(self):

This comment has been minimized.

Copy link
@kaedroho

kaedroho May 31, 2017

Shouldn't this take a segment_id argument?

This comment has been minimized.

Copy link
@mvantellingen

mvantellingen May 31, 2017

Author Member

yes, sorry, i'll add test for it too

"""Return a single segment stored in the adapter storage."""
return None

Expand Down Expand Up @@ -79,7 +79,7 @@ def __init__(self, request):
super(SessionSegmentsAdapter, self).__init__(request)
self.request.session.setdefault('segments', [])

def get_all_segments(self):
def get_segments(self):
"""Return the segments stored in the request session.
:returns: The segments in the request session
Expand All @@ -88,7 +88,11 @@ def get_all_segments(self):
"""
return self.request.session['segments']

def get_segment(self, segment_id):
def set_segments(self, segments):
"""Set the currently active segments"""
self.request.session['segments'] = segments

def get_segment_by_id(self, segment_id):
"""Find and return a single segment from the request session.
:param segment_id: The primary key of the segment
Expand Down Expand Up @@ -175,7 +179,8 @@ def refresh(self):
"""
enabled_segments = Segment.objects.filter(status=Segment.STATUS_ENABLED)
persistent_segments = enabled_segments.filter(persistent=True)
session_segments = self.request.session['segments']

session_segments = self.get_segments()
rules = AbstractBaseRule.__subclasses__()

# Create a list to store the new request session segments and
Expand All @@ -196,8 +201,7 @@ def refresh(self):
if not any(seg['id'] == segdict['id'] for seg in new_segments):
new_segments.append(segdict)

self.request.session['segments'] = new_segments

self.set_segments(new_segments)
self.update_visit_count()


Expand Down
2 changes: 1 addition & 1 deletion src/wagtail_personalisation/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def render(self, value, context=None):
"""
request = context['request']
adapter = get_segment_adapter(request)
user_segments = adapter.get_all_segments()
user_segments = adapter.get_segments()

if value['segment']:
for segment in user_segments:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from django import template
from django.template import TemplateSyntaxError

from django.utils.safestring import mark_safe

from wagtail_personalisation.adapters import get_segment_adapter
from wagtail_personalisation.models import Segment
from wagtail_personalisation.utils import parse_tag

Expand Down Expand Up @@ -48,11 +48,11 @@ def render(self, context):
return ""

# Check if user has segment
user_segment = context['request'].segment_adapter.get_segment(segment_id=segment.pk)
adapter = get_segment_adapter(context['request'])
user_segment = adapter.get_segment_by_id(segment_id=segment.pk)
if not user_segment:
return ""
return ''

content = self.nodelist.render(context)
content = mark_safe(content)

return content
2 changes: 1 addition & 1 deletion src/wagtail_personalisation/wagtail_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def serve_variation(page, request, serve_args, serve_kwargs):
user_segments = []
adapter = get_segment_adapter(request)

for segment in adapter.get_all_segments():
for segment in adapter.get_segments():
try:
user_segment = Segment.objects.get(
pk=segment['id'], status=Segment.STATUS_ENABLED)
Expand Down

0 comments on commit f2aa887

Please sign in to comment.