diff --git a/src/main/scala/fr/brouillard/gitbucket/h2/controller/H2BackupController.scala b/src/main/scala/fr/brouillard/gitbucket/h2/controller/H2BackupController.scala index 97b0716..23a6d99 100644 --- a/src/main/scala/fr/brouillard/gitbucket/h2/controller/H2BackupController.scala +++ b/src/main/scala/fr/brouillard/gitbucket/h2/controller/H2BackupController.scala @@ -3,7 +3,7 @@ package fr.brouillard.gitbucket.h2.controller import java.io.File import java.util.Date import fr.brouillard.gitbucket.h2._ -import fr.brouillard.gitbucket.h2.controller.H2BackupController.{defaultBackupFileName, doBackup} +import fr.brouillard.gitbucket.h2.controller.H2BackupController.{defaultBackupFileName, doBackup, exportConnectedDatabase, logger} import gitbucket.core.controller.ControllerBase import gitbucket.core.model.Account import gitbucket.core.util.AdminAuthenticator @@ -13,7 +13,13 @@ import org.scalatra.{ActionResult, Ok, Params} import org.slf4j.LoggerFactory import org.scalatra.forms._ +import java.sql.Connection +import scala.util.Using + object H2BackupController { + + private val logger = LoggerFactory.getLogger(classOf[H2BackupController]) + def defaultBackupFileName(): String = { val format = new java.text.SimpleDateFormat("yyyy-MM-dd_HH-mm") "gitbucket-db-" + format.format(new Date()) + ".zip" @@ -28,10 +34,23 @@ object H2BackupController { case _ => org.scalatra.Unauthorized() } } + + def exportConnectedDatabase(conn: Connection, exportFile: File): Unit = { + val destFile = if (exportFile.isAbsolute) exportFile else new File(GitBucketHome + "/backup", exportFile.toString) + + logger.info("Exporting database to {}", destFile) + + Using.resource(conn.prepareStatement("BACKUP TO ?")){ statement => + statement.setString(1, destFile.toString) + statement.execute() + } + + logger.info("Exported {} bytes.", exportFile.length()) + } + } class H2BackupController extends ControllerBase with AdminAuthenticator { - private val logger = LoggerFactory.getLogger(classOf[H2BackupController]) case class BackupForm(destFile: String) @@ -39,17 +58,8 @@ class H2BackupController extends ControllerBase with AdminAuthenticator { "dest" -> trim(label("Destination", text(required))) )(BackupForm.apply) - // private val defaultBackupFile:String = new File(GitBucketHome, "gitbucket-database-backup.zip").toString; - def exportDatabase(exportFile: File): Unit = { - val destFile = if (exportFile.isAbsolute) exportFile else new File(GitBucketHome + "/backup", exportFile.toString) - - val session = Database.getSession(request) - val conn = session.conn - - logger.info("exporting database to {}", destFile) - - conn.prepareStatement("BACKUP TO '" + destFile + "'").execute() + exportConnectedDatabase(Database.getSession(request).conn, exportFile) } get("/admin/h2backup")(adminOnly { diff --git a/src/test/scala/fr/brouillard/gitbucket/h2/controller/H2BackupControllerTests.scala b/src/test/scala/fr/brouillard/gitbucket/h2/controller/H2BackupControllerTests.scala index 37df236..7777fd7 100644 --- a/src/test/scala/fr/brouillard/gitbucket/h2/controller/H2BackupControllerTests.scala +++ b/src/test/scala/fr/brouillard/gitbucket/h2/controller/H2BackupControllerTests.scala @@ -2,13 +2,18 @@ package fr.brouillard.gitbucket.h2.controller import gitbucket.core.model.Account import gitbucket.core.servlet.ApiAuthenticationFilter +import org.apache.commons.io.FileSystemUtils +import org.h2.Driver +import org.h2.engine.Database import org.scalatest.funsuite.AnyFunSuite import org.scalatest.matchers.should.Matchers.{convertToAnyShouldWrapper, equal} import org.scalatra.{Ok, Params, ScalatraParams} import org.scalatra.test.scalatest.ScalatraFunSuite import java.io.File -import java.util.Date +import java.nio.file.{Files, Path, Paths} +import java.util.{Date, Properties} +import scala.util.Using class H2BackupControllerTests extends ScalatraFunSuite { addFilter(classOf[ApiAuthenticationFilter], path="/api/*") @@ -59,6 +64,49 @@ class H2BackupControllerObjectTests extends AnyFunSuite { description = None) } + private def h2Url(file: File): String = { + "jdbc:h2:file:" + file + ";DATABASE_TO_UPPER=false" + } + + test("exports connected database with safe file name") { + exportsConnectedDatabase("backup.zip") + } + + test("exports connected database with unsafe file name") { + exportsConnectedDatabase("data.zip' drop database xyx") + } + + private def exportsConnectedDatabase(backupFileName: String): Unit = { + val tempDir = Files.createTempDirectory(classOf[H2BackupControllerObjectTests].getName + "-exports-connected-database") + try { + val requestedDbFile = new File(tempDir.toFile, "data") + // H2 can create several files; in this case, it will only create a data file and no lock files. + val createdDbFile = new File(tempDir.toFile, "data.mv.db") + val backup = new File(tempDir.toFile, backupFileName) + + val driver = new Driver() + val conn = driver.connect(h2Url(requestedDbFile), new Properties()); + try { + assert(createdDbFile.exists()) + + H2BackupController.exportConnectedDatabase(conn, backup) + try { + assert(backup.length() > 0) + } + finally { + assert(backup.delete()) + } + } + finally { + conn.close() + assert(createdDbFile.delete()) + } + } + finally { + assert(tempDir.toFile.delete()) + } + } + test("generates default file name") { assertDefaultFileName(H2BackupController.defaultBackupFileName()) }