In [None]:
import json
import os
from datetime import datetime, timedelta
from typing import Any, List, Dict, Optional
from dotenv import load_dotenv
from sqlalchemy import create_engine, Column, Integer, String, Date, inspect, event, types
from sqlalchemy.orm import sessionmaker, Session, declarative_base
from sqlalchemy.exc import SQLAlchemyError
import pandas as pd
from wifor_db import setup_logger

# set up logger
script_dir = "log_files"
logger = setup_logger("add_regions", script_dir)
if logger:
    logger.info("Logging setup complete")
else:
    print("Logger setup failed.")

# Set up Base instance
Base = declarative_base()

def load_table_schema(file_path: str) -> Optional[Any]:
    """
    Load a table schema from a JSON file.

    Args:
        file_path (str): The path to the JSON file containing the table schema.

    Returns:
        Optional[Any]: The loaded JSON object if successful, None otherwise.

    Raises:
        logs an error message if an exception occurs during file reading or JSON loading.
    """
    try:
        with open(file_path, 'r') as file:
            return json.load(file)
    except Exception as e:
        logger.error(f"Error loading table schema: {e}")
        return None

def create_repr_string(name: str, columns: List[Dict[str, str]]) -> Optional[str]:
    """
    Create a string representation for a SQLAlchemy model class.

    Args:
        name (str): The name of the model class.
        columns (List[Dict[str, str]]): A list of dictionaries, each representing a column in the table.
                                        Each dictionary should have keys like 'name' and 'type'.

    Returns:
        Optional[str]: A string that can be used as the __repr__ method for a SQLAlchemy model class.
                       Returns None if an error occurs.

    Raises:
        logs an error message if an exception occurs during string creation.
    """
    try:
        repr_parts = [f"{column['name']}='{{self.{column['name']}}}'" if 'String' in column['type'] else f"{column['name']}={{self.{column['name']}}}" for column in columns]
        standard_parts = ["version_number={self.version_number}", "effective_date='{self.effective_date}'", "expiry_date='{self.expiry_date}'"]
        return f"<{name}(" + ', '.join(repr_parts + standard_parts) + ")>"
    except Exception as e:
        logger.error(f"Error creating representation string: {e}")
        return None

def before_flush(cls, session: Session, flush_context: Any, instances: Any, columns: List[Dict[str, str]], unique_identifier: str) -> None:
    """
    Custom 'before_flush' event handler for SQLAlchemy models.

    This function is designed to be used as a class method in SQLAlchemy models. It checks for
    duplicate entries and manages the versioning of records.

    Args:
        cls (Type[Any]): The class on which the event was invoked.
        session (Session): The session which is flushing.
        flush_context (Any): The context for the flush.
        instances (Any): The set of instances participating in the flush.
        columns (List[Dict[str, str]]): A list of dictionaries representing the columns of the table.
        unique_identifier (str): The field name used as a unique identifier for the entries.

    Raises:
        logs an error message if an exception occurs during the flush process.
    """
    try:
        for instance in session.new:
            if isinstance(instance, cls):
                filter_args = {col['name']: getattr(instance, col['name']) for col in columns}
                same_entry = session.query(cls).filter_by(**filter_args).first()

                if same_entry:
                    # If entries are identical, remove the new instance
                    session.expunge(instance)
                else:
                    # Check for previous entry with the same unique identifier and no expiry date
                    previous_entry = session.query(cls).filter_by(**{unique_identifier: getattr(instance, unique_identifier), 'expiry_date': None}).first()

                    if previous_entry:
                        # Update the existing entry expiry date
                        new_effective_date = datetime.now() - timedelta(days=1)
                        previous_entry.expiry_date = new_effective_date

                        # Increment version number of new entry
                        instance.version_number = previous_entry.version_number + 1
    except Exception as e:
        logger.error(f"Error in before_flush: {e}")

def create_engine_from_env():
    """
    Creates a SQLAlchemy engine based on environment variables.

    This function supports creating engines for both SQLite and MySQL databases.
    The database type and credentials are read from environment variables.

    Returns:
        sqlalchemy.engine.Engine: SQLAlchemy engine if successful, None otherwise.

    Raises:
        Logs an error message using the configured logger if an exception occurs during engine creation.
    """
    try:
        load_dotenv()
        current_db = os.environ.get('CURRENT_DB')

        if current_db == 'sqlite':
            db_path = os.environ.get('SQLITE_DB_PATH')
            if not db_path:
                raise ValueError("SQLite database path is not set in .env file.")
            return create_engine(f"sqlite:///{db_path}", echo=False)

        elif current_db == 'mysql':
            db_user = os.environ.get('MYSQL_DB_USER')
            db_password = os.environ.get('MYSQL_DB_PASSWORD')
            db_host = os.environ.get('MYSQL_DB_HOST')
            db_name = os.environ.get('MYSQL_DB_NAME')

            if not all([db_user, db_password, db_host, db_name]):
                raise ValueError("MySQL credentials are not set properly in .env file.")

            db_url = f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}"
            return create_engine(db_url, echo=False)

        else:
            raise ValueError("Database type is not defined or not supported.")

    except ValueError as e:
        logger.error(f"Configuration Error: {e}")
    except SQLAlchemyError as e:
        logger.error(f"SQLAlchemy Engine Creation Error: {e}")
    except Exception as e:
        logger.error(f"Unexpected Error: {e}")
    return None

def create_session():
    """
    Create and return a SQLAlchemy session.
    """
    engine = create_engine_from_env()
    try:
        if engine is None:
            raise ValueError("Engine cannot be None.")
        Session = sessionmaker(bind=engine)
        return Session()
    except SQLAlchemyError as e:
        print(f"SQLAlchemy Error: {e}")
        return None
    except ValueError as e:
        print(f"Configuration Error: {e}")
        return None
    except Exception as e:
        print(f"Unexpected Error: {e}")
        return None

def create_table(cls):
    """
    Initialize the database, create the table if it does not exist, and return a session for querying.
    """
    try:
        # Create a session
        session = create_session()
        if session is None:
            raise ValueError("Failed to create a session.")

        engine = session.get_bind()

        # Check if the table exists and create it if not
        if not inspect(engine).has_table(cls.__tablename__):
            cls.metadata.create_all(engine)
            print(f"The table '{cls.__tablename__}' has been created in the database.")

        return session

    except SQLAlchemyError as e:
        print(f"SQLAlchemy Error: {e}")
        return None
    except ValueError as e:
        print(f"Configuration Error: {e}")
        return None
    except Exception as e:
        print(f"Unexpected Error: {e}")
        return None
    
def add_data(cls, data, column_names):
    """
    Add data from a pandas or geopandas DataFrame to the database table.
    """
    if not isinstance(data, pd.DataFrame):
        raise ValueError("Data must be a pandas or geopandas DataFrame")

    filtered_data = data[column_names]
    session = create_session()

    try:
        for _, row in filtered_data.iterrows():
            instance = cls(**row.to_dict())
            session.add(instance)
        session.commit()
    except Exception as e:
        session.rollback()
        raise e
    finally:
        session.close()

# Function to create a dynamic table class based on JSON schema
def create_class(schema):
    name = schema['table_name']
    columns = schema['columns']
    unique_identifier = schema['identifier']

    attrs = {'__tablename__': name}

    # Add standard columns at the beginning
    attrs['id'] = Column(Integer, primary_key=True, autoincrement=True)

    # Add columns from JSON schema
    for column in columns:
        column_type = getattr(types, column['type'].split('(')[0])
        if '(' in column['type']:
            size = int(column['type'].split('(')[1].replace(')', ''))
            column_type = column_type(size)
        attrs[column['name']] = Column(column_type)

    # Add standard columns at the end
    attrs['version_number'] = Column(Integer, default=1)
    attrs['effective_date'] = Column(Date, default=datetime.now)
    attrs['expiry_date'] = Column(Date, default=None)

    # Add dynamic __repr__ method
    repr_string = create_repr_string(name, columns)
    attrs['__repr__'] = lambda self: repr_string.format(self=self)

    # Add dynamic before_flush class method
    attrs['before_flush'] = classmethod(lambda cls, session, flush_context, instances: before_flush(cls, session, flush_context, instances, columns, unique_identifier))

    # Assign functions as class methods
    attrs['create_session'] = classmethod(lambda cls: create_session())
    attrs['create_table'] = classmethod(create_table)
    attrs['add_data'] = classmethod(add_data)

    return type(name, (Base,), attrs)

In [None]:
schema = load_table_schema('../wifor_db/tables/regions.json')
Regions = create_class(schema)

In [None]:
session = Regions.create_table()

In [None]:
import geopandas as gpd

geo_df = gpd.read_file('ref-nuts-2021\\NUTS_RG_01M_2021_4326.geojson')
column_names = [column['name'] for column in schema['columns']]

Regions.add_data(geo_df, column_names)

In [None]:
session = Regions.create_session()
region = session.query(Regions).filter_by(NUTS_ID="AT").all()
region