diff --git a/api/main.py b/api/main.py index b6b3fdc..b51a185 100644 --- a/api/main.py +++ b/api/main.py @@ -1,12 +1,17 @@ from contextlib import asynccontextmanager import os +import uuid + +from fastapi import FastAPI, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +from fastapi.exceptions import RequestValidationError +from starlette.exceptions import HTTPException as StarletteHTTPException +from starlette.middleware.base import BaseHTTPMiddleware -from fastapi import FastAPI from api.routes import templates, forms from api.db.init_db import init_db from api.errors.handlers import register_exception_handlers -from fastapi.middleware.cors import CORSMiddleware -from api.routes import forms, templates @asynccontextmanager async def lifespan(app: FastAPI): @@ -18,7 +23,6 @@ async def lifespan(app: FastAPI): app = FastAPI(lifespan=lifespan) -register_exception_handlers(app) default_origins = "http://127.0.0.1:5173" allowed_origins = [ @@ -35,5 +39,71 @@ async def lifespan(app: FastAPI): allow_headers=["*"], ) +class RequestIDMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + request_id = str(uuid.uuid4()) + request.state.request_id = request_id + response = await call_next(request) + response.headers["X-Request-ID"] = request_id + return response + + +app.add_middleware(RequestIDMiddleware) + + +@app.exception_handler(StarletteHTTPException) +async def http_exception_handler(request: Request, exc: StarletteHTTPException): + return JSONResponse( + status_code=exc.status_code, + content={ + "error": { + "type": "HTTPException", + "message": exc.detail, + "details": {} + } + }, + ) + +@app.exception_handler(RequestValidationError) +async def validation_exception_handler(request: Request, exc: RequestValidationError): + formatted_errors = [] + + for err in exc.errors(): + loc = err.get("loc", []) + field = loc[-1] if loc else "unknown" + issue = err.get("msg", "Invalid value") + expected = err.get("type", "") + + formatted_errors.append({ + "field": field, + "issue": issue, + "expected": expected + }) + + return JSONResponse( + status_code=422, + content={ + "error": { + "type": "ValidationError", + "message": "Invalid request data", + "details": formatted_errors, + } + }, + ) + +@app.exception_handler(Exception) +async def general_exception_handler(request: Request, exc: Exception): + return JSONResponse( + status_code=500, + content={ + "error": { + "type": "InternalServerError", + "message": str(exc), + "details": {} + } + }, + ) + + app.include_router(templates.router) app.include_router(forms.router) diff --git a/api/schemas/forms.py b/api/schemas/forms.py index 3cce650..bf6957e 100644 --- a/api/schemas/forms.py +++ b/api/schemas/forms.py @@ -1,9 +1,15 @@ -from pydantic import BaseModel +from pydantic import BaseModel, field_validator class FormFill(BaseModel): template_id: int input_text: str + @field_validator("input_text") + def validate_input_text(cls, value): + if not value or not value.strip(): + raise ValueError("Input text cannot be empty") + return value + class FormFillResponse(BaseModel): id: int diff --git a/api/services/prompt_builder.py b/api/services/prompt_builder.py new file mode 100644 index 0000000..843c7cd --- /dev/null +++ b/api/services/prompt_builder.py @@ -0,0 +1,62 @@ +def build_extraction_prompt(input_text: str) -> str: + return f""" +You are an AI system that extracts structured information from incident reports. +Your task is to extract ONLY information explicitly present in the input text. + +STRICT RULES: +- Do NOT infer or guess missing information +- If a field is not clearly mentioned, return an empty string "" +- Do NOT add any extra fields beyond those specified +- Do NOT modify or reinterpret values + +Extract the following fields: +- name +- location +- date (YYYY-MM-DD if possible) +- incident_type +- description + +Return ONLY valid JSON. Do not include any extra text, explanation, or formatting outside JSON. +The output MUST be a valid JSON object and parsable by json.loads(). +Format: +{{ + "name": "", + "location": "", + "date": "", + "incident_type": "", + "description": "" +}} + +Example: + +Input: +Fire reported near Central Park on Jan 5 involving a vehicle. + +Output: +{{ + "name": "", + "location": "Central Park", + "date": "2024-01-05", + "incident_type": "fire", + "description": "Fire involving a vehicle" +}} + +Negative Example (DO NOT DO THIS): + +Incorrect Output: +(This output is incorrect because it includes inferred/assumed values) +{{ + "location": "Central Park (assumed)", + "date": "2024-01-05" +}} + +Correct Output: +{{ + "location": "Central Park", + "date": "" +}} + +Now extract strictly from the following input (follow all rules above): + +{input_text} +""" \ No newline at end of file diff --git a/src/llm.py b/src/llm.py index 3621187..cd8e3cb 100644 --- a/src/llm.py +++ b/src/llm.py @@ -1,8 +1,34 @@ import json import os import requests +from api.services.prompt_builder import build_extraction_prompt from requests.exceptions import Timeout, RequestException +def safe_extract_value(response: str): + if not response: + return None + + response = response.strip() + + + response = response.replace('"', '').replace("'", "") + + + if ":" in response: + response = response.split(":")[-1].strip() + + + response = response.split("\n")[0] + + + if response.lower() in ["-1", "none", "null", "not found"]: + return None + + + if len(response) > 200: + return None + + return response class LLM: def __init__(self, transcript_text=None, target_fields=None, json=None): @@ -50,18 +76,31 @@ def main_loop(self): max_retries = 3 # self.type_check_all() + total_fields = len(self._target_fields) for i, field in enumerate(self._target_fields.keys(), 1): prompt = self.build_prompt(field) + # print(prompt) # ollama_url = "http://localhost:11434/api/generate" ollama_host = os.getenv("OLLAMA_HOST", "http://localhost:11434").rstrip("/") ollama_url = f"{ollama_host}/api/generate" + base_prompt = build_extraction_prompt(self._transcript_text) + + prompt = f""" + {base_prompt} + + Focus specifically on extracting the value for this field: + {field} + + Return only the extracted value as a plain string. Do not return JSON. + """ + payload = { "model": "mistral", "prompt": prompt, - "stream": False, # don't really know why --> look into this later. + "stream": False, # streaming disabled; using single response mode } json_data = None @@ -70,7 +109,7 @@ def main_loop(self): try: response = requests.post(ollama_url, json=payload, timeout=timeout) response.raise_for_status() - json_data = response.json() + json_data = response.json() break except Timeout: print(f"Ollama request timed out (attempt {attempt+1})") @@ -84,14 +123,16 @@ def main_loop(self): except requests.exceptions.HTTPError as e: raise RuntimeError(f"Ollama returned an error: {e}") + # parse response if json_data is None: raise RuntimeError("Failed to get response from Ollama after retries.") - else: - # parse response - parsed_response = json_data["response"] - # print(parsed_response) - self.add_response_to_json(field, parsed_response) - print(f"[{i}/{total_fields}] Extracted data for field '{field}' successfully.") + + raw_response = json_data.get("response", "") + parsed_response = safe_extract_value(raw_response) + + self.add_response_to_json(field, parsed_response) + + print(f"[{i}/{total_fields}] Extracted data for field '{field}' successfully.") print("----------------------------------") print("\t[LOG] Resulting JSON created from the input text:") @@ -101,17 +142,18 @@ def main_loop(self): return self def add_response_to_json(self, field, value): - """ - this method adds the following value under the specified field, - or under a new field if the field doesn't exist, to the json dict - """ - value = value.strip().replace('"', "") + value = value.strip().replace('"', "") if value else None parsed_value = None - if value != "-1": + if value: parsed_value = value + else: + parsed_value = { + "value": None, + "requires_review": True + } - if ";" in value: + if value and ";" in value: parsed_value = self.handle_plural_values(value) if field in self._json.keys(): @@ -121,28 +163,27 @@ def add_response_to_json(self, field, value): return + def handle_plural_values(self, plural_value): """ - This method handles plural values. - Takes in strings of the form 'value1; value2; value3; ...; valueN' - returns a list with the respective values -> [value1, value2, value3, ..., valueN] + This method handles plural values. """ if ";" not in plural_value: raise ValueError( f"Value is not plural, doesn't have ; separator, Value: {plural_value}" ) - print( - f"\t[LOG]: Formating plural values for JSON, [For input {plural_value}]..." - ) values = plural_value.split(";") - # Remove trailing leading whitespace for i in range(len(values)): - values[i] = values[i].lstrip() + + values[i] = values[i].strip() + + print(f"\t[LOG]: Resulting formatted list of values: {values}") + return values def get_data(self):