Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 74 additions & 4 deletions api/main.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from contextlib import asynccontextmanager
import os
import uuid

from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware

from fastapi import FastAPI
from api.routes import templates, forms
from api.db.init_db import init_db
from api.errors.handlers import register_exception_handlers
from fastapi.middleware.cors import CORSMiddleware
from api.routes import forms, templates

@asynccontextmanager
async def lifespan(app: FastAPI):
Expand All @@ -18,7 +23,6 @@ async def lifespan(app: FastAPI):

app = FastAPI(lifespan=lifespan)

register_exception_handlers(app)

default_origins = "http://127.0.0.1:5173"
allowed_origins = [
Expand All @@ -35,5 +39,71 @@ async def lifespan(app: FastAPI):
allow_headers=["*"],
)

class RequestIDMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
request_id = str(uuid.uuid4())
request.state.request_id = request_id
response = await call_next(request)
response.headers["X-Request-ID"] = request_id
return response


app.add_middleware(RequestIDMiddleware)


@app.exception_handler(StarletteHTTPException)
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
return JSONResponse(
status_code=exc.status_code,
content={
"error": {
"type": "HTTPException",
"message": exc.detail,
"details": {}
}
},
)

@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
formatted_errors = []

for err in exc.errors():
loc = err.get("loc", [])
field = loc[-1] if loc else "unknown"
issue = err.get("msg", "Invalid value")
expected = err.get("type", "")

formatted_errors.append({
"field": field,
"issue": issue,
"expected": expected
})

return JSONResponse(
status_code=422,
content={
"error": {
"type": "ValidationError",
"message": "Invalid request data",
"details": formatted_errors,
}
},
)

@app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception):
return JSONResponse(
status_code=500,
content={
"error": {
"type": "InternalServerError",
"message": str(exc),
"details": {}
}
},
)


app.include_router(templates.router)
app.include_router(forms.router)
8 changes: 7 additions & 1 deletion api/schemas/forms.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from pydantic import BaseModel
from pydantic import BaseModel, field_validator

class FormFill(BaseModel):
template_id: int
input_text: str

@field_validator("input_text")
def validate_input_text(cls, value):
if not value or not value.strip():
raise ValueError("Input text cannot be empty")
return value


class FormFillResponse(BaseModel):
id: int
Expand Down
62 changes: 62 additions & 0 deletions api/services/prompt_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
def build_extraction_prompt(input_text: str) -> str:
return f"""
You are an AI system that extracts structured information from incident reports.
Your task is to extract ONLY information explicitly present in the input text.

STRICT RULES:
- Do NOT infer or guess missing information
- If a field is not clearly mentioned, return an empty string ""
- Do NOT add any extra fields beyond those specified
- Do NOT modify or reinterpret values

Extract the following fields:
- name
- location
- date (YYYY-MM-DD if possible)
- incident_type
- description

Return ONLY valid JSON. Do not include any extra text, explanation, or formatting outside JSON.
The output MUST be a valid JSON object and parsable by json.loads().
Format:
{{
"name": "",
"location": "",
"date": "",
"incident_type": "",
"description": ""
}}

Example:

Input:
Fire reported near Central Park on Jan 5 involving a vehicle.

Output:
{{
"name": "",
"location": "Central Park",
"date": "2024-01-05",
"incident_type": "fire",
"description": "Fire involving a vehicle"
}}

Negative Example (DO NOT DO THIS):

Incorrect Output:
(This output is incorrect because it includes inferred/assumed values)
{{
"location": "Central Park (assumed)",
"date": "2024-01-05"
}}

Correct Output:
{{
"location": "Central Park",
"date": ""
}}

Now extract strictly from the following input (follow all rules above):

{input_text}
"""
87 changes: 64 additions & 23 deletions src/llm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,34 @@
import json
import os
import requests
from api.services.prompt_builder import build_extraction_prompt
from requests.exceptions import Timeout, RequestException

def safe_extract_value(response: str):
if not response:
return None

response = response.strip()


response = response.replace('"', '').replace("'", "")


if ":" in response:
response = response.split(":")[-1].strip()


response = response.split("\n")[0]


if response.lower() in ["-1", "none", "null", "not found"]:
return None


if len(response) > 200:
return None

return response

class LLM:
def __init__(self, transcript_text=None, target_fields=None, json=None):
Expand Down Expand Up @@ -50,18 +76,31 @@ def main_loop(self):
max_retries = 3

# self.type_check_all()

total_fields = len(self._target_fields)
for i, field in enumerate(self._target_fields.keys(), 1):
prompt = self.build_prompt(field)

# print(prompt)
# ollama_url = "http://localhost:11434/api/generate"
ollama_host = os.getenv("OLLAMA_HOST", "http://localhost:11434").rstrip("/")
ollama_url = f"{ollama_host}/api/generate"

base_prompt = build_extraction_prompt(self._transcript_text)

prompt = f"""
{base_prompt}

Focus specifically on extracting the value for this field:
{field}

Return only the extracted value as a plain string. Do not return JSON.
"""

payload = {
"model": "mistral",
"prompt": prompt,
"stream": False, # don't really know why --> look into this later.
"stream": False, # streaming disabled; using single response mode
}

json_data = None
Expand All @@ -70,7 +109,7 @@ def main_loop(self):
try:
response = requests.post(ollama_url, json=payload, timeout=timeout)
response.raise_for_status()
json_data = response.json()
json_data = response.json()
break
except Timeout:
print(f"Ollama request timed out (attempt {attempt+1})")
Expand All @@ -84,14 +123,16 @@ def main_loop(self):
except requests.exceptions.HTTPError as e:
raise RuntimeError(f"Ollama returned an error: {e}")

# parse response
if json_data is None:
raise RuntimeError("Failed to get response from Ollama after retries.")
else:
# parse response
parsed_response = json_data["response"]
# print(parsed_response)
self.add_response_to_json(field, parsed_response)
print(f"[{i}/{total_fields}] Extracted data for field '{field}' successfully.")

raw_response = json_data.get("response", "")
parsed_response = safe_extract_value(raw_response)

self.add_response_to_json(field, parsed_response)

print(f"[{i}/{total_fields}] Extracted data for field '{field}' successfully.")

print("----------------------------------")
print("\t[LOG] Resulting JSON created from the input text:")
Expand All @@ -101,17 +142,18 @@ def main_loop(self):
return self

def add_response_to_json(self, field, value):
"""
this method adds the following value under the specified field,
or under a new field if the field doesn't exist, to the json dict
"""
value = value.strip().replace('"', "")
value = value.strip().replace('"', "") if value else None
parsed_value = None

if value != "-1":
if value:
parsed_value = value
else:
parsed_value = {
"value": None,
"requires_review": True
}

if ";" in value:
if value and ";" in value:
parsed_value = self.handle_plural_values(value)

if field in self._json.keys():
Expand All @@ -121,28 +163,27 @@ def add_response_to_json(self, field, value):

return


def handle_plural_values(self, plural_value):
"""
This method handles plural values.
Takes in strings of the form 'value1; value2; value3; ...; valueN'
returns a list with the respective values -> [value1, value2, value3, ..., valueN]
This method handles plural values.
"""
if ";" not in plural_value:
raise ValueError(
f"Value is not plural, doesn't have ; separator, Value: {plural_value}"
)

print(
f"\t[LOG]: Formating plural values for JSON, [For input {plural_value}]..."
)
values = plural_value.split(";")

# Remove trailing leading whitespace
for i in range(len(values)):
values[i] = values[i].lstrip()

values[i] = values[i].strip()



print(f"\t[LOG]: Resulting formatted list of values: {values}")


return values

def get_data(self):
Expand Down