From 836f62c9dc7d4eb5fc13c810f39fc29c8cf2f0e1 Mon Sep 17 00:00:00 2001 From: bkb2135 Date: Tue, 16 Jul 2024 12:28:02 +0000 Subject: [PATCH] Whole page summarization context --- prompting/conversation.py | 7 +++++-- prompting/tools/datasets/wiki.py | 15 ++++++++++----- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/prompting/conversation.py b/prompting/conversation.py index 3e13a9a4a..d909cac7f 100644 --- a/prompting/conversation.py +++ b/prompting/conversation.py @@ -1,6 +1,6 @@ import random from transformers import Pipeline -from prompting.tasks import Task, TASKS, TranslationPipeline, TranslationTask +from prompting.tasks import Task, TASKS, TranslationPipeline, TranslationTask, SummarizationTask from prompting.tools import Selector, DATASETS from prompting.task_registry import TASK_REGISTRY @@ -42,7 +42,10 @@ def create_task( if dataset is None: raise ValueError(f"Dataset {dataset_name} not found") else: - dataset = dataset() + if task_name == SummarizationTask.name: + dataset = dataset(selector='all') + else: + dataset = dataset() if task_name == TranslationTask.name: return task( diff --git a/prompting/tools/datasets/wiki.py b/prompting/tools/datasets/wiki.py index a125630aa..31ccda122 100644 --- a/prompting/tools/datasets/wiki.py +++ b/prompting/tools/datasets/wiki.py @@ -208,13 +208,18 @@ def get( ) if not sections: return None - - key = header, section_title = selector(list(sections.keys())) - content = "\n".join(sections[key]) - section_length = len(content.split()) + if selector == 'all': + content = "\n".join(["\n".join(section) for section in sections.values()]) + section_length = len(content.split()) + topic = "All Sections" + else: + key = header, section_title = selector(list(sections.keys())) + content = "\n".join(sections[key]) + section_length = len(content.split()) + topic = header or section_title context = { "title": name, # title of wiki article - "topic": header or section_title, # title of wiki section + "topic": topic, # title of wiki section "subtopic": section_title, "content": content, "internal_links": list(filter(lambda x: x not in exclude, page.sections)),