# **Functions interacting with Amazon Keyspaces**

## **`createDataFrameFromTable()`**

Create a dataframe from an Amazon Keyspaces / Cassandra table.

**Params**:

- `sparkSession`: `pyspark.sql.SparkSession` variable.

- `keyspaceName`: Name of the keyspace in Cassandra / Amazon Keyspaces.

- `tableName`: Name of the table in the keyspace.

**Returns**:

- PySpark dataframe `pyspark.sql.dataframe.DataFrame` with the contents of the table.

In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.dataframe import DataFrame

In [2]:
def createDataFrameFromTable(sparkSession : SparkSession, 
                             keyspaceName : str, 
                             tableName : str) -> DataFrame:
    
    try:
        return spark.read\
              .format("org.apache.spark.sql.cassandra")\
              .options(table=tableName, keyspace=keyspaceName)\
              .load()
    
    except Exception as err:
        print(f"Exception : {err}")
        return None

## **`createSparkSessionWithCassandraConf()`**

Create a `SparkSession` with configuration required to connect to Amazon Keyspaces / Cassandra.

**Params**:

- `appName`: Name of the `SparkSession`.

**Returns**:

- PySpark Session `pyspark.sql.SparkSession` object.

In [3]:
from pyspark.sql import SparkSession

In [4]:
def createSparkSessionWithCassandraConf(appName : str) -> SparkSession:
    
    try:
        if len(appName)==0:
            print(f"Error : appName cannot be empty.")
        
        spark=SparkSession.builder.appName(appName)\
        .config("spark.files", "../application.conf")\
        .config("spark.jars", "../jar-files/spark-cassandra-connector_2.12-3.3.0.jar,"
                                "../jar-files/spark-cassandra-connector-assembly_2.12-3.3.0.jar")\
        .getOrCreate()

        spark.conf.set("spark.cassandra.connection.config.profile.path", "application.conf")
        spark.conf.set("spark.cassandra.connection.ssl.clientAuth.enabled", "true")
        spark.conf.set("spark.cassandra.connection.ssl.enabled", "true")
        
        return spark
    
    except:
        print(f"Exception : {err}")
        return None

## **`getSizeOfDataFrame()`**:

Amazon Keyspaces does not support COUNT of its table. Get the count of a DataFrame created from an Amazon Keyspace table. Should contain `ID` column with unique integer values.

**Params**:

- `dataframe`: `pyspark.sql.DataFrame` variable.

**Returns**:

- Count of the PySpark dataframe.

In [5]:
from pyspark.sql.dataframe import DataFrame

In [6]:
def getSizeOfDataFrame(dataframe : DataFrame) -> int:
    
    try:
        return dataframe[["ID"]].filter(dataframe.ID > 0).count()
    
    except:
        print(f"Exception : {err}")
        return -1

## **`saveDataFrameToTable()`**

Saves a dataframe to an Amazon Keyspaces / Cassandra table.

**Params**:

- `dataframe`: `pyspark.sql.DataFrame` variable. Should contain `ID` column with unique integer values.

- `keyspaceName`: Name of the keyspace in Cassandra / Amazon Keyspaces.

- `tableName`: Name of the table in the keyspace.

- `mode`: Mode of saving the dataframe to the table. Default value is `APPEND`.

- `batch_size`: Number of rows to save in each batch. To save the whole dataframe at once, pass the total value. Default value is `1024`.

- `verbose`: Display console output. Default value is `True`.


In [7]:
from pyspark.sql.dataframe import DataFrame

In [8]:
def saveDataFrameToTable(dataframe : DataFrame, 
                         keyspaceName : str,
                         tableName : str,
                         mode : str="APPEND", 
                         batch_size : int=1024, 
                         verbose : bool=True) -> None:
    
    # Invalid batch_size.
    if batch_size <= 0:
        print("Error: batch_size cannot be less than 1")
    
    try:
        # Get the total count of articles.
        NO_OF_ARTICLES=getSizeOfDataFrame(dataframe)
        NO_OF_ARTICLES

        # Save to table in batches.
        for start in range(0, NO_OF_ARTICLES+1, batch_size):
            dataframe.filter((dataframe.ID >= start) & (dataframe.ID < start+batch_size))\
                    .write.format("org.apache.spark.sql.cassandra")\
                    .options(table=tableName, keyspace=keyspaceName)\
                    .mode(mode)\
                    .save()

        if(verbose):
            print(f"Batch [{start}, {start+batch_size-1}] saved to {keyspaceName}.{tableName}.")
            
    except Exception as err:
        print(f"Exception : {err}")