In [0]:
%run ./base_connector

In [0]:
class JDBCConnector(BaseConnector):
    def __init__(self, host, port, database, user, password, provider="postgres"):
        self.provider = provider.lower()

        self.drivers = {
            "postgres": "org.postgresql.Driver",
            "mysql": "com.mysql.cj.jdbc.Driver",
            "sqlserver": "com.microsoft.sqlserver.jdbc.SQLServerDriver",
            "oracle": "oracle.jdbc.driver.OracleDriver"
        }

        if provider not in self.drivers:
            raise ValueError(f"Unsupported provider: {provider}")

        self.url = self.build_url(host, port, database)
        self.props = { 
            "user": user, 
            "password": password, 
            "driver": self.drivers[provider]
        }

    def build_url(self, host, port, db):
        if self.provider == "postgres":
            return f"jdbc:postgresql://{host}:{port}/{db}"
        elif self.provider == "mysql":
            return f"jdbc:mysql://{host}:{port}/{db}"
        elif self.provider == "sqlserver":
            return f"jdbc:sqlserver://{host}:{port};databaseName={db}"
        elif self.provider == "oracle":
            return f"jdbc:oracle:thin:@{host}:{port}/{db}"
        else:
            raise ValueError("Unsupported provider")

    def read(self, table=None, query=None, partitions=None):
        if table:
            source = table
        elif query:
            source = f"({query}) as t"
        else:
            raise ValueError("Either table or query must be provided")

        reader = spark.read.format("jdbc") \
            .option("url", self.url) \
            .option("dbtable", source)

        if partitions:
            reader = reader \
                .option("numPartitions", partitions.get("numPartitions", 4)) \
                .option("partitionColumn", partitions["partitionColumn"]) \
                .option("lowerBound", partitions["lowerBound"]) \
                .option("upperBound", partitions["upperBound"])

        df = reader.options(**self.props).load()
        return df

    def write(self, df_source, table, mode="append"):
        (
            df_source.write
            .format("jdbc")
            .option("url", self.url)
            .option("dbtable", table)
            .options(**self.props)
            .mode(mode)
            .save()
        )

In [0]:
""" TEST """
jdbc = JDBCConnector(
    host="sql12.freesqldatabase.com",
    port=3306,
    database="sql12809587",
    user="sql12809587",
    password="HsBh2pRe3j",
    provider="mysql"
)

df = jdbc.read(table="public.users")
df.show()