Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ import PackageDescription
let package = Package(
name: "swift-huggingface",
platforms: [
.macOS(.v14),
.macCatalyst(.v14),
.macOS(.v13),
.macCatalyst(.v16),
.iOS(.v16),
.watchOS(.v10),
.tvOS(.v17),
.watchOS(.v9),
.tvOS(.v16),
.visionOS(.v1),
],
products: [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
/// A manager for handling Hugging Face OAuth authentication.
///
/// - SeeAlso: [Hugging Face OAuth Documentation](https://huggingface.co/docs/api-inference/authentication)
@available(macOS 14.0, macCatalyst 17.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *)
@Observable
@MainActor
public final class HuggingFaceAuthenticationManager: Sendable {
Expand Down Expand Up @@ -184,6 +185,7 @@

// MARK: -

@available(macOS 14.0, macCatalyst 17.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *)
extension HuggingFaceAuthenticationManager {
/// OAuth scopes supported by HuggingFace
public enum Scope: Hashable, Sendable {
Expand Down Expand Up @@ -247,6 +249,7 @@
}
}

@available(macOS 14.0, macCatalyst 17.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *)
extension HuggingFaceAuthenticationManager.Scope: RawRepresentable {
public init(rawValue: String) {
switch rawValue {
Expand Down Expand Up @@ -299,6 +302,7 @@
}
}

@available(macOS 14.0, macCatalyst 17.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *)
extension HuggingFaceAuthenticationManager.Scope: Codable {
public init(from decoder: Decoder) throws {
let container = try decoder.singleValueContainer()
Expand All @@ -312,12 +316,14 @@
}
}

@available(macOS 14.0, macCatalyst 17.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *)
extension HuggingFaceAuthenticationManager.Scope: ExpressibleByStringLiteral {
public init(stringLiteral value: String) {
self = Self(rawValue: value)
}
}

@available(macOS 14.0, macCatalyst 17.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *)
extension Set<HuggingFaceAuthenticationManager.Scope> {
public static var basic: Self { [.openid, .profile, .email] }
public static var readAccess: Self { [.openid, .profile, .email, .readRepos] }
Expand All @@ -329,6 +335,7 @@

// MARK: -

@available(macOS 14.0, macCatalyst 17.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *)
extension HuggingFaceAuthenticationManager {
/// A mechanism for storing and retrieving OAuth tokens.
public struct TokenStorage: Sendable {
Expand Down
54 changes: 45 additions & 9 deletions Sources/HuggingFace/Shared/TokenProvider.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,15 @@ import Foundation
///
/// ## OAuth Authentication
///
/// For OAuth-based authentication, use the `.oauth` case with an authentication manager:
/// For OAuth-based authentication (requires macOS 14+, iOS 17+), use the `.oauth(manager:)` factory method:
///
/// ```swift
/// let authManager = HuggingFaceAuthenticationManager(
/// let authManager = try HuggingFaceAuthenticationManager(
/// clientID: "your-client-id",
/// redirectURL: URL(string: "myapp://oauth")!
/// redirectURL: URL(string: "myapp://oauth")!,
/// scope: .basic,
/// keychainService: "com.example.app",
/// keychainAccount: "huggingface"
/// )
/// let client = HubClient(tokenProvider: .oauth(manager: authManager))
/// ```
Expand Down Expand Up @@ -117,13 +120,13 @@ public indirect enum TokenProvider: Sendable {
/// the same token detection logic as the Hugging Face CLI.
case environment

/// An OAuth token provider that uses HuggingFaceAuthenticationManager.
/// An OAuth token provider that retrieves tokens asynchronously.
///
/// Use this case for OAuth-based authentication flows. The authentication
/// manager handles the complete OAuth flow including token refresh.
/// Use this case for OAuth-based authentication flows. Create instances using
/// the `TokenProvider.oauth(manager:)` factory method when using `HuggingFaceAuthenticationManager`.
///
/// - Parameter manager: The OAuth authentication manager that handles token retrieval and refresh.
case oauth(manager: HuggingFaceAuthenticationManager)
/// - Parameter getToken: A closure that retrieves a valid OAuth token.
case oauth(getToken: @Sendable () async throws -> String)

/// A composite token provider that tries multiple providers in order.
///
Expand Down Expand Up @@ -185,7 +188,7 @@ public indirect enum TokenProvider: Sendable {
case .environment:
return try getTokenFromEnvironment()

case .oauth(let manager):
case .oauth:
fatalError(
"OAuth token provider requires async context. Use getToken() in an async context or switch to a synchronous provider."
)
Comment on lines +191 to 194
Copy link

Copilot AI Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The synchronous getToken() method handles the .oauth case by throwing a fatal error, but there doesn't appear to be an async overload func getToken() async throws -> String? that properly implements the OAuth case by calling the closure. Without the async overload, OAuth functionality will be broken as the closure cannot be invoked from synchronous contexts. You need to add an async version of getToken() that handles .oauth(let getToken): by calling try await getToken().

Copilot uses AI. Check for mistakes.
Expand All @@ -209,6 +212,39 @@ public indirect enum TokenProvider: Sendable {
}
}

// MARK: - OAuth Factory

#if canImport(AuthenticationServices)
import Observation

extension TokenProvider {
/// Creates an OAuth token provider using HuggingFaceAuthenticationManager.
///
/// Use this factory method for OAuth-based authentication flows. The authentication
/// manager handles the complete OAuth flow including token refresh.
///
/// ```swift
/// let authManager = try HuggingFaceAuthenticationManager(
/// clientID: "your-client-id",
/// redirectURL: URL(string: "myapp://oauth")!,
/// scope: .basic,
/// keychainService: "com.example.app",
/// keychainAccount: "huggingface"
/// )
/// let client = HubClient(tokenProvider: .oauth(manager: authManager))
/// ```
///
/// - Parameter manager: The OAuth authentication manager that handles token retrieval and refresh.
/// - Returns: A token provider that retrieves tokens from the authentication manager.
@available(macOS 14.0, macCatalyst 17.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *)
public static func oauth(manager: HuggingFaceAuthenticationManager) -> TokenProvider {
return .oauth(getToken: { @MainActor in
try await manager.getValidToken()
})
}
}
#endif

// MARK: - ExpressibleByStringLiteral & ExpressibleByStringInterpolation

extension TokenProvider: ExpressibleByStringLiteral, ExpressibleByStringInterpolation {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,100 +89,100 @@ import Testing
}
}
}
#endif // swift(>=6.1)

@Suite("Hugging Face OAuth Scope Tests", .serialized)
struct HuggingFaceScopeTests {
typealias Scope = HuggingFaceAuthenticationManager.Scope

@Test("OAuth Scope sets work correctly")
func testScopeSets() {
// Test basic scope set
let basicScopes = Set<Scope>.basic
#expect(basicScopes.contains(.openid))
#expect(basicScopes.contains(.profile))
#expect(basicScopes.contains(.email))

// Test read access scope set
let readScopes = Set<Scope>.readAccess
#expect(readScopes.contains(.readRepos))

// Test write access scope set
let writeScopes = Set<Scope>.writeAccess
#expect(writeScopes.contains(.writeRepos))

// Test full access scope set
let fullScopes = Set<Scope>.fullAccess
#expect(fullScopes.contains(.manageRepos))
#expect(fullScopes.contains(.inferenceAPI))

// Test inference only scope set
let inferenceScopes = Set<Scope>.inferenceOnly
#expect(inferenceScopes.contains(.openid))
#expect(inferenceScopes.contains(.inferenceAPI))

// Test discussions scope set
let discussionScopes = Set<Scope>.discussions
#expect(discussionScopes.contains(.writeDiscussions))
}
@Suite("Hugging Face OAuth Scope Tests", .serialized)
struct HuggingFaceScopeTests {
typealias Scope = HuggingFaceAuthenticationManager.Scope

@Test("OAuth Scope sets work correctly")
func testScopeSets() {
// Test basic scope set
let basicScopes = Set<Scope>.basic
#expect(basicScopes.contains(.openid))
#expect(basicScopes.contains(.profile))
#expect(basicScopes.contains(.email))

// Test read access scope set
let readScopes = Set<Scope>.readAccess
#expect(readScopes.contains(.readRepos))

// Test write access scope set
let writeScopes = Set<Scope>.writeAccess
#expect(writeScopes.contains(.writeRepos))

// Test full access scope set
let fullScopes = Set<Scope>.fullAccess
#expect(fullScopes.contains(.manageRepos))
#expect(fullScopes.contains(.inferenceAPI))

// Test inference only scope set
let inferenceScopes = Set<Scope>.inferenceOnly
#expect(inferenceScopes.contains(.openid))
#expect(inferenceScopes.contains(.inferenceAPI))

// Test discussions scope set
let discussionScopes = Set<Scope>.discussions
#expect(discussionScopes.contains(.writeDiscussions))
}

@Test("OAuth Scope raw values are correct")
func testScopeRawValues() {
#expect(Scope.openid.rawValue == "openid")
#expect(Scope.profile.rawValue == "profile")
#expect(Scope.email.rawValue == "email")
#expect(Scope.readBilling.rawValue == "read-billing")
#expect(Scope.readRepos.rawValue == "read-repos")
#expect(Scope.writeRepos.rawValue == "write-repos")
#expect(Scope.manageRepos.rawValue == "manage-repos")
#expect(Scope.inferenceAPI.rawValue == "inference-api")
#expect(Scope.writeDiscussions.rawValue == "write-discussions")

// Test custom scope
let customScope = Scope.other("custom-scope")
#expect(customScope.rawValue == "custom-scope")
}
@Test("OAuth Scope raw values are correct")
func testScopeRawValues() {
#expect(Scope.openid.rawValue == "openid")
#expect(Scope.profile.rawValue == "profile")
#expect(Scope.email.rawValue == "email")
#expect(Scope.readBilling.rawValue == "read-billing")
#expect(Scope.readRepos.rawValue == "read-repos")
#expect(Scope.writeRepos.rawValue == "write-repos")
#expect(Scope.manageRepos.rawValue == "manage-repos")
#expect(Scope.inferenceAPI.rawValue == "inference-api")
#expect(Scope.writeDiscussions.rawValue == "write-discussions")

// Test custom scope
let customScope = Scope.other("custom-scope")
#expect(customScope.rawValue == "custom-scope")
}

@Test("OAuth Scope initialization from raw values")
func testScopeInitializationFromRawValue() {
#expect(Scope(rawValue: "openid") == .openid)
#expect(Scope(rawValue: "profile") == .profile)
#expect(Scope(rawValue: "email") == .email)
#expect(Scope(rawValue: "read-billing") == .readBilling)
#expect(Scope(rawValue: "read-repos") == .readRepos)
#expect(Scope(rawValue: "write-repos") == .writeRepos)
#expect(Scope(rawValue: "manage-repos") == .manageRepos)
#expect(Scope(rawValue: "inference-api") == .inferenceAPI)
#expect(Scope(rawValue: "write-discussions") == .writeDiscussions)

// Test custom scope
let customScope = Scope(rawValue: "custom-scope")
#expect(customScope == .other("custom-scope"))
}
@Test("OAuth Scope initialization from raw values")
func testScopeInitializationFromRawValue() {
#expect(Scope(rawValue: "openid") == .openid)
#expect(Scope(rawValue: "profile") == .profile)
#expect(Scope(rawValue: "email") == .email)
#expect(Scope(rawValue: "read-billing") == .readBilling)
#expect(Scope(rawValue: "read-repos") == .readRepos)
#expect(Scope(rawValue: "write-repos") == .writeRepos)
#expect(Scope(rawValue: "manage-repos") == .manageRepos)
#expect(Scope(rawValue: "inference-api") == .inferenceAPI)
#expect(Scope(rawValue: "write-discussions") == .writeDiscussions)

// Test custom scope
let customScope = Scope(rawValue: "custom-scope")
#expect(customScope == .other("custom-scope"))
}

@Test("OAuth Scope descriptions are correct")
func testScopeDescriptions() {
#expect(Scope.openid.description.contains("ID token"))
#expect(Scope.profile.description.contains("profile information"))
#expect(Scope.email.description.contains("email address"))
#expect(Scope.readBilling.description.contains("payment method"))
#expect(Scope.readRepos.description.contains("read access"))
#expect(Scope.writeRepos.description.contains("write/read access"))
#expect(Scope.manageRepos.description.contains("full access"))
#expect(Scope.inferenceAPI.description.contains("Inference API"))
#expect(Scope.writeDiscussions.description.contains("discussions"))

// Test custom scope description
let customScope = Scope.other("custom-scope")
#expect(customScope.description == "custom-scope")
}
@Test("OAuth Scope descriptions are correct")
func testScopeDescriptions() {
#expect(Scope.openid.description.contains("ID token"))
#expect(Scope.profile.description.contains("profile information"))
#expect(Scope.email.description.contains("email address"))
#expect(Scope.readBilling.description.contains("payment method"))
#expect(Scope.readRepos.description.contains("read access"))
#expect(Scope.writeRepos.description.contains("write/read access"))
#expect(Scope.manageRepos.description.contains("full access"))
#expect(Scope.inferenceAPI.description.contains("Inference API"))
#expect(Scope.writeDiscussions.description.contains("discussions"))

// Test custom scope description
let customScope = Scope.other("custom-scope")
#expect(customScope.description == "custom-scope")
}

@Test("OAuth Scope string literal support")
func testScopeStringLiteral() {
let scope: Scope = "openid"
#expect(scope == .openid)
@Test("OAuth Scope string literal support")
func testScopeStringLiteral() {
let scope: Scope = "openid"
#expect(scope == .openid)

let customScope: Scope = "custom-scope"
#expect(customScope == .other("custom-scope"))
let customScope: Scope = "custom-scope"
#expect(customScope == .other("custom-scope"))
}
}
}
#endif // swift(>=6.1)