Skip to content

Commit

Permalink
Merge pull request #366 from asogaard/sqlite-utility-methods
Browse files Browse the repository at this point in the history
SQLite utility methods
  • Loading branch information
asogaard authored Dec 8, 2022
2 parents f12f7ee + 6aef937 commit f64a29a
Show file tree
Hide file tree
Showing 9 changed files with 249 additions and 223 deletions.
46 changes: 7 additions & 39 deletions src/graphnet/data/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch
from torch.utils.data import DataLoader

from graphnet.data.sqlite.sqlite_utilities import run_sql_code, save_to_sql
from graphnet.data.sqlite.sqlite_utilities import create_table_and_save_to_sql
from graphnet.training.utils import get_predictions, make_dataloader

from graphnet.utilities.logging import get_logger
Expand Down Expand Up @@ -97,7 +97,7 @@ def __call__(
df = self._inference(device, dataloader)
truth = self._get_truth(database, event_batches[i].tolist())
retro = self._get_retro(database, event_batches[i].tolist())
self._append_to_pipeline(outdir, truth, retro, df, i)
self._append_to_pipeline(outdir, truth, retro, df)
i += 1
else:
logger.info(outdir)
Expand Down Expand Up @@ -210,44 +210,12 @@ def _append_to_pipeline(
truth: pd.DataFrame,
retro: pd.DataFrame,
df: pd.DataFrame,
i: int,
) -> None:
os.makedirs(outdir, exist_ok=True)
pipeline_database = outdir + "/%s.db" % self._pipeline_name
if i == 0:
# Only setup table schemes if its the first time appending
self._create_table(pipeline_database, "reconstruction", df)
self._create_table(pipeline_database, "truth", truth)
save_to_sql(df, "reconstruction", pipeline_database)
save_to_sql(truth, "truth", pipeline_database)
create_table_and_save_to_sql(df, "reconstruction", pipeline_database)
create_table_and_save_to_sql(truth, "truth", pipeline_database)
if isinstance(retro, pd.DataFrame):
if i == 0:
self._create_table(pipeline_database, "retro", retro)
save_to_sql(retro, self._retro_table_name, pipeline_database)

# @FIXME: Duplicate.
def _create_table(
self, pipeline_database: str, table_name: str, df: pd.DataFrame
) -> None:
"""Create a table.
Args:
pipeline_database: Path to the pipeline database.
table_name: Name of the table in pipeline database.
df: DataFrame of combined predictions.
"""
query_columns_list = list()
for column in df.columns:
if column == "event_no":
type_ = "INTEGER PRIMARY KEY NOT NULL"
else:
type_ = "FLOAT"
query_columns_list.append(f"{column} {type_}")
query_columns = ", ".join(query_columns_list)

code = (
"PRAGMA foreign_keys=off;\n"
f"CREATE TABLE {table_name} ({query_columns});\n"
"PRAGMA foreign_keys=on;"
)
run_sql_code(pipeline_database, code)
create_table_and_save_to_sql(
retro, self._retro_table_name, pipeline_database
)
2 changes: 1 addition & 1 deletion src/graphnet/data/sqlite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from graphnet.utilities.imports import has_torch_package

from .sqlite_dataconverter import SQLiteDataConverter
from .sqlite_utilities import run_sql_code, save_to_sql, create_table
from .sqlite_utilities import create_table_and_save_to_sql

if has_torch_package():
from .sqlite_dataset import SQLiteDataset
Expand Down
95 changes: 28 additions & 67 deletions src/graphnet/data/sqlite/sqlite_dataconverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from tqdm import tqdm

from graphnet.data.dataconverter import DataConverter # type: ignore[attr-defined]
from graphnet.data.sqlite.sqlite_utilities import run_sql_code, save_to_sql
from graphnet.data.sqlite.sqlite_utilities import (
create_table,
create_table_and_save_to_sql,
)


class SQLiteDataConverter(DataConverter):
Expand Down Expand Up @@ -51,7 +54,15 @@ def save_data(self, data: List[OrderedDict], output_file: str) -> None:
saved_any = False
for table, df in dataframe.items():
if len(df) > 0:
save_to_sql(df, table, output_file)
create_table_and_save_to_sql(
df,
table,
output_file,
default_type="FLOAT",
integer_primary_key=not (
is_pulse_map(table) or is_mc_tree(table)
),
)
saved_any = True

if saved_any:
Expand Down Expand Up @@ -92,12 +103,14 @@ def merge_files(
input_files, table_name
)
if len(column_names) > 1:
is_pulse_map = is_pulsemap_check(table_name)
self._create_table(
output_file,
table_name,
create_table(
column_names,
is_pulse_map=is_pulse_map,
table_name,
output_file,
default_type="FLOAT",
integer_primary_key=not (
is_pulse_map(table_name) or is_mc_tree(table_name)
),
)

# Merge temporary databases into newly created one
Expand Down Expand Up @@ -157,60 +170,6 @@ def any_pulsemap_is_non_empty(self, data_dict: Dict[str, Dict]) -> bool:
pulsemap_dicts = [data_dict[pulsemap] for pulsemap in self._pulsemaps]
return any(d["dom_x"] for d in pulsemap_dicts)

def _attach_index(self, database: str, table_name: str) -> None:
"""Attach the table index.
Important for query times!
"""
code = (
"PRAGMA foreign_keys=off;\n"
"BEGIN TRANSACTION;\n"
f"CREATE INDEX event_no_{table_name} ON {table_name} (event_no);\n"
"COMMIT TRANSACTION;\n"
"PRAGMA foreign_keys=on;"
)
run_sql_code(database, code)

def _create_table(
self,
database: str,
table_name: str,
columns: List[str],
is_pulse_map: bool = False,
) -> None:
"""Create a table.
Args:
database: Path to the database.
table_name: Name of the table.
columns: The names of the columns of the table.
is_pulse_map: Whether or not this is a pulse map table.
"""
query_columns = list()
for column in columns:
if column == "event_no":
if not is_pulse_map:
type_ = "INTEGER PRIMARY KEY NOT NULL"
else:
type_ = "NOT NULL"
else:
type_ = "FLOAT"
query_columns.append(f"{column} {type_}")
query_columns_string = ", ".join(query_columns)

code = (
"PRAGMA foreign_keys=off;\n"
f"CREATE TABLE {table_name} ({query_columns_string});\n"
"PRAGMA foreign_keys=on;"
)
run_sql_code(database, code)

if is_pulse_map:
self.debug(table_name)
self.debug("Attaching indices")
self._attach_index(database, table_name)
return

def _submit_to_database(
self, database: str, key: str, data: pd.DataFrame
) -> None:
Expand Down Expand Up @@ -280,9 +239,11 @@ def construct_dataframe(extraction: Dict[str, Any]) -> pd.DataFrame:
return out


def is_pulsemap_check(table_name: str) -> bool:
"""Check whether `table_name` corresponds to a pulsemap."""
if "pulse" in table_name.lower():
return True
else:
return False
def is_pulse_map(table_name: str) -> bool:
"""Check whether `table_name` corresponds to a pulse map."""
return "pulse" in table_name.lower() or "series" in table_name.lower()


def is_mc_tree(table_name: str) -> bool:
"""Check whether `table_name` corresponds to an MC tree."""
return "I3MCTree" in table_name
104 changes: 84 additions & 20 deletions src/graphnet/data/sqlite/sqlite_utilities.py
Original file line number Diff line number Diff line change
@@ -1,79 +1,117 @@
"""SQLite-specific utility functions for use in `graphnet.data`."""

import os.path
from typing import List

import pandas as pd
import sqlalchemy
import sqlite3


def run_sql_code(database: str, code: str) -> None:
def database_exists(database_path: str) -> bool:
"""Check whether database exists at `database_path`."""
assert database_path.endswith(
".db"
), "Provided database path does not end in `.db`."
return os.path.exists(database_path)


def database_table_exists(database_path: str, table_name: str) -> bool:
"""Check whether `table_name` exists in database at `database_path`."""
if not database_exists(database_path):
return False
query = f"SELECT name FROM sqlite_master WHERE type='table' AND name='{table_name}';"
with sqlite3.connect(database_path) as conn:
result = pd.read_sql(query, conn)
return len(result) == 1


def run_sql_code(database_path: str, code: str) -> None:
"""Execute SQLite code.
Args:
database: Path to databases
database_path: Path to databases
code: SQLite code
"""
conn = sqlite3.connect(database)
conn = sqlite3.connect(database_path)
c = conn.cursor()
c.executescript(code)
c.close()


def save_to_sql(df: pd.DataFrame, table_name: str, database: str) -> None:
def save_to_sql(df: pd.DataFrame, table_name: str, database_path: str) -> None:
"""Save a dataframe `df` to a table `table_name` in SQLite `database`.
Table must exist already.
Args:
df: Dataframe with data to be stored in sqlite table
table_name: Name of table. Must exist already
database: Path to SQLite database
database_path: Path to SQLite database
"""
engine = sqlalchemy.create_engine("sqlite:///" + database)
engine = sqlalchemy.create_engine("sqlite:///" + database_path)
df.to_sql(table_name, con=engine, index=False, if_exists="append")
engine.dispose()


def attach_index(database: str, table_name: str) -> None:
"""Attaches the table index.
def attach_index(
database_path: str, table_name: str, index_column: str = "event_no"
) -> None:
"""Attach the table (i.e., event) index.
Important for query times!
"""
code = (
"PRAGMA foreign_keys=off;\n"
"BEGIN TRANSACTION;\n"
f"CREATE INDEX event_no_{table_name} ON {table_name} (event_no);\n"
f"CREATE INDEX {index_column}_{table_name} "
f"ON {table_name} ({index_column});\n"
"COMMIT TRANSACTION;\n"
"PRAGMA foreign_keys=on;"
)
run_sql_code(database, code)
run_sql_code(database_path, code)


def create_table(
df: pd.DataFrame,
columns: List[str],
table_name: str,
database_path: str,
is_pulse_map: bool = False,
*,
index_column: str = "event_no",
default_type: str = "NOT NULL",
integer_primary_key: bool = True,
) -> None:
"""Create a table.
Args:
df: Data to be saved to table
columns: Column names to be created in table.
table_name: Name of the table.
database_path: Path to the database.
is_pulse_map: Whether or not this is a pulse map table.
index_column: Name of the index column.
default_type: The type used for all non-index columns.
integer_primary_key: Whether or not to create the `index_column` with
the `INTEGER PRIMARY KEY` type. Such a column is required to have
unique, integer values for each row. This is appropriate when the
table has one row per event, e.g., event-level MC truth. It is not
appropriate for pulse map series, particle-level MC truth, and
other such data that is expected to have more that one row per
event (i.e., with the same index).
"""
query_columns = list()
for column in df.columns:
if column == "event_no":
if not is_pulse_map:
# Prepare column names and types
query_columns = []
for column in columns:
type_ = default_type
if column == index_column:
if integer_primary_key:
type_ = "INTEGER PRIMARY KEY NOT NULL"
else:
type_ = "NOT NULL"
else:
type_ = "NOT NULL"

query_columns.append(f"{column} {type_}")

query_columns_string = ", ".join(query_columns)

# Run SQL code
code = (
"PRAGMA foreign_keys=off;\n"
f"CREATE TABLE {table_name} ({query_columns_string});\n"
Expand All @@ -83,3 +121,29 @@ def create_table(
database_path,
code,
)

# Attaching index to all non-truth-like tables (e.g., pulse maps).
if not integer_primary_key:
attach_index(database_path, table_name)


def create_table_and_save_to_sql(
df: pd.DataFrame,
table_name: str,
database_path: str,
*,
index_column: str = "event_no",
default_type: str = "NOT NULL",
integer_primary_key: bool = True,
) -> None:
"""Create table if it doesn't exist and save dataframe to it."""
if not database_table_exists(database_path, table_name):
create_table(
df.columns,
table_name,
database_path,
index_column=index_column,
default_type=default_type,
integer_primary_key=integer_primary_key,
)
save_to_sql(df, table_name=table_name, database_path=database_path)
Loading

0 comments on commit f64a29a

Please sign in to comment.