Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ios): don't call urlSchemeTask methods if it was stopped #7453

Closed
wants to merge 6 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 117 additions & 54 deletions ios/Capacitor/Capacitor/WebViewAssetHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import MobileCoreServices
open class WebViewAssetHandler: NSObject, WKURLSchemeHandler {
private var router: Router
private var serverUrl: URL?
private var pendingTasks = ConcurrentTasks()

public init(router: Router) {
self.router = router
Expand All @@ -25,6 +26,7 @@ open class WebViewAssetHandler: NSObject, WKURLSchemeHandler {
}

open func webView(_ webView: WKWebView, start urlSchemeTask: WKURLSchemeTask) {
pendingTasks.insert(urlSchemeTask.hash)
let startPath: String
let url = urlSchemeTask.request.url!
let stringToLoad = url.path
Expand All @@ -47,67 +49,82 @@ open class WebViewAssetHandler: NSObject, WKURLSchemeHandler {
}

let fileUrl = URL.init(fileURLWithPath: startPath)
if self.pendingTasks.contains(urlSchemeTask.hash) {
do {
var data = Data()
let mimeType = mimeTypeForExtension(pathExtension: url.pathExtension)
var headers = [
"Content-Type": mimeType,
"Cache-Control": "no-cache"
]

do {
var data = Data()
let mimeType = mimeTypeForExtension(pathExtension: url.pathExtension)
var headers = [
"Content-Type": mimeType,
"Cache-Control": "no-cache"
]

// if using live reload, then set CORS headers
if isUsingLiveReload(localUrl) {
headers["Access-Control-Allow-Origin"] = self.serverUrl?.absoluteString
headers["Access-Control-Allow-Methods"] = "GET, HEAD, OPTIONS, TRACE"
}

if let rangeString = urlSchemeTask.request.value(forHTTPHeaderField: "Range"),
let totalSize = try fileUrl.resourceValues(forKeys: [.fileSizeKey]).fileSize,
isMediaExtension(pathExtension: url.pathExtension) {
let fileHandle = try FileHandle(forReadingFrom: fileUrl)
let parts = rangeString.components(separatedBy: "=")
let streamParts = parts[1].components(separatedBy: "-")
let fromRange = Int(streamParts[0]) ?? 0
var toRange = totalSize - 1
if streamParts.count > 1 {
toRange = Int(streamParts[1]) ?? toRange
// if using live reload, then set CORS headers
if isUsingLiveReload(localUrl) {
headers["Access-Control-Allow-Origin"] = self.serverUrl?.absoluteString
headers["Access-Control-Allow-Methods"] = "GET, HEAD, OPTIONS, TRACE"
}
let rangeLength = toRange - fromRange + 1
try fileHandle.seek(toOffset: UInt64(fromRange))
data = fileHandle.readData(ofLength: rangeLength)
headers["Accept-Ranges"] = "bytes"
headers["Content-Range"] = "bytes \(fromRange)-\(toRange)/\(totalSize)"
headers["Content-Length"] = String(data.count)
let response = HTTPURLResponse(url: localUrl, statusCode: 206, httpVersion: nil, headerFields: headers)
urlSchemeTask.didReceive(response!)
try fileHandle.close()
} else {
if !stringToLoad.contains("cordova.js") {

if let rangeString = urlSchemeTask.request.value(forHTTPHeaderField: "Range"),
let totalSize = try fileUrl.resourceValues(forKeys: [.fileSizeKey]).fileSize,
isMediaExtension(pathExtension: url.pathExtension) {
let fileHandle = try FileHandle(forReadingFrom: fileUrl)
let parts = rangeString.components(separatedBy: "=")
let streamParts = parts[1].components(separatedBy: "-")
let fromRange = Int(streamParts[0]) ?? 0
var toRange = totalSize - 1
if streamParts.count > 1 {
toRange = Int(streamParts[1]) ?? toRange
}
let rangeLength = toRange - fromRange + 1
try fileHandle.seek(toOffset: UInt64(fromRange))
data = fileHandle.readData(ofLength: rangeLength)
headers["Accept-Ranges"] = "bytes"
headers["Content-Range"] = "bytes \(fromRange)-\(toRange)/\(totalSize)"
headers["Content-Length"] = String(data.count)
let response = HTTPURLResponse(url: localUrl, statusCode: 206, httpVersion: nil, headerFields: headers)
pendingTasks.withTask(urlSchemeTask) {
urlSchemeTask.didReceive(response!)
}
try fileHandle.close()
} else {
if !stringToLoad.contains("cordova.js") {
if isMediaExtension(pathExtension: url.pathExtension) {
data = try Data(contentsOf: fileUrl, options: Data.ReadingOptions.mappedIfSafe)
} else {
data = try Data(contentsOf: fileUrl)
}
}
let urlResponse = URLResponse(url: localUrl, mimeType: mimeType, expectedContentLength: data.count, textEncodingName: nil)
let httpResponse = HTTPURLResponse(url: localUrl, statusCode: 200, httpVersion: nil, headerFields: headers)
if isMediaExtension(pathExtension: url.pathExtension) {
data = try Data(contentsOf: fileUrl, options: Data.ReadingOptions.mappedIfSafe)
pendingTasks.withTask(urlSchemeTask) {
urlSchemeTask.didReceive(urlResponse)
}
} else {
data = try Data(contentsOf: fileUrl)
pendingTasks.withTask(urlSchemeTask) {
urlSchemeTask.didReceive(httpResponse!)
}
}
}
let urlResponse = URLResponse(url: localUrl, mimeType: mimeType, expectedContentLength: data.count, textEncodingName: nil)
let httpResponse = HTTPURLResponse(url: localUrl, statusCode: 200, httpVersion: nil, headerFields: headers)
if isMediaExtension(pathExtension: url.pathExtension) {
urlSchemeTask.didReceive(urlResponse)
} else {
urlSchemeTask.didReceive(httpResponse!)
pendingTasks.withTask(urlSchemeTask) {
urlSchemeTask.didReceive(data)
}
} catch let error as NSError {
pendingTasks.withTask(urlSchemeTask) {
urlSchemeTask.didFailWithError(error)
}
self.pendingTasks.remove(urlSchemeTask.hash)
return
}
urlSchemeTask.didReceive(data)
} catch let error as NSError {
urlSchemeTask.didFailWithError(error)
return
pendingTasks.withTask(urlSchemeTask) {
urlSchemeTask.didFinish()
}
self.pendingTasks.remove(urlSchemeTask.hash)
}
urlSchemeTask.didFinish()
}

open func webView(_ webView: WKWebView, stop urlSchemeTask: WKURLSchemeTask) {
CAPLog.print("scheme stop")
pendingTasks.remove(urlSchemeTask.hash)
}

open func mimeTypeForExtension(pathExtension: String) -> String {
Expand Down Expand Up @@ -137,7 +154,10 @@ open class WebViewAssetHandler: NSObject, WKURLSchemeHandler {

func handleCapacitorHttpRequest(_ urlSchemeTask: WKURLSchemeTask, _ localUrl: URL, _ isHttpsRequest: Bool) {
var urlRequest = urlSchemeTask.request
guard let url = urlRequest.url else { return }
guard let url = urlRequest.url else {
pendingTasks.remove(urlSchemeTask.hash)
return
}
var targetUrl = url.absoluteString
.replacingOccurrences(of: CapacitorBridge.httpInterceptorStartIdentifier, with: "")
.replacingOccurrences(of: CapacitorBridge.httpsInterceptorStartIdentifier, with: "")
Expand All @@ -156,7 +176,10 @@ open class WebViewAssetHandler: NSObject, WKURLSchemeHandler {
let urlSession = URLSession.shared
let task = urlSession.dataTask(with: urlRequest) { (data, response, error) in
if let error = error {
urlSchemeTask.didFailWithError(error)
self.pendingTasks.withTask(urlSchemeTask) {
urlSchemeTask.didFailWithError(error)
}
self.pendingTasks.remove(urlSchemeTask.hash)
return
}

Expand All @@ -181,16 +204,23 @@ open class WebViewAssetHandler: NSObject, WKURLSchemeHandler {
httpVersion: nil,
headerFields: mergedHeaders
) {
urlSchemeTask.didReceive(modifiedResponse)
self.pendingTasks.withTask(urlSchemeTask) {
urlSchemeTask.didReceive(modifiedResponse)
}
}
}

if let data = data {
urlSchemeTask.didReceive(data)
self.pendingTasks.withTask(urlSchemeTask) {
urlSchemeTask.didReceive(data)
}
}
}
}
urlSchemeTask.didFinish()
self.pendingTasks.withTask(urlSchemeTask) {
urlSchemeTask.didFinish()
}
self.pendingTasks.remove(urlSchemeTask.hash)
return
}

Expand Down Expand Up @@ -546,3 +576,36 @@ open class WebViewAssetHandler: NSObject, WKURLSchemeHandler {
"zip": "application/x-zip-compressed"
]
}

private class ConcurrentTasks {
private var tasks: Set<Int>
private let lock = NSLock()

init() {
tasks = []
}

func contains(_ value: Int) -> Bool {
lock.withLock { tasks.contains(value) }
}

@discardableResult
func remove(_ value: Int) -> Int? {
lock.withLock { tasks.remove(value) }
}

@discardableResult
func insert(_ value: Int) -> (inserted: Bool, memberAfterInsert: Int) {
lock.withLock { tasks.insert(value) }
}

func withTask(_ schemeTask: WKURLSchemeTask, action: @escaping () -> Void) {
lock.withLock {
if tasks.contains(schemeTask.hash) {
DispatchQueue.main.async {
action()
}
}
}
}
Comment on lines +602 to +610
Copy link
Member

@Steven0351 Steven0351 May 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of using lock.withLock, we should probably use the individual locking methods so the code would look like this:

func withTask(_ schemeTask: WKURLSchemeTask, action: @escaping () -> Void) {
  lock.lock()
  if tasks.contains(schemeTask.hash) {
    DispatchQueue.main.async {
      action()
      lock.unlock()
    }
  } else {
    lock.unlock()
  }
}

In it's currently state, the lock is can actually be released before the action is run because DispatchQueue.main.async returns immediately and would still allow for the race condition to occur.

}