Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Server fixes #660

Merged
merged 1 commit into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion mrtd-reader/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ java {
dependencies {
implementation(libs.net.sf.scuba.scuba.sc.android)
implementation(libs.org.jmrtd.jmrtd)
implementation(libs.com.google.mlkit.text.recognition)
implementation(libs.kotlinx.io.bytestring)

testImplementation(libs.junit)
Expand Down
6 changes: 4 additions & 2 deletions server/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ plugins {
id("war")
id("java-library")
id("org.jetbrains.kotlin.jvm")
alias(libs.plugins.ksp)
}

kotlin {
Expand All @@ -15,8 +16,10 @@ java {
}

dependencies {
ksp(project(":processor"))
implementation(project(":identity"))
implementation(project(":identity-flow"))
implementation(project(":processor-annotations"))
implementation(project(":identity-issuance"))

implementation(libs.javax.servlet.api)
Expand All @@ -33,5 +36,4 @@ dependencies {
testImplementation(libs.junit)
}

gretty {
}
gretty {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package com.android.identity.wallet.server

import kotlinx.datetime.Instant
import com.android.identity.cbor.annotation.CborSerializable
import kotlinx.io.bytestring.ByteString

/** Data stored in authentication cookie for admin interface. */
@CborSerializable
data class AdminAuthCookie(
val expiration: Instant,
val passwordHash: ByteString
) {
companion object
}
227 changes: 197 additions & 30 deletions server/src/main/java/com/android/identity/wallet/server/FlowServlet.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
package com.android.identity.wallet.server

import com.android.identity.cbor.Cbor
import com.android.identity.crypto.Algorithm
import com.android.identity.crypto.Crypto
import com.android.identity.flow.handler.AesGcmCipher
import com.android.identity.flow.handler.FlowDispatcherLocal
import com.android.identity.flow.handler.FlowExceptionMap
Expand All @@ -10,19 +13,28 @@ import com.android.identity.flow.server.FlowEnvironment
import com.android.identity.flow.server.Resources
import com.android.identity.flow.server.Storage
import com.android.identity.flow.transport.HttpTransport
import com.android.identity.issuance.hardcoded.IssuerDocument
import com.android.identity.issuance.hardcoded.IssuingAuthorityState
import com.android.identity.issuance.hardcoded.WalletServerState
import com.android.identity.util.Logger
import io.ktor.utils.io.core.toByteArray
import jakarta.servlet.ServletConfig
import jakarta.servlet.http.Cookie
import jakarta.servlet.http.HttpServlet
import kotlinx.coroutines.runBlocking
import jakarta.servlet.http.HttpServletRequest
import jakarta.servlet.http.HttpServletResponse
import kotlinx.datetime.Clock
import kotlinx.io.bytestring.ByteString
import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.bouncycastle.util.encoders.Base64
import java.lang.UnsupportedOperationException
import java.net.URLEncoder
import java.nio.charset.Charset
import java.security.Security
import kotlin.random.Random
import kotlin.time.Duration.Companion.days
import kotlin.time.DurationUnit

// To run this servlet for development, use this command:
//
Expand All @@ -35,19 +47,25 @@ class FlowServlet : HttpServlet() {
companion object {
private const val TAG = "FlowServlet"

private const val PASSWORD_SALT = "1CxCucFhcIzcbMnSrIgB"
private val AUTH_VALIDITY_DURATION = 7.days
private val LIST_HEAD = "<!DOCTYPE html><html><head></head><body><ul>"
private val LIST_TAIL = "</ul></body></html>"
private val TABLE_HEAD = "<!DOCTYPE html><html><head><link rel='stylesheet' href='table.css'/></head><body><table>"
private val TABLE_TAIL = "</table></body></html>"

private lateinit var serverEnvironment: FlowEnvironment
private lateinit var httpHandler: HttpHandler
private lateinit var stateCipher: SimpleCipher
private lateinit var adminPasswordHash: ByteString

@Synchronized
private fun initialize(servletConfig: ServletConfig) {
if (this::serverEnvironment.isInitialized) {
return
}

serverEnvironment = ServerEnvironment(
directory = "environment",
servletConfig
)
serverEnvironment = ServerEnvironment(servletConfig)

val dispatcherBuilder = FlowDispatcherLocal.Builder()
WalletServerState.registerAll(dispatcherBuilder)
Expand All @@ -66,7 +84,13 @@ class FlowServlet : HttpServlet() {
newKey
}
}
adminPasswordHash = runBlocking {
// Don't write initial password hash in the storage
storage.get("RootState", "", "passwordHash")
?: saltedHash(servletConfig.getInitParameter("initialAdminPassword"))
}
val cipher = AesGcmCipher(messageEncryptionKey)
stateCipher = cipher
val localPoll = FlowNotificationsLocalPoll(cipher)
(serverEnvironment as ServerEnvironment).notifications = localPoll
val localDispatcher = dispatcherBuilder.build(
Expand All @@ -77,6 +101,15 @@ class FlowServlet : HttpServlet() {

httpHandler = HttpHandler(localDispatcher, localPoll)
}

@Synchronized
private fun updateAdminPasswordHash(hash: ByteString) {
adminPasswordHash = hash
}

private fun saltedHash(password: String): ByteString {
return ByteString(Crypto.digest(Algorithm.SHA256, "$PASSWORD_SALT$password".toByteArray()))
}
}

@Override
Expand All @@ -98,20 +131,21 @@ class FlowServlet : HttpServlet() {
}

override fun doPost(req: HttpServletRequest, resp: HttpServletResponse) {
val path = req.servletPath.substring(1)
val threadId = Thread.currentThread().id
val remoteHost = getRemoteHost(req)
val prefix = "tid=$threadId host=$remoteHost"
val requestLength = req.contentLength
Logger.i(TAG, "$prefix: POST ${req.servletPath} ($requestLength bytes)")
val parts = req.servletPath.split("/")
if (parts.size < 3) {
Logger.i(TAG, "$prefix: POST $path ($requestLength bytes)")
val parts = path.split("/")
if (parts.size != 2) {
Logger.i(TAG, "$prefix: malformed request")
throw Exception("Illegal request!")
}
val target = parts[parts.lastIndex-1]
val action = parts.last()
val target = parts[0]
val action = parts[1]
if (target == "admin") {
doAdmin(action, req.parameterMap, resp)
doAdmin(req, resp)
return
}
val requestData = req.inputStream.readNBytes(requestLength)
Expand Down Expand Up @@ -144,7 +178,44 @@ class FlowServlet : HttpServlet() {
}
}

private fun doAdmin(action: String, parameters: Map<String, Array<String>>, resp: HttpServletResponse) {

private fun getAuthCookie(req: HttpServletRequest): Cookie? {
if (req.cookies != null) {
for (cookie in req.cookies) {
if (cookie.name == "Auth") {
return cookie
}
}
}
return null
}

private fun adminAuthCheck(req: HttpServletRequest, resp: HttpServletResponse): Boolean {
val cookie = getAuthCookie(req)
if (cookie != null) {
try {
val parsedCookie =
AdminAuthCookie.fromCbor(stateCipher.decrypt(Base64.decode(cookie.value)))
if (parsedCookie.expiration >= Clock.System.now()
&& parsedCookie.passwordHash == adminPasswordHash) {
return true
}
Logger.e(TAG, "Expired or stale Auth cookie: ${parsedCookie.expiration}")
} catch (err: Exception) {
Logger.e(TAG, "Error parsing Auth cookie", err)
}
}
resp.sendRedirect("${req.contextPath}/login.html")
return false
}

private fun doAdmin(req: HttpServletRequest, resp: HttpServletResponse) {
val action = req.servletPath.split("/").last()
if (action != "login" && !adminAuthCheck(req, resp)) {
resp.sendRedirect("${req.contextPath}/login.html")
return
}
val parameters = req.parameterMap
when (action) {
"updateDocument" -> {
val clientId = parameters["clientId"]!![0]
Expand All @@ -162,37 +233,133 @@ class FlowServlet : HttpServlet() {
resp.contentType = "text/plain"
resp.writer.println("Success")
}
"login" -> {
val password = parameters["password"]!![0]
val cookie = getAuthCookie(req)
if (cookie != null) {
cookie.maxAge = 0 // remove existing cookie if present
}
if (saltedHash(password) == adminPasswordHash) {
val expiration = Clock.System.now() + AUTH_VALIDITY_DURATION
val auth = Base64.toBase64String(stateCipher.encrypt(
AdminAuthCookie(expiration, adminPasswordHash).toCbor()))
val newCookie = Cookie("Auth", auth)
newCookie.path = "${req.contextPath}/"
newCookie.maxAge = AUTH_VALIDITY_DURATION.toInt(DurationUnit.SECONDS)
resp.addCookie(newCookie)
Logger.i(TAG, "Successful login")
resp.sendRedirect("${req.contextPath}/index.html")
} else {
Logger.e(TAG, "Incorrect password")
resp.sendRedirect("${req.contextPath}/login.html")
}
}
"password" -> {
val oldPassword = parameters["oldPassword"]!![0]
if (saltedHash(oldPassword) != adminPasswordHash) {
resp.contentType = "text/plain"
resp.writer.println("Old password is not correct")
return
}
val password = parameters["newPassword"]!![0]
if (password != parameters["newPassword1"]!![0]) {
resp.contentType = "text/plain"
resp.writer.println("Passwords do not match")
return
}
val hash = saltedHash(password)
adminPasswordHash = hash
val storage = serverEnvironment.getInterface(Storage::class)!!
runBlocking {
if (storage.get("RootState", "", "passwordHash") == null) {
storage.insert("RootState", "", hash,"passwordHash")
} else {
storage.update("RootState", "", "passwordHash", hash)
}
}
updateAdminPasswordHash(hash)
resp.sendRedirect("${req.contextPath}/login.html")
}
else -> {
resp.sendError(404)
}
}
}

private fun htmlEscape(text: String?): String {
return (text ?: "<null>")
.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
}

override fun doGet(req: HttpServletRequest, resp: HttpServletResponse) {
val rawPath = req.servletPath.substring(1)
val path = rawPath.ifEmpty { "index.html" }
val threadId = Thread.currentThread().id
val remoteHost = getRemoteHost(req)
val prefix = "tid=$threadId host=$remoteHost"
Logger.i(TAG, "$prefix: GET ${req.servletPath}")
val parts = req.servletPath.split("/")
if (req.servletPath.indexOf("..") >= 0) {
Logger.i(TAG, "$prefix: malformed request")
throw Exception("Illegal request!")
Logger.i(TAG, "$prefix: GET $path")
if (path != "login.html" && !adminAuthCheck(req, resp)) {
return
}
val resources = serverEnvironment.getInterface(Resources::class)!!
val path = parts.last()
val data = resources.getRawResource("www/$path")
if (data == null) {
resp.sendError(404)
} else {
val extension = path.substring(path.lastIndexOf(".") + 1)
resp.contentType = when(extension) {
"html" -> "text/html"
"jpeg", "jpg" -> "image/jpeg"
"png" -> "image/png"
"js" -> "application/javascript"
else -> "application/octet-stream"
when (path) {
"documents.html" -> {
val clientIds = req.parameterMap["clientId"]
val clientId = if (clientIds == null) "" else clientIds[0]!!
resp.contentType = "text/html; charset=utf-8"
val writer = resp.outputStream.writer(Charset.forName("utf-8"))
val storage = serverEnvironment.getInterface(Storage::class)!!
writer.write(TABLE_HEAD)
writer.write("<tr><th>Id</th><th>Display Name</th></tr>")
runBlocking {
val documentIds = storage.enumerate("IssuerDocument", clientId)
for (documentId in documentIds) {
writer.write("<tr>")
writer.write("<td class='code'>${htmlEscape(documentId)}</td>")
val documentData = storage.get("IssuerDocument", clientId, documentId)!!
val document = IssuerDocument.fromDataItem(Cbor.decode(documentData.toByteArray()))
writer.write("<td>${htmlEscape(document.documentConfiguration?.displayName)}</td>")
writer.write("<tr>")
}
}
writer.write(TABLE_TAIL)
writer.flush()
}
"clients.html" -> {
resp.contentType = "text/html; charset=utf-8"
val writer = resp.outputStream.writer(Charset.forName("utf-8"))
val storage = serverEnvironment.getInterface(Storage::class)!!
writer.write(LIST_HEAD)
runBlocking {
val clients = storage.enumerate("ClientKeys", "")
for (client in clients) {
val escaped = htmlEscape(client)
val urlenc = URLEncoder.encode(client, "utf-8")
writer.write("<li><a href='documents.html?clientId=$urlenc'>$escaped</a></li>")
}
}
writer.write(LIST_TAIL)
writer.flush()
}
else -> {
val resources = serverEnvironment.getInterface(Resources::class)!!
val data = resources.getRawResource("www/$path")
if (data == null) {
resp.sendError(404)
} else {
val extension = path.substring(path.lastIndexOf(".") + 1)
resp.contentType = when (extension) {
"html" -> "text/html; charset=utf-8"
"jpeg", "jpg" -> "image/jpeg"
"png" -> "image/png"
"js" -> "application/javascript"
"css" -> "text/css"
else -> "application/octet-stream"
}
resp.outputStream.write(data.toByteArray())
}
}
resp.outputStream.write(data.toByteArray())
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@ import kotlin.reflect.KClass
import kotlin.reflect.cast

class ServerEnvironment(
private val directory: String,
servletConfig: ServletConfig,
) : FlowEnvironment {
private val configuration = ServerConfiguration(servletConfig)
private val settings = WalletServerSettings(configuration)
private val resources = ServerResources("$directory/resources")
private val resources = ServerResources()
private val storage = ServerStorage(
settings.databaseConnection ?: defaultDatabase(),
settings.databaseUser ?: "",
Expand All @@ -40,7 +39,7 @@ class ServerEnvironment(
}

private fun defaultDatabase(): String {
val dbFile = File("$directory/db/db.hsqldb").absoluteFile
val dbFile = File("environment/db/db.hsqldb").absoluteFile
if (!dbFile.canRead()) {
val parent = File(dbFile.parent)
if (!parent.exists()) {
Expand Down
Loading
Loading