Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Examples/Pose3SLAMG2O/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
/// - Does not use a proper general purpose solver.
/// - Has not been compared against other implementations, so it could be wrong.

import _Differentiation
import Foundation
import SwiftFusion
import TensorFlow
Expand Down
13 changes: 2 additions & 11 deletions Package.resolved

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 8 additions & 5 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@ let package = Package(
// .package(url: /* package url */, from: "1.0.0"),
.package(url: "https://github.com/google/swift-benchmark.git", from: "0.1.0"),

.package(url: "https://github.com/saeta/penguin.git", .branch("master")),
.package(url: "https://github.com/saeta/penguin.git", .branch("main")),

.package(url: "https://github.com/ProfFan/tensorboardx-s4tf.git", from: "0.1.3"),
.package(url: "https://github.com/apple/swift-tools-support-core.git", .branch("swift-5.2-branch")),
.package(url: "https://github.com/apple/swift-argument-parser.git", from: "0.3.0"),
.package(url: "https://github.com/tensorflow/swift-models.git", .branch("b2fc0325bf9d476bf2d7a4cd0a09d36486c506e4")),
],
targets: [
// Targets are the basic building blocks of a package. A target can define a module or a test suite.
Expand All @@ -46,8 +45,7 @@ let package = Package(
name: "BeeDataset",
dependencies: [
"SwiftFusion",
.product(name: "Datasets", package: "swift-models"),
.product(name: "ModelSupport", package: "swift-models"),
"ModelSupport",
]),
.target(
name: "Pose3SLAMG2O",
Expand All @@ -58,9 +56,14 @@ let package = Package(
dependencies: [
"SwiftFusion",
"PenguinTesting",
.product(name: "ModelSupport", package: "swift-models"),
"ModelSupport",
]),
.testTarget(
name: "BeeDatasetTests",
dependencies: ["BeeDataset"]),
.target(
name: "ModelSupport",
dependencies: ["STBImage"]),
.target(
name: "STBImage"),
])
1 change: 1 addition & 0 deletions Sources/BeeDataset/BeeFrames.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

import _Differentiation
import Foundation
import ModelSupport
import SwiftFusion
Expand Down
1 change: 1 addition & 0 deletions Sources/BeeDataset/BeeOrientedBoundingBoxes.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

import _Differentiation
import Foundation
import SwiftFusion

Expand Down
2 changes: 1 addition & 1 deletion Sources/BeeDataset/BeeVideo.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

import Datasets
import _Differentiation
import Foundation
import ModelSupport
import SwiftFusion
Expand Down
3 changes: 2 additions & 1 deletion Sources/BeeDataset/Download.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.

import Datasets
import _Differentiation
import Foundation
import ModelSupport

/// Downloads the bee dataset (if it's not already present), and returns its URL on the local
/// system.
Expand Down
112 changes: 112 additions & 0 deletions Sources/ModelSupport/DatasetUtilities.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import Foundation

#if canImport(FoundationNetworking)
import FoundationNetworking
#endif

public enum DatasetUtilities {
public static let currentWorkingDirectoryURL = URL(
fileURLWithPath: FileManager.default.currentDirectoryPath)

public static let defaultDirectory = try! FileManager.default.url(
for: .cachesDirectory, in: .userDomainMask, appropriateFor: nil, create: true)
.appendingPathComponent("swift-models").appendingPathComponent("datasets")

@discardableResult
public static func downloadResource(
filename: String,
fileExtension: String,
remoteRoot: URL,
localStorageDirectory: URL = currentWorkingDirectoryURL,
extract: Bool = true
) -> URL {
print("Loading resource: \(filename)")

let resource = ResourceDefinition(
filename: filename,
fileExtension: fileExtension,
remoteRoot: remoteRoot,
localStorageDirectory: localStorageDirectory)

let localURL = resource.localURL

if !FileManager.default.fileExists(atPath: localURL.path) {
print(
"File does not exist locally at expected path: \(localURL.path) and must be fetched"
)
fetchFromRemoteAndSave(resource, extract: extract)
}

return localURL
}

@discardableResult
public static func fetchResource(
filename: String,
fileExtension: String,
remoteRoot: URL,
localStorageDirectory: URL = currentWorkingDirectoryURL
) -> Data {
let localURL = DatasetUtilities.downloadResource(
filename: filename, fileExtension: fileExtension, remoteRoot: remoteRoot,
localStorageDirectory: localStorageDirectory)

do {
let data = try Data(contentsOf: localURL)
return data
} catch {
fatalError("Failed to contents of resource: \(localURL)")
}
}

struct ResourceDefinition {
let filename: String
let fileExtension: String
let remoteRoot: URL
let localStorageDirectory: URL

var localURL: URL {
localStorageDirectory.appendingPathComponent(filename)
}

var remoteURL: URL {
remoteRoot.appendingPathComponent(filename).appendingPathExtension(fileExtension)
}

var archiveURL: URL {
localURL.appendingPathExtension(fileExtension)
}
}

static func fetchFromRemoteAndSave(_ resource: ResourceDefinition, extract: Bool) {
let remoteLocation = resource.remoteURL
let archiveLocation = resource.localStorageDirectory

do {
print("Fetching URL: \(remoteLocation)...")
try download(from: remoteLocation, to: archiveLocation)
} catch {
fatalError("Failed to fetch and save resource with error: \(error)")
}
print("Archive saved to: \(archiveLocation.path)")

if extract {
extractArchive(
at: resource.archiveURL, to: resource.localStorageDirectory,
fileExtension: resource.fileExtension, deleteArchiveWhenDone: true)
}
}
}
159 changes: 159 additions & 0 deletions Sources/ModelSupport/FileManagement.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import Foundation

#if canImport(FoundationNetworking)
import FoundationNetworking
#endif

/// Creates a directory at a path, if missing. If the directory exists, this does nothing.
///
/// - Parameters:
/// - path: The path of the desired directory.
public func createDirectoryIfMissing(at path: String) throws {
guard !FileManager.default.fileExists(atPath: path) else { return }
try FileManager.default.createDirectory(
atPath: path,
withIntermediateDirectories: true,
attributes: nil)
}

/// Downloads a remote file and places it either within a target directory or at a target file name.
/// If `destination` has been explicitly specified as a directory (setting `isDirectory` to true
/// when appending the last path component), the file retains its original name and is placed within
/// this directory. If `destination` isn't marked in this fashion, the file is saved as a file named
/// after `destination` and its last path component. If the encompassing directory is missing in
/// either case, it is created.
///
/// - Parameters:
/// - source: The remote URL of the file to download.
/// - destination: Either the local directory to place the file in, or the local filename.
public func download(from source: URL, to destination: URL) throws {
let destinationFile: String
if destination.hasDirectoryPath {
try createDirectoryIfMissing(at: destination.path)
let fileName = source.lastPathComponent
destinationFile = destination.appendingPathComponent(fileName).path
} else {
try createDirectoryIfMissing(at: destination.deletingLastPathComponent().path)
destinationFile = destination.path
}

let downloadedFile = try Data(contentsOf: source)
try downloadedFile.write(to: URL(fileURLWithPath: destinationFile))
}

/// Collect all file URLs under a folder `url`, potentially recursing through all subfolders.
/// Optionally filters some extension (only jpeg or txt files for instance).
///
/// - Parameters:
/// - url: The folder to explore.
/// - recurse: Will explore all subfolders if set to `true`.
/// - extensions: Only keeps URLs with extensions in that array if it's provided
public func collectURLs(
under directory: URL, recurse: Bool = false, filtering extensions: [String]? = nil
) -> [URL] {
var files: [URL] = []
do {
let dirContents = try FileManager.default.contentsOfDirectory(
at: directory, includingPropertiesForKeys: [.isDirectoryKey],
options: [.skipsHiddenFiles])
for content in dirContents {
if content.hasDirectoryPath && recurse {
files += collectURLs(under: content, recurse: recurse, filtering: extensions)
} else if content.isFileURL
&& (extensions == nil
|| extensions!.contains(content.pathExtension.lowercased()))
{
files.append(content)
}
}
} catch {
fatalError("Could not explore this folder: \(error)")
}
return files
}

/// Extracts a compressed file to a specified directory. This keys off of either the explicit
/// file extension or one determined from the archive to determine which unarchiving method to use.
/// This optionally deletes the original archive when done.
///
/// - Parameters:
/// - archive: The source archive file, assumed to be on the local filesystem.
/// - localStorageDirectory: A directory that the archive will be unpacked into.
/// - fileExtension: An optional explicitly-specified file extension for the archive, determining
/// how it is unpacked.
/// - deleteArchiveWhenDone: Whether or not the original archive is deleted when the extraction
/// process has been completed. This defaults to false.
public func extractArchive(
at archive: URL, to localStorageDirectory: URL, fileExtension: String? = nil,
deleteArchiveWhenDone: Bool = false
) {
let archivePath = archive.path

#if os(macOS)
var binaryLocation = "/usr/bin/"
#else
var binaryLocation = "/bin/"
#endif

let toolName: String
let arguments: [String]
let adjustedPathExtension: String
if archive.path.hasSuffix(".tar.gz") {
adjustedPathExtension = "tar.gz"
} else {
adjustedPathExtension = archive.pathExtension
}
switch fileExtension ?? adjustedPathExtension {
case "gz":
toolName = "gunzip"
arguments = [archivePath]
case "tar":
toolName = "tar"
arguments = ["xf", archivePath, "-C", localStorageDirectory.path]
case "tar.gz", "tgz":
toolName = "tar"
arguments = ["xzf", archivePath, "-C", localStorageDirectory.path]
case "zip":
binaryLocation = "/usr/bin/"
toolName = "unzip"
arguments = ["-qq", archivePath, "-d", localStorageDirectory.path]
default:
print(
"Unable to find archiver for extension \(fileExtension ?? adjustedPathExtension).")
exit(-1)
}
let toolLocation = "\(binaryLocation)\(toolName)"

let task = Process()
task.executableURL = URL(fileURLWithPath: toolLocation)
task.arguments = arguments
do {
try task.run()
task.waitUntilExit()
} catch {
print("Failed to extract \(archivePath) with error: \(error)")
exit(-1)
}

if FileManager.default.fileExists(atPath: archivePath) && deleteArchiveWhenDone {
do {
try FileManager.default.removeItem(atPath: archivePath)
} catch {
print("Could not remove archive, error: \(error)")
exit(-1)
}
}
}
Loading