Repository Structure

Create the following Python modules:

    data_preprocessing.py: Handles data loading and preprocessing.
    neo4j_data_preprocess_ingest.py: Manages Neo4j database connections, schema setup, and data ingestion.
    neo4j_test_functions.py: Contains functions to test and query data from Neo4j.
    model_loading.py: Responsible for loading and testing the LLM model.
    graphqa_functions.py: Sets Neo4j permissions and integrates with GraphQA using LangChain.
    utility_functions.py: Houses shared utility functions.
    main.py: Serves as the entry point to run and test individual modules.

Each module will have:

    A main() function for individual testing.
    A --debug command-line option to enable detailed logging.
    Exception handling to trace and identify issues.

In [11]:
%%writefile ../../src/neo4j_model/modules/utility_functions.py

import os
from dotenv import load_dotenv
from neo4j import GraphDatabase

def load_env_variables():
    """Loads environment variables from .env file and returns as dictionary."""
    load_dotenv('../../.env')
    return {
        'NEO4J_URI': os.getenv("NEO4J_URI"),
        'NEO4J_USERNAME': os.getenv("NEO4J_USERNAME"),
        'NEO4J_PASSWORD': os.getenv("NEO4J_PASSWORD")
    }

def connect_to_neo4j():
    """Connects to the Neo4j database using credentials from environment variables."""
    env_vars = load_env_variables()
    uri = env_vars['NEO4J_URI']
    username = env_vars['NEO4J_USERNAME']
    password = env_vars['NEO4J_PASSWORD']
    
    try:
        driver = GraphDatabase.driver(uri, auth=(username, password))
        driver.verify_connectivity()
        print("Connected to Neo4j successfully.")
    except Exception as e:
        print(f"Failed to connect to Neo4j: {e}")
        raise
    return driver

def main():
    """Main function to test Neo4j connection."""
    print("Testing Neo4j connection...")
    try:
        driver = connect_to_neo4j()
        driver.close()  # Close the connection after testing
        print("Neo4j connection test completed successfully.")
    except Exception as e:
        print("Neo4j connection test failed.")

if __name__ == "__main__":
    main()

Overwriting ../../src/neo4j_model/modules/utility_functions.py


In [12]:
%%writefile ../../src/neo4j_model/modules/data_cleaning.py


import pandas as pd
import os
from dotenv import load_dotenv

def load_data(file_path):
    """Loads data from the specified CSV file path."""
    dataframe = pd.read_csv(file_path)
    return dataframe

def initial_data_check(dataframe, debug=False):
    """Performs initial data checks, including missing values, data types, and unique counts."""
    if debug:
        print("Initial data shape:", dataframe.shape)
        
        # Check for missing values
        missing_values_summary = dataframe.isnull().sum()
        print("Missing Values Summary:\n", missing_values_summary[missing_values_summary > 0])
        
        # Data types of each column
        print("Data Types:\n", dataframe.dtypes)
        
        # Count unique values in key columns
        unique_counts = dataframe.nunique()
        print("Unique Values in Key Columns:\n", unique_counts[['Player', 'Season', 'Team', 'Salary']])
        
        # Check if key columns have unexpected unique values
        print(f"Unique Players: {dataframe['Player'].nunique()}")
        print(f"Unique Seasons: {dataframe['Season'].nunique()}")
        print(f"Unique Teams: {dataframe['Team'].nunique()}")
        print(f"Unique Salary values: {dataframe['Salary'].nunique()}")
        
def clean_data(dataframe, debug=False):
    """Cleans and prepares data for Neo4j ingestion, displaying nulls before and after cleaning."""
    if debug:
        print("Cleaning data...")
    
    # Step 1: Display total nulls before any cleaning
    total_nulls_before = dataframe.isnull().sum().sum()
    if debug:
        print("Total null values before cleaning:", total_nulls_before)
        print("Null values by column before cleaning:\n", dataframe.isnull().sum())

    # Remove '2nd Apron' column if it exists
    if '2nd Apron' in dataframe.columns:
        dataframe = dataframe.drop(columns=['2nd Apron'])
        if debug:
            print("Dropped '2nd Apron' column.")

    # Fill missing values in 'Injury_Periods' with 'Not_injured'
    dataframe['Injury_Periods'] = dataframe['Injury_Periods'].fillna("Not_injured")
    if debug:
        print("Filled missing 'Injury_Periods' with 'Not_injured'.")

    # Step 2: Drop rows with any remaining missing values
    dataframe_cleaned = dataframe.dropna()
    
    # Step 3: Display total nulls after cleaning
    total_nulls_after = dataframe_cleaned.isnull().sum().sum()
    if debug:
        print("Total null values after cleaning:", total_nulls_after)
        print("Null values by column after cleaning:\n", dataframe_cleaned.isnull().sum())
        print("Data shape after dropping remaining NaNs:", dataframe_cleaned.shape)
    
    return dataframe_cleaned


def check_for_duplicates(dataframe, debug=False):
    """Checks and logs duplicate entries in the dataframe."""
    duplicates = dataframe.duplicated(subset=["Player", "Season", "Salary"])
    num_duplicates = duplicates.sum()
    if debug:
        print("Number of duplicate rows:", num_duplicates)
    if num_duplicates > 0:
        print("Duplicate rows based on [Player, Season, Salary]:\n", dataframe[duplicates])

def map_team_ids(dataframe, debug=False):
    """Maps team abbreviations to TeamID."""
    team_id_mapping = {
        "ATL": 1610612737, "BOS": 1610612738, "BKN": 1610612751, "CHA": 1610612766,
        "CHI": 1610612741, "CLE": 1610612739, "DAL": 1610612742, "DEN": 1610612743,
        "DET": 1610612765, "GSW": 1610612744, "HOU": 1610612745, "IND": 1610612754,
        "LAC": 1610612746, "LAL": 1610612747, "MEM": 1610612763, "MIA": 1610612748,
        "MIL": 1610612749, "MIN": 1610612750, "NOP": 1610612740, "NYK": 1610612752,
        "OKC": 1610612760, "ORL": 1610612753, "PHI": 1610612755, "PHX": 1610612756,
        "POR": 1610612757, "SAC": 1610612758, "SAS": 1610612759, "TOR": 1610612761,
        "UTA": 1610612762, "WAS": 1610612764
    }
    dataframe['TeamID'] = dataframe['Team'].map(team_id_mapping)
    if debug:
        unmapped_teams = dataframe['TeamID'].isnull().sum()
        if unmapped_teams > 0:
            print(f"{unmapped_teams} teams were not mapped to TeamIDs. Check team abbreviations.")
    return dataframe

def add_suffixes_to_columns(dataframe, debug=False):
    """Renames statistical columns to include '_total' suffix where needed."""
    suffix_mapping = {
        'PTS': 'PTS_total', 'AST': 'AST_total', 'TRB': 'TRB_total', 'STL': 'STL_total',
        'BLK': 'BLK_total', 'TOV': 'TOV_total', 'PF': 'PF_total', 'WS': 'WS_total',
        'OWS': 'OWS_total', 'DWS': 'DWS_total', 'VORP': 'VORP_total'
    }
    dataframe = dataframe.rename(columns=suffix_mapping)
    if debug:
        print("Applied suffixes to cumulative columns.")
    return dataframe

def final_data_check(dataframe, debug=False):
    """Final data checks before saving for Neo4j ingestion."""
    if debug:
        # Verify all necessary columns exist
        required_columns = ['Player', 'Season', 'TeamID', 'Salary'] + list(dataframe.filter(regex='_total').columns)
        missing_columns = [col for col in required_columns if col not in dataframe.columns]
        if missing_columns:
            print(f"Missing expected columns: {missing_columns}")
        
        # Re-check for missing values
        null_summary = dataframe.isnull().sum()
        if null_summary.any():
            print("Columns with remaining null values:\n", null_summary[null_summary > 0])
        
        # Check if data types are compatible for Neo4j
        print("Final Data Types:\n", dataframe.dtypes)

        # Preview the first few rows of the final dataframe
        print("Preview of cleaned data:\n", dataframe.head())

def check_data_statistics(dataframe, debug=False):
    """Generates descriptive statistics and checks for unique counts to spot anomalies."""
    if debug:
        print("Descriptive Statistics of the Dataset:\n", dataframe.describe(include='all'))
        print("\nUnique Value Counts for Key Columns:")
        unique_counts = dataframe.nunique()
        print(unique_counts[['Player', 'Season', 'Team', 'Salary']])
        print(f"\nUnique Players: {dataframe['Player'].nunique()}")
        print(f"Unique Seasons: {dataframe['Season'].nunique()}")
        print(f"Unique Teams: {dataframe['Team'].nunique()}")
        print(f"Unique Contracts (based on Salary): {dataframe['Salary'].nunique()}")

def identify_conflicting_contracts(dataframe, debug=False):
    """Identifies conflicting contracts and displays detailed entries for conflicts."""
    conflicting_contracts = dataframe.groupby(['Player', 'Season', 'Salary']).size().reset_index(name='count')
    conflicting_contracts = conflicting_contracts[conflicting_contracts['count'] > 1]
    
    if not conflicting_contracts.empty:
        print("Potential Conflicting Contracts Found:")
        print(conflicting_contracts)
        for _, row in conflicting_contracts.iterrows():
            player, season, salary = row['Player'], row['Season'], row['Salary']
            print(f"\nDetails of Conflicting Contracts for Player: {player}, Season: {season}, Salary: {salary}")
            print(dataframe[(dataframe['Player'] == player) & 
                            (dataframe['Season'] == season) & 
                            (dataframe['Salary'] == salary)])

def data_initial_summary(dataframe, debug=False):
    """Performs initial checks and calls the detailed statistics and conflict identification functions."""
    initial_data_check(dataframe, debug)
    check_data_statistics(dataframe, debug)
    identify_conflicting_contracts(dataframe, debug)


def main(debug=True):
    load_dotenv('/workspaces/custom_ollama_docker/.env')
    data_file = os.getenv('DATA_FILE', '/workspaces/custom_ollama_docker/data/neo4j/raw/nba_player_data_final_inflated.csv')
    output_file = os.getenv('CLEANED_DATA_FILE', '/workspaces/custom_ollama_docker/data/neo4j/processed/nba_player_data_cleaned.csv')
    
    try:
        dataframe = load_data(data_file)
        print(f"Loaded data with shape: {dataframe.shape}")
        
        # Step 2: Initial checks, data statistics, and conflict identification
        data_initial_summary(dataframe, debug)
        
        # Clean data, check duplicates, map IDs, add suffixes, and run final checks
        dataframe_cleaned = clean_data(dataframe, debug)
        check_for_duplicates(dataframe_cleaned, debug)
        dataframe_mapped = map_team_ids(dataframe_cleaned, debug)
        dataframe_final = add_suffixes_to_columns(dataframe_mapped, debug)
        final_data_check(dataframe_final, debug)
        
        dataframe_final.to_csv(output_file, index=False)
        print(f"Preprocessed data saved to {output_file}")
    
    except Exception as e:
        print("An error occurred during data preprocessing:", e)


if __name__ == "__main__":
    main(debug=True)


Overwriting ../../src/neo4j_model/modules/data_cleaning.py


In [13]:
%%writefile ../../src/neo4j_model/modules/neo4j_data_preprocess_ingest.py
from utility_functions import load_env_variables, connect_to_neo4j

import pandas as pd
from neo4j import GraphDatabase
import os
from dotenv import load_dotenv



# Function to create constraints only if they don't already exist
def create_constraint_if_not_exists(session, constraint_query, constraint_name):
    check_query = f"SHOW CONSTRAINTS WHERE name = '{constraint_name}'"
    result = session.run(check_query)
    if result.single():
        print(f"Constraint '{constraint_name}' already exists.")
    else:
        session.run(constraint_query)
        print(f"Successfully created constraint: {constraint_name}")



# Function to delete duplicate nodes before creating uniqueness constraints
def delete_duplicate_nodes(session, label, property_name):
    print(f"Deleting duplicate nodes for {label} based on {property_name}...")
    delete_query = f"""
    MATCH (n:{label})
    WITH n.{property_name} AS prop, COLLECT(n) AS nodes
    WHERE SIZE(nodes) > 1
    UNWIND TAIL(nodes) AS duplicateNode
    DETACH DELETE duplicateNode
    """
    session.run(delete_query)
    print(f"Duplicate nodes deleted for {label} based on {property_name}.")

# Setup schema with constraints and cleanup
def setup_schema_with_cleanup(session):
    constraints = [
        {"query": "CREATE CONSTRAINT player_name_unique FOR (p:Player) REQUIRE p.name IS UNIQUE", "name": "player_name_unique"},
        {"query": "CREATE CONSTRAINT team_name_unique FOR (t:Team) REQUIRE t.name IS UNIQUE", "name": "team_name_unique"},
        {"query": "CREATE CONSTRAINT season_name_unique FOR (s:Season) REQUIRE s.name IS UNIQUE", "name": "season_name_unique"},
        {"query": "CREATE CONSTRAINT position_name_unique FOR (pos:Position) REQUIRE pos.name IS UNIQUE", "name": "position_name_unique"},
        {"query": "CREATE CONSTRAINT contract_unique FOR (c:Contract) REQUIRE (c.salary, c.player_name, c.season) IS UNIQUE", "name": "contract_unique"},
    ]

    cleanup_mappings = [
        {"label": "Player", "property": "name"},
        {"label": "Team", "property": "name"},
        {"label": "Season", "property": "name"},
        {"label": "Position", "property": "name"},
        {"label": "Contract", "property": "salary"}
    ]

    for mapping in cleanup_mappings:
        delete_duplicate_nodes(session, mapping["label"], mapping["property"])

    for constraint in constraints:
        create_constraint_if_not_exists(session, constraint["query"], constraint["name"])



def create_player_node(tx, player_data):
    query = """
    MERGE (p:Player {name: $name})
    ON CREATE SET p.age = $age,
                  p.position = $position,
                  p.years_of_service = $years_of_service,
                  p.injury_risk = $injury_risk,
                  p.season_salary = $salary,
                  p.season = $season,
                  p.per = $per,
                  p.ws = $ws,
                  p.bpm = $bpm,
                  p.vorp = $vorp
    """
    tx.run(query,
           name=player_data["Player"],
           age=player_data["Age"],
           position=player_data["Position"],
           years_of_service=player_data["Years of Service"],
           injury_risk=player_data["Injury_Risk"],
           salary=player_data["Salary"],
           season=player_data["Season"],
           per=player_data.get("PER"),
           ws=player_data.get("WS"),
           bpm=player_data.get("BPM"),
           vorp=player_data.get("VORP"))


def create_team_node(tx, team_name, team_id, team_data):
    query = """
    MERGE (t:Team {name: $name})
    ON CREATE SET t.team_id = $team_id,
                  t.needs = $needs,
                  t.strategy = $strategy,
                  t.cap_space = $cap_space
    """
    tx.run(query,
           name=team_name,
           team_id=team_id,
           needs=team_data.get("Needs"),
           strategy=team_data.get("Strategy"),
           cap_space=team_data.get("Cap Space"))


def create_season_node(tx, season):
    query = """
    MERGE (s:Season {name: $season})
    """
    tx.run(query, season=season)


def create_position_node(tx, position):
    query = """
    MERGE (pos:Position {name: $position})
    """
    tx.run(query, position=position)


def create_contract_node(tx, contract_data):
    query = """
    MERGE (c:Contract {player_name: $player_name, season: $season})
    ON CREATE SET c.salary = $salary,
                  c.cap = $cap,
                  c.luxury_tax = $luxury_tax,
                  c.duration = $duration,
                  c.player_option = $player_option,
                  c.team_option = $team_option,
                  c.no_trade_clause = $no_trade_clause
    """
    tx.run(query,
           player_name=contract_data["Player"],
           season=contract_data["Season"],
           salary=contract_data["Salary"],
           cap=contract_data["Salary Cap"],
           luxury_tax=contract_data["Luxury Tax"],
           duration=contract_data.get("Contract Duration"),
           player_option=contract_data.get("Player Option"),
           team_option=contract_data.get("Team Option"),
           no_trade_clause=contract_data.get("No Trade Clause"))


def delete_duplicate_contract_nodes(session):
    delete_query = """
    MATCH (c:Contract)
    WITH c.salary AS salary, c.player_name AS player_name, c.season AS season, COLLECT(c) AS contracts
    WHERE SIZE(contracts) > 1
    UNWIND TAIL(contracts) AS duplicateContract
    DETACH DELETE duplicateContract
    """
    session.run(delete_query)
    print("Duplicate Contract nodes deleted based on salary, player_name, and season.")


def create_statistics_node(tx, player_name, stats_data):
    query = """
    MERGE (stat:Statistics {player: $player, season: $season})
    ON CREATE SET 
        stat.ppg = $pts_total,               // Use total fields
        stat.assists_total = $ast_total,
        stat.rebounds_total = $trb_total,
        stat.steals_total = $stl_total,
        stat.blocks_total = $blk_total,
        stat.turnovers_total = $tov_total,
        stat.personal_fouls_total = $pf_total,
        stat.win_shares_total = $ws_total,
        stat.offensive_win_shares_total = $ows_total,
        stat.defensive_win_shares_total = $dws_total,
        stat.vorp_total = $vorp_total,
        stat.games_played = $games_played  // Retain games played without "total" as it isn't cumulative
    """
    tx.run(query, 
           player=player_name, 
           season=stats_data["Season"], 
           pts_total=stats_data["PTS_total"], 
           ast_total=stats_data["AST_total"], 
           trb_total=stats_data["TRB_total"], 
           stl_total=stats_data["STL_total"],
           blk_total=stats_data["BLK_total"],
           tov_total=stats_data["TOV_total"],
           pf_total=stats_data["PF_total"],
           ws_total=stats_data["WS_total"],
           ows_total=stats_data["OWS_total"],
           dws_total=stats_data["DWS_total"],
           vorp_total=stats_data["VORP_total"],
           games_played=stats_data["GP"])




def create_injury_node(tx, player_name, injury_data):
    if pd.isna(injury_data["Total_Days_Injured"]) or pd.isna(injury_data["Injury_Periods"]) or pd.isna(injury_data["Injury_Risk"]):
        return
    query = """
    MERGE (i:Injury {player: $player})
    ON CREATE SET i.total_days = $total_days,
                  i.injury_periods = $injury_periods,
                  i.risk = $risk,
                  i.injury_history = $injury_history
    """
    tx.run(query,
           player=player_name,
           total_days=injury_data["Total_Days_Injured"],
           injury_periods=injury_data["Injury_Periods"],
           risk=injury_data["Injury_Risk"],
           injury_history=injury_data.get("Injury_History"))


def create_relationships(tx, player_data):
    """Create relationships between Player, Team, Season, Contract, and other nodes in the database."""
    relationships = [
        {
            "query": """
                MATCH (p:Player {name: $player}), (t:Team {name: $team}), (s:Season {name: $season})
                MERGE (p)-[:HAS_PLAYED_FOR {season: $season}]->(t)
                MERGE (p)-[:PARTICIPATED_IN]->(s)
            """,
            "params": {"player": player_data["Player"], "team": player_data["Team"], "season": player_data["Season"]}
        },
        {
            "query": """
                MATCH (p:Player {name: $player}), (c:Contract {salary: $salary, season: $season})
                MERGE (p)-[:HAS_CONTRACT {season: $season}]->(c)
            """,
            "params": {"player": player_data["Player"], "salary": player_data["Salary"], "season": player_data["Season"]}
        },
        {
            "query": """
                MATCH (p:Player {name: $player}), (stat:Statistics {player: $player, season: $season})
                MERGE (p)-[:POSSESSES {season: $season}]->(stat)
            """,
            "params": {"player": player_data["Player"], "season": player_data["Season"]}
        },
        # Additional relationships for injury and current team
    ]
    for rel in relationships:
        tx.run(rel["query"], **rel["params"])
    print(f"Relationships created for Player: {player_data['Player']} for season: {player_data['Season']}.")


def calculate_and_set_trade_value(tx, player_name):
    # Placeholder for actual calculation logic
    trade_value = 0  # Replace with real calculation if needed
    query = """
    MATCH (p:Player {name: $player})
    SET p.trade_value = $trade_value
    """
    tx.run(query, player=player_name, trade_value=trade_value)


# Example query to check indexes
def check_indexes(session):
    result = session.run("CALL db.indexes")
    for record in result:
        print(record)


def clear_database(session):
    delete_query = "MATCH (n) DETACH DELETE n"
    session.run(delete_query)
    print("All nodes and relationships deleted from the database.")


# Function to clear all constraints and indexes
def clear_constraints_and_indexes(session):
    # Delete all constraints
    print("Clearing all constraints...")
    constraints_result = session.run("SHOW CONSTRAINTS")
    for record in constraints_result:
        constraint_name = record['name']
        session.run(f"DROP CONSTRAINT {constraint_name}")
        print(f"Constraint '{constraint_name}' has been deleted.")
    
    # Delete all indexes
    print("Clearing all indexes...")
    indexes_result = session.run("SHOW INDEXES")
    for record in indexes_result:
        index_name = record['name']
        session.run(f"DROP INDEX {index_name}")
        print(f"Index '{index_name}' has been deleted.")


# Function to create indexes if they don't already exist
def create_index_if_not_exists(session, index_query, index_name):
    try:
        check_query = f"SHOW INDEXES WHERE name = '{index_name}'"
        result = session.run(check_query)
        if result.single():
            print(f"Index '{index_name}' already exists.")
        else:
            session.run(index_query)
            print(f"Successfully created index: {index_name}")
    except Exception as e:
        print(f"Failed to create index: {index_name}. Error: {e}")


# Function to set up indexes
def setup_indexes(session):
    indexes = [
        {"query": "CREATE INDEX player_name_index IF NOT EXISTS FOR (p:Player) ON (p.name)", "name": "player_name_index"},
        {"query": "CREATE INDEX team_name_index IF NOT EXISTS FOR (t:Team) ON (t.name)", "name": "team_name_index"},
        {"query": "CREATE INDEX contract_season_index IF NOT EXISTS FOR (c:Contract) ON (c.season)", "name": "contract_season_index"}
    ]

    for index in indexes:
        create_index_if_not_exists(session, index["query"], index["name"])


# Function to insert data into Neo4j
def insert_enhanced_data(tx, player_data):
    create_player_node(tx, player_data)
    create_team_node(tx, player_data["Team"], player_data["TeamID"], player_data)
    create_season_node(tx, player_data["Season"])
    create_contract_node(tx, player_data)
    create_statistics_node(tx, player_data["Player"], player_data)
    create_injury_node(tx, player_data["Player"], player_data)
    create_relationships(tx, player_data)


# Main function to insert data
def main():
    load_dotenv('../../.env')
    data_file = '../../data/neo4j/processed/nba_player_data_cleaned.csv'

    try:
        driver = connect_to_neo4j()
        dataframe = pd.read_csv(data_file)
        data_dicts = dataframe.to_dict(orient='records')
        print(f"Loaded {len(data_dicts)} records from {data_file}.")

        with driver.session() as session:
            clear_database(session)
            clear_constraints_and_indexes(session)
            setup_schema_with_cleanup(session)
            setup_indexes(session)
            print("Database schema and indexes set up successfully.")

            for player_data in data_dicts:
                session.execute_write(insert_enhanced_data, player_data)
                session.execute_write(calculate_and_set_trade_value, player_data["Player"])

            print("Data inserted into Neo4j successfully.")

    except Exception as e:
        print("An error occurred during Neo4j data ingestion:", e)


if __name__ == '__main__':
    main()


Overwriting ../../src/neo4j_model/modules/neo4j_data_preprocess_ingest.py


In [24]:
%%writefile ../../src/neo4j_model/modules/neo4j_test_functions.py
from utility_functions import load_env_variables, connect_to_neo4j
from neo4j import GraphDatabase
import pandas as pd
import os
from dotenv import load_dotenv

def query_to_dataframe(driver, query, parameters=None):
    """Runs a Cypher query and returns a pandas DataFrame."""
    print(f"Running query: {query}")
    print(f"With parameters: {parameters}")
    
    with driver.session() as session:
        result = session.run(query, parameters)
        columns = result.keys()
        data = [record.values() for record in result]
        
        print(f"Query returned {len(data)} records.")
        return pd.DataFrame(data, columns=columns)

def get_player_statistics(driver, season):
    """Retrieves player statistics for a given season."""
    print(f"Fetching player statistics for season: {season}")
    query = """
    MATCH (p:Player)-[:POSSESSES]->(stat:Statistics {season: $season})
    RETURN p.name AS player, stat.ppg AS points_per_game, stat.assists_total AS assists, stat.rebounds_total AS rebounds
    ORDER BY stat.ppg DESC
    LIMIT 10
    """
    parameters = {"season": season}
    return query_to_dataframe(driver, query, parameters)

def get_player_contracts(driver, season):
    """Retrieves player contracts for a given season."""
    print(f"Fetching player contracts for season: {season}")
    query = """
    MATCH (p:Player)-[:HAS_CONTRACT]->(c:Contract {season: $season})
    RETURN p.name AS player, c.salary AS salary, c.duration AS contract_duration
    ORDER BY c.salary DESC
    LIMIT 10
    """
    parameters = {"season": season}
    return query_to_dataframe(driver, query, parameters)

def get_top_teams_by_salary(driver, season):
    """Retrieves teams ranked by total salary for a given season."""
    print(f"Fetching top teams by salary for season: {season}")
    query = """
    MATCH (t:Team)<-[:HAS_PLAYED_FOR]-(p:Player)-[:HAS_CONTRACT]->(c:Contract {season: $season})
    RETURN t.name AS team, SUM(c.salary) AS total_salary
    ORDER BY total_salary DESC
    LIMIT 10
    """
    parameters = {"season": season}
    return query_to_dataframe(driver, query, parameters)

def get_players_with_high_injury_risk(driver):
    """Retrieves players with a high injury risk."""
    print("Fetching players with high injury risk...")
    query = """
    MATCH (p:Player)-[:SUFFERED]->(i:Injury)
    WHERE i.risk >= 0.8
    RETURN p.name AS player, i.total_days AS total_days_injured, i.risk AS injury_risk
    ORDER BY i.risk DESC
    LIMIT 10
    """
    return query_to_dataframe(driver, query)

def get_team_strategies(driver):
    """Retrieves team strategies and associated needs."""
    print("Fetching team strategies and needs...")
    query = """
    MATCH (t:Team)
    RETURN t.name AS team, t.strategy AS strategy, t.needs AS needs
    """
    return query_to_dataframe(driver, query)

def get_top_players_by_vorp(driver, season):
    """Retrieves top players by VORP for a given season."""
    print(f"Fetching top players by VORP for season: {season}")
    query = """
    MATCH (p:Player)-[:POSSESSES]->(stat:Statistics {season: $season})
    RETURN p.name AS player, stat.vorp_total AS vorp
    ORDER BY stat.vorp_total DESC
    LIMIT 10
    """
    parameters = {"season": season}
    return query_to_dataframe(driver, query, parameters)

def main(season='2023-24'):
    """Main function to connect to Neo4j, retrieve data, and print the results."""
    print("Starting Neo4j test function execution...")
    
    try:
        driver = connect_to_neo4j()
        
        print("Retrieving player statistics...")
        player_stats_df = get_player_statistics(driver, season)
        print("Player Statistics:\n", player_stats_df)
        
        print("\nRetrieving player contracts...")
        player_contracts_df = get_player_contracts(driver, season)
        print("Player Contracts:\n", player_contracts_df)
        
        print("\nRetrieving top teams by salary...")
        team_salary_df = get_top_teams_by_salary(driver, season)
        print("Top Teams by Salary:\n", team_salary_df)
        
        print("\nRetrieving players with high injury risk...")
        injury_risk_df = get_players_with_high_injury_risk(driver)
        print("Players with High Injury Risk:\n", injury_risk_df)
        
        print("\nRetrieving team strategies and needs...")
        team_strategy_df = get_team_strategies(driver)
        print("Team Strategies:\n", team_strategy_df)
        
        print("\nRetrieving top players by VORP...")
        top_vorp_df = get_top_players_by_vorp(driver, season)
        print("Top Players by VORP:\n", top_vorp_df)
    
    except Exception as e:
        print(f"An error occurred during Neo4j data testing: {e}")

if __name__ == '__main__':
    main()


Overwriting ../../src/neo4j_model/modules/neo4j_test_functions.py


In [15]:
%%writefile ../../src/neo4j_model/modules/model_loading.py

from langchain_ollama import ChatOllama
from langchain.schema import HumanMessage

def load_llm(model_name):
    """Loads the LLM model."""
    llm = ChatOllama(model=model_name, temperature=0)
    print(f"LLM model '{model_name}' loaded successfully.")
    return llm

def test_llm(llm, prompt):
    """Runs a test prompt and shows the output."""
    response = llm([HumanMessage(content=prompt)])
    print(f"Test Prompt: {prompt}\nLLM Response: {response.content}")
    return response.content

def main(model_name='tomasonjo/llama3-text2cypher-demo', debug=False):
    try:
        if debug:
            print("Debug mode enabled.")
        
        # Step 1: Load LLM
        llm = load_llm(model_name)
        
        # Step 2: Test LLM with a sample prompt
        test_prompt = "Why is the sky blue?"
        response = test_llm(llm, test_prompt)
        
        if debug:
            print("LLM Response Retrieved.")
        return response
    
    except Exception as e:
        print(f"An error occurred during LLM model loading: {e}")

# Example usage without argparse
if __name__ == '__main__':
    # You can call main with debug=True to see additional print statements for debugging
    main(debug=True)


Overwriting ../../src/neo4j_model/modules/model_loading.py


In [16]:
%%writefile ../../src/neo4j_model/modules/graphqa_functions.py
from utility_functions import load_env_variables, connect_to_neo4j

from langchain_ollama import ChatOllama
from langchain.schema import HumanMessage
from langchain_community.chains.graph_qa.cypher import GraphCypherQAChain
from langchain.prompts import PromptTemplate
from langchain_community.graphs import Neo4jGraph
from neo4j import GraphDatabase
from langchain_experimental.utilities import PythonREPL
from langchain_core.tools import Tool
import os
from dotenv import load_dotenv
import re

# Step 2: Setup the GraphCypherQAChain
def setup_graphqa_chain(llm, graph, cypher_prompt):
    """Sets up the GraphCypherQAChain with schema, cypher_prompt, and configurations."""
    
    return GraphCypherQAChain.from_llm(
        cypher_llm=llm,
        qa_llm=llm,
        validate_cypher=True,
        graph=graph,
        verbose=True,
        return_intermediate_steps=True,
        return_direct=True,
        cypher_prompt=cypher_prompt,
        allow_dangerous_requests=True,
    )

# Step 3: Validate the generated Cypher query against the schema
def validate_generated_query(query, driver, schema):
    """Validates the generated Cypher query against the actual schema to identify potential issues."""
    schema_properties = [re.sub(r'\.\n-.*$', '', prop.strip().lower()) for prop in schema.split("Properties: ")[1].split(",")]
    
    with driver.session() as session:
        result = session.run("CALL db.schema.nodeTypeProperties()")
        db_properties = {record["propertyName"].lower() for record in result}
    
    missing_properties = [prop for prop in schema_properties if prop not in db_properties]
    print(f"Missing Properties: {missing_properties}")
    
    return missing_properties

# Step 4: Handle missing properties in the generated query
def handle_missing_properties(missing_properties, query):
    """Removes references to missing properties from the Cypher query."""
    if not missing_properties:
        return query
    
    for prop in missing_properties:
        query = re.sub(fr"stat\.{prop}\s*", '', query)
    
    print(f"Adjusted Query After Removing Missing Properties:\n{query}")
    return query

# Step 5: Add per-game calculations to the query if necessary
def add_per_game_calculations(query):
    """Adjusts the query for per-game calculations by dividing cumulative stats by games played."""
    print(f"Original Query Before Per-Game Calculations:\n{query}")
    
    total_stat_mappings = {
        "ppg": "stat.ppg / stat.games_played",
        "assists_total": "stat.assists_total / stat.games_played",
        "rebounds_total": "stat.rebounds_total / stat.games_played",
        "steals_total": "stat.steals_total / stat.games_played",
        "blocks_total": "stat.blocks_total / stat.games_played",
        "turnovers_total": "stat.turnovers_total / stat.games_played",
        "personal_fouls_total": "stat.personal_fouls_total / stat.games_played",
        "win_shares_total": "stat.win_shares_total / stat.games_played",
        "offensive_win_shares_total": "stat.offensive_win_shares_total / stat.games_played",
        "defensive_win_shares_total": "stat.defensive_win_shares_total / stat.games_played",
        "vorp_total": "stat.vorp_total / stat.games_played",
    }
    
    for total_stat, per_game_stat in total_stat_mappings.items():
        query = re.sub(fr"stat\.{total_stat}", per_game_stat, query, flags=re.IGNORECASE)
    
    print(f"Adjusted Query After Per-Game Calculations:\n{query}")
    return query

# Step 6: Run the Cypher query
def run_cypher_query(query, driver):
    """Executes the Cypher query on the Neo4j database and returns the results."""
    print(f"Executing Query:\n{query}")
    with driver.session() as session:
        result = session.run(query)
        return [record.data() for record in result]


# Step 7: Agent REPL loop for error handling and query debugging
def agent_repl_loop(sample_question, schema, driver, llm, repl_tool, cypher_prompt):
    """Runs the REPL loop to handle query generation and errors, and returns intermediate steps for debugging."""
    # Initialize variables to return intermediate steps
    prompt_text = cypher_prompt.format(schema=schema, question=sample_question)
    llm_response_content = ""
    adjusted_query = ""
    query_result = []

    try:
        # Step 1: Generate the initial Cypher query using the language model
        llm_response = llm([HumanMessage(content=prompt_text)])
        llm_response_content = llm_response.content.strip()
        
        # Display generated query
        print(f"Generated Query:\n{llm_response_content}")

        # Step 2: Add per-game calculations if needed
        adjusted_query = add_per_game_calculations(llm_response_content)
        
        # Step 3: Execute the adjusted Cypher query on Neo4j
        query_result = run_cypher_query(adjusted_query, driver)
        
        # Display query results
        print("Query Results:")
        for record in query_result:
            print(record)

    except Exception as e:
        # Capture error information and use the REPL tool to debug
        print(f"Error occurred: {e}")
        repl_output = repl_tool.func(f"print('{str(e)}')")
        print(f"Python REPL Output for Debugging:\n{repl_output}")
    
    # Return all relevant data for display in the Streamlit app
    return {
        "prompt_text": prompt_text,
        "generated_query": llm_response_content,
        "adjusted_query": adjusted_query,
        "query_result": query_result,
    }


# Step 8: Main function to initialize and run the agent with a test question
def main(question='Who are the top 5 players in the 2023-24 season based on assist total?'):
    try:
        driver = connect_to_neo4j()
        graph = Neo4jGraph(url=os.getenv("NEO4J_URI"), username=os.getenv("NEO4J_USERNAME"), password=os.getenv("NEO4J_PASSWORD"))
        llm = ChatOllama(model='tomasonjo/llama3-text2cypher-demo', temperature=0)
        python_repl = PythonREPL()
        repl_tool = Tool(
            name="python_repl",
            description="A Python shell for executing commands.",
            func=python_repl.run,
        )
        
        # Define schema and cypher_prompt
        schema = """
        Nodes:
        - Player: Represents an NBA player. Properties: name, age, position, years_of_service, injury_risk, season_salary, season, per, ws, bpm, vorp.
        - Team: Represents an NBA team. Properties: name, team_id, needs, strategy, cap_space.
        - Season: Represents a specific NBA season. Properties: name.
        - Contract: Represents player contracts. Properties: player_name, salary, cap, luxury_tax, duration, player_option, team_option, no_trade_clause.
        - Statistics: Represents player statistics. Properties: player, season, ppg, assists_total, rebounds_total, steals_total, blocks_total, turnovers_total, personal_fouls_total, win_shares_total, offensive_win_shares_total, defensive_win_shares_total, vorp_total, games_played.
        - Injury: Represents player injury details. Properties: player, total_days, injury_periods, risk, injury_history.
        """
        
        cypher_prompt = PromptTemplate(
            template="""
            You are a Cypher query expert for a Neo4j database with the following schema:
            
            Schema:
            Nodes:
            - Player: Represents an NBA player with properties like name, assists_total, and other statistics.
            - Team: Represents an NBA team with a property 'season' (e.g., '2023-24').
            
            Relationships:
            - :PARTICIPATED_IN (Player)-[:PARTICIPATED_IN]->(Team): Links players to the teams for a given season.

            Use the schema above to generate a Cypher query that answers the given question.
            Make the query flexible by using case-insensitive matching and partial string matching where appropriate.
            
            Question: {question}
            
            Cypher Query:
            """,
            input_variables=["schema", "question"]
        )

        
        agent_repl_loop(question, schema, driver, llm, repl_tool, cypher_prompt)
    
    except Exception as e:
        print(f"An error occurred during GraphQA operations: {e}")

if __name__ == '__main__':
    main()


Overwriting ../../src/neo4j_model/modules/graphqa_functions.py


In [21]:
%%writefile ../../src/neo4j_model/modules/graphqa_module.py

from langchain_ollama import ChatOllama
from langchain.schema import HumanMessage
from langchain_community.chains.graph_qa.cypher import GraphCypherQAChain
from langchain.prompts import PromptTemplate
from langchain_community.graphs import Neo4jGraph
from neo4j import GraphDatabase
from langchain_experimental.utilities import PythonREPL
from langchain_core.tools import Tool
import os
from dotenv import load_dotenv
import re

# --- SECTION 1: ENVIRONMENT SETUP AND CONNECTIONS ---

def load_environment():
    """Loads environment variables from the .env file."""
    dotenv_path = os.path.join(os.getcwd(), '../../.env')
    load_dotenv(dotenv_path)

def create_neo4j_driver():
    """Initializes the Neo4j driver using credentials from environment variables."""
    uri = os.getenv("NEO4J_URI")
    username = os.getenv("NEO4J_USERNAME")
    password = os.getenv("NEO4J_PASSWORD")
    return GraphDatabase.driver(uri, auth=(username, password))

def initialize_graph_connection():
    """Initializes the Neo4jGraph connection for LangChain's graph QA chain."""
    uri = os.getenv("NEO4J_URI")
    username = os.getenv("NEO4J_USERNAME")
    password = os.getenv("NEO4J_PASSWORD")
    return Neo4jGraph(url=uri, username=username, password=password)


# --- SECTION 2: PROMPT AND LLM INITIALIZATION ---

def initialize_llm():
    """Initializes the ChatOllama model for Cypher query generation."""
    cypher_model = 'tomasonjo/llama3-text2cypher-demo'
    return ChatOllama(model=cypher_model, temperature=0)

def create_cypher_prompt_template():
    """Creates a PromptTemplate for generating Cypher queries based on schema."""
    
    schema = """
    Nodes:
    - Player: Represents an NBA player. Properties: name, age, position, years_of_service, injury_risk, season_salary, season, per, ws, bpm, vorp.
    - Team: Represents an NBA team. Properties: name, team_id, needs, strategy, cap_space.
    - Season: Represents a specific NBA season. Properties: name.
    - Contract: Represents player contracts. Properties: player_name, salary, cap, luxury_tax, duration, player_option, team_option, no_trade_clause.
    - Statistics: Represents player statistics. Properties: player, season, ppg, assists_total, rebounds_total, steals_total, blocks_total, turnovers_total, personal_fouls_total, win_shares_total, offensive_win_shares_total, defensive_win_shares_total, vorp_total, games_played.
    - Injury: Represents player injury details. Properties: player, total_days, injury_periods, risk, injury_history.

    Relationships:
    - Player -[:HAS_PLAYED_FOR]-> Team
    - Player -[:PARTICIPATED_IN]-> Season
    - Player -[:HAS_CONTRACT]-> Contract
    - Player -[:POSSESSES]-> Statistics
    - Player -[:SUFFERED]-> Injury
    - Team -[:HAS_PLAYER]-> Player
    - Team -[:CURRENT_TEAM]-> Player
    """

    return PromptTemplate(
        template=f"""
        You are a Cypher query expert for a Neo4j database with the following schema:
        
        Schema:
        {schema}
        
        Use the schema above to generate a Cypher query that answers the given question.
        Make the query flexible by using case-insensitive matching and partial string matching where appropriate.
        Focus on searching player statistics, contracts, and team details.

        Now, generate a Cypher query for the following question:

        Question: {{question}}
        
        Cypher Query:
        """,
        input_variables=["question"],
    )


# --- SECTION 3: QUERY GENERATION AND MODIFICATION ---

def generate_query(llm, prompt_template, schema, question):
    """Generates a Cypher query from an LLM using a schema and question."""
    prompt_text = prompt_template.format(schema=schema, question=question)
    print(f"Generated Prompt:\n{prompt_text}")
    llm_response = llm([HumanMessage(content=prompt_text)])
    return llm_response.content.strip()

def add_per_game_calculations(query):
    """Adjusts the query for per-game calculations by dividing cumulative stats by games played."""
    print(f"Original Query Before Per-Game Calculations:\n{query}")
    
    total_stat_mappings = {
        "ppg": "stat.ppg / stat.games_played",
        "assists_total": "stat.assists_total / stat.games_played",
        "rebounds_total": "stat.rebounds_total / stat.games_played",
        # Other stat mappings can go here
    }
    
    for total_stat, per_game_stat in total_stat_mappings.items():
        query = re.sub(fr"stat\.{total_stat}", per_game_stat, query, flags=re.IGNORECASE)
    
    print(f"Adjusted Query After Per-Game Calculations:\n{query}")
    return query

def handle_missing_properties(missing_properties, query):
    """Removes references to missing properties from the Cypher query."""
    if not missing_properties:
        return query
    
    for prop in missing_properties:
        query = re.sub(fr"stat\.{prop}\s*", '', query)
    print(f"Adjusted Query After Removing Missing Properties:\n{query}")
    return query


# --- SECTION 4: VALIDATION AND EXECUTION ---

def validate_generated_query(query, driver, schema):
    """Validates the generated Cypher query against the actual schema to identify potential issues."""
    schema_properties = [
        re.sub(r'\.\n-.*$', '', prop.strip().lower())
        for prop in schema.split("Properties: ")[1].split(",")
    ]
    
    with driver.session() as session:
        result = session.run("CALL db.schema.nodeTypeProperties()")
        db_properties = {record["propertyName"].lower() for record in result}
    
    missing_properties = [prop for prop in schema_properties if prop not in db_properties]
    print(f"Missing Properties: {missing_properties}")
    
    return missing_properties

def run_cypher_query(query, driver):
    """Executes the given Cypher query on the Neo4j database and returns the results."""
    print(f"Executing Query:\n{query}")
    with driver.session() as session:
        result = session.run(query)
        return [record.data() for record in result]


# --- SECTION 5: MAIN LOOP AND EXECUTION ---

def agent_repl_loop(question, driver, llm, schema, prompt_template, repl_tool):
    """Runs the REPL loop to handle query generation and debugging autonomously."""
    results = {
        "prompt_text": None,
        "generated_query": None,
        "adjusted_query": None,
        "query_result": None,
    }
    
    try:
        # Generate initial query
        results["prompt_text"] = prompt_template.format(schema=schema, question=question)
        raw_query = generate_query(llm, prompt_template, schema, question)
        print(f"Generated Raw Query:\n{raw_query}")
        
        results["generated_query"] = raw_query

        # Validate and adjust query
        missing_properties = validate_generated_query(raw_query, driver, schema)
        adjusted_query = handle_missing_properties(missing_properties, raw_query)
        final_query = add_per_game_calculations(adjusted_query)
        
        results["adjusted_query"] = final_query

        # Execute query
        query_results = run_cypher_query(final_query, driver)
        print("Query Results:")
        for record in query_results:
            print(record)
        
        results["query_result"] = query_results

    except Exception as e:
        print(f"Error occurred: {e}")
        if repl_tool:
            repl_tool.func(f"print('{str(e)}')")
    
    return results  # Return the results dictionary


# Define the schema for the graph, fixing VORP issue and capitalized properties
schema = """
Nodes:
- Player: Represents an NBA player. Properties: name, age, position, years_of_service, injury_risk, season_salary, season, per, ws, bpm, vorp.
- Team: Represents an NBA team. Properties: name, team_id, needs, strategy, cap_space.
- Season: Represents a specific NBA season. Properties: name.
- Contract: Represents player contracts. Properties: player_name, salary, cap, luxury_tax, duration, player_option, team_option, no_trade_clause.
- Statistics: Represents player statistics. Properties: player, season, ppg, assists_total, rebounds_total, steals_total, blocks_total, turnovers_total, personal_fouls_total, win_shares_total, offensive_win_shares_total, defensive_win_shares_total, vorp_total, games_played.
- Injury: Represents player injury details. Properties: player, total_days, injury_periods, risk, injury_history.

Relationships:
- Player -[:HAS_PLAYED_FOR]-> Team
- Player -[:PARTICIPATED_IN]-> Season
- Player -[:HAS_CONTRACT]-> Contract
- Player -[:POSSESSES]-> Statistics
- Player -[:SUFFERED]-> Injury
- Team -[:HAS_PLAYER]-> Player
- Team -[:CURRENT_TEAM]-> Player
"""

def main():
    """Main function to set up connections, load schema, and run REPL loop."""
    load_environment()
    driver = create_neo4j_driver()
    llm = initialize_llm()
    graph = initialize_graph_connection()
    repl_tool = Tool(name="python_repl", description="Python shell for REPL.", func=PythonREPL().run)
    
    # Sample question
    sample_question = "Who are the top 5 players in the 2023-24 season based on assist total?"
    
    # Prompt setup
    prompt_template = create_cypher_prompt_template()
    agent_repl_loop(sample_question, driver, llm, schema, prompt_template, repl_tool)


if __name__ == "__main__":
    main()


Overwriting ../../src/neo4j_model/modules/graphqa_module.py


In [22]:
%%writefile ../../src/neo4j_model/modules/main.py

from graphqa_module import (
    load_environment,
    create_neo4j_driver,
    initialize_llm,
    create_cypher_prompt_template,
    validate_generated_query,
    add_per_game_calculations,
    handle_missing_properties,
    agent_repl_loop,
)
from utility_functions import connect_to_neo4j
from data_cleaning import (
    load_data, 
    data_initial_summary, 
    clean_data, 
    check_for_duplicates, 
    map_team_ids, 
    add_suffixes_to_columns, 
    final_data_check
)
from neo4j_data_preprocess_ingest import (
    clear_database,
    clear_constraints_and_indexes,
    setup_schema_with_cleanup,
    setup_indexes,
    insert_enhanced_data,
    calculate_and_set_trade_value
)
from neo4j_test_functions import get_player_statistics, get_player_contracts
from model_loading import load_llm, test_llm

import os
from dotenv import load_dotenv
import pandas as pd


def check_neo4j_connection():
    """Tests Neo4j connection using utility functions."""
    print("Testing Neo4j connection...")
    try:
        driver = connect_to_neo4j()
        driver.close()
        print("Neo4j connection successful.")
    except Exception as e:
        print(f"Neo4j connection test failed: {e}")


def run_data_cleaning():
    """Executes data cleaning pipeline."""
    load_dotenv('/workspaces/custom_ollama_docker/.env')
    data_file = os.getenv('DATA_FILE', '/workspaces/custom_ollama_docker/data/neo4j/raw/nba_player_data_final_inflated.csv')
    output_file = os.getenv('CLEANED_DATA_FILE', '/workspaces/custom_ollama_docker/data/neo4j/processed/nba_player_data_cleaned.csv')
    
    try:
        dataframe = load_data(data_file)
        print(f"Loaded data with shape: {dataframe.shape}")
        
        data_initial_summary(dataframe, debug=True)
        
        # Data cleaning and validation pipeline
        dataframe_cleaned = clean_data(dataframe, debug=True)
        check_for_duplicates(dataframe_cleaned, debug=True)
        dataframe_mapped = map_team_ids(dataframe_cleaned, debug=True)
        dataframe_final = add_suffixes_to_columns(dataframe_mapped, debug=True)
        final_data_check(dataframe_final, debug=True)
        
        dataframe_final.to_csv(output_file, index=False)
        print(f"Preprocessed data saved to {output_file}")
    except Exception as e:
        print(f"Data cleaning process failed: {e}")


def run_data_ingestion():
    """Executes data ingestion into Neo4j."""
    load_dotenv('/workspaces/custom_ollama_docker/.env')
    data_file = '/workspaces/custom_ollama_docker/data/neo4j/processed/nba_player_data_cleaned.csv'

    try:
        driver = connect_to_neo4j()
        dataframe = pd.read_csv(data_file)
        data_dicts = dataframe.to_dict(orient='records')
        print(f"Loaded {len(data_dicts)} records from {data_file}.")

        with driver.session() as session:
            clear_database(session)
            clear_constraints_and_indexes(session)
            setup_schema_with_cleanup(session)
            setup_indexes(session)
            print("Database schema and indexes set up successfully.")

            for player_data in data_dicts:
                session.execute_write(insert_enhanced_data, player_data)
                session.execute_write(calculate_and_set_trade_value, player_data["Player"])

            print("Data inserted into Neo4j successfully.")
    except Exception as e:
        print(f"Data ingestion process failed: {e}")


def run_neo4j_test_queries(season='2023-24'):
    """Tests Neo4j data retrieval functions."""
    print(f"Running Neo4j test queries for season {season}...")
    try:
        driver = connect_to_neo4j()
        
        print("Retrieving player statistics...")
        player_stats_df = get_player_statistics(driver, season)
        print("Player Statistics:\n", player_stats_df.head())
        
        print("\nRetrieving player contracts...")
        player_contracts_df = get_player_contracts(driver, season)
        print("Player Contracts:\n", player_contracts_df.head())
    except Exception as e:
        print(f"Neo4j test queries failed: {e}")


def run_model_loading(model_name='tomasonjo/llama3-text2cypher-demo'):
    """Loads and tests the specified LLM model."""
    try:
        llm = load_llm(model_name)
        test_prompt = "Why is the sky blue?"
        response = test_llm(llm, test_prompt)
        print(f"Model test response: {response}")
    except Exception as e:
        print(f"Model loading process failed: {e}")


def run_graphqa(question='Who are the top 5 players in the 2023-24 season based on assist total?'):
    """Runs the GraphQA agent for the specified question."""
    load_environment()
    driver = create_neo4j_driver()
    llm = initialize_llm()
    prompt_template = create_cypher_prompt_template()
    
    try:
        agent_repl_loop(
            question=question,
            driver=driver,
            llm=llm,
            schema=schema,
            prompt_template=prompt_template,
            repl_tool=Tool(name="python_repl", description="Python shell for REPL.", func=PythonREPL().run)
        )
    except Exception as e:
        print(f"GraphQA agent execution failed: {e}")


def main():
    """Main function to run the pipeline sequentially."""
    check_neo4j_connection()
    
    # # Data cleaning
    # run_data_cleaning()
    
    # # Data ingestion into Neo4j
    # run_data_ingestion()
    
    # Neo4j test queries
    run_neo4j_test_queries(season='2023-24')
    
    # Model loading and testing
    run_model_loading(model_name='tomasonjo/llama3-text2cypher-demo')
    
    # GraphQA setup and query test
    run_graphqa(question='Who are the top 5 players in the 2023-24 season based on assist total?')

if __name__ == "__main__":
    main()


Overwriting ../../src/neo4j_model/modules/main.py


In [27]:
%%writefile ../../src/neo4j_model/streamlit_app.py

import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'modules')))
import streamlit as st
import pandas as pd
from dotenv import load_dotenv
from modules.graphqa_module import (
    load_environment,
    create_neo4j_driver,
    initialize_llm,
    create_cypher_prompt_template,
    agent_repl_loop,
)
from modules.utility_functions import connect_to_neo4j
from modules.data_cleaning import (
    load_data,
    data_initial_summary,
    clean_data,
    check_for_duplicates,
    map_team_ids,
    add_suffixes_to_columns,
    final_data_check
)
from modules.neo4j_data_preprocess_ingest import (
    clear_database,
    clear_constraints_and_indexes,
    setup_schema_with_cleanup,
    setup_indexes,
    insert_enhanced_data,
    calculate_and_set_trade_value
)
from modules.neo4j_test_functions import (
    get_player_statistics,
    get_player_contracts,
    get_top_teams_by_salary,
    get_players_with_high_injury_risk,
    get_team_strategies,
    get_top_players_by_vorp
)


schema = """
Nodes:
- Player: Represents an NBA player. Properties: name, age, position, years_of_service, injury_risk, season_salary, season, per, ws, bpm, vorp.
- Team: Represents an NBA team. Properties: name, team_id, needs, strategy, cap_space.
- Season: Represents a specific NBA season. Properties: name.
- Contract: Represents player contracts. Properties: player_name, salary, cap, luxury_tax, duration, player_option, team_option, no_trade_clause.
- Statistics: Represents player statistics. Properties: player, season, ppg, assists_total, rebounds_total, steals_total, blocks_total, turnovers_total, personal_fouls_total, win_shares_total, offensive_win_shares_total, defensive_win_shares_total, vorp_total, games_played.
- Injury: Represents player injury details. Properties: player, total_days, injury_periods, risk, injury_history.

Relationships:
- Player -[:HAS_PLAYED_FOR]-> Team
- Player -[:PARTICIPATED_IN]-> Season
- Player -[:HAS_CONTRACT]-> Contract
- Player -[:POSSESSES]-> Statistics
- Player -[:SUFFERED]-> Injury
- Team -[:HAS_PLAYER]-> Player
- Team -[:CURRENT_TEAM]-> Player
"""

# App setup
st.title("NBA Player Data Analysis with Neo4j and Llama3")
st.sidebar.title("Workflow Steps")

# --- Step 1: Load and Clean Data ---
if st.sidebar.checkbox("Step 1: Load and Clean Data"):
    st.subheader("Data Loading and Cleaning")
    load_dotenv('/workspaces/custom_ollama_docker/.env')
    data_file = st.text_input("Enter path to raw data file:", "/workspaces/custom_ollama_docker/data/neo4j/raw/nba_player_data_final_inflated.csv")
    output_file = "/workspaces/custom_ollama_docker/data/neo4j/processed/nba_player_data_cleaned.csv"
    
    if st.button("Load and Clean Data"):
        try:
            dataframe = load_data(data_file)
            st.write("Initial Data Loaded:", dataframe.head())
            
            # Data cleaning steps
            data_initial_summary(dataframe, debug=True)
            dataframe_cleaned = clean_data(dataframe, debug=True)
            check_for_duplicates(dataframe_cleaned, debug=True)
            dataframe_mapped = map_team_ids(dataframe_cleaned, debug=True)
            dataframe_final = add_suffixes_to_columns(dataframe_mapped, debug=True)
            final_data_check(dataframe_final, debug=True)
            
            dataframe_final.to_csv(output_file, index=False)
            st.success(f"Preprocessed data saved to {output_file}")
            st.write("Cleaned Data:", dataframe_final.head())
        except Exception as e:
            st.error(f"Data cleaning process failed: {e}")

# --- Step 2: Ingest Data into Neo4j ---
if st.sidebar.checkbox("Step 2: Ingest Data into Neo4j"):
    st.subheader("Data Ingestion into Neo4j")
    load_dotenv('/workspaces/custom_ollama_docker/.env')
    data_file = "/workspaces/custom_ollama_docker/data/neo4j/processed/nba_player_data_cleaned.csv"
    
    if st.button("Ingest Data"):
        try:
            driver = connect_to_neo4j()
            dataframe = pd.read_csv(data_file)
            data_dicts = dataframe.to_dict(orient='records')
            st.write(f"Loaded {len(data_dicts)} records for ingestion.")
            
            with driver.session() as session:
                clear_database(session)
                clear_constraints_and_indexes(session)
                setup_schema_with_cleanup(session)
                setup_indexes(session)
                st.write("Database schema and indexes set up successfully.")
                
                for player_data in data_dicts:
                    session.execute_write(insert_enhanced_data, player_data)
                    session.execute_write(calculate_and_set_trade_value, player_data["Player"])

                st.success("Data inserted into Neo4j successfully.")
        except Exception as e:
            st.error(f"Data ingestion process failed: {e}")

# --- Step 3: Test Neo4j Queries ---
if st.sidebar.checkbox("Step 3: Test Neo4j Queries"):
    st.subheader("Test Data Retrieval from Neo4j")
    season = st.text_input("Enter season (e.g., '2023-24'):", "2023-24")
    query_type = st.selectbox("Select a query to run:", [
        "Player Statistics",
        "Top Teams by Salary",
        "Top Players by VORP"
    ])

    if st.button("Run Test Query"):
        try:
            driver = connect_to_neo4j()
            
            if query_type == "Player Statistics":
                st.write("Retrieving player statistics...")
                query = """
                MATCH (p:Player)-[:POSSESSES]->(stat:Statistics {season: $season})
                RETURN p.name AS player, stat.ppg AS points_per_game, stat.assists_total AS assists
                """
                st.code(query, language="cypher")
                player_stats_df = get_player_statistics(driver, season)
                st.write("Player Statistics:", player_stats_df.head())
            
                
            elif query_type == "Top Teams by Salary":
                st.write("Retrieving top teams by salary...")
                query = """
                MATCH (t:Team)<-[:HAS_PLAYED_FOR]-(p:Player)-[:HAS_CONTRACT]->(c:Contract {season: $season})
                RETURN t.name AS team, SUM(c.salary) AS total_salary
                ORDER BY total_salary DESC
                LIMIT 5
                """
                st.code(query, language="cypher")
                team_salary_df = get_top_teams_by_salary(driver, season)
                st.write("Top Teams by Salary:", team_salary_df.head())
                
                
            elif query_type == "Top Players by VORP":
                st.write("Retrieving top players by VORP...")
                query = """
                MATCH (p:Player)-[:POSSESSES]->(stat:Statistics {season: $season})
                RETURN p.name AS player, stat.vorp_total AS vorp
                ORDER BY stat.vorp_total DESC
                LIMIT 5
                """
                st.code(query, language="cypher")
                top_vorp_df = get_top_players_by_vorp(driver, season)
                st.write("Top Players by VORP:", top_vorp_df.head())
                
        except Exception as e:
            st.error(f"Neo4j test queries failed: {e}")



# --- Step 4: GraphQA with Llama3 ---
if st.sidebar.checkbox("Step 4 : GraphQA with Llama3"):
    st.subheader("GraphQA - Ask Questions on NBA Data")
    load_environment()
    driver = create_neo4j_driver()
    llm = initialize_llm()
    prompt_template = create_cypher_prompt_template()
    
    question = st.text_input("Enter your question about NBA data:", "Who are the top 5 players in the 2023-24 season based on assist total?")
    
    if st.button("Run GraphQA"):
        try:
            # Run the agent REPL loop and capture all intermediate outputs
            results = agent_repl_loop(
                question=question,  # Pass `question` instead of `sample_question`
                schema=schema,
                driver=driver,
                llm=llm,
                repl_tool=None,  # Optional: Python REPL for interactive debugging
                prompt_template=prompt_template
            )
            
            # Display the generated prompt
            st.write("Generated Prompt:")
            st.code(results["prompt_text"], language="plaintext")

            # Display the initial generated query
            st.write("Generated Query:")
            st.code(results["generated_query"], language="cypher")

            # Display the adjusted query if it was modified
            if results["adjusted_query"] != results["generated_query"]:
                st.write("Adjusted Query After Per-Game Calculations:")
                st.code(results["adjusted_query"], language="cypher")

            # Show the query results
            if results["query_result"]:
                st.write("Query Results:")
                st.write(pd.DataFrame(results["query_result"]))  # Display results as a DataFrame
            else:
                st.write("No results returned from the query.")

        except Exception as e:
            st.error(f"GraphQA agent execution failed: {e}")




Overwriting ../../src/neo4j_model/streamlit_app.py
