In [4]:
import os
from neo4j import GraphDatabase
from dotenv import load_dotenv

load_dotenv()

graph = GraphDatabase.driver(
	os.environ['NEO4J_URI'],
	auth=(os.environ['NEO4J_USER'], os.environ['NEO4J_PWD'])
)
database = os.environ['NEO4J_DATABASE']

with graph.session(database=database) as session:
	records = session.execute_read(
		lambda tx: tx.run('MATCH (b:Bird)-[:HAS_FACT]->(d:Fact) ORDER BY RAND() RETURN b.name, d.text LIMIT 20').data()
	)
records

[{'b.name': 'Yellow-tailed Black-Cockatoo',
  'd.text': 'Breeds April–July in northern Australia; January–May in northern New South Wales; December–February in southern New South Wales; and October–March in South Australia, Victoria and Tasmania. Nest is bed of woodchips in large tree-hollow. Two eggs  ; incubation 28–29 days, by female only; chick has long, dense yellow down; usually only one nestling survives; fed by both parents, and leaves the hollow after three months, remaining with parents until at least next breeding season.'},
 {'b.name': "Chapman's Swift",
  'd.text': 'Panama to Colombia, Venezuela, Guianas, ne Brazil; Trinidad'},
 {'b.name': 'Spotted Buttonquail',
  'd.text': "Editor's Note: Additional distribution information for this taxon can be found in the 'Subspecies' article above. In the future we will develop a range-wide distribution article."},
 {'b.name': 'White-throated Robin',
  'd.text': 'Invertebrates and fruit. Of 58 invertebrates from stomachs of breeding b

In [None]:
from typing import (
	List
)
import ollama
from pydantic import BaseModel, Field, ValidationError

class Location(BaseModel):
	name: str = Field(
		description='the name of the location as described in the text. Please expand any abbreviations, e.g. "SW Africa" should become "Southwest Africa".'
	)
	continents: List[str] = Field(
		description='all the continents that the location belongs to.'
	)
	countries: List[str] = Field(
		description='all the countries that the location could exist within.'
	)
	regions: List[str] = Field(
		description='all the regions that could potentially be a part of the location. These can be more casual delineations, and do not need to be formal.'
	)
	geographic_landmarks: List[str] = Field(
		description='any physical landmarks such as islands, penninsulas, mountain ranges, etc. that are a part of this location. Unlike continents/countries/regions, no not include these unless they are mentioned explicitly in the text.'
	)
	

class LocationEntries(BaseModel):
    mentioned_locations: List[Location]


class ValidationEntry(BaseModel):
    satisfactory: bool


def parse_locations(location_text: str) -> LocationEntries:
	temp = 0
	while temp < 1:
		try:
			response = ollama.chat(
				model='llama3.1:8b',
				messages=[
					{
						'role': 'user',
						'content': (
							'You are a location extraction engine. '
							'Your goal is to return a JSON list of all the locations you are able to extract from a given piece of text. '
							'A "location" can be a range (e.g. SW Africa), a region, a physical landmark, or some other distinct location entity. '
							'This is the format you should use in your response:\n'
							'{\n'
							'\t"mentioned_locations": [\n'
							'\t\t{\n'
							'\t\t\t"name": str - the name of the location as described in the text. Please expand any abbreviations, e.g. "SW Africa" should become "Southwest Africa".\n'
							'\t\t\t"continents": list[str] - all the continents that the location belongs to.\n'
							'\t\t\t"countries": list[str] - all the countries that the location could exist within.\n'
							'\t\t\t"regions": list[str] - all the regions that could potentially be a part of the location. These can be more casual delineations, and do not need to be formal.\n'
							'\t\t\t"geographic_landmarks": list[str] - any physical landmarks such as islands, penninsulas, mountain ranges, etc. that are a part of this location. Unlike continents/countries/regions, no not include these unless they are mentioned explicitly in the text.\n'
							'\t\t}\n'
							'\t\t...more locations...\n'
							'\t]\n'
							'}\n'
							'If there are NO locations in the text, return "{ "mentioned_locations": [] }". '
							'ONLY return locations that are EXPLICITLY MENTIONED, no entries like "global" or "worldwide".\n'
							f'This is the text to extract locations from:\n\n'
							f'"{location_text}"\n\n'
							'ONLY return the JSON, DO NOT EXPLAIN.'
						)
					}
				],
				format=LocationEntries.model_json_schema(), # <--- Native support
				options={'temperature': temp}#, 'num_ctx': 8192, 'num_predict': 512}
			)['message']['content']

			return LocationEntries.model_validate_json(response)

		except ValidationError:
			temp += 0.1

#'- Refers to an actual verifiable location.\n'
#'- The attributes in the JSON are related to each other. E.g. the "continent" being "Africa" and the "country" being the "United States" would fail this criteria.\n\n'

def validate_location(location: Location) -> bool:
	temp = 0
	while temp < 1:
		try:
			val_response = ollama.chat(
				model='llama3.1:8b',
				messages=[
					{
						'role': 'user',
						'content': (
							'Evaluate the following JSON that corresponds to a single location entry. Your goal is to answer if the location data meets the following criteria:\n'
							'- Location is SPECIFIC. Not "global", "worldwide", or anything that is NOT a SPECIFIC LOCATION ON THE MAP.\n\n'
							'- Does not refer to a library or other man-made building. Only geographic features.\n'
							f'JSON to evaluate: "{dict(location)}"'
						)
					}
				],
				format=ValidationEntry.model_json_schema(), # <--- Native support
				options={'temperature': 0}#, 'num_ctx': 8192, 'num_predict': 512}
			)['message']['content']

			return ValidationEntry.model_validate_json(val_response).satisfactory

		except ValidationError:
			temp += 0.1

 
for record in records:
	location_text = record['d.text']
	print(f'Location Text: "{location_text}"')
	locations = parse_locations(location_text)
	print(f'Found {len(locations.mentioned_locations)} locations\n')
	

	for ind, location in enumerate(locations.mentioned_locations):
		print(f'Loc {ind+1}: {location}')
		is_valid = validate_location(location)
		print(f'IS VALID: {is_valid}')

In [None]:
from typing import (
	Union,
    Dict
)
import os
from neo4j import GraphDatabase
from tqdm.notebook import tqdm
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s %(name)s: %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    force=True,  # reset handlers if notebook re-runs
)
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)

logger = logging.getLogger("notebook")


from dotenv import load_dotenv

load_dotenv()

graph = GraphDatabase.driver(
	os.environ['NEO4J_URI'],
	auth=(os.environ['NEO4J_USER'], os.environ['NEO4J_PWD'])
)
database = os.environ['NEO4J_DATABASE']

#idx = -1
idx = 8286
batch_size = 200
num_processed = 0


def get_batch(start_idx: int) -> Dict[str, str]:
	return session.execute_read(
		lambda tx: tx.run('''
			MATCH (f:Fact)
			WHERE id(f) > $idx
			ORDER BY id(f)
			RETURN id(f) as f_id, f.text as f_text
			LIMIT $batch_size
		''', idx=start_idx, batch_size=batch_size).data()
	)


def parse_locations_from_record(record):
	out = []
	locations = parse_locations(record['f_text'])
	for loc in locations.mentioned_locations:
		if validate_location(loc):
			out.append({"f_id": record["f_id"], **dict(loc)})
	return out


with graph.session(database=database) as session:
	records = get_batch(idx)

while records:
	with graph.session(database=database) as session:
		logger.info(f'BATCH {num_processed}-{num_processed+batch_size}')
		batch_data = [
			{ 'f_id': record['f_id'] }
			for record in records
		]

		batch_locations = []

		for record in tqdm(records):
			locations = parse_locations(record['f_text'])
			for loc in locations.mentioned_locations:
				if validate_location(loc):
					batch_locations.append({"f_id": record["f_id"], **dict(loc)})

		logger.info('WRITING TO DB...')

		session.execute_write(
			lambda tx: tx.run('''
				UNWIND $batch_locations AS row

				MATCH (f:Fact) WHERE id(f) = row.f_id

				CREATE (l:Location {name: row.name}) 
				CREATE (l)-[:MENTIONED_IN]->(f)

				FOREACH (continent_name IN row.continents |
					MERGE (cn:Continent {name: continent_name})
					CREATE (l)-[:IS_IN]->(cn)
				)

				FOREACH (country_name IN row.countries |
					MERGE (c:Country {name: country_name})
					CREATE (l)-[:IS_IN]->(c)
				)

				FOREACH (region_name IN row.regions |
					MERGE (r:Region {name: region_name})
					CREATE (l)-[:IS_IN]->(r)
				)

				FOREACH (geo_land_name in row.geographic_landmarks|
					MERGE (g:GeoLandmark {name: geo_land_name})
					CREATE (l)-[:IS_IN]->(g)
				)
			''', batch_locations=batch_locations)
		)

		num_processed += len(records)
		idx = max(records, key=lambda r: r['f_id'])['f_id']
		logger.info(f'Processed {num_processed} Entries, MAX IDX {idx}')
		
		records = get_batch(idx)

In [None]:
from pydantic import BaseModel
from enum import Enum

import instructor
import ollama

#client = instructor.from_ollama(
#	ollama.Client(host='http://localhost:11434'),
#	mode=instructor.Mode.JSON_SCHEMA,
#)
client = instructor.from_provider('ollama/llama3.1:8b')

<instructor.core.client.Instructor at 0x7fe5fa1b3050>