diff --git a/api/main.py b/api/main.py index b6b3fdc..6a319ca 100644 --- a/api/main.py +++ b/api/main.py @@ -1,12 +1,14 @@ from contextlib import asynccontextmanager import os +from contextlib import asynccontextmanager from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware + from api.routes import templates, forms from api.db.init_db import init_db from api.errors.handlers import register_exception_handlers -from fastapi.middleware.cors import CORSMiddleware -from api.routes import forms, templates + @asynccontextmanager async def lifespan(app: FastAPI): @@ -34,6 +36,7 @@ async def lifespan(app: FastAPI): allow_methods=["*"], allow_headers=["*"], ) +register_exception_handlers(app) app.include_router(templates.router) app.include_router(forms.router) diff --git a/api/schemas/forms.py b/api/schemas/forms.py index 3cce650..bf6957e 100644 --- a/api/schemas/forms.py +++ b/api/schemas/forms.py @@ -1,9 +1,15 @@ -from pydantic import BaseModel +from pydantic import BaseModel, field_validator class FormFill(BaseModel): template_id: int input_text: str + @field_validator("input_text") + def validate_input_text(cls, value): + if not value or not value.strip(): + raise ValueError("Input text cannot be empty") + return value + class FormFillResponse(BaseModel): id: int diff --git a/api/services/prompt_builder.py b/api/services/prompt_builder.py new file mode 100644 index 0000000..d4b20bd --- /dev/null +++ b/api/services/prompt_builder.py @@ -0,0 +1,85 @@ +def build_extraction_prompt(input_text: str) -> str: + return f""" +You are an AI system that extracts structured information from incident reports. +Your task is to extract ONLY information explicitly present in the input text. + +STRICT RULES: +- Do NOT infer or guess missing information +- If a field is not clearly mentioned, return an empty string "" +- Do NOT add any extra fields beyond those specified +- Do NOT modify or reinterpret values + +Extract the following fields: +- name +- location +- date (YYYY-MM-DD if possible) +- incident_type +- description + +Return ONLY valid JSON. Do not include any extra text, explanation, or formatting outside JSON. +The output MUST be a valid JSON object and parsable by json.loads(). +Format: +{{ + "name": "", + "location": "", + "date": "", + "incident_type": "", + "description": "" +}} + +Example: + +Input: +Fire reported near Central Park on Jan 5 involving a vehicle. + +Output: +{{ + "name": "", + "location": "Central Park", + "date": "2024-01-05", + "incident_type": "fire", + "description": "Fire involving a vehicle" +}} + +Negative Example (DO NOT DO THIS): + +Incorrect Output: +(This output is incorrect because it includes inferred/assumed values) +{{ + "location": "Central Park (assumed)", + "date": "2024-01-05" +}} + +Correct Output: +{{ + "location": "Central Park", + "date": "" +}} + +Now extract strictly from the following input (follow all rules above): + +{input_text} +""" + +def build_field_prompt(transcript_text: str, current_field: str) -> str: + return f""" +SYSTEM PROMPT: +You are an AI assistant designed to help fill out JSON fields with information extracted from transcribed text. + +You will receive: +- a transcript +- a target field name + +Return ONLY the value for that field. + +Rules: +- If multiple values exist → separate with ";" +- If no value found → return "-1" +- Do NOT add explanation + +DATA: +Target JSON field: {current_field} + +TEXT: +{transcript_text} +""" \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index e7524c5..a25789d 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -24,6 +24,9 @@ services: ollama: condition: service_healthy command: /bin/sh -c "python3 -m api.db.init_db && python3 -m uvicorn api.main:app --host 0.0.0.0 --port 8000" + # Fix #224 — expose port so API is reachable at http://localhost:8000 + ports: + - "8000:8000" volumes: - .:/app ports: diff --git a/requirements.txt b/requirements.txt index 3d408fb..6953f25 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,4 @@ numpy<2 ollama pypdf python-multipart +pypdf diff --git a/src/backend.py b/src/backend.py new file mode 100644 index 0000000..84d937c --- /dev/null +++ b/src/backend.py @@ -0,0 +1,171 @@ +import json +import os +import requests +from json_manager import JsonManager +from input_manager import InputManager +from pdfrw import PdfReader, PdfWriter + + + +class textToJSON(): + def __init__(self, transcript_text, target_fields, json={}): + self.__transcript_text = transcript_text # str + self.__target_fields = target_fields # List, contains the template field. + self.__json = json # dictionary + self.type_check_all() + self.main_loop() + + + def type_check_all(self): + if type(self.__transcript_text) != str: + raise TypeError(f"ERROR in textToJSON() ->\ + Transcript must be text. Input:\n\ttranscript_text: {self.__transcript_text}") + elif type(self.__target_fields) != list: + raise TypeError(f"ERROR in textToJSON() ->\ + Target fields must be a list. Input:\n\ttarget_fields: {self.__target_fields}") + + + def build_prompt(self, current_field): + """ + This method is in charge of the prompt engineering. It creates a specific prompt for each target field. + @params: current_field -> represents the current element of the json that is being prompted. + """ + prompt = f""" + SYSTEM PROMPT: + You are an AI assistant designed to help fillout json files with information extracted from transcribed voice recordings. + You will receive the transcription, and the name of the JSON field whose value you have to identify in the context. Return + only a single string containing the identified value for the JSON field. + If the field name is plural, and you identify more than one possible value in the text, return both separated by a ";". + If you don't identify the value in the provided text, return "-1". + --- + DATA: + Target JSON field to find in text: {current_field} + + TEXT: {self.__transcript_text} + """ + + return prompt + + def main_loop(self): #FUTURE -> Refactor this to its own class + for field in self.__target_fields: + prompt = self.build_prompt(field) + # print(prompt) + # ollama_url = "http://localhost:11434/api/generate" + ollama_host = os.getenv("OLLAMA_HOST", "http://localhost:11434").rstrip("/") + ollama_url = f"{ollama_host}/api/generate" + + payload = { + "model": "mistral", + "prompt": prompt, + "stream": False # don't really know why --> look into this later. + } + + response = requests.post(ollama_url, json=payload) + + # parse response + json_data = response.json() + parsed_response = json_data['response'] + # print(parsed_response) + self.add_response_to_json(field, parsed_response) + + print("----------------------------------") + print("\t[LOG] Resulting JSON created from the input text:") + print(json.dumps(self.__json, indent=2)) + print("--------- extracted data ---------") + + return None + + def add_response_to_json(self, field, value): + """ + this method adds the following value under the specified field, + or under a new field if the field doesn't exist, to the json dict + """ + value = value.strip().replace('"', '') + parsed_value = None + plural = False + + if value != "-1": + parsed_value = value + + if ";" in value: + parsed_value = self.handle_plural_values(value) + plural = True + + + if field in self.__json.keys(): + self.__json[field].append(parsed_value) + else: + self.__json[field] = parsed_value + + return + + def handle_plural_values(self, plural_value): + """ + This method handles plural values. + Takes in strings of the form 'value1; value2; value3; ...; valueN' + returns a list with the respective values -> [value1, value2, value3, ..., valueN] + """ + if ";" not in plural_value: + raise ValueError(f"Value is not plural, doesn't have ; separator, Value: {plural_value}") + + print(f"\t[LOG]: Formating plural values for JSON, [For input {plural_value}]...") + values = plural_value.split(";") + + # Remove trailing leading whitespace + for i in range(len(values)): + values[i] = values[i].lstrip() + + print(f"\t[LOG]: Resulting formatted list of values: {values}") + + return values + + + def get_data(self): + return self.__json + +class Fill(): + def __init__(self): + pass + + def fill_form(user_input: str, definitions: list, pdf_form: str): + """ + Fill a PDF form with values from user_input using testToJSON. + Fields are filled in the visual order (top-to-bottom, left-to-right). + """ + + output_pdf = pdf_form[:-4] + "_filled.pdf" + + # Generate dictionary of answers from your original function + t2j = textToJSON(user_input, definitions) + textbox_answers = t2j.get_data() # This is a dictionary + + answers_list = list(textbox_answers.values()) + + # Read PDF + pdf = PdfReader(pdf_form) + + # Loop through pages + for page in pdf.pages: + if page.Annots: + sorted_annots = sorted( + page.Annots, + key=lambda a: (-float(a.Rect[1]), float(a.Rect[0])) + ) + + i = 0 + for annot in sorted_annots: + if annot.Subtype == '/Widget' and annot.T: + field_name = annot.T[1:-1] + + if i < len(answers_list): + annot.V = f'{answers_list[i]}' + annot.AP = None + i += 1 + else: + # Stop if we run out of answers + break + + PdfWriter().write(output_pdf, pdf) + + # Your main.py expects this function to return the path + return output_pdf diff --git a/src/file_manipulator.py b/src/file_manipulator.py index e499c89..498c31d 100644 --- a/src/file_manipulator.py +++ b/src/file_manipulator.py @@ -37,7 +37,10 @@ def fill_form(self, user_input: str, fields: list, pdf_form_path: str): print("[3] Starting extraction and PDF filling process...") try: - self.llm._target_fields = fields + if isinstance(fields, dict): + self.llm._target_fields = list(fields.keys()) + else: + self.llm._target_fields = fields self.llm._transcript_text = user_input output_name = self.filler.fill_form(pdf_form=pdf_form_path, llm=self.llm) diff --git a/src/filler.py b/src/filler.py index 7f738c2..21c4b81 100644 --- a/src/filler.py +++ b/src/filler.py @@ -1,52 +1,58 @@ from pdfrw import PdfReader, PdfWriter from src.llm import LLM from datetime import datetime - +from src.validation import validate_extraction class Filler: def __init__(self): pass - def fill_form(self, pdf_form: str, llm: LLM): - """ - Fill a PDF form with values from user_input using LLM. - Fields are filled in the visual order (top-to-bottom, left-to-right). - """ - output_pdf = ( - pdf_form[:-4] - + "_" - + datetime.now().strftime("%Y%m%d_%H%M%S") - + "_filled.pdf" - ) - - # Generate dictionary of answers from your original function - t2j = llm.main_loop() - textbox_answers = t2j.get_data() # This is a dictionary - - answers_list = list(textbox_answers.values()) - - # Read PDF - pdf = PdfReader(pdf_form) - - # Loop through pages - i = 0 - for page in pdf.pages: - if page.Annots: - sorted_annots = sorted( - page.Annots, key=lambda a: (-float(a.Rect[1]), float(a.Rect[0])) - ) - - for annot in sorted_annots: - if annot.Subtype == "/Widget" and annot.T: - if i < len(answers_list): - annot.V = f"{answers_list[i]}" - annot.AP = None - i += 1 - else: - # Stop if we run out of answers - break - - PdfWriter().write(output_pdf, pdf) - - # Your main.py expects this function to return the path - return output_pdf +def fill_form(self, pdf_form: str, llm: LLM): + """ + Fill a PDF form with values from user_input using LLM. + Fields are filled in the visual order (top-to-bottom, left-to-right). + """ + output_pdf = ( + pdf_form[:-4] + + "_" + + datetime.now().strftime("%Y%m%d_%H%M%S") + + "_filled.pdf" + ) + + # Generate dictionary of answers + t2j = llm.main_loop() + raw_data = t2j.get_data() + + # Validation step (separate concern ✅) + validated_data, errors = validate_extraction(raw_data) + + if errors: + print("[Validation Warning]", errors) + + textbox_answers = validated_data + answers_list = list(textbox_answers.values()) + + # Read PDF + pdf = PdfReader(pdf_form) + + # Loop through pages + i = 0 + for page in pdf.pages: + if page.Annots: + sorted_annots = sorted( + page.Annots, + key=lambda a: (-float(a.Rect[1]), float(a.Rect[0])) + ) + + for annot in sorted_annots: + if annot.Subtype == "/Widget" and annot.T: + if i < len(answers_list): + annot.V = f"{answers_list[i]}" + annot.AP = None + i += 1 + else: + break + + PdfWriter().write(output_pdf, pdf) + + return output_pdf diff --git a/src/llm.py b/src/llm.py index 3621187..5bec8ca 100644 --- a/src/llm.py +++ b/src/llm.py @@ -1,8 +1,35 @@ import json import os import requests +from api.services.prompt_builder import build_field_prompt from requests.exceptions import Timeout, RequestException +from src.llm_client import call_llm +def safe_extract_value(response: str): + if not response: + return None + + response = response.strip() + + + response = response.replace('"', '').replace("'", "") + + + if ":" in response: + response = response.split(":")[-1].strip() + + + response = response.split("\n")[0] + + + if response.lower() in ["-1", "none", "null", "not found"]: + return None + + + if len(response) > 200: + return None + + return response class LLM: def __init__(self, transcript_text=None, target_fields=None, json=None): @@ -24,26 +51,7 @@ def type_check_all(self): Target fields must be a list. Input:\n\ttarget_fields: {self._target_fields}" ) - def build_prompt(self, current_field): - """ - This method is in charge of the prompt engineering. It creates a specific prompt for each target field. - @params: current_field -> represents the current element of the json that is being prompted. - """ - prompt = f""" - SYSTEM PROMPT: - You are an AI assistant designed to help fillout json files with information extracted from transcribed voice recordings. - You will receive the transcription, and the name of the JSON field whose value you have to identify in the context. Return - only a single string containing the identified value for the JSON field. - If the field name is plural, and you identify more than one possible value in the text, return both separated by a ";". - If you don't identify the value in the provided text, return "-1". - --- - DATA: - Target JSON field to find in text: {current_field} - - TEXT: {self._transcript_text} - """ - - return prompt + def main_loop(self): timeout = 30 @@ -51,26 +59,26 @@ def main_loop(self): # self.type_check_all() total_fields = len(self._target_fields) - for i, field in enumerate(self._target_fields.keys(), 1): - prompt = self.build_prompt(field) + for i, field in enumerate(self._target_fields, 1): + prompt = build_field_prompt(self._transcript_text, field) # print(prompt) # ollama_url = "http://localhost:11434/api/generate" ollama_host = os.getenv("OLLAMA_HOST", "http://localhost:11434").rstrip("/") ollama_url = f"{ollama_host}/api/generate" + + payload = { "model": "mistral", "prompt": prompt, - "stream": False, # don't really know why --> look into this later. + "stream": False, # streaming disabled; using single response mode } json_data = None try: for attempt in range(max_retries): try: - response = requests.post(ollama_url, json=payload, timeout=timeout) - response.raise_for_status() - json_data = response.json() + json_data = call_llm(prompt, timeout=timeout, retries=max_retries) break except Timeout: print(f"Ollama request timed out (attempt {attempt+1})") @@ -84,6 +92,7 @@ def main_loop(self): except requests.exceptions.HTTPError as e: raise RuntimeError(f"Ollama returned an error: {e}") + if json_data is None: raise RuntimeError("Failed to get response from Ollama after retries.") else: @@ -101,17 +110,18 @@ def main_loop(self): return self def add_response_to_json(self, field, value): - """ - this method adds the following value under the specified field, - or under a new field if the field doesn't exist, to the json dict - """ - value = value.strip().replace('"', "") + value = value.strip().replace('"', "") if value else None parsed_value = None - if value != "-1": + if value: parsed_value = value + else: + parsed_value = { + "value": None, + "requires_review": True + } - if ";" in value: + if value and ";" in value: parsed_value = self.handle_plural_values(value) if field in self._json.keys(): @@ -121,23 +131,18 @@ def add_response_to_json(self, field, value): return + def handle_plural_values(self, plural_value): """ - This method handles plural values. - Takes in strings of the form 'value1; value2; value3; ...; valueN' - returns a list with the respective values -> [value1, value2, value3, ..., valueN] + This method handles plural values. """ if ";" not in plural_value: raise ValueError( f"Value is not plural, doesn't have ; separator, Value: {plural_value}" ) - print( - f"\t[LOG]: Formating plural values for JSON, [For input {plural_value}]..." - ) values = plural_value.split(";") - # Remove trailing leading whitespace for i in range(len(values)): values[i] = values[i].lstrip() diff --git a/src/llm_client.py b/src/llm_client.py new file mode 100644 index 0000000..66a0d2d --- /dev/null +++ b/src/llm_client.py @@ -0,0 +1,28 @@ +import requests +from requests.exceptions import Timeout, RequestException + + +def call_llm(prompt: str, timeout: int = 30, retries: int = 2): + payload = { + "model": "mistral", + "prompt": prompt, + "stream": False + } + + url = "http://localhost:11434/api/generate" + + for attempt in range(retries + 1): + try: + response = requests.post(url, json=payload, timeout=timeout) + response.raise_for_status() + return response.json() + + except Timeout: + if attempt == retries: + raise RuntimeError("LLM request timed out") + + except RequestException as e: + if attempt == retries: + raise RuntimeError(f"LLM request failed: {e}") + + return None \ No newline at end of file diff --git a/src/validation.py b/src/validation.py new file mode 100644 index 0000000..894ac33 --- /dev/null +++ b/src/validation.py @@ -0,0 +1,29 @@ +from pydantic import BaseModel, ValidationError +from typing import Optional + + +class ExtractionSchema(BaseModel): + name: Optional[str] + location: Optional[str] + date: Optional[str] + incident_type: Optional[str] + description: Optional[str] + + class Config: + extra = "allow" + + +def validate_extraction(data: dict): + try: + validated = ExtractionSchema(**data) + return validated.dict(), None + except ValidationError as e: + formatted_errors = [] + + for err in e.errors(): + formatted_errors.append({ + "field": err.get("loc", ["unknown"])[0], + "issue": err.get("msg", "Invalid value") + }) + + return data, formatted_errors \ No newline at end of file diff --git a/temp_outfile.pdf b/temp_outfile.pdf new file mode 100644 index 0000000..f082168 Binary files /dev/null and b/temp_outfile.pdf differ