### Configurations

Please enter your configurations in the cell below. Ensure you fill out all

In [None]:
%%spark
// Fabric config
var WorkspaceId = "<workspace_id>"
var LakehouseId = "<lakehouse_id>"
var IntermediateFolderPath = f"abfss://${WorkspaceId}@onelake.dfs.fabric.microsoft.com/${LakehouseId}/Files/hms_output/syn/"

var ContainerName = "<container_name>"
var StorageName = "<storage_name>"
var SynapseWorkspaceName = <synapse_workspace_name>
var WarehouseMappings:Map[String, String] = Map(
    f"abfss://${ContainerName}@${StorageName}.dfs.core.windows.net/synapse/workspaces/${SynapseWorkspaceName}/warehouse"-> f"abfss://${WorkspaceId}@onelake.dfs.fabric.microsoft.com/${LakehouseId}/Files/warehouse_dir_syn",
    f"dbfs:/mnt/${StorageName}/databricks/warehouse"->f"abfss://${WorkspaceId}@onelake.dfs.fabric.microsoft.com/${LakehouseId}/Files/warehouse_dir_dbx",
    f"abfss://${ContainerName}@${StorageName}.dfs.core.windows.net/apps/spark/warehouse"->f"abfss://${WorkspaceId}@onelake.dfs.fabric.microsoft.com/${LakehouseId}/Files/warehouse_dir_hdi"
)

// Metastore config
var DatabasePrefix = ""
var TablePrefix = ""
var IgnoreIfExists = true

In [None]:
%%spark
import java.net.URI
import java.util.Calendar

import scala.collection.mutable.{ListBuffer, Map}
import org.apache.spark.sql._
import org.apache.spark.sql.types.{ObjectType, _}
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.catalog._
import org.json4s._
import org.json4s.JsonAST.JString
import org.json4s.jackson.Serialization
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.http.client.methods.{CloseableHttpResponse, HttpPost}
import org.apache.http.entity.StringEntity
import org.apache.http.impl.client.{CloseableHttpClient, HttpClients}
import scala.io.Source


var locationPrefixMappingList = WarehouseMappings.toList.sortBy(pair => pair._1).reverse

val DatabaseType = "database"
val TableType = "table"
val PartitionType = "partition"

object ImportMetadata {

  val spark = SparkSession.builder().getOrCreate()

  case object URISerializer extends CustomSerializer[URI](format => ( {
    case JString(uri) => new URI(uri)
  }, {
    case uri: URI => JString(uri.toString())
  }))

  case object SturctTypeSerializer extends CustomSerializer[StructType](format => ( {
    case JString(structType)  => DataType.fromJson(structType).asInstanceOf[StructType]
  }, {
    case structType: StructType => JString(structType.json)
  }))


  implicit val formats = DefaultFormats + URISerializer + SturctTypeSerializer

  case class CatalogPartitions(database: String, table: String, tablePartitons: Seq[CatalogTablePartition])

  case class CatalogTables(database: String, tables: Seq[CatalogTable])

  case class CatalogStat(entityType: String, count: Int, database: Option[String], table: Option[String])

  def ConvertLocation(location: String) : String = {
    var locationMapping = locationPrefixMappingList.find(mapping => {location.startsWith(mapping._1)})

    if (locationMapping != None) {
      return location.replaceFirst(locationMapping.get._1, locationMapping.get._2)
    }

    return location;
  }

  def ConvertCatalogDatabase(database: CatalogDatabase) : CatalogDatabase = {
    var convertedDatabase  = new CatalogDatabase(
      DatabasePrefix + database.name,
      database.description,
      new URI(ConvertLocation(database.locationUri.toString())),
      database.properties)

    return convertedDatabase;
  }

  def ConvertCatalogStorageFormat(format : CatalogStorageFormat) : CatalogStorageFormat = {

    var formatlocation: Option[URI] = None
    if (format.locationUri != None) {
      formatlocation = Some(new URI(ConvertLocation(format.locationUri.get.toString())))
    }

    var convertedStorageFormat = new CatalogStorageFormat(
      formatlocation,
      format.inputFormat,
      format.outputFormat,
      format.serde,
      format.compressed,
      format.properties
    )

    return  convertedStorageFormat;
  }

  def ConvertCatalogTable(table: CatalogTable) : CatalogTable = {

    var dbName = Some(DatabasePrefix + table.identifier.database.get);
    var tblName = TablePrefix + table.identifier.table;

    var convertedTable = new CatalogTable(
      new TableIdentifier(tblName, dbName),
      org.apache.spark.sql.catalyst.catalog.CatalogTableType("EXTERNAL"),
      ConvertCatalogStorageFormat(table.storage),
      table.schema,
      table.provider,
      table.partitionColumnNames,
      table.bucketSpec,
      table.owner,
      table.createTime,
      table.lastAccessTime,
      table.createVersion,
      table.properties,
      table.stats,
      table.viewText,
      table.comment,
      table.unsupportedFeatures,
      table.tracksPartitionsInCatalog,
      table.schemaPreservesCase,
      table.ignoredProperties)

    return convertedTable;
  }

  def ConvertCatalogTablePartition(partition : CatalogTablePartition) : CatalogTablePartition = {
    var convertedPartition = new CatalogTablePartition(
      partition.spec,
      ConvertCatalogStorageFormat(partition.storage),
      partition.parameters,
      partition.createTime,
      partition.lastAccessTime,
      partition.stats
    );

    return convertedPartition;
  }

  val MaxRetryCount = 3;

  def RetriableFunc(func: () => Unit, retryCount: Int = 0): Unit = {
    try {
      func()
    } catch {
      case e:Exception => {
        if (retryCount < MaxRetryCount){
          RetriableFunc(func, retryCount + 1)
        } else {
          throw e
        }
      }
    }
  }

  def RetriableQueryFunc(func: () => Object, retryCount: Int = 0): Object = {
    try {
      func()
    } catch {
      case e:Exception => {
        if (retryCount < MaxRetryCount){
          RetriableQueryFunc(func, retryCount + 1)
        } else {
          throw e
        }
      }
    }
  }

// Create DBs

  def CreateDatabases(dataPath: String): Int = {

    println("Start to create databases " + Calendar.getInstance().getTime())

    val ds = spark.read.format("text").load(dataPath)

    var createdCount = 0;
    var existsDbs = spark.sharedState.externalCatalog.listDatabases()
    var data = ds.collect()
    var total = data.size

    data.foreach(row => {
      var jsonString = row.getString(0)
      var newDb = ConvertCatalogDatabase(Serialization.read[CatalogDatabase](jsonString))

      var exists = existsDbs.contains(newDb.name)
      if (exists && !IgnoreIfExists) {

        println(createdCount + "/" + total + " databases created. " + Calendar.getInstance().getTime())
        println("Database " + newDb.name + " already exists")

        throw new Exception("Database " + newDb.name + " already exists")
      } else if (!exists) {
        CreateDatabase(newDb.name)
      }

      createdCount+=1;

      if (createdCount%100 == 0) {
        println(createdCount + "/" + total + " databases created" + Calendar.getInstance().getTime())
      }
    });

    println("Databases Created completed. Total " + createdCount + " database created. " + Calendar.getInstance().getTime())
    return createdCount
  }

  def CreateDatabase(dbName: String) = {
    mssparkutils.lakehouse.create(dbName, "imported db", WorkspaceId)
  }

  // Create Tables

  def CreateTables(dataPath: String): Int = {
    println("Start to create tables " + Calendar.getInstance().getTime())

    val ds = spark.read.format("text").load(dataPath);

    var createdCount = 0;
    ds.collect().foreach(row => {
      var jsonString = row.getString(0)
      var tables = Serialization.read[CatalogTables](jsonString);

      var existsTables = spark.sharedState.externalCatalog.listTables(DatabasePrefix + tables.database)
      var perTables = tables.tables.toParArray

      perTables.foreach(table => {
        var newTable = ConvertCatalogTable(table)
        var exists = existsTables.contains(newTable.identifier.table)
        if (exists && !IgnoreIfExists) {

          println(createdCount + " tables created. " + Calendar.getInstance().getTime())
          println("Table " + newTable.identifier.database + "." + newTable.identifier.table + " already exists")

          throw new Exception("Table " + newTable.identifier.database + "." + newTable.identifier.table + " already exists")
        } else if (!exists) {
          CreateTable(newTable)
        }

        createdCount+=1;
      })

      println(createdCount + " tables created" + Calendar.getInstance().getTime())
    })

    println("Tables Created completed. Total " + createdCount + " table created. " + Calendar.getInstance().getTime())
    return createdCount
  }

  def CreateTable(table:CatalogTable) = {
    RetriableFunc(() => {
      spark.sharedState.externalCatalog.createTable(table, IgnoreIfExists)
    })
  }

  def ValidateTablePath(dataPath: String) = {
    println("Start to validate table path " + Calendar.getInstance().getTime())

    val ds = spark.read.format("text").load(dataPath)
    var hadoopConf = spark.sparkContext.hadoopConfiguration

    ds.collect().foreach(row => {
      var jsonString = row.getString(0)
      var tables = Serialization.read[CatalogTables](jsonString);

      tables.tables.toParArray.foreach(table => {
        var newTable = ConvertCatalogTable(table)
        try{
          var p = new Path(newTable.location);
          var fs = p.getFileSystem(hadoopConf);
        } catch {
          case e:Exception => {
            throw new Exception("Validate table path failed. Table: " + newTable.identifier.database.getOrElse() + "." + newTable.identifier.table + ", Location: " +  newTable.location + " , exception: " + e)
          }
        }
      })
    })

    println("Validate table path completed")
  }

  // Create Partitions

  def CreatePartitions(dataPath: String): Int = {
    println("Start to create partitions " + Calendar.getInstance().getTime())

    val ds = spark.read.format("text").load(dataPath);

    var createdCount = 0;
    ds.collect().foreach(row => {
      var jsonString = row.getString(0)
      var parts = Serialization.read[CatalogPartitions](jsonString);

      var catalogTablePartitions = new ListBuffer[CatalogTablePartition]()
      parts.tablePartitons.foreach( part => {
        catalogTablePartitions += ConvertCatalogTablePartition(part)
      })

      RetriableFunc(() => {
        spark.sharedState.externalCatalog.createPartitions(DatabasePrefix + parts.database, TablePrefix + parts.table, catalogTablePartitions, IgnoreIfExists)
      })

      createdCount+=catalogTablePartitions.size;
      println(createdCount +  " partitions created" + Calendar.getInstance().getTime())
    });

    println("Partition Created completed. Total " + createdCount + " partition created. " + Calendar.getInstance().getTime())
    return createdCount
  }

  def LoadStats(dataPath: String): List[CatalogStat] = {
    println("Start to load stats " + Calendar.getInstance().getTime())

    val ds = spark.read.format("text").load(dataPath);

    var statBuffer = new ListBuffer[CatalogStat];
    ds.collect().foreach(row => {
      var jsonString = row.getString(0)
      statBuffer.append(Serialization.read[CatalogStat](jsonString))
    })

    return statBuffer.toList
  }

  def ValidateImportResult(entityType: String, createdCount: Int, stats: List[CatalogStat]):Boolean = {
    var mappingStat = stats.find(stat => stat.entityType == entityType && stat.database == None && stat.table == None);
    if (mappingStat == None) {
      println("Validated failed. Failed to get orignal " + entityType + " count")
      return false
    }

    if (mappingStat.get.count != createdCount) {
      println("Validated failed. Catalog object count missmatch. Expected " + entityType + " count is " + mappingStat.get.count + ", but created " + entityType + " count is " + createdCount);
      return false;
    }

    println("Validated passed. Catalog objects are created as expected. " + createdCount + " " + entityType + " are created." )
    return true
  }

  def ImportCatalogObjectsFromFile(inputPath: String) = {

    val dbsPath = inputPath + "databases";
    val tablesPath = inputPath + "tables";
    val partPath = inputPath + "partitions";

    CreateDatabases(dbsPath)
    CreateTables(tablesPath)
    CreatePartitions(partPath)
  }
}

var stats = ImportMetadata.LoadStats(IntermediateFolderPath + "/catalogObjectStats")

In [None]:
%%spark
// Validate table path
ImportMetadata.ValidateTablePath(IntermediateFolderPath + "/tables")

In [None]:
// Import databases
var createdDb = ImportMetadata.CreateDatabases(IntermediateFolderPath + "/databases")
ImportMetadata.ValidateImportResult(DatabaseType, createdDb, stats)

In [None]:
%%spark
// Validate Lakehouse (database) creation
spark.sharedState.externalCatalog.listDatabases()

In [None]:
%%spark
// Import Tables
var createdTbl = ImportMetadata.CreateTables(IntermediateFolderPath + "/tables")
ImportMetadata.ValidateImportResult(TableType, createdTbl, stats)

In [None]:
%%spark
// Import Partitions
var createdPart = ImportMetadata.CreatePartitions(IntermediateFolderPath + "/partitions")
ImportMetadata.ValidateImportResult(PartitionType, createdPart, stats)

In [None]:
%%pyspark
spark.catalog.listTables()