In [None]:
%%configure -f
{ "jars": ["wasb:///sql/sqljdbc41.jar"] }

In [None]:
case class Rectangle(name: String, width: Double, height: Double)

In [None]:
object DatabaseUtilities {

  def getSqlJdbcConnectionString(sqlServerFQDN: String, sqlDatabaseName: String,
                             databaseUsername: String, databasePassword: String): String = {

    val serverName = sqlServerFQDN.split('.')(0)
    val certificateHostname = sqlServerFQDN.replace(serverName, "*")
    val serverPort = "1433"

    val sqlDatabaseConnectionString = f"jdbc:sqlserver://$sqlServerFQDN:$serverPort;database=$sqlDatabaseName;" +
      f"user=$databaseUsername@$serverName;password=$databasePassword;" +
      f"encrypt=true;hostNameInCertificate=$certificateHostname;loginTimeout=30;"

    sqlDatabaseConnectionString
  }
}

In [None]:
val sqlServerFQDN = "***Enter SQL Server FQDN here***"
val sqlDatabaseName = "***Enter SQL Database Name here***"
val databaseUsername = "***Enter SQL Database Username here***"
val databasePassword = "***Enter SQL Database Password here***"
val databaseTableName = "RectangleDetails"

val sqlDatabaseConnectionString : String = DatabaseUtilities.getSqlJdbcConnectionString(
      sqlServerFQDN, sqlDatabaseName, databaseUsername, databasePassword)
      
sqlDatabaseConnectionString


In [None]:
import java.sql.{Statement, Connection, DriverManager}

val sqlDriverConnection: Connection =  DriverManager.getConnection(sqlDatabaseConnectionString)

sqlDriverConnection.setAutoCommit(false)

val sqlDriverStatement: Statement = sqlDriverConnection.createStatement()

sqlDriverStatement.addBatch(f"IF NOT EXISTS(SELECT * FROM sys.objects WHERE object_id" +
    f" = OBJECT_ID(N'[dbo].[$databaseTableName]') AND type in (N'U'))" +
    f"\nCREATE TABLE $databaseTableName(Name NVARCHAR(128) NOT NULL, Width FLOAT, Height FLOAT)")

sqlDriverStatement.addBatch(f"IF IndexProperty(Object_Id('$databaseTableName'), 'IX_RectangleName', 'IndexId') IS NULL" +
    f"\nCREATE CLUSTERED INDEX IX_RectangleName ON $databaseTableName(Name)")
    
sqlDriverStatement.executeBatch()
sqlDriverConnection.commit()

sqlDriverConnection.close()

In [None]:
import org.apache.spark.sql.DataFrame

object DataFrameExtensions {

  implicit def extendedDataFrame(dataFrame: DataFrame): ExtendedDataFrame = new ExtendedDataFrame(dataFrame: DataFrame)

  class ExtendedDataFrame(dataFrame: DataFrame) {

    def saveToAzureSql(sqlDatabaseConnectionString: String, sqlTableName: String): Unit = {

      val tableHeader: String = dataFrame.columns.mkString(",")

      val recordFormat: scala.collection.mutable.StringBuilder = new scala.collection.mutable.StringBuilder()

      dataFrame.dtypes.foreach(x => {

        x._2 match {

          case "StringType" => recordFormat.append("'%s',")

          case _ => recordFormat.append("%s,")
        }
      })

      val formatRecord: Seq[Any] => String = recordFormat.stripSuffix(",").format

      dataFrame.foreachPartition { partition =>

        val sqlExecutorConnection: Connection = DriverManager.getConnection(sqlDatabaseConnectionString)

        //Batch size of 1000 is used since Azure SQL database cannot insert more than 1000 rows at the same time.

        partition.grouped(1000).foreach {

          group => {

            val insertString: scala.collection.mutable.StringBuilder = new scala.collection.mutable.StringBuilder()

            group.foreach {

              record => {

                insertString.append("(" + formatRecord(record.toSeq) + "),")
              }
            }

            sqlExecutorConnection.createStatement().executeUpdate(f"INSERT INTO [dbo].[$sqlTableName] ($tableHeader) VALUES "
                                                                  + insertString.stripSuffix(","))

          }
        }

        sqlExecutorConnection.close()
      }
    }
  }
}

In [None]:
val rectangleList: List[Rectangle] = List(Rectangle("RectangleA", 10, 20),
Rectangle("RectangleB", 30, 40), Rectangle("RectangleC", 50, 60))

In [None]:
val rectangleDataFrame = hiveContext.createDataFrame(rectangleList)

In [None]:
import DataFrameExtensions._

rectangleDataFrame.saveToAzureSql(sqlDatabaseConnectionString, databaseTableName)