In [None]:
import os
import pandas as pd
import numpy as np
from PyPDF2 import PdfReader
from sentence_transformers import SentenceTransformer
import torch

import faiss
import spacy
from transformers import AutoTokenizer, AutoModel, pipeline
from neo4j import GraphDatabase
from langchain.vectorstores import FAISS
from langchain.docstore import InMemoryDocstore
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain_community.llms import Ollama

from langchain_core.retrievers import BaseRetriever
from langchain_core.documents import Document
from typing import List, Dict
from pydantic import BaseModel, validator

In [None]:
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Load SpaCy model
nlp = spacy.load("en_core_web_sm")

# Load HuggingFace pipeline for relationship extraction
rel_extractor = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")

In [None]:
# Function to load and preprocess data from various sources
def load_data():
    # Load CSV files
    united_dates_locations = pd.read_csv("data/united_dates_locations.csv")
    alliance_dates_locations = pd.read_csv("data/alliance_dates_locations.csv")
    air_canada_dates_locations = pd.read_csv("data/air_canada_dates_locations.csv")

    # Load text files
    with open("data/united_aircraft_details.txt", "r") as file:
        united_aircraft_details_content = file.read().split('\n\n')
    with open("data/alliance_aircraft_details.txt", "r") as file:
        alliance_aircraft_details_content = file.read().split('\n\n')
    with open("data/air_canada_aircraft_details.txt", "r") as file:
        air_canada_aircraft_details_content = file.read().split('\n\n')

    # Extract text from PDF files
    def extract_text_from_pdf(pdf_path):
        text = []
        with open(pdf_path, "rb") as file:
            reader = PdfReader(file)
            for page_num in range(len(reader.pages)):
                text.append(reader.pages[page_num].extract_text())
        return "\n".join(text)

    united_pdf_content = extract_text_from_pdf("data/united_accident_outcomes.pdf")
    alliance_pdf_content = extract_text_from_pdf("data/alliance_accident_outcomes.pdf")
    air_canada_pdf_content = extract_text_from_pdf("data/air_canada_accident_outcomes.pdf")

    # Combine data
    united_data = united_dates_locations['summary'].tolist() + united_aircraft_details_content + [united_pdf_content]
    alliance_data = alliance_dates_locations['summary'].tolist() + alliance_aircraft_details_content + [alliance_pdf_content]
    air_canada_data = air_canada_dates_locations['summary'].tolist() + air_canada_aircraft_details_content + [air_canada_pdf_content]

    return united_data, alliance_data, air_canada_data

# Load the data
united_data, alliance_data, air_canada_data = load_data()

In [None]:
# Load the tokenizer and model for embedding
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

def embed_text(text):
    inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True)
    with torch.no_grad():
        outputs = model(**inputs)
    return outputs.last_hidden_state.mean(dim=1).numpy()

def create_faiss_retriever(index_path, model_name, data_list):
    try:
        embeddings = HuggingFaceEmbeddings(model_name=model_name)
        index = faiss.read_index(index_path)
        
        # Initialize the document store and populate with documents
        docs = {}
        index_to_docstore_id = {}
        
        for i, text in enumerate(data_list):
            doc_id = str(i)
            docs[doc_id] = Document(page_content=text)
            index_to_docstore_id[i] = doc_id
        
        docstore = InMemoryDocstore(docs)
        
        return FAISS(embedding_function=embeddings, index=index, docstore=docstore, index_to_docstore_id=index_to_docstore_id)
    except Exception as e:
        logger.error(f"Error creating FAISS retriever: {e}")
        raise

# Create FAISS indexes
def vectorize_and_index(data_list, index_path):
    try:
        embeddings = np.vstack([embed_text(text) for text in data_list])
        create_faiss_index(embeddings, index_path)
    except Exception as e:
        logger.error(f"Error in vectorize_and_index: {e}")
        raise

def create_faiss_index(data, index_path):
    try:
        d = data.shape[1]
        index = faiss.IndexFlatL2(d)
        index.add(data.astype('float32'))
        faiss.write_index(index, index_path)
        logger.info(f"FAISS index created at {index_path}")
    except Exception as e:
        logger.error(f"Error creating FAISS index: {e}")
        raise

# Generate embeddings and create FAISS indexes
vectorize_and_index(united_data, "faiss_indexes/united_faiss.index")
vectorize_and_index(alliance_data, "faiss_indexes/alliance_faiss.index")
vectorize_and_index(air_canada_data, "faiss_indexes/air_canada_faiss.index")

# Load FAISS indexes and create retrievers
united_faiss = create_faiss_retriever("faiss_indexes/united_faiss.index", "sentence-transformers/all-MiniLM-L6-v2", united_data)
alliance_faiss = create_faiss_retriever("faiss_indexes/alliance_faiss.index", "sentence-transformers/all-MiniLM-L6-v2", alliance_data)
air_canada_faiss = create_faiss_retriever("faiss_indexes/air_canada_faiss.index", "sentence-transformers/all-MiniLM-L6-v2", air_canada_data)

In [None]:
def extract_entities_and_relationships(text):
    doc = nlp(text)
    entities = [(ent.text, ent.label_) for ent in doc.ents]
    relationships = []
    for sent in doc.sents:
        if len(sent.ents) > 1:
            pairs = [(ent1.text, ent2.text) for ent1 in sent.ents for ent2 in sent.ents if ent1 != ent2]
            for pair in pairs:
                rel = rel_extractor(" ".join([pair[0], pair[1]]), candidate_labels=["caused by", "led to", "related to"])
                relationships.append((pair[0], rel['labels'][0], pair[1]))
    return entities, relationships

class GraphDBClient:
    def __init__(self, uri, user, password):
        self.uri = uri
        self.username = user
        self.password = password
        self.driver = None
        try:
            self.connect()
        except Exception as e:
            logger.error(f"Failed to connect to Neo4j: {e}")
            raise ConnectionError(f"Failed to connect to Neo4j: {e}")

    def connect(self):
        self.driver = GraphDatabase.driver(self.uri, auth=(self.username, self.password))
        logger.info("Successfully connected to Neo4j.")

    def query(self, query, parameters=None):
        try:
            with self.driver.session() as session:
                result = session.run(query, parameters)
                return [record for record in result]
        except Exception as e:
            logger.error(f"Error executing query in Neo4j: {e}")
            raise

# Usage
try:
    graph_client = GraphDBClient(os.getenv('NEO4J_URI'), os.getenv('NEO4J_USERNAME'), os.getenv('NEO4J_PASSWORD'))
    results = graph_client.query("MATCH (n) RETURN n LIMIT 5")
    logger.info(f"Query results: {results}")
except Exception as e:
    logger.error(f"An error occurred during Neo4j operations: {e}")

# Connect to the Neo4j database
graphdb_client = GraphDBClient(uri=os.getenv('NEO4J_URI'), user=os.getenv('NEO4J_USERNAME'), password=os.getenv('NEO4J_PASSWORD'))

def create_knowledge_graph(conn, datasets):
    for airline, data in datasets.items():
        for entry in data:
            entity_name = entry  # Using the text entry as the entity name
            
            # Extract entities and relationships
            entities, relationships = extract_entities_and_relationships(entity_name)
            
            # Create Airline node
            conn.query("""
            MERGE (a:Airline {name: $airline})
            RETURN a
            """, parameters={'airline': airline})
            logger.info(f"Airline node created for {airline}")
            
            # Create Accident node
            conn.query("""
            MERGE (acc:Accident {name: $entity_name, content: $content})
            RETURN acc
            """, parameters={'entity_name': entity_name, 'content': entity_name})
            logger.info(f"Accident node created for {entity_name}")
            
            # Create relationships between Airline and Accident
            conn.query("""
            MATCH (a:Airline {name: $airline})
            MATCH (acc:Accident {name: $entity_name})
            MERGE (a)-[:HAS_ACCIDENT]->(acc)
            """, parameters={'airline': airline, 'entity_name': entity_name})
            logger.info(f"Relationship created between Airline {airline} and Accident {entity_name}")
            
            # Create entities and relationships to the Accident
            for entity_text, entity_label in entities:
                conn.query("""
                MERGE (e:Entity {name: $entity, type: $type})
                MERGE (acc:Accident {name: $entity_name})-[:MENTIONS]->(e)
                """, parameters={'entity': entity_text, 'type': entity_label, 'entity_name': entity_name})
                logger.info(f"Entity node created for {entity_text} with label {entity_label}, related to Accident {entity_name}")
            
            # Create relationships between entities based on extracted relationships
            for entity1, relation, entity2 in relationships:
                conn.query("""
                MATCH (e1:Entity {name: $entity1}), (e2:Entity {name: $entity2})
                MERGE (e1)-[r:RELATION {type: $relation, source: $source}]->(e2)
                """, parameters={'entity1': entity1, 'entity2': entity2, 'relation': relation, 'source': airline})
                logger.info(f"Relationship {relation} created between Entity {entity1} and Entity {entity2} for Airline {airline}")


# Combine all datasets
datasets = {
    "United_Airlines": united_data,
    "Alliance_Airlines": alliance_data,
    "Air_Canada": air_canada_data
}
create_knowledge_graph(graphdb_client, datasets)

In [None]:
def query_knowledge_graph(client, keyword):
    query = """
    MATCH (e:Entity)-[r]->(related:Entity)
    WHERE e.name =~ $name OR related.name =~ $name
    RETURN e.name AS entity, type(r) AS relationship, related.name AS related_entity
    """
    try:
        result = client.session().run(query, name=keyword)
        return [record for record in result]
    except Exception as e:
        print(f"Error querying knowledge graph: {e}")
        return []
    
class VectorStoreRetriever:
    def __init__(self, index):
        self.index = index  # Assume this is a FAISS index

    def search(self, query_vector, k=10):
        # Ensure query_vector is a numpy array and is in the correct shape
        distances, indices = self.index.search(np.array([query_vector]), k)  # Retrieve the top-k closest vectors
        return distances, indices

In [None]:
def extract_entities(query):
    doc = nlp(query)
    keywords = [token.text for token in doc if token.is_alpha and not token.is_stop]
    return keywords

class CustomMultiRetriever(BaseModel):
    faiss_retrievers: Dict[str, Any]
    knowledge_graph_client: Any

    # If you need to perform any checks or initializations post-creation, use validators or root validators
    @validator('faiss_retrievers')
    def check_faiss_retrievers(cls, value):
        if not value:
            raise ValueError("FAISS retrievers cannot be empty")
        return value

    def get_relevant_documents(self, query):
        # Example method that uses the initialized FAISS retrievers and the knowledge graph client
        results = {}
        for airline, retriever in self.faiss_retrievers.items():
            # Simulate a search operation; ensure the retriever has a method that can handle the search
            distances, indices = retriever.search(query)  # Your retriever must have a search method
            results[airline] = (distances, indices)
        return results

custom_multi_retriever = CustomMultiRetriever(
    faiss_retrievers={
        "United Airlines": united_faiss.as_retriever(),
        "Alliance Airlines": alliance_faiss.as_retriever,
        "Air Canada": air_canada_faiss.as_retriever
    },
    knowledge_graph_client=graphdb_client
)

In [None]:
# Define the prompt template
prompt_template = PromptTemplate(
    template="""
    You are an AI assistant that specializes in providing detailed information about airline accidents. 
    When given a query about a specific flight, you should:
    
    1. Identify the flight number and any other relevant details from the query.
    2. Retrieve specific information about the flight from the provided context, including any relevant accidents or incidents.
    3. Summarize the information in a clear and concise manner.
    4. If there are multiple incidents related to the flight, provide details on each incident separately.
    5. Ensure the response is focused on the specific flight mentioned in the query.

    Use the provided context to generate the response and avoid including unrelated information.

    Context:
    {context}
    """,
    input_variables=["context"]
)

# Initialize the generative model
generative_model = Ollama(model="gemma:7b")

# Create the LLMChain
llm_chain = LLMChain(prompt=prompt_template, llm=generative_model)

def integrate_graph_and_docs(graph_results, retrieved_docs):
    integrated_context = "Retrieved Context:\n"
    if graph_results:
        for result in graph_results:
            entity = result.get('entity', 'Unknown Entity')
            relationship = result.get('relationship', 'Unknown Relationship')
            related_entity = result.get('related_entity', 'Unknown Related Entity')
            integrated_context += f"{entity} {relationship} {related_entity}. "
    else:
        integrated_context += "No relevant data found in the knowledge graph.\n"
        
    if retrieved_docs:
        integrated_context += "\nFrom FAISS:\n"
        for doc in retrieved_docs:
            integrated_context += doc.page_content + "\n"
    else:
        integrated_context += "No relevant data found in the FAISS retrievers.\n"
    
    return integrated_context

def process_query(user_query, custom_multi_retriever, graphdb_client, llm_chain):
    try:
        entities = extract_entities(user_query)
        logger.info(f"Extracted entities: {entities}")

        graph_context = ""
        for entity in entities:
            graph_results = query_knowledge_graph(graphdb_client, entity)
            logger.info(f"Graph results for entity {entity}: {graph_results}")
            if (len(graph_results) > 0):
                graph_context += integrate_graph_and_docs(graph_results, [])
            else:
                graph_context += f"No relevant data found in the knowledge graph for entity: {entity}.\n"

        faiss_docs = custom_multi_retriever.get_relevant_documents(user_query)
        logger.info(f"FAISS results: {faiss_docs}")

        final_context = graph_context + " ".join([doc.page_content for docs in faiss_docs.values() for doc in docs])
        logger.info(f"Final integrated context: {final_context}")

        response = llm_chain({"context": final_context})
        return response
    except Exception as e:
        logger.error(f"Error processing query: {e}")
        raise


In [None]:
# Example usage
user_query = "Tell me about what aircraft faced accident on this day"
response = process_query(user_query)
logger.info(f"Response: {response}")