Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,22 @@ let package = Package(
.library(name: "Transformers", targets: ["Tokenizers", "Generation", "Models"]),
],
dependencies: [
.package(url: "https://github.com/huggingface/swift-jinja.git", from: "2.0.0")
.package(url: "https://github.com/huggingface/swift-jinja.git", from: "2.0.0"),
.package(path: "../swift-huggingface"),
],
targets: [
.target(name: "Generation", dependencies: ["Tokenizers"]),
.target(name: "Hub", dependencies: [.product(name: "Jinja", package: "swift-jinja")], resources: [.process("Resources")], swiftSettings: swiftSettings),
.target(
name: "Hub",
dependencies: [
.product(name: "Jinja", package: "swift-jinja"),
.product(name: "HuggingFace", package: "swift-huggingface"),
],
resources: [
.process("Resources")
],
swiftSettings: swiftSettings
),
.target(name: "Models", dependencies: ["Tokenizers", "Generation"]),
.target(name: "Tokenizers", dependencies: ["Hub", .product(name: "Jinja", package: "swift-jinja")]),
.testTarget(name: "GenerationTests", dependencies: ["Generation"]),
Expand Down
495 changes: 0 additions & 495 deletions Sources/Hub/Downloader.swift

This file was deleted.

40 changes: 40 additions & 0 deletions Sources/Hub/Hub.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
//

import Foundation
import HuggingFace

/// A namespace struct providing access to Hugging Face Hub functionality.
///
Expand Down Expand Up @@ -108,6 +109,45 @@ public extension Hub {
}
}

// MARK: - Type Conversions for HuggingFace Integration

extension Hub.Repo {
/// Converts this `Hub.Repo` to a `Repo.ID` for use with `HubClient`.
///
/// Model names without a namespace (e.g., "t5-base") are treated as having
/// an implicit "hf" namespace, making them "hf/t5-base".
var repoID: HuggingFace.Repo.ID {
if let repoID = HuggingFace.Repo.ID(rawValue: id) {
return repoID
}
// Handle models without namespace (e.g., "t5-base" -> "hf/t5-base")
// These are legacy model IDs that don't follow the namespace/name format
return HuggingFace.Repo.ID(namespace: "hf", name: id)
}
}

extension Hub.RepoType {
/// Converts this `Hub.RepoType` to a `Repo.Kind` for use with `HubClient`.
var repoKind: HuggingFace.Repo.Kind {
switch self {
case .models: return .model
case .datasets: return .dataset
case .spaces: return .space
}
}
}

extension HuggingFace.Repo.Kind {
/// Converts this `Repo.Kind` to a `Hub.RepoType`.
var hubRepoType: Hub.RepoType {
switch self {
case .model: return .models
case .dataset: return .datasets
case .space: return .spaces
}
}
}

/// Manages language model configuration loading from the Hugging Face Hub.
///
/// This class handles the asynchronous loading and processing of model configurations,
Expand Down
234 changes: 109 additions & 125 deletions Sources/Hub/HubApi.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import CryptoKit
import Foundation
import HuggingFace
import Network
import os

Expand Down Expand Up @@ -79,6 +80,12 @@ public struct HubApi: Sendable {
public typealias RepoType = Hub.RepoType
public typealias Repo = Hub.Repo

/// The underlying `HubClient` instance from the `HuggingFace` module.
///
/// Use this property to access the full `HubClient` API for advanced operations
/// not exposed through `HubApi`.
public let client: HubClient

/// Session actor for metadata requests with relative redirect handling (used in HEAD requests).
///
/// Static to share a single URLSession across all HubApi instances, preventing resource
Expand Down Expand Up @@ -115,6 +122,15 @@ public struct HubApi: Sendable {
self.endpoint = endpoint ?? Self.hfEndpointfromEnv()
self.useBackgroundSession = useBackgroundSession
self.useOfflineMode = useOfflineMode

// Create the underlying HubClient with matching configuration
let host = URL(string: self.endpoint) ?? HubClient.defaultHost
if let token = self.hfToken {
self.client = HubClient(host: host, bearerToken: token)
} else {
self.client = HubClient(host: host, tokenProvider: .environment)
}

NetworkMonitor.shared.startMonitoring()
}

Expand Down Expand Up @@ -454,130 +470,106 @@ public extension HubApi {
}
}

struct HubFileDownloader {
let hub: HubApi
let repo: Repo
let revision: String
let repoDestination: URL
let repoMetadataDestination: URL
let relativeFilename: String
let hfToken: String?
let endpoint: String?
let backgroundSession: Bool

var source: URL {
// https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/tokenizer.json?download=true
var url = URL(string: endpoint ?? "https://huggingface.co")!
if repo.type != .models {
url = url.appending(path: repo.type.rawValue)
}
url = url.appending(path: repo.id)
url = url.appending(path: "resolve")
url = url.appending(component: revision) // Encode slashes (e.g., "pr/1" -> "pr%2F1")
url = url.appending(path: relativeFilename)
return url
}

var destination: URL {
repoDestination.appending(path: relativeFilename)
}

var metadataDestination: URL {
repoMetadataDestination.appending(path: relativeFilename + ".metadata")
/// Builds the source URL for downloading a file from the Hub.
private func sourceURL(for repo: Repo, revision: String, filename: String) -> URL {
var url = URL(string: endpoint)!
if repo.type != .models {
url = url.appending(path: repo.type.rawValue)
}
url = url.appending(path: repo.id)
url = url.appending(path: "resolve")
url = url.appending(component: revision) // Encode slashes (e.g., "pr/1" -> "pr%2F1")
url = url.appending(path: filename)
return url
}

var downloaded: Bool {
FileManager.default.fileExists(atPath: destination.path)
/// Downloads a single file using HubClient with metadata tracking for offline mode support.
private func downloadFile(
filename: String,
repo: Repo,
revision: String,
repoDestination: URL,
repoMetadataDestination: URL,
fileProgress: Progress,
progressHandler: @escaping (Progress) -> Void,
parentProgress: Progress
) async throws {
let destination = repoDestination.appending(path: filename)
let metadataDestination = repoMetadataDestination.appending(path: filename + ".metadata")
let source = sourceURL(for: repo, revision: revision, filename: filename)
let downloaded = FileManager.default.fileExists(atPath: destination.path)

let localMetadata = try readDownloadMetadata(metadataPath: metadataDestination)
let remoteMetadata = try await getFileMetadata(url: source)

let localCommitHash = localMetadata?.commitHash ?? ""
let remoteCommitHash = remoteMetadata.commitHash ?? ""

// Local file exists + metadata exists + commit_hash matches => skip download
if isValidHash(hash: remoteCommitHash, pattern: commitHashPattern),
downloaded,
localMetadata != nil,
localCommitHash == remoteCommitHash
{
return
}

/// We're using incomplete destination to prepare cache destination because incomplete files include lfs + non-lfs files (vs only lfs for metadata files)
func prepareCacheDestination(_ incompleteDestination: URL) throws {
let directoryURL = incompleteDestination.deletingLastPathComponent()
try FileManager.default.createDirectory(at: directoryURL, withIntermediateDirectories: true, attributes: nil)
if !FileManager.default.fileExists(atPath: incompleteDestination.path) {
try "".write(to: incompleteDestination, atomically: true, encoding: .utf8)
}
// From now on, etag, commit_hash, url and size are not empty
guard let remoteCommitHash = remoteMetadata.commitHash,
let remoteEtag = remoteMetadata.etag,
remoteMetadata.location != ""
else {
throw EnvironmentError.invalidMetadataError("File metadata must have been retrieved from server")
}

/// Downloads the file with progress tracking.
/// - Parameter progressHandler: Called with download progress (0.0-1.0) and speed in bytes/sec, if available.
/// - Returns: Local file URL (uses cached file if commit hash matches).
/// - Throws: ``EnvironmentError`` errors for file and metadata validation failures, ``Downloader.DownloadError`` errors during transfer, or ``CancellationError`` if the task is cancelled.
@discardableResult
func download(progressHandler: @escaping (Double, Double?) -> Void) async throws -> URL {
let localMetadata = try hub.readDownloadMetadata(metadataPath: metadataDestination)
let remoteMetadata = try await hub.getFileMetadata(url: source)

let localCommitHash = localMetadata?.commitHash ?? ""
let remoteCommitHash = remoteMetadata.commitHash ?? ""

// Local file exists + metadata exists + commit_hash matches => return file
if hub.isValidHash(hash: remoteCommitHash, pattern: hub.commitHashPattern), downloaded, localMetadata != nil,
localCommitHash == remoteCommitHash
{
return destination
// Local file exists => check if it's up-to-date
if downloaded {
// etag matches => update metadata and skip download
if localMetadata?.etag == remoteEtag {
try writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataPath: metadataDestination)
return
}

// From now on, etag, commit_hash, url and size are not empty
guard let remoteCommitHash = remoteMetadata.commitHash,
let remoteEtag = remoteMetadata.etag,
let remoteSize = remoteMetadata.size,
remoteMetadata.location != ""
else {
throw EnvironmentError.invalidMetadataError("File metadata must have been retrieved from server")
}

// Local file exists => check if it's up-to-date
if downloaded {
// etag matches => update metadata and return file
if localMetadata?.etag == remoteEtag {
try hub.writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataPath: metadataDestination)
return destination
}

// etag is a sha256
// => means it's an LFS file (large)
// => let's compute local hash and compare
// => if match, update metadata and return file
if hub.isValidHash(hash: remoteEtag, pattern: hub.sha256Pattern) {
let fileHash = try hub.computeFileHash(file: destination)
if fileHash == remoteEtag {
try hub.writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataPath: metadataDestination)
return destination
}
}
}

// Otherwise, let's download the file!
let incompleteDestination = repoMetadataDestination.appending(path: relativeFilename + ".\(remoteEtag).incomplete")
try prepareCacheDestination(incompleteDestination)

let downloader = Downloader(to: destination, incompleteDestination: incompleteDestination, inBackground: backgroundSession)

try await withTaskCancellationHandler {
let sub = await downloader.download(from: source, using: hfToken, expectedSize: remoteSize)
listen: for await state in sub {
switch state {
case .notStarted:
continue
case let .downloading(progress, speed):
progressHandler(progress, speed)
case let .failed(error):
throw error
case .completed:
break listen
}
}
} onCancel: {
Task {
await downloader.cancel()
// etag is a sha256 => means it's an LFS file (large)
// => compute local hash and compare
// => if match, update metadata and skip download
if isValidHash(hash: remoteEtag, pattern: sha256Pattern) {
let fileHash = try computeFileHash(file: destination)
if fileHash == remoteEtag {
try writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataPath: metadataDestination)
return
}
}
}

try hub.writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataPath: metadataDestination)
// Create a separate progress for the download (not linked to parent)
// to avoid issues with HubClient modifying totalUnitCount
let downloadProgress = Progress()

// If the file exists locally but metadata check failed, force a re-download
// to skip HubClient's cache (which may have the old/wrong version)
let forceDownload = downloaded

// Download the file using HubClient
_ = try await client.downloadFile(
at: filename,
from: repo.repoID,
to: destination,
kind: repo.type.repoKind,
revision: revision,
inBackground: useBackgroundSession,
forceDownload: forceDownload,
progress: downloadProgress
)

return destination
// Update parent progress with throughput info from the download
if let throughput = downloadProgress.userInfo[.throughputKey] as? Double {
parentProgress.setUserInfoObject(throughput, forKey: .throughputKey)
}
progressHandler(parentProgress)

// Write metadata for offline mode support
try writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataPath: metadataDestination)
}

@discardableResult
Expand Down Expand Up @@ -634,26 +626,18 @@ public extension HubApi {
let progress = Progress(totalUnitCount: Int64(filenames.count))
for filename in filenames {
let fileProgress = Progress(totalUnitCount: 100, parent: progress, pendingUnitCount: 1)
let downloader = HubFileDownloader(
hub: self,

try await downloadFile(
filename: filename,
repo: repo,
revision: revision,
repoDestination: repoDestination,
repoMetadataDestination: repoMetadataDestination,
relativeFilename: filename,
hfToken: hfToken,
endpoint: endpoint,
backgroundSession: useBackgroundSession
fileProgress: fileProgress,
progressHandler: progressHandler,
parentProgress: progress
)

try await downloader.download { fractionDownloaded, speed in
fileProgress.completedUnitCount = Int64(100 * fractionDownloaded)
if let speed {
fileProgress.setUserInfoObject(speed, forKey: .throughputKey)
progress.setUserInfoObject(speed, forKey: .throughputKey)
}
progressHandler(progress)
}
if Task.isCancelled {
return repoDestination
}
Expand Down
Loading
Loading