-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
apps/ai_reports: connect with external API
- add a celery task which sends a comment to the XAI server and saves the response as AiReport - connect comment post_save signal with celery task - rename category to label to be in line with the XAI response - change AiReport fields to JSONField for now - add tests - **BREAKING CHANGE** Reset migrations for the ai_reports app (see changelog)
- Loading branch information
Showing
17 changed files
with
398 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -670,3 +670,4 @@ | |
}, | ||
}, | ||
} | ||
XAI_API_URL = "" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# Generated by Django 3.2.19 on 2024-06-03 14:43 | ||
|
||
from django.db import migrations, models | ||
|
||
|
||
class Migration(migrations.Migration): | ||
dependencies = [ | ||
("ai_reports", "0002_aireport_show_in_discussion"), | ||
] | ||
|
||
operations = [ | ||
migrations.AddField( | ||
model_name="aireport", | ||
name="confidence_json", | ||
field=models.JSONField(default=[]), | ||
preserve_default=False, | ||
), | ||
migrations.AddField( | ||
model_name="aireport", | ||
name="explanation_json", | ||
field=models.JSONField(default={}), | ||
preserve_default=False, | ||
), | ||
migrations.AddField( | ||
model_name="aireport", | ||
name="label", | ||
field=models.JSONField(default=[]), | ||
preserve_default=False, | ||
), | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Generated by Django 3.2.19 on 2024-06-03 14:18 | ||
import json | ||
|
||
from django.db import migrations | ||
|
||
|
||
def migrate_to_json(apps, schema_editor): | ||
AiReport = apps.get_model("ai_reports", "AiReport") | ||
|
||
for report in AiReport.objects.all(): | ||
report.label = [report.category] | ||
report.confidence_json = [report.confidence] | ||
report.explanation_json = {report.category: report.explanation} | ||
report.save() | ||
|
||
|
||
class Migration(migrations.Migration): | ||
dependencies = [ | ||
("ai_reports", "0003_auto_20240603_1643"), | ||
] | ||
|
||
operations = [ | ||
migrations.RunPython( | ||
code=migrate_to_json, reverse_code=migrations.RunPython.noop | ||
), | ||
migrations.RemoveField( | ||
model_name="aireport", | ||
name="category", | ||
), | ||
migrations.RemoveField( | ||
model_name="aireport", | ||
name="explanation", | ||
), | ||
migrations.RemoveField( | ||
model_name="aireport", | ||
name="confidence", | ||
), | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Generated by Django 3.2.19 on 2024-06-03 14:47 | ||
|
||
from django.db import migrations | ||
|
||
|
||
class Migration(migrations.Migration): | ||
dependencies = [ | ||
("ai_reports", "0004_migrate_fields_to_json"), | ||
] | ||
|
||
operations = [ | ||
migrations.RenameField( | ||
model_name="aireport", | ||
old_name="confidence_json", | ||
new_name="confidence", | ||
), | ||
migrations.RenameField( | ||
model_name="aireport", | ||
old_name="explanation_json", | ||
new_name="explanation", | ||
), | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,24 @@ | ||
import json | ||
|
||
from rest_framework import serializers | ||
|
||
from apps.ai_reports.models import AiReport | ||
|
||
|
||
class AiReportSerializer(serializers.ModelSerializer): | ||
explanation = serializers.SerializerMethodField() | ||
|
||
class Meta: | ||
model = AiReport | ||
fields = ( | ||
"category", | ||
"label", | ||
"confidence", | ||
"explanation", | ||
"is_pending", | ||
"comment", | ||
"show_in_discussion", | ||
) | ||
|
||
# FIXME: remove once frontend knows what to do with this | ||
def get_explanation(self, ai_report: AiReport) -> str: | ||
return json.dumps(ai_report.explanation) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import httpx | ||
from django.conf import settings | ||
from django.db.models import signals | ||
from django.dispatch import receiver | ||
|
||
from adhocracy4.comments.models import Comment | ||
from apps.ai_reports.tasks import get_classification_for_comment | ||
|
||
client = httpx.Client() | ||
|
||
|
||
@receiver(signals.post_save, sender=Comment) | ||
def get_ai_classification(sender, instance, created, update_fields, **kwargs): | ||
if getattr(settings, "XAI_API_URL"): | ||
comment_text_changed = getattr(instance, "_former_comment") != getattr( | ||
instance, "comment" | ||
) | ||
if created or comment_text_changed: | ||
# FIXME: use delay_on_commit() once updated to celery 5.x | ||
get_classification_for_comment.delay(instance.pk) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import ast | ||
|
||
import backoff | ||
import httpx | ||
from celery import shared_task | ||
from django.conf import settings | ||
|
||
from adhocracy4.comments.models import Comment | ||
from apps import logger | ||
from apps.ai_reports.models import AiReport | ||
|
||
client = httpx.Client() | ||
|
||
|
||
@shared_task | ||
def get_classification_for_comment(comment_pk: int) -> None: | ||
try: | ||
comment = Comment.objects.get(pk=comment_pk) | ||
response = call_ai_api(comment=comment.comment) | ||
if response.status_code == 200: | ||
extract_and_save_ai_classifications(comment=comment, report=response.json()) | ||
else: | ||
logger.error("Error: XAI server returned %s", response.status_code) | ||
except httpx.HTTPError as e: | ||
logger.error("Error connecting to %s: %s", settings.XAI_API_URL, str(e)) | ||
|
||
|
||
def skip_retry(e: Exception) -> bool: | ||
if isinstance(e, httpx.HTTPStatusError): | ||
return 400 <= e.response.status_code < 500 | ||
return False | ||
|
||
|
||
@backoff.on_exception( | ||
backoff.expo, httpx.HTTPError, max_tries=4, factor=2, giveup=skip_retry | ||
) | ||
def call_ai_api(comment: str) -> httpx.Response: | ||
response = client.post( | ||
settings.XAI_API_URL, | ||
json={"comment": comment}, | ||
headers={"Accept": "application/json", "Content-Type": "application/json"}, | ||
timeout=25.0, | ||
) | ||
response.raise_for_status() | ||
return response | ||
|
||
|
||
def extract_and_save_ai_classifications(comment: Comment, report: dict) -> None: | ||
# FIXME: the data returned from the api is not actually valid json, so we need | ||
# to use ast to explicitly convert it. This should be fixed on their side. | ||
confidence = ast.literal_eval((report["confidence"])) | ||
label = ast.literal_eval(report["label"]) | ||
explanation = ast.literal_eval(report["explanation"]) | ||
|
||
ai_report = AiReport( | ||
comment=comment, confidence=confidence, label=label, explanation=explanation | ||
) | ||
ai_report.save() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
### Added | ||
|
||
- add celery task which sends a comment to the xai and stores the response as | ||
AiReport | ||
- add comment post_save signal which connects comment creation / editing with | ||
the xai celery task | ||
- added tests | ||
- add tests | ||
- add pytest-mock package to dev requirements | ||
- added backup and httpx to fork requirements | ||
|
||
### Changed | ||
|
||
- rename category to label to be in line with the XAI response | ||
- change AiReport fields to JSONField for now | ||
- add manual migrations for the ai_reports app to migrate fields which can't | ||
automatically be converted to JSONFields |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,3 @@ | ||
# requirements needed in this fork, but not a+ | ||
backoff==2.2.1 | ||
httpx==0.27.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,8 @@ | ||
from pytest_factoryboy import register | ||
|
||
from tests.ideas import factories as idea_factories | ||
|
||
from .factories import AiReportFactory | ||
|
||
register(AiReportFactory) | ||
register(idea_factories.IdeaFactory) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.