In [None]:
import os
from neo4j import GraphDatabase
from dotenv import load_dotenv

load_dotenv()

graph = GraphDatabase.driver(
	os.environ['NEO4J_URI'],
	auth=(os.environ['NEO4J_USER'], os.environ['NEO4J_PWD'])
)
database = os.environ['NEO4J_DATABASE']

with graph.session(database=database) as session:
	records = session.execute_read(
		lambda tx: tx.run('MATCH (b:Bird)-[:HAS_FACT]->(d:Fact) ORDER BY RAND() RETURN b.name AS name, d.text AS text LIMIT 20').data()
	)
records

In [None]:
import re
from typing import List, Annotated
from pydantic import BaseModel, BeforeValidator
from enum import Enum

import instructor
import pandas as pd

client = instructor.from_provider(
    'ollama/gpt-oss:20b',
	base_url='http://localhost:11434/v1',
	mode=instructor.Mode.JSON,
)

seed = 42

instruct_params = {
	'max_retries': 2,
	'timeout': 15.0,
	'extra_body': {
		'options': {
			'temperature': 0.0,
			'seed': 42,
			'top_k': 1,
			'num_predict': 50
		}
	}
}


class Continent(str, Enum):
	NORTH_AMERICA = 'North America'
	SOUTH_AMERICA = 'South America'
	EUROPE = 'Europe'
	ASIA = 'Asia'
	AFRICA = 'Africa'
	OCEANIA = 'Oceania'
	ANTARCTICA = 'Antarctica'

class Continents(BaseModel):
    continents: List[Continent]

def get_continents(text: str):
	try:
		resp = client.create(
			response_model=Continents,
			messages=[
				{'role': 'system', 'content': (
					'Your goal is to determine which continents a Bird could be located in, given a text passage describing a fact about the bird. '
					'Only return continents that are either directly referenced in the text, or that have countries or regions within that continent that are directly referenced in the text. '
					'Only return continents that that describe where the bird can be found in its natural habitat. '
				)},
				{'role': 'user', 'content': f'Text: *{text}* DO NOT EXPLAIN'}
			],
			**instruct_params
		)
		return [ c.value for c in resp.continents ]
	except Exception as e:
		print(f'CONTINENTS ERR: {e}')
		return []


c_df = pd.read_csv('../data/countries_by_continent.csv')
countries = c_df.Country.values.tolist()

def clean_alpha_only(val: str) -> str:
	return re.sub(r'\s+', ' ', re.sub(r'[^a-zA-Z]', ' ', val)).strip()
    
Country = Enum('Country', { clean_alpha_only(country).upper(): clean_alpha_only(country) for country in countries }, type=str)

class Countries(BaseModel):
    countries: List[
        Annotated[
            Country,
            BeforeValidator(clean_alpha_only)
        ]
	]

def get_countries(text: str):
	try:
		countries = client.create(
			response_model=Countries,
			messages=[
				{'role': 'system', 'content': (
					'Your goal is to determine which countries a Bird could be located in, given a text passage describing a fact about the bird. '
					'Only return countries that are either directly referenced in the text, or that have regions within that country that are directly referenced in the text. '
					'Only return countries that that describe where the bird can be found in its natural habitat. '
				)},
				{'role': 'user', 'content': f'Text: *{text}* DO NOT EXPLAIN'}
			],
			**instruct_params
		)
		return [ c.value for c in countries.countries ]
	except Exception as e:
		print(f'COUNTRIES ERR: {e}')
		return []
		

class Regions(BaseModel):
    regions: List[str]
 
def get_regions(text: str):
	try:
		resp = client.create(
			response_model=Regions,
			messages=[
				{'role': 'system', 'content': (
					'Your goal is to extract all geographic locations (e.g. mountain ranges, deserts, forests, etc.) that a bird could be located in, given a text passage describing a fact about the bird. '
					'Geographic locations do not include countries, full continents, cities, or man-made structures. '
					'Return geographic locations without abbreviations, e.g. "Ural Mountains" instead of "Ural mtns", "Cooke Island" instead of "Cooke I", etc. '
					'Only return geographic locations that are directly referenced in the text and that describe where the bird can be found in its natural habitat. '
				)},
				{'role': 'user', 'content': f'Text: *{text}* DO NOT EXPLAIN'}
			],
			**instruct_params
		)
		return [ 
			' '.join([ w.capitalize() for w in r.split() ])
			for r in resp.regions 
        ]
	except Exception as e:
		print(f'REGIONS ERR: {e}')
		return []


for record in records:
#for record in [{'text': 'New Caledonian Buttonquail has traditionally been considered conspecific with Painted Buttonquail (Turnix varius), but recognized as very distinctive in HBW and separated in one recent list (2). In their assessment of the two taxa, del Hoyo and Collar (3), using the Tobias et al. (4) criteria, from which the numbers in brackets are derived, found that they differ in several plumage and morphometric characters. The comparison is based on one specimen of New Caledonian Buttonquail, a male with no specific locality and undated, but registered in the Natural History Museum in London in 1889: it differs from males of Painted Buttonquail in its smaller size (wing 80 mm versus 101.5 in one published sample (5), and bill 14.1 mm, tarsus 20 mm, and tail 39 mm) [allow 3]; dorsal and rump feathers mostly black versus black with rusty barring, without the rust color predominating on the mantle, and with proportionately more black markings in wing-coverts [2]; breast barred blackish-and-buff with some pale grey bases versus pale gray with buff blades with narrow blackish edges [2].'}]:
	print(f'TEXT: {record['text']}\n')
	continents = get_continents(record['text'])	
	print(f'Continents: {continents}')
	countries = get_countries(record['text'])
	print(f'Countries: {countries}')
	continents_from_countries = c_df[c_df.Country.isin(countries)].Continent.unique().tolist()
	print(f'Continents from Countries: {continents_from_countries}')
	regions = get_regions(record['text'])
	print(f'Regions: {regions}')
	print('\n')

In [None]:
from typing import (
	Union,
    Dict
)
import os
import warnings
from neo4j import GraphDatabase
from tqdm.notebook import tqdm
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s %(name)s: %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    force=True,  # reset handlers if notebook re-runs
)
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)

logger = logging.getLogger("notebook")

warnings.filterwarnings("ignore", module="neo4j")
logging.getLogger("neo4j").setLevel(logging.CRITICAL)

from dotenv import load_dotenv

load_dotenv()

graph = GraphDatabase.driver(
	os.environ['NEO4J_URI'],
	auth=(os.environ['NEO4J_USER'], os.environ['NEO4J_PWD']),
	warn_notification_severity="OFF"
)
database = os.environ['NEO4J_DATABASE']

idx = 483750
batch_size = 200
num_processed = 0
total_fact_count = 101292


def get_batch(start_idx: int) -> Dict[str, str]:
	with graph.session(database=database) as session:
		return session.execute_read(
			lambda tx: tx.run('''
				MATCH (f:Fact)
				WHERE id(f) > $idx
				ORDER BY id(f)
				RETURN id(f) as f_id, f.text as f_text
				LIMIT $batch_size
			''', idx=start_idx, batch_size=batch_size).data()
		)


records = get_batch(idx)

with tqdm(total=total_fact_count, unit="rec", desc="Processing") as pbar:
	while records:
		with graph.session(database=database) as session:
			for record in records:
				f_id = record['f_id']
				text = record['f_text']
				
				continents = get_continents(text)
				countries = get_countries(text)
				regions = get_regions(text)

				session.execute_write(
					lambda tx: tx.run('''
						MATCH (b:Bird)-[:HAS_FACT]->(f:Fact) WHERE id(f) = $f_id

						FOREACH (continent IN $continents |
							MERGE (c:Continent {name: continent})
							CREATE (b)-[:IN_CONTINENT]->(c)
						)

						FOREACH (country IN $countries |
							MERGE (cn:Country {name: country})
							CREATE (b)-[:IN_COUNTRY]->(cn)
      					)

						FOREACH (region IN $regions |
							MERGE (r:Region {name: region})
							CREATE (b)-[:IN_REGION]->(r)
      					)
						''', 
						f_id=f_id, 
						continents=continents,
						countries=countries,
						regions=regions
                    )
				)

				pbar.update()

		num_processed += len(records)
		idx = max(records, key=lambda r: r['f_id'])['f_id']
		logger.info(f'Processed {num_processed} Entries, MAX IDX {idx}')
		records = get_batch(idx)