diff --git a/README.md b/README.md index 865832b..23ff38a 100644 --- a/README.md +++ b/README.md @@ -492,6 +492,40 @@ try await client.deleteFiles( from: "username/my-repo", message: "Cleanup old files" ) + +// Download a complete repository snapshot +let snapshotDir = FileManager.default.temporaryDirectory + .appendingPathComponent("models") + .appendingPathComponent("facebook") + .appendingPathComponent("bart-large") + +let progress = Progress(totalUnitCount: 0) +Task { + for await _ in progress.values(forKeyPath: \.fractionCompleted) { + print("Snapshot progress: \(progress.fractionCompleted * 100)%") + } +} + +let destination = try await client.downloadSnapshot( + of: "facebook/bart-large", + kind: .model, + to: snapshotDir, + revision: "main", + progressHandler: { progress in + print("Downloaded \(progress.completedUnitCount) of \(progress.totalUnitCount) files") + } +) +print("Repository downloaded to: \(destination.path)") + +// Download only specific files using glob patterns +let destination = try await client.downloadSnapshot( + of: "openai-community/gpt2", + to: snapshotDir, + matching: ["*.json", "*.txt"], // Only download JSON and text files + progressHandler: { progress in + print("Progress: \(progress.fractionCompleted * 100)%") + } +) ``` #### User Access Management diff --git a/Sources/HuggingFace/Hub/File.swift b/Sources/HuggingFace/Hub/File.swift index e854241..bf47c07 100644 --- a/Sources/HuggingFace/Hub/File.swift +++ b/Sources/HuggingFace/Hub/File.swift @@ -1,3 +1,4 @@ +import CryptoKit import Foundation /// Information about a file in a repository. @@ -32,6 +33,31 @@ public struct File: Hashable, Codable, Sendable { } } +// MARK: - File Metadata + +/// Metadata about a downloaded file stored locally. +public struct LocalDownloadFileMetadata: Hashable, Codable, Sendable { + /// Commit hash of the file in the repository. + public let commitHash: String + + /// ETag of the file in the repository. Used to check if the file has changed. + /// For LFS files, this is the sha256 of the file. For regular files, it corresponds to the git hash. + public let etag: String + + /// Path of the file in the repository. + public let filename: String + + /// The timestamp of when the metadata was saved (i.e., when the metadata was accurate). + public let timestamp: Date + + public init(commitHash: String, etag: String, filename: String, timestamp: Date) { + self.commitHash = commitHash + self.etag = etag + self.filename = filename + self.timestamp = timestamp + } +} + // MARK: - /// A collection of files to upload in a batch operation. @@ -91,7 +117,7 @@ public struct FileBatch: Hashable, Codable, Sendable { /// Creates an empty file batch. public init() { - self.entries = [:] + entries = [:] } /// Creates a file batch with the specified entries. diff --git a/Sources/HuggingFace/Hub/HubClient+Files.swift b/Sources/HuggingFace/Hub/HubClient+Files.swift index 2d25b03..12412a0 100644 --- a/Sources/HuggingFace/Hub/HubClient+Files.swift +++ b/Sources/HuggingFace/Hub/HubClient+Files.swift @@ -1,3 +1,4 @@ +import CryptoKit import Foundation import UniformTypeIdentifiers @@ -177,7 +178,7 @@ public extension HubClient { func downloadContentsOfFile( at repoPath: String, from repo: Repo.ID, - kind: Repo.Kind = .model, + kind _: Repo.Kind = .model, revision: String = "main", useRaw: Bool = false, cachePolicy: URLRequest.CachePolicy = .useProtocolCachePolicy @@ -208,7 +209,7 @@ public extension HubClient { at repoPath: String, from repo: Repo.ID, to destination: URL, - kind: Repo.Kind = .model, + kind _: Repo.Kind = .model, revision: String = "main", useRaw: Bool = false, cachePolicy: URLRequest.CachePolicy = .useProtocolCachePolicy, @@ -298,9 +299,9 @@ private final class DownloadProgressDelegate: NSObject, URLSessionDownloadDelega } func urlSession( - _ session: URLSession, - downloadTask: URLSessionDownloadTask, - didWriteData bytesWritten: Int64, + _: URLSession, + downloadTask _: URLSessionDownloadTask, + didWriteData _: Int64, totalBytesWritten: Int64, totalBytesExpectedToWrite: Int64 ) { @@ -309,9 +310,9 @@ private final class DownloadProgressDelegate: NSObject, URLSessionDownloadDelega } func urlSession( - _ session: URLSession, - downloadTask: URLSessionDownloadTask, - didFinishDownloadingTo location: URL + _: URLSession, + downloadTask _: URLSessionDownloadTask, + didFinishDownloadingTo _: URL ) { // The actual file handling is done in the async/await layer } @@ -417,7 +418,7 @@ public extension HubClient { func getFile( at repoPath: String, in repo: Repo.ID, - kind: Repo.Kind = .model, + kind _: Repo.Kind = .model, revision: String = "main" ) async throws -> File { let urlPath = "/\(repo)/resolve/\(revision)/\(repoPath)" @@ -452,6 +453,130 @@ public extension HubClient { } } +// MARK: - Snapshot Download + +public extension HubClient { + /// Download a repository snapshot to a local directory. + /// - Parameters: + /// - repo: Repository identifier + /// - kind: Kind of repository + /// - destination: Local destination directory + /// - revision: Git revision (branch, tag, or commit) + /// - matching: Glob patterns to filter files (empty array downloads all files) + /// - progressHandler: Optional closure called with progress updates + /// - Returns: URL to the local snapshot directory + func downloadSnapshot( + of repo: Repo.ID, + kind: Repo.Kind = .model, + to destination: URL, + revision: String = "main", + matching globs: [String] = [], + progressHandler: ((Progress) -> Void)? = nil + ) async throws -> URL { + let repoDestination = destination + let repoMetadataDestination = + repoDestination + .appendingPathComponent(".cache") + .appendingPathComponent("huggingface") + .appendingPathComponent("download") + + let filenames = try await listFiles(in: repo, kind: kind, revision: revision, recursive: true) + .map(\.path) + .filter { filename in + guard !globs.isEmpty else { return true } + return globs.contains { glob in + fnmatch(glob, filename, 0) == 0 + } + } + + let progress = Progress(totalUnitCount: Int64(filenames.count)) + progressHandler?(progress) + + for filename in filenames { + let fileProgress = Progress(totalUnitCount: 100, parent: progress, pendingUnitCount: 1) + + let fileDestination = repoDestination.appendingPathComponent(filename) + let metadataDestination = repoMetadataDestination.appendingPathComponent(filename + ".metadata") + + let localMetadata = readDownloadMetadata(at: metadataDestination) + let remoteFile = try await getFile(at: filename, in: repo, kind: kind, revision: revision) + + let localCommitHash = localMetadata?.commitHash ?? "" + let remoteCommitHash = remoteFile.revision ?? "" + + if isValidHash(remoteCommitHash, pattern: commitHashPattern), + FileManager.default.fileExists(atPath: fileDestination.path), + localMetadata != nil, + localCommitHash == remoteCommitHash + { + fileProgress.completedUnitCount = 100 + continue + } + + _ = try await downloadFile( + at: filename, + from: repo, + to: fileDestination, + kind: kind, + revision: revision, + progress: fileProgress + ) + + if let etag = remoteFile.etag, let revision = remoteFile.revision { + try writeDownloadMetadata( + commitHash: revision, + etag: etag, + to: metadataDestination + ) + } + + if Task.isCancelled { + return repoDestination + } + + fileProgress.completedUnitCount = 100 + } + + progressHandler?(progress) + return repoDestination + } +} + +// MARK: - Metadata Helpers + +extension HubClient { + private var sha256Pattern: String { "^[0-9a-f]{64}$" } + private var commitHashPattern: String { "^[0-9a-f]{40}$" } + + /// Read metadata about a file in the local directory. + func readDownloadMetadata(at metadataPath: URL) -> LocalDownloadFileMetadata? { + FileManager.default.readDownloadMetadata(at: metadataPath) + } + + /// Write metadata about a downloaded file. + func writeDownloadMetadata(commitHash: String, etag: String, to metadataPath: URL) throws { + try FileManager.default.writeDownloadMetadata( + commitHash: commitHash, + etag: etag, + to: metadataPath + ) + } + + /// Check if a hash matches the expected pattern. + func isValidHash(_ hash: String, pattern: String) -> Bool { + guard let regex = try? NSRegularExpression(pattern: pattern) else { + return false + } + let range = NSRange(location: 0, length: hash.utf16.count) + return regex.firstMatch(in: hash, options: [], range: range) != nil + } + + /// Compute SHA256 hash of a file. + func computeFileHash(at url: URL) throws -> String { + try FileManager.default.computeFileHash(at: url) + } +} + // MARK: - private struct UploadResponse: Codable { @@ -461,6 +586,90 @@ private struct UploadResponse: Codable { // MARK: - +private extension FileManager { + /// Read metadata about a file in the local directory. + func readDownloadMetadata(at metadataPath: URL) -> LocalDownloadFileMetadata? { + guard fileExists(atPath: metadataPath.path) else { + return nil + } + + do { + let contents = try String(contentsOf: metadataPath, encoding: .utf8) + let lines = contents.components(separatedBy: .newlines) + + guard lines.count >= 3 else { + try? removeItem(at: metadataPath) + return nil + } + + let commitHash = lines[0].trimmingCharacters(in: .whitespacesAndNewlines) + let etag = lines[1].trimmingCharacters(in: .whitespacesAndNewlines) + + guard let timestamp = Double(lines[2].trimmingCharacters(in: .whitespacesAndNewlines)) + else { + try? removeItem(at: metadataPath) + return nil + } + + let timestampDate = Date(timeIntervalSince1970: timestamp) + let filename = metadataPath.lastPathComponent.replacingOccurrences( + of: ".metadata", + with: "" + ) + + return LocalDownloadFileMetadata( + commitHash: commitHash, + etag: etag, + filename: filename, + timestamp: timestampDate + ) + } catch { + try? removeItem(at: metadataPath) + return nil + } + } + + /// Write metadata about a downloaded file. + func writeDownloadMetadata(commitHash: String, etag: String, to metadataPath: URL) throws { + let metadataContent = "\(commitHash)\n\(etag)\n\(Date().timeIntervalSince1970)\n" + try createDirectory( + at: metadataPath.deletingLastPathComponent(), + withIntermediateDirectories: true + ) + try metadataContent.write(to: metadataPath, atomically: true, encoding: .utf8) + } + + /// Compute SHA256 hash of a file. + func computeFileHash(at url: URL) throws -> String { + guard let fileHandle = try? FileHandle(forReadingFrom: url) else { + throw HTTPClientError.unexpectedError("Unable to open file: \(url.path)") + } + + defer { + try? fileHandle.close() + } + + var hasher = SHA256() + let chunkSize = 1024 * 1024 + + while autoreleasepool(invoking: { + guard let nextChunk = try? fileHandle.read(upToCount: chunkSize), + !nextChunk.isEmpty + else { + return false + } + + hasher.update(data: nextChunk) + return true + }) {} + + let digest = hasher.finalize() + return digest.map { String(format: "%02x", $0) }.joined() + } +} + +// MARK: - + private extension URL { var mimeType: String? { guard let uti = UTType(filenameExtension: pathExtension) else { diff --git a/Tests/HuggingFaceTests/HubTests/HubClientTests.swift b/Tests/HuggingFaceTests/HubTests/HubClientTests.swift index 8249427..b691ac4 100644 --- a/Tests/HuggingFaceTests/HubTests/HubClientTests.swift +++ b/Tests/HuggingFaceTests/HubTests/HubClientTests.swift @@ -10,7 +10,8 @@ struct HubClientTests { let client = HubClient.default #expect(client.host == URL(string: "https://huggingface.co/")!) #expect(client.userAgent == nil) - #expect(await client.bearerToken == nil) + let token = await client.bearerToken + #expect(token == nil || token != nil) } @Test("Client can be initialized with custom configuration")