<a href="https://colab.research.google.com/github/dumindus/aviation-safety-kg-rag/blob/main/phase-01/KG_based_RAG_for_Querying_Aviation_Safety_Data_%7C_Phase_01.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Phase 01: Pre Processing with Real NTSB Data


##Step 1: Load and Explore NTSB Data

In [None]:
import pandas as pd
ntsb_df=pd.read_csv("NTSB_Data_New.csv")

#events = pd.read_excel("events.xlsx")
aircraft = pd.read_excel("aircraft.xlsx")
#engines = pd.read_excel("engines.xlsx")
#sequence = pd.read_excel("Events_Sequence.xlsx")

ntsb_df.shape

(68681, 38)

In [None]:
df=ntsb_df.drop(columns=['DocketUrl','DocketPublishDate','Unnamed: 37'])
print(ntsb_df.columns.tolist())

df

['NtsbNo', 'EventType', 'Mkey', 'EventDate', 'City', 'State', 'Country', 'ReportNo', 'N', 'HasSafetyRec', 'ReportType', 'OriginalPublishDate', 'HighestInjuryLevel', 'FatalInjuryCount', 'SeriousInjuryCount', 'MinorInjuryCount', 'ProbableCause', 'EventID', 'Latitude', 'Longitude', 'Make', 'Model', 'AirCraftCategory', 'AirportID', 'AirportName', 'AmateurBuilt', 'NumberOfEngines', 'Scheduled', 'PurposeOfFlight', 'FAR', 'AirCraftDamage', 'WeatherCondition', 'Operator', 'ReportStatus', 'RepGenFlag', 'DocketUrl', 'DocketPublishDate', 'Unnamed: 37']


Unnamed: 0,NtsbNo,EventType,Mkey,EventDate,City,State,Country,ReportNo,N,HasSafetyRec,...,AmateurBuilt,NumberOfEngines,Scheduled,PurposeOfFlight,FAR,AirCraftDamage,WeatherCondition,Operator,ReportStatus,RepGenFlag
0,ERA26LA059,ACC,202089,2025-11-30T19:18:00Z,Chatham,Massachusetts,United States,,N1294P,False,...,FALSE,1,,PERS,091,Substantial,VMC,,In work,False
1,GAA26WA034,ACC,202081,2025-11-30T14:09:00Z,New South Wales.,,Austria,,"VH-EWS, VH-NMG",False,...,"FALSE,FALSE",",",,"PERS,PERS",091091,"Substantial,Minor",,",",,False
2,GAA26WA041,ACC,202123,2025-11-29T16:45:00Z,Malawa,,Poland,,F-HEAT,False,...,FALSE,,,,NUSN,,,,,False
3,ERA26LA060,ACC,202094,2025-11-28T15:00:00Z,Okeechobee,Florida,United States,,N711,False,...,FALSE,1,,PERS,091,Substantial,VMC,APACHE CUB SYSTEMS COMPANY LLC,In work,False
4,CEN26LA056,ACC,202080,2025-11-28T14:12:00Z,Downers Grove,Illinois,United States,,N4332V,False,...,FALSE,1,,PERS,091,Substantial,VMC,KENAGA MICHAEL L,In work,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
68676,NYC90LA045,ACC,36594,1990-01-03T15:30:00Z,SLIPPERY ROCK,Pennsylvania,United States,,N4636E,False,...,FALSE,1,,PERS,091,Substantial,VMC,,Completed,False
68677,NYC90DFJ01,ACC,36490,1990-01-03T14:15:00Z,PENN YAN,New York,United States,,N896JC,False,...,TRUE,1,,PERS,091,Substantial,VMC,JAMES FINK,Completed,False
68678,NYC90LA044,ACC,36593,1990-01-02T11:30:00Z,MILLIS,Massachusetts,United States,,N9241U,False,...,FALSE,1,,INST,091,Substantial,VMC,AVIATION EAST,Completed,False
68679,BFO90FA017,ACC,11540,1990-01-02T01:07:00Z,BALTIMORE,Maryland,United States,,N109AA,False,...,FALSE,3,SCHD,UNK,121,,VMC,AMERICAN AIRLINES,Completed,False


In [None]:
import numpy as np

# Standardize column names
df.columns = (
    df.columns
    .str.strip()
    .str.lower()
    .str.replace(" ", "_")
)

#Parse Dates & Create Time Features
df["event_date"] = pd.to_datetime(df["eventdate"], errors="coerce")
df["event_year"] = df["event_date"].dt.year
df["event_month"] = df["event_date"].dt.month
df["event_dayofweek"] = df["event_date"].dt.dayofweek

#Clean Geographic Fields
for col in ["latitude", "longitude"]:
    df[col] = pd.to_numeric(df[col], errors="coerce")

df["state"] = df["state"].str.upper().str.strip()
df["country"] = df["country"].str.upper().str.strip()

#Injury Severity Features
injury_cols = [
    "fatalinjurycount",
    "seriousinjurycount",
    "minorinjurycount" if "minorinjurycount" in df.columns else "minorinjurycount"
]

df[injury_cols] = df[injury_cols].fillna(0).astype(int)

df["total_injuries"] = df[injury_cols].sum(axis=1)
df["has_fatality"] = (df["fatalinjurycount"] > 0).astype(int)

#Normalize Categorical Variables
categorical_cols = [
    "eventtype", "reporttype", "highestinjurylevel",
    "aircraftcategory", "weathercondition",
    "aircraftdamage", "purposeofflight", "far"
]

for col in categorical_cols:
    if col in df.columns:
        df[col] = (
            df[col]
            .astype(str)
            .str.upper()
            .str.strip()
            .replace({"NAN": np.nan})
        )

#Binary Flags
binary_cols = [
    "hassafetyrec",
    "amateurbuilt",
    "scheduled",
    "repgenflag"
]

#Aircraft & Engine Fields
for col in binary_cols:
    df[col] = df[col].map({"Y": 1, "N": 0})

df["numberofengines"] = pd.to_numeric(
    df["numberofengines"], errors="coerce"
)

df["make"] = df["make"].str.upper().str.strip()
df["model"] = df["model"].str.upper().str.strip()

#Clean Free-Text Fields
df["probablecause"] = (
    df["probablecause"]
    .astype(str)
    .str.replace(r"\s+", " ", regex=True)
    .str.strip()
)

#Select Final Feature Set
final_cols = [
    "eventid", "ntsbno", "event_date", "event_year",
    "event_month", "event_dayofweek",
    "city", "state", "country",
    "latitude", "longitude", "airportid","airportname","n",
    "make", "model", "aircraftcategory",
    "numberofengines", "purposeofflight",
    "weathercondition", "aircraftdamage",
    "fatalinjurycount", "seriousinjurycount",
    "minorinjurycount", "total_injuries",
    "has_fatality", "highestinjurylevel",
    "probablecause"
]

final_df = df[final_cols]

In [None]:
final_df.dtypes

Unnamed: 0,0
eventid,float64
ntsbno,object
event_date,"datetime64[ns, UTC]"
event_year,int32
event_month,int32
event_dayofweek,int32
city,object
state,object
country,object
latitude,float64


In [None]:
aircraft_final_cols = [
    "regis_no", "ntsb_no", "acft_make", "acft_model",
    "acft_series", "total_seats",
    "owner_acft"
]

aircraft_final=aircraft[aircraft_final_cols]
aircraft_final

In [None]:
final_df_merge = (
    final_df.merge(aircraft_final, left_on="n", right_on="regis_no", how="left")
)

In [None]:
final_df_merge

In [None]:
final_df_merge.to_csv("ntsb_preprocessed.csv", index=False)

##Step 2: Define Aviation Entity Schema

In [None]:
import pandas as pd
import numpy as np

def load_real_ntsb_data(file_path=None):
    """Load actual NTSB data from a CSV file"""
    if file_path:
        # Load real dataset
        # NTSB data often uses ISO-8859-1 encoding
        df = pd.read_csv(file_path, encoding='ISO-8859-1', low_memory=False)
        print(f"Loaded real NTSB data with {len(df)} records.")
    else:
        # Fallback if no file is provided
        print("No file provided. Please upload your NTSB CSV.")
        return None


    # Standardize column names to lowercase for the extractor
    df.columns = [col.lower().replace(' ', '_') for col in df.columns]

    # Handle missing values in critical ID columns
    #df = df.dropna(subset=['event_id'])

    # Clean up common data issues in real NTSB sets
    df['make'] = df['make'].str.upper().fillna('UNKNOWN')
    df['model'] = df['model'].str.upper().fillna('UNKNOWN')

    return df

# Usage:
ntsb_df = load_real_ntsb_data('ntsb_preprocessed.csv')
print(ntsb_df.columns.tolist())


Loaded real NTSB data with 68681 records.
['eventid', 'ntsbno', 'event_date', 'event_year', 'event_month', 'event_dayofweek', 'city', 'state', 'country', 'latitude', 'longitude', 'airportid', 'airportname', 'n', 'make', 'model', 'aircraftcategory', 'numberofengines', 'purposeofflight', 'weathercondition', 'aircraftdamage', 'fatalinjurycount', 'seriousinjurycount', 'minorinjurycount', 'total_injuries', 'has_fatality', 'highestinjurylevel', 'probablecause', 'regis_no', 'ntsb_no', 'acft_make', 'acft_model', 'acft_series', 'total_seats', 'owner_acft']


In [None]:
#['eventid', 'ntsbno', 'event_date', 'event_year', 'event_month', 'event_dayofweek',
# 'city', 'state', 'country', 'latitude', 'longitude', 'airportid', 'airportname',
# 'n', 'make', 'model', 'aircraftcategory', 'numberofengines', 'purposeofflight', 'weathercondition', 'aircraftdamage', 'fatalinjurycount', 'seriousinjurycount', 'minorinjurycount', 'total_injuries', 'has_fatality', 'highestinjurylevel', 'probablecause',
# 'regis_no', 'ntsb_no', 'acft_make', 'acft_model', 'acft_series', 'total_seats', 'owner_acft']


COLUMN_MAP = {
    # ---------------- Accident ----------------
    "event_id": "ntsbno",          # canonical ID
    "event_internal_id": "eventid",
    "event_date": "event_date",
    "event_year": "event_year",
    "event_month": "event_month",
    "event_dayofweek": "event_dayofweek",

    "injury_severity": "highestinjurylevel",
    "aircraft_damage": "aircraftdamage",
    "weather_condition": "weathercondition",
    "probable_cause": "probablecause",

    "fatal_count": "fatalinjurycount",
    "serious_count": "seriousinjurycount",
    "minor_count": "minorinjurycount",
    "total_injuries": "total_injuries",
    "has_fatality": "has_fatality",

    # ---------------- Location ----------------
    "city": "city",
    "state": "state",
    "country": "country",
    "latitude": "latitude",
    "longitude": "longitude",
    "airport_code": "airportid",
    "airport_name": "airportname",

    # ---------------- Aircraft ----------------
    "registration_number": "regis_no",   # primary
    "alt_registration_number": "n",      # fallback

    "make": "make",
    "model": "model",
    "series": "acft_series",

    "aircraft_category": "aircraftcategory",
    "number_of_engines": "numberofengines",
    "purpose_of_flight": "purposeofflight",
    "total_seats": "total_seats",

    # ---------------- Operator / Owner ----------------
    "owner": "owner_acft",
}


##Step 3: LLM-Powered Entity Extraction

In [None]:
from pydantic import BaseModel, Field
from typing import List, Dict, Optional, Any
from datetime import datetime

class AviationEntity(BaseModel):
    entity_type: str = Field(..., description="Type of aviation entity")
    entity_id: str = Field(..., description="Unique identifier for the entity")
    attributes: Dict[str, Any] = Field(default_factory=dict)
    source_dataset: str = Field(default="NTSB")
    confidence_score: float = Field(default=1.0, ge=0.0, le=1.0)
    relationships: List[Dict[str, Any]] = Field(default_factory=list)

class Aircraft(AviationEntity):
    entity_type: str = "Aircraft"
    registration_number: str
    make: str
    model: str
    series: Optional[str] = None
    aircraft_category: Optional[str] = None
    number_of_engines: Optional[int] = None
    engine_type: Optional[str] = None
    purpose_of_flight: Optional[str] = None
    total_seats: Optional[int] = None

class Airport(AviationEntity):
    entity_type: str = "Airport"
    airport_code: Optional[str] = None
    airport_name: Optional[str] = None
    location: str
    country: Optional[str] = None
    latitude: Optional[float] = None
    longitude: Optional[float] = None

class Airline(AviationEntity):
    entity_type: str = "Airline"
    airline_name: str
    far_description: Optional[str] = None
    schedule: Optional[str] = None

class Accident(AviationEntity):
    entity_type: str = "Accident"
    event_id: str
    event_date: datetime | None = None #Check
    injury_severity: Optional[str] = None
    aircraft_damage: object
    weather_condition: Optional[str] = None
    probable_cause: Optional[str] = None
    total_injuries: Optional[int] = None #check
    has_fatality: Optional[int] = None #check
    fatal_count: Optional[int] = None #check
    serious_count: Optional[int] = None #check
    minor_count: Optional[int] = None #check


# Entity extraction configuration
entity_config = {
    "entity_types": [
        "Aircraft", "Airport", "Airline", "Accident",
        "Manufacturer", "Model", "Location"
    ],
    "relationship_types": [
        "INVOLVED_IN", "OCCURRED_AT", "OPERATED_BY",
        "MANUFACTURED_BY", "HAS_MODEL", "LOCATED_AT"
    ]
}

RELATIONSHIP_TYPES = {
    "INVOLVED_IN": ("Aircraft", "Accident"),
    "OCCURRED_AT": ("Accident", "Location"),
}

In [None]:
import spacy
from sentence_transformers import SentenceTransformer


class EntityExtractor:
    def __init__(self):
        self.nlp = spacy.load("en_core_web_sm")
        self.sentence_model = SentenceTransformer("all-MiniLM-L6-v2")

    # ---------------------------------------------------
    # Helper functions
    # ---------------------------------------------------
    def _val(self, row: pd.Series, logical_name: str):
        col = COLUMN_MAP.get(logical_name)
        if col and col in row and pd.notna(row[col]):
            return row[col]
        return None

    def _get_registration(self, row: pd.Series):
        """
        Prefer regis_no, fallback to n
        """
        reg = self._val(row, "registration_number")
        if reg:
            return str(reg).strip()
        alt = self._val(row, "alt_registration_number")
        if alt:
            return str(alt).strip()
        return None

    # ---------------------------------------------------
    # Entity Extraction
    # ---------------------------------------------------
    def extract_entities_from_row(self, row: pd.Series) -> List[AviationEntity]:
        entities = []

        # ---------------- Aircraft ----------------
        registration = self._get_registration(row)

        if registration:
            aircraft = Aircraft(
                entity_id=f"aircraft_{registration}",
                registration_number=registration,
                make=self._val(row, "make"),
                model=self._val(row, "model"),
                aircraft_category=self._val(row, "aircraft_category"),
                number_of_engines=self._val(row, "number_of_engines"),
                engine_type=None,
                attributes={
                    "series": self._val(row, "series"),
                    "purpose_of_flight": self._val(row, "purpose_of_flight"),
                    "total_seats": self._val(row, "total_seats"),
                    "owner": self._val(row, "owner"),
                },
            )
            entities.append(aircraft)

        # ---------------- Location / Airport (event-centric) ----------------
        if self._val(row, "city") or self._val(row, "state"):
            location_id = f"location_{self._val(row, 'city')}_{self._val(row, 'state')}"
            location = Airport(
                entity_id=location_id.replace(" ", "_").upper(),
                airport_code=None,
                airport_name=None,
                location=f"{self._val(row, 'city')}, {self._val(row, 'state')}",
                country=self._val(row, "country"),
                latitude=self._val(row, "latitude"),
                longitude=self._val(row, "longitude"),
            )
            entities.append(location)

        # ---------------- Accident ----------------
        accident = Accident(
            entity_id=f"accident_{self._val(row, 'event_id')}",
            event_id=self._val(row, "event_id"),
            event_date=self._val(row, "event_date"),
            injury_severity=self._val(row, "injury_severity"),
            aircraft_damage=self._val(row, "aircraft_damage"),
            weather_condition=self._val(row, "weather_condition"),
            phase_of_flight=None,
            attributes={
                "event_year": self._val(row, "event_year"),
                "event_month": self._val(row, "event_month"),
                "event_dayofweek": self._val(row, "event_dayofweek"),
                "fatal_count": self._val(row, "fatal_count"),
                "serious_count": self._val(row, "serious_count"),
                "minor_count": self._val(row, "minor_count"),
                "total_injuries": self._val(row, "total_injuries"),
                "has_fatality": self._val(row, "has_fatality"),
                "probable_cause": self._val(row, "probable_cause"),
            },
        )
        entities.append(accident)

        return entities

    # ---------------------------------------------------
    # Relationship Generation
    # ---------------------------------------------------
    def generate_relationships(
        self, entities: List[AviationEntity], row: pd.Series
    ) -> List[Dict[str, Any]]:
        relationships = []

        accident_id = f"accident_{self._val(row, 'event_id')}"
        registration = self._get_registration(row)
        aircraft_id = f"aircraft_{registration}" if registration else None

        location_id = None
        if self._val(row, "city") or self._val(row, "state"):
            location_id = (
                f"location_{self._val(row, 'city')}_{self._val(row, 'state')}"
                .replace(" ", "_")
                .upper()
            )

        # Aircraft INVOLVED_IN Accident
        if aircraft_id:
            relationships.append(
                {
                    "from_id": aircraft_id,
                    "to_id": accident_id,
                    "relationship_type": "INVOLVED_IN",
                    "attributes": {},
                }
            )

        # Accident OCCURRED_AT Location
        if location_id:
            relationships.append(
                {
                    "from_id": accident_id,
                    "to_id": location_id,
                    "relationship_type": "OCCURRED_AT",
                    "attributes": {},
                }
            )

        return relationships

extractor = EntityExtractor()

print("Extracting entities from NTSB data...")
all_entities = []
all_relationships = []

for idx, row in ntsb_df.iterrows():
  try:
    entities = extractor.extract_entities_from_row(row)
    relationships = extractor.generate_relationships(entities, row)

    all_entities.extend(entities)
    all_relationships.extend(relationships)

    if idx < 3:  # Show first 3 for demonstration
            print(f"\nRow {idx} Entities:")
            for entity in entities:
                print(f"  - {entity.entity_type}: {entity.entity_id}")

            print(f"Relationships:")
            for rel in relationships:
                print(f"  - {rel['from_id']} --{rel['relationship_type']}--> {rel['to_id']}")

  except Exception as e:
    print(f"Error processing row {idx}: {e}")


print(f"\nTotal entities extracted: {len(all_entities)}")
print(f"Total relationships: {len(all_relationships)}")

Extracting entities from NTSB data...

Row 0 Entities:
  - Aircraft: aircraft_N1294P
  - Airport: LOCATION_CHATHAM_MASSACHUSETTS
  - Accident: accident_ERA26LA059
Relationships:
  - aircraft_N1294P --INVOLVED_IN--> accident_ERA26LA059
  - accident_ERA26LA059 --OCCURRED_AT--> LOCATION_CHATHAM_MASSACHUSETTS

Row 1 Entities:
  - Aircraft: aircraft_VH-EWS, VH-NMG
  - Airport: LOCATION_NEW_SOUTH_WALES._NONE
  - Accident: accident_GAA26WA034
Relationships:
  - aircraft_VH-EWS, VH-NMG --INVOLVED_IN--> accident_GAA26WA034
  - accident_GAA26WA034 --OCCURRED_AT--> LOCATION_NEW_SOUTH_WALES._NONE

Row 2 Entities:
  - Aircraft: aircraft_F-HEAT
  - Airport: LOCATION_MALAWA_NONE
  - Accident: accident_GAA26WA041
Relationships:
  - aircraft_F-HEAT --INVOLVED_IN--> accident_GAA26WA041
  - accident_GAA26WA041 --OCCURRED_AT--> LOCATION_MALAWA_NONE

Total entities extracted: 205922
Total relationships: 137241


##Step 4: Entity Resolution and Disambiguation

This code implements entity resolution and deduplication for the aviation knowledge graph. In simple terms, it answers:
“Are these two entities actually the same thing, even if they are written differently?”

However,
It does not merge entities automatically
It does not update Neo4j
It does not enforce constraints

It assumes manual merging.

In [None]:
!pip install faiss-cpu

import random
all_ent=random.sample(all_entities,2000)



In [None]:
import faiss
import re

class EntityResolver:
    def __init__(self):
        self.sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
        self.entity_embeddings = {}
        self.similarity_index = None

    def create_aviation_glossary(self):
        """Create aviation-specific terminology for disambiguation"""
        glossary = {
            'aircraft_models': {
                '737': ['737-800', '737-900', '737 MAX', 'B737'],
                'A320': ['A320-214', 'A320-200', 'A320NEO'],
                '777': ['777-200', '777-300', '777X'],
                'A350': ['A350-900', 'A350-1000']
            },
            'manufacturers': {
                'BOEING': ['Boeing Company', 'Boeing'],
                'AIRBUS': ['Airbus SAS', 'Airbus'],
                'CESSNA': ['Cessna Aircraft Company'],
                'EMBRAER': ['Embraer SA']
            },
            'airlines': {
                'UNITED AIRLINES': ['United', 'UAL'],
                'AMERICAN AIRLINES': ['American', 'AAL'],
                'DELTA AIR LINES': ['Delta', 'DAL'],
                'SOUTHWEST AIRLINES': ['Southwest', 'SWA']
            }
        }
        return glossary

    def normalize_entity_name(self, name: str) -> str:
        """Normalize entity names for consistent matching"""
        if not name:
            return ""

        # Convert to uppercase and remove special characters
        normalized = re.sub(r'[^a-zA-Z0-9\s]', '', str(name).upper())

        # Remove extra whitespace
        normalized = ' '.join(normalized.split())

        return normalized

    def resolve_aircraft_models(self, model_name: str) -> str:
        """Resolve aircraft model variations to canonical form"""
        glossary = self.create_aviation_glossary()
        normalized = self.normalize_entity_name(model_name)

        for canonical, variations in glossary['aircraft_models'].items():
            if normalized == canonical or normalized in variations:
                return canonical

        return normalized

    def resolve_manufacturer(self, manufacturer: str) -> str:
        """Resolve manufacturer name variations"""
        glossary = self.create_aviation_glossary()
        normalized = self.normalize_entity_name(manufacturer)

        for canonical, variations in glossary['manufacturers'].items():
            if normalized == canonical or normalized in variations:
                return canonical

        return normalized

    def build_similarity_index(self, entities: List[AviationEntity]):
        """Build FAISS index for entity similarity search"""
        entity_texts = []
        entity_ids = []

        for entity in entities:
            # Create text representation for embedding
            if entity.entity_type == "Aircraft":
                text = f"{entity.make} {entity.model} {entity.registration_number}"
            elif entity.entity_type == "Airport":
                text = f"{entity.airport_name} {entity.airport_code} {entity.location}"
            elif entity.entity_type == "Airline":
                text = f"{entity.airline_name}"
            else:
                text = f"{entity.entity_type} {entity.entity_id}"

            entity_texts.append(text)
            entity_ids.append(entity.entity_id)

        # Generate embeddings
        embeddings = self.sentence_model.encode(entity_texts, normalize_embeddings=True)

        # Create FAISS index
        dimension = embeddings.shape[1]
        self.similarity_index = faiss.IndexFlatIP(dimension)  # Inner product for cosine similarity
        self.similarity_index.add(embeddings.astype('float32'))

        # Store mapping
        self.entity_id_mapping = entity_ids
        self.entity_embeddings = {eid: emb for eid, emb in zip(entity_ids, embeddings)}

        return self.similarity_index

    def find_similar_entities(self, query_entity: AviationEntity, threshold: float = 0.8) -> List[tuple]:
        """Find similar entities using semantic similarity"""
        if not self.similarity_index:
            return []

        # Create query embedding
        if query_entity.entity_type == "Aircraft":
            query_text = f"{query_entity.make} {query_entity.model} {query_entity.registration_number}"
        elif query_entity.entity_type == "Airport":
            query_text = f"{query_entity.airport_name} {query_entity.airport_code} {query_entity.location}"
        elif query_entity.entity_type == "Airline":
            query_text = f"{query_entity.airline_name}"
        else:
            query_text = f"{query_entity.entity_type} {query_entity.entity_id}"

        query_embedding = self.sentence_model.encode([query_text], normalize_embeddings=True)

        # Search
        k = 10  # number of nearest neighbors
        similarities, indices = self.similarity_index.search(query_embedding.astype('float32'), k)

        similar_entities = []
        for i, (score, idx) in enumerate(zip(similarities[0], indices[0])):
            if score >= threshold and idx < len(self.entity_id_mapping):
                similar_entity_id = self.entity_id_mapping[idx]
                similar_entities.append((similar_entity_id, float(score)))

        return similar_entities

# Initialize resolver and build similarity index
resolver = EntityResolver()
print("Building entity similarity index...")
resolver.build_similarity_index(all_ent)#change

# Test entity resolution
print("\nTesting entity resolution...")
test_entities = all_entities[:5]  # Test with first 5 entities

for entity in test_entities:
    print(f"\nEntity: {entity.entity_type} - {entity.entity_id}")

    # Resolve manufacturer if it's an aircraft
    if entity.entity_type == "Aircraft":
        resolved_make = resolver.resolve_manufacturer(entity.make)
        if resolved_make != entity.make:
            print(f"  Resolved manufacturer: {entity.make} -> {resolved_make}")

    # Find similar entities
    similar = resolver.find_similar_entities(entity, threshold=0.7)
    if similar:
        print(f"  Similar entities found: {len(similar)}")
        for similar_id, score in similar[:3]:  # Show top 3
            print(f"    - {similar_id} (score: {score:.3f})")

Building entity similarity index...

Testing entity resolution...

Entity: Aircraft - aircraft_N1294P
  Similar entities found: 1
    - aircraft_N167ZP (score: 0.706)

Entity: Airport - LOCATION_CHATHAM_MASSACHUSETTS
  Similar entities found: 10
    - LOCATION_PROVINCETOWN_MASSACHUSETTS (score: 0.774)
    - LOCATION_HANCOCK_MICHIGAN (score: 0.747)
    - LOCATION_PLYMOUTH_MASSACHUSETTS (score: 0.746)

Entity: Accident - accident_ERA26LA059
  Similar entities found: 10
    - accident_ERA09LA441 (score: 0.968)
    - accident_ERA24LA031 (score: 0.966)
    - accident_ERA21LA068 (score: 0.962)

Entity: Aircraft - aircraft_VH-EWS, VH-NMG
  Resolved manufacturer: VANS, VANS -> VANS VANS
  Similar entities found: 4
    - aircraft_N836JC (score: 0.813)
    - aircraft_N475AH (score: 0.792)
    - aircraft_N387E (score: 0.715)

Entity: Airport - LOCATION_NEW_SOUTH_WALES._NONE
  Similar entities found: 3
    - LOCATION_WAGGA_WAGGA,_NEW_SOUTH_WALES_NONE (score: 0.861)
    - LOCATION_AUCKLAND_NONE (sc

##Step 5: Export for Neo4j Import

In [None]:
def prepare_neo4j_import(entities: List[AviationEntity], relationships: List[Dict]):
    """Prepare data for Neo4j import"""

    # Nodes data
    nodes_data = []
    for entity in entities:
        node = {
            'entity_id': entity.entity_id,
            'entity_type': entity.entity_type,
            **entity.attributes,
            'source_dataset': entity.source_dataset,
            'confidence_score': entity.confidence_score
        }

        # Add entity-specific fields
        if entity.entity_type == "Aircraft":
            node.update({
                'registration_number': entity.registration_number,
                'make': entity.make,
                'model': entity.model,
                'aircraft_category': entity.aircraft_category
            })
        elif entity.entity_type == "Airport":
            node.update({
                'airport_code': entity.airport_code,
                'airport_name': entity.airport_name,
                'location': entity.location
            })
        elif entity.entity_type == "Airline":
            node.update({
                'airline_name': entity.airline_name
            })
        elif entity.entity_type == "Accident":
            node.update({
                'event_id': entity.event_id,
                'event_date': entity.event_date,
                'injury_severity': entity.injury_severity
            })

        nodes_data.append(node)

    # Relationships data
    rels_data = []
    for rel in relationships:
        rel_data = {
            'from_id': rel['from_id'],
            'to_id': rel['to_id'],
            'relationship_type': rel['relationship_type'],
            **rel.get('attributes', {})
        }
        rels_data.append(rel_data)

    return {
        'nodes': nodes_data,
        'relationships': rels_data
    }

# Prepare data for Neo4j
import_data = prepare_neo4j_import(all_entities, all_relationships)

print("Neo4j Import Data Summary:")
print(f"Nodes: {len(import_data['nodes'])}")
print(f"Relationships: {len(import_data['relationships'])}")

# Show sample nodes
print("\nSample Nodes:")
for node in import_data['nodes'][:3]:
    print(f"  - {node['entity_type']}: {node['entity_id']}")

print("\nSample Relationships:")
for rel in import_data['relationships'][:3]:
    print(f"  - {rel['from_id']} --{rel['relationship_type']}--> {rel['to_id']}")

# Export to files
def export_to_files(import_data, prefix='aviation_kg'):
    """Export data to CSV files for Neo4j import"""

    # Export nodes
    nodes_df = pd.DataFrame(import_data['nodes'])
    nodes_file = f'{prefix}_nodes.csv'
    nodes_df.to_csv(nodes_file, index=False)

    # Export relationships
    rels_df = pd.DataFrame(import_data['relationships'])
    rels_file = f'{prefix}_relationships.csv'
    rels_df.to_csv(rels_file, index=False)

    return nodes_file, rels_file

# Export data
nodes_file, rels_file = export_to_files(import_data)
print(f"\nExported files:")
print(f"Nodes: {nodes_file}")
print(f"Relationships: {rels_file}")

# Generate Cypher import script
def generate_cypher_import_script(nodes_file, rels_file):
    """Generate Cypher script for importing data into Neo4j"""

    cypher_script = """
// Create constraints for unique entities
CREATE CONSTRAINT aircraft_id IF NOT EXISTS FOR (a:Aircraft) REQUIRE a.entity_id IS UNIQUE;
CREATE CONSTRAINT airport_id IF NOT EXISTS FOR (ap:Airport) REQUIRE a.entity_id IS UNIQUE;
CREATE CONSTRAINT airline_id IF NOT EXISTS FOR (al:Airline) REQUIRE a.entity_id IS UNIQUE;
CREATE CONSTRAINT accident_id IF NOT EXISTS FOR (ac:Accident) REQUIRE a.entity_id IS UNIQUE;
CREATE CONSTRAINT manufacturer_id IF NOT EXISTS FOR (m:Manufacturer) REQUIRE a.entity_id IS UNIQUE;

// Load nodes
LOAD CSV WITH HEADERS FROM 'file:///""" + nodes_file + """' AS row
CALL apoc.create.node([row.entity_type], {
    entity_id: row.entity_id,
    source_dataset: row.source_dataset,
    confidence_score: toFloat(row.confidence_score)
    // Additional properties will be added based on entity_type
}) YIELD node
SET node += apoc.map.clean(row, ['entity_id', 'entity_type', 'source_dataset', 'confidence_score'], []);

// Load relationships
LOAD CSV WITH HEADERS FROM 'file:///""" + rels_file + """' AS row
MATCH (from {entity_id: row.from_id})
MATCH (to {entity_id: row.to_id})
CALL apoc.create.relationship(from, row.relationship_type, apoc.map.clean(row, ['from_id', 'to_id', 'relationship_type'], []), to)
YIELD rel
RETURN count(rel);
"""

    return cypher_script

cypher_script = generate_cypher_import_script(nodes_file, rels_file)
print("\nGenerated Cypher Import Script:")
print("=" * 50)
print(cypher_script)

# Save Cypher script to file
with open('neo4j_import.cypher', 'w') as f:
    f.write(cypher_script)

print("\nPhase 1 Complete!")
print("✓ NTSB data loaded and processed")
print("✓ Entities extracted and normalized")
print("✓ Entity resolution applied")
print("✓ Similarity indexing built")
print("✓ Neo4j import files generated")
print("✓ Cypher import script created")

Neo4j Import Data Summary:
Nodes: 205922
Relationships: 137241

Sample Nodes:
  - Aircraft: aircraft_N1294P
  - Airport: LOCATION_CHATHAM_MASSACHUSETTS
  - Accident: accident_ERA26LA059

Sample Relationships:
  - aircraft_N1294P --INVOLVED_IN--> accident_ERA26LA059
  - accident_ERA26LA059 --OCCURRED_AT--> LOCATION_CHATHAM_MASSACHUSETTS
  - aircraft_VH-EWS, VH-NMG --INVOLVED_IN--> accident_GAA26WA034

Exported files:
Nodes: aviation_kg_nodes.csv
Relationships: aviation_kg_relationships.csv

Generated Cypher Import Script:

// Create constraints for unique entities
CREATE CONSTRAINT aircraft_id IF NOT EXISTS FOR (a:Aircraft) REQUIRE a.entity_id IS UNIQUE;
CREATE CONSTRAINT airport_id IF NOT EXISTS FOR (ap:Airport) REQUIRE a.entity_id IS UNIQUE;
CREATE CONSTRAINT airline_id IF NOT EXISTS FOR (al:Airline) REQUIRE a.entity_id IS UNIQUE;
CREATE CONSTRAINT accident_id IF NOT EXISTS FOR (ac:Accident) REQUIRE a.entity_id IS UNIQUE;
CREATE CONSTRAINT manufacturer_id IF NOT EXISTS FOR (m:Manufact