diff --git a/README.md b/README.md index d2c4a57..8520c3f 100644 --- a/README.md +++ b/README.md @@ -378,6 +378,122 @@ _ = try await client.updateDiscussionStatus( ) ``` +#### File Operations + +```swift +// List files in a repository +let files = try await client.listFiles( + in: "facebook/bart-large", + kind: .model, + revision: "main", + recursive: true +) + +for file in files { + if file.type == .file { + print("\(file.path) - \(file.size ?? 0) bytes") + } +} + +// Check if a file exists +let exists = await client.fileExists( + at: "README.md", + in: "facebook/bart-large" +) + +// Get file information +let file = try await client.getFile( + at: "pytorch_model.bin", + in: "facebook/bart-large" +) +print("File size: \(file.size ?? 0)") +print("Is LFS: \(file.isLFS)") + +// Download file data +let data = try await client.downloadContentsOfFile( + at: "config.json", + from: "openai-community/gpt2" +) +let config = try JSONDecoder().decode(ModelConfig.self, from: data) + +// Download file to disk +let destination = FileManager.default.temporaryDirectory + .appendingPathComponent("model.safetensors") + +let fileURL = try await client.downloadFile( + at: "model.safetensors", + from: "openai-community/gpt2", + to: destination +) + +// Download with progress tracking +let progress = Progress(totalUnitCount: 0) +Task { + for await _ in progress.values(forKeyPath: \.fractionCompleted) { + print("Download progress: \(progress.fractionCompleted * 100)%") + } +} + +let fileURL = try await client.downloadFile( + at: "pytorch_model.bin", + from: "facebook/bart-large", + to: destination, + progress: progress +) + +// Resume a download +let resumeData: Data = // ... from previous download +let fileURL = try await client.resumeDownloadFile( + resumeData: resumeData, + to: destination, + progress: progress +) + +// Upload a file +let result = try await client.uploadFile( + URL(fileURLWithPath: "/path/to/local/file.csv"), + to: "data/new_dataset.csv", + in: "username/my-dataset", + kind: .dataset, + branch: "main", + message: "Add new dataset" +) +print("Uploaded to: \(result.path)") + +// Upload multiple files in a batch +let results = try await client.uploadFiles( + [ + "README.md": .path("/path/to/readme.md"), + "data.json": .path("/path/to/data.json"), + ], + to: "username/my-repo", + message: "Initial commit", + maxConcurrent: 3 +) + +// Or build a batch programmatically +var batch = FileBatch() +batch["config.json"] = .path("/path/to/config.json") +batch["model.safetensors"] = .url( + URL(fileURLWithPath: "/path/to/model.safetensors"), + mimeType: "application/octet-stream" +) + +// Delete a file +try await client.deleteFile( + at: "old_file.txt", + from: "username/my-repo", + message: "Remove old file" +) + +// Delete multiple files +try await client.deleteFiles( + at: ["file1.txt", "file2.txt", "old_dir/file3.txt"], + from: "username/my-repo", + message: "Cleanup old files" +) +``` + #### User Access Management ```swift diff --git a/Sources/HuggingFace/Hub/File.swift b/Sources/HuggingFace/Hub/File.swift new file mode 100644 index 0000000..e854241 --- /dev/null +++ b/Sources/HuggingFace/Hub/File.swift @@ -0,0 +1,137 @@ +import Foundation + +/// Information about a file in a repository. +public struct File: Hashable, Codable, Sendable { + /// A Boolean value indicating whether the file exists in the repository. + public let exists: Bool + + /// The size of the file in bytes. + public let size: Int64? + + /// The entity tag (ETag) for the file, used for caching and change detection. + public let etag: String? + + /// The Git revision (commit SHA) at which this file information was retrieved. + public let revision: String? + + /// A Boolean value indicating whether the file is stored using Git Large File Storage (LFS). + public let isLFS: Bool + + init( + exists: Bool, + size: Int64? = nil, + etag: String? = nil, + revision: String? = nil, + isLFS: Bool = false + ) { + self.exists = exists + self.size = size + self.etag = etag + self.revision = revision + self.isLFS = isLFS + } +} + +// MARK: - + +/// A collection of files to upload in a batch operation. +/// +/// Use `FileBatch` to prepare multiple files for uploading to a repository in a single operation. +/// You can add files using subscript notation or dictionary literal syntax. +/// +/// ```swift +/// var batch = FileBatch() +/// batch["config.json"] = .path("/path/to/config.json") +/// batch["model.safetensors"] = .url( +/// URL(fileURLWithPath: "/path/to/model.safetensors"), +/// mimeType: "application/octet-stream" +/// ) +/// let _ = try await client.uploadFiles(batch, to: "username/my-repo", message: "Initial commit") +/// ``` +/// - SeeAlso: `HubClient.uploadFiles(_:to:kind:branch:message:maxConcurrent:)` +public struct FileBatch: Hashable, Codable, Sendable { + /// An entry representing a file to upload. + public struct Entry: Hashable, Codable, Sendable { + /// The file URL pointing to the local file to upload. + public var url: URL + + /// The MIME type of the file. + public var mimeType: String? + + private init(url: URL, mimeType: String? = nil) { + self.url = url + self.mimeType = mimeType + } + + /// Creates a file entry from a file system path. + /// - Parameters: + /// - path: The file system path to the local file. + /// - mimeType: The MIME type of the file. If not provided, the MIME type is inferred from the file extension. + /// - Returns: A file entry for the specified path. + public static func path(_ path: String, mimeType: String? = nil) -> Self { + return Self(url: URL(fileURLWithPath: path), mimeType: mimeType) + } + + /// Creates a file entry from a URL. + /// - Parameters: + /// - url: The file URL. Must be a file URL (e.g., `file:///path/to/file`), not a remote URL. + /// - mimeType: Optional MIME type for the file. + /// - Returns: A file entry, or `nil` if the URL is not a file URL. + /// - Note: Only file URLs are accepted because this API requires local file access for upload. + /// Remote URLs (http, https, etc.) are not supported and will return `nil`. + public static func url(_ url: URL, mimeType: String? = nil) -> Self? { + guard url.isFileURL else { + return nil + } + return Self(url: url, mimeType: mimeType) + } + } + + private var entries: [String: Entry] + + /// Creates an empty file batch. + public init() { + self.entries = [:] + } + + /// Creates a file batch with the specified entries. + /// - Parameter entries: A dictionary mapping repository paths to file entries. + public init(_ entries: [String: Entry]) { + self.entries = entries + } + + /// Accesses the file entry for the specified repository path. + /// - Parameter path: The path in the repository where the file will be uploaded. + /// - Returns: The file entry for the specified path, or `nil` if no entry exists. + public subscript(path: String) -> Entry? { + get { + return entries[path] + } + set { + entries[path] = newValue + } + } +} + +// MARK: - Collection + +extension FileBatch: Swift.Collection { + public typealias Index = Dictionary.Index + + public var startIndex: Index { entries.startIndex } + public var endIndex: Index { entries.endIndex } + public func index(after i: Index) -> Index { entries.index(after: i) } + public subscript(position: Index) -> (key: String, value: Entry) { entries[position] } + + public func makeIterator() -> Dictionary.Iterator { + return entries.makeIterator() + } +} + +// MARK: - ExpressibleByDictionaryLiteral + +extension FileBatch: ExpressibleByDictionaryLiteral { + public init(dictionaryLiteral elements: (String, Entry)...) { + self.init(Dictionary(uniqueKeysWithValues: elements)) + } +} diff --git a/Sources/HuggingFace/Hub/HubClient+Files.swift b/Sources/HuggingFace/Hub/HubClient+Files.swift new file mode 100644 index 0000000..e67aecb --- /dev/null +++ b/Sources/HuggingFace/Hub/HubClient+Files.swift @@ -0,0 +1,471 @@ +import Foundation +import UniformTypeIdentifiers + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + +// MARK: - Upload Operations + +public extension HubClient { + /// Upload a single file to a repository + /// - Parameters: + /// - filePath: Local file path to upload + /// - repoPath: Destination path in repository + /// - repo: Repository identifier + /// - kind: Kind of repository (model, dataset, or space) + /// - branch: Target branch (default: "main") + /// - message: Commit message + /// - Returns: Tuple of (path, commit) where commit may be nil + func uploadFile( + _ filePath: String, + to repoPath: String, + in repo: Repo.ID, + kind: Repo.Kind = .model, + branch: String = "main", + message: String? = nil + ) async throws -> (path: String, commit: String?) { + let fileURL = URL(fileURLWithPath: filePath) + return try await uploadFile(fileURL, to: repoPath, in: repo, kind: kind, branch: branch, message: message) + } + + /// Upload a single file to a repository + /// - Parameters: + /// - fileURL: Local file URL to upload + /// - path: Destination path in repository + /// - repo: Repository identifier + /// - kind: Kind of repository (model, dataset, or space) + /// - branch: Target branch (default: "main") + /// - message: Commit message + /// - Returns: Tuple of (path, commit) where commit may be nil + func uploadFile( + _ fileURL: URL, + to repoPath: String, + in repo: Repo.ID, + kind: Repo.Kind = .model, + branch: String = "main", + message: String? = nil + ) async throws -> (path: String, commit: String?) { + let urlPath = "/api/\(kind.pluralized)/\(repo)/upload/\(branch)" + var request = try await httpClient.createRequest(.post, urlPath) + + let boundary = "----hf-\(UUID().uuidString)" + request.setValue( + "multipart/form-data; boundary=\(boundary)", + forHTTPHeaderField: "Content-Type" + ) + + // Determine file size for streaming decision + let fileSize = try fileURL.resourceValues(forKeys: [.fileSizeKey]).fileSize ?? 0 + let threshold = 10 * 1024 * 1024 // 10MB + let shouldStream = fileSize >= threshold + + let mimeType = fileURL.mimeType + + if shouldStream { + // Large file: stream from disk using URLSession.uploadTask + request.setValue("100-continue", forHTTPHeaderField: "Expect") + let tempFile = try MultipartBuilder(boundary: boundary) + .addText(name: "path", value: repoPath) + .addOptionalText(name: "message", value: message) + .addFileStreamed(name: "file", fileURL: fileURL, mimeType: mimeType) + .buildToTempFile() + defer { try? FileManager.default.removeItem(at: tempFile) } + + let (data, response) = try await session.upload(for: request, fromFile: tempFile) + _ = try httpClient.validateResponse(response, data: data) + + if data.isEmpty { + return (path: repoPath, commit: nil) + } + + let result = try JSONDecoder().decode(UploadResponse.self, from: data) + return (path: result.path, commit: result.commit) + } else { + // Small file: build in memory + let body = try MultipartBuilder(boundary: boundary) + .addText(name: "path", value: repoPath) + .addOptionalText(name: "message", value: message) + .addFile(name: "file", fileURL: fileURL, mimeType: mimeType) + .buildInMemory() + + let (data, response) = try await session.upload(for: request, from: body) + _ = try httpClient.validateResponse(response, data: data) + + if data.isEmpty { + return (path: repoPath, commit: nil) + } + + let result = try JSONDecoder().decode(UploadResponse.self, from: data) + return (path: result.path, commit: result.commit) + } + } + + /// Upload multiple files to a repository + /// - Parameters: + /// - batch: Batch of files to upload (path: URL dictionary) + /// - repo: Repository identifier + /// - kind: Kind of repository + /// - branch: Target branch + /// - message: Commit message + /// - maxConcurrent: Maximum concurrent uploads + /// - Returns: Array of (path, commit) tuples + func uploadFiles( + _ batch: FileBatch, + to repo: Repo.ID, + kind: Repo.Kind = .model, + branch: String = "main", + message: String, + maxConcurrent: Int = 3 + ) async throws -> [(path: String, commit: String?)] { + let entries = Array(batch) + + return try await withThrowingTaskGroup( + of: (Int, (path: String, commit: String?)).self + ) { group in + var results: [(path: String, commit: String?)?] = Array( + repeating: nil, + count: entries.count + ) + var activeCount = 0 + + for (index, (path, entry)) in entries.enumerated() { + // Limit concurrency + while activeCount >= maxConcurrent { + if let (idx, result) = try await group.next() { + results[idx] = result + activeCount -= 1 + } + } + + group.addTask { + let result = try await self.uploadFile( + entry.url, + to: path, + in: repo, + kind: kind, + branch: branch, + message: message + ) + return (index, result) + } + activeCount += 1 + } + + // Collect remaining results + for try await (index, result) in group { + results[index] = result + } + + return results.compactMap { $0 } + } + } +} + +// MARK: - Download Operations + +public extension HubClient { + /// Download file data using URLSession.dataTask + /// - Parameters: + /// - repoPath: Path to file in repository + /// - repo: Repository identifier + /// - kind: Kind of repository + /// - revision: Git revision (branch, tag, or commit) + /// - useRaw: Use raw endpoint instead of resolve + /// - cachePolicy: Cache policy for the request + /// - Returns: File data + func downloadContentsOfFile( + at repoPath: String, + from repo: Repo.ID, + kind: Repo.Kind = .model, + revision: String = "main", + useRaw: Bool = false, + cachePolicy: URLRequest.CachePolicy = .useProtocolCachePolicy + ) async throws -> Data { + let endpoint = useRaw ? "raw" : "resolve" + let urlPath = "/\(repo)/\(endpoint)/\(revision)/\(repoPath)" + var request = try await httpClient.createRequest(.get, urlPath) + request.cachePolicy = cachePolicy + + let (data, response) = try await session.data(for: request) + _ = try httpClient.validateResponse(response, data: data) + + return data + } + + /// Download file to a destination URL using URLSession.downloadTask + /// - Parameters: + /// - repoPath: Path to file in repository + /// - repo: Repository identifier + /// - destination: Destination URL for downloaded file + /// - kind: Kind of repository + /// - revision: Git revision + /// - useRaw: Use raw endpoint + /// - cachePolicy: Cache policy for the request + /// - progress: Optional Progress object to track download progress + /// - Returns: Final destination URL + func downloadFile( + at repoPath: String, + from repo: Repo.ID, + to destination: URL, + kind: Repo.Kind = .model, + revision: String = "main", + useRaw: Bool = false, + cachePolicy: URLRequest.CachePolicy = .useProtocolCachePolicy, + progress: Progress? = nil + ) async throws -> URL { + let endpoint = useRaw ? "raw" : "resolve" + let urlPath = "/\(repo)/\(endpoint)/\(revision)/\(repoPath)" + var request = try await httpClient.createRequest(.get, urlPath) + request.cachePolicy = cachePolicy + + let (tempURL, response) = try await session.download( + for: request, + delegate: progress.map { DownloadProgressDelegate(progress: $0) } + ) + _ = try httpClient.validateResponse(response, data: nil) + + // Move from temporary location to final destination + try? FileManager.default.removeItem(at: destination) + try FileManager.default.moveItem(at: tempURL, to: destination) + + return destination + } + + /// Download file with resume capability + /// - Parameters: + /// - resumeData: Resume data from a previous download attempt + /// - destination: Destination URL for downloaded file + /// - progress: Optional Progress object to track download progress + /// - Returns: Final destination URL + func resumeDownloadFile( + resumeData: Data, + to destination: URL, + progress: Progress? = nil + ) async throws -> URL { + let (tempURL, response) = try await session.download( + resumeFrom: resumeData, + delegate: progress.map { DownloadProgressDelegate(progress: $0) } + ) + _ = try httpClient.validateResponse(response, data: nil) + + // Move from temporary location to final destination + try? FileManager.default.removeItem(at: destination) + try FileManager.default.moveItem(at: tempURL, to: destination) + + return destination + } + + /// Download file to a destination URL (convenience method without progress tracking) + /// - Parameters: + /// - repoPath: Path to file in repository + /// - repo: Repository identifier + /// - destination: Destination URL for downloaded file + /// - kind: Kind of repository + /// - revision: Git revision + /// - useRaw: Use raw endpoint + /// - cachePolicy: Cache policy for the request + /// - Returns: Final destination URL + func downloadContentsOfFile( + at repoPath: String, + from repo: Repo.ID, + to destination: URL, + kind: Repo.Kind = .model, + revision: String = "main", + useRaw: Bool = false, + cachePolicy: URLRequest.CachePolicy = .useProtocolCachePolicy + ) async throws -> URL { + return try await downloadFile( + at: repoPath, + from: repo, + to: destination, + kind: kind, + revision: revision, + useRaw: useRaw, + cachePolicy: cachePolicy, + progress: nil + ) + } +} + +// MARK: - Progress Delegate + +private final class DownloadProgressDelegate: NSObject, URLSessionDownloadDelegate, @unchecked Sendable { + private let progress: Progress + + init(progress: Progress) { + self.progress = progress + } + + func urlSession( + _ session: URLSession, + downloadTask: URLSessionDownloadTask, + didWriteData bytesWritten: Int64, + totalBytesWritten: Int64, + totalBytesExpectedToWrite: Int64 + ) { + progress.totalUnitCount = totalBytesExpectedToWrite + progress.completedUnitCount = totalBytesWritten + } + + func urlSession( + _ session: URLSession, + downloadTask: URLSessionDownloadTask, + didFinishDownloadingTo location: URL + ) { + // The actual file handling is done in the async/await layer + } +} + +// MARK: - Delete Operations + +public extension HubClient { + /// Delete a file from a repository + /// - Parameters: + /// - repoPath: Path to file to delete + /// - repo: Repository identifier + /// - kind: Kind of repository + /// - branch: Target branch + /// - message: Commit message + func deleteFile( + at repoPath: String, + from repo: Repo.ID, + kind: Repo.Kind = .model, + branch: String = "main", + message: String + ) async throws { + try await deleteFiles(at: [repoPath], from: repo, kind: kind, branch: branch, message: message) + } + + /// Delete multiple files from a repository + /// - Parameters: + /// - paths: Paths to files to delete + /// - repo: Repository identifier + /// - kind: Kind of repository + /// - branch: Target branch + /// - message: Commit message + func deleteFiles( + at repoPaths: [String], + from repo: Repo.ID, + kind: Repo.Kind = .model, + branch: String = "main", + message: String + ) async throws { + let urlPath = "/api/\(kind.pluralized)/\(repo)/commit/\(branch)" + let operations = repoPaths.map { path in + Value.object(["op": .string("delete"), "path": .string(path)]) + } + let params: [String: Value] = [ + "title": .string(message), + "operations": .array(operations), + ] + + let _: Bool = try await httpClient.fetch(.post, urlPath, params: params) + } +} + +// MARK: - Query Operations + +public extension HubClient { + /// Check if a file exists in a repository + /// - Parameters: + /// - repoPath: Path to file + /// - repo: Repository identifier + /// - kind: Kind of repository + /// - revision: Git revision + /// - Returns: True if file exists + func fileExists( + at repoPath: String, + in repo: Repo.ID, + kind: Repo.Kind = .model, + revision: String = "main" + ) async -> Bool { + do { + let info = try await getFile(at: repoPath, in: repo, kind: kind, revision: revision) + return info.exists + } catch { + return false + } + } + + /// List files in a repository + /// - Parameters: + /// - repo: Repository identifier + /// - kind: Kind of repository + /// - revision: Git revision + /// - recursive: List files recursively + /// - Returns: Array of tree entries + func listFiles( + in repo: Repo.ID, + kind: Repo.Kind = .model, + revision: String = "main", + recursive: Bool = true + ) async throws -> [Git.TreeEntry] { + let urlPath = "/api/\(kind.pluralized)/\(repo)/tree/\(revision)" + let params: [String: Value]? = recursive ? ["recursive": .bool(true)] : nil + + return try await httpClient.fetch(.get, urlPath, params: params) + } + + /// Get file information + /// - Parameters: + /// - repoPath: Path to file + /// - repo: Repository identifier + /// - kind: Kind of repository + /// - revision: Git revision + /// - Returns: File information + func getFile( + at repoPath: String, + in repo: Repo.ID, + kind: Repo.Kind = .model, + revision: String = "main" + ) async throws -> File { + let urlPath = "/\(repo)/resolve/\(revision)/\(repoPath)" + var request = try await httpClient.createRequest(.head, urlPath) + request.setValue("bytes=0-0", forHTTPHeaderField: "Range") + + do { + let (_, response) = try await session.data(for: request) + guard let httpResponse = response as? HTTPURLResponse else { + return File(exists: false) + } + + let exists = httpResponse.statusCode == 200 || httpResponse.statusCode == 206 + let size = httpResponse.value(forHTTPHeaderField: "Content-Length") + .flatMap { Int64($0) } + let etag = httpResponse.value(forHTTPHeaderField: "ETag") + let revision = httpResponse.value(forHTTPHeaderField: "X-Repo-Commit") + let isLFS = + httpResponse.value(forHTTPHeaderField: "X-Linked-Size") != nil + || httpResponse.value(forHTTPHeaderField: "Link")?.contains("lfs") == true + + return File( + exists: exists, + size: size, + etag: etag, + revision: revision, + isLFS: isLFS + ) + } catch { + return File(exists: false) + } + } +} + +// MARK: - + +private struct UploadResponse: Codable { + let path: String + let commit: String? +} + +// MARK: - + +private extension URL { + var mimeType: String? { + guard let uti = UTType(filenameExtension: pathExtension) else { + return nil + } + return uti.preferredMIMEType + } +} diff --git a/Sources/HuggingFace/Shared/HTTPClient.swift b/Sources/HuggingFace/Shared/HTTPClient.swift index edecba5..66801f2 100644 --- a/Sources/HuggingFace/Shared/HTTPClient.swift +++ b/Sources/HuggingFace/Shared/HTTPClient.swift @@ -12,6 +12,7 @@ enum HTTPMethod: String, Hashable, Sendable { case put = "PUT" case delete = "DELETE" case patch = "PATCH" + case head = "HEAD" } /// Base HTTP client with common functionality for all Hugging Face API clients. @@ -182,9 +183,7 @@ final class HTTPClient: @unchecked Sendable { return data } - // MARK: - Private Methods - - private func createRequest( + func createRequest( _ method: HTTPMethod, _ path: String, params: [String: Value]? = nil, @@ -195,7 +194,7 @@ final class HTTPClient: @unchecked Sendable { var httpBody: Data? = nil switch method { - case .get: + case .get, .head: if let params { var queryItems: [URLQueryItem] = [] for (key, value) in params { @@ -250,7 +249,7 @@ final class HTTPClient: @unchecked Sendable { return request } - private func validateResponse(_ response: URLResponse, data: Data? = nil) throws -> HTTPURLResponse { + func validateResponse(_ response: URLResponse, data: Data? = nil) throws -> HTTPURLResponse { guard let httpResponse = response as? HTTPURLResponse else { throw HTTPClientError.unexpectedError("Invalid response from server: \(response)") } diff --git a/Sources/HuggingFace/Shared/MultipartBuilder.swift b/Sources/HuggingFace/Shared/MultipartBuilder.swift new file mode 100644 index 0000000..2524d62 --- /dev/null +++ b/Sources/HuggingFace/Shared/MultipartBuilder.swift @@ -0,0 +1,199 @@ +import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + +/// Builder for constructing multipart/form-data payloads +struct MultipartBuilder { + private let boundary: String + private var parts: [Part] = [] + + enum Part { + case text(name: String, value: String) + case file(name: String, fileURL: URL, mimeType: String?) + } + + init(boundary: String) { + self.boundary = boundary + } + + /// Add a text field to the multipart payload + func addText(name: String, value: String) -> MultipartBuilder { + var builder = self + builder.parts.append(.text(name: name, value: value)) + return builder + } + + /// Add an optional text field (only adds if value is non-nil) + func addOptionalText(name: String, value: String?) -> MultipartBuilder { + if let value = value { + return addText(name: name, value: value) + } + return self + } + + /// Add a file field to the multipart payload (loads entire file into memory) + func addFile(name: String, fileURL: URL, mimeType: String?) -> MultipartBuilder { + var builder = self + builder.parts.append(.file(name: name, fileURL: fileURL, mimeType: mimeType)) + return builder + } + + /// Add a file field to the multipart payload (for streamed output) + func addFileStreamed(name: String, fileURL: URL, mimeType: String?) -> MultipartBuilder { + // For now, same as addFile - streaming happens in buildToTempFile + var builder = self + builder.parts.append(.file(name: name, fileURL: fileURL, mimeType: mimeType)) + return builder + } + + /// Build the multipart payload in memory + /// - Note: Only suitable for small payloads. Use `buildToTempFile()` for large files. + /// - Returns: Complete multipart body data + func buildInMemory() throws -> Data { + var body = Data() + + for part in parts { + switch part { + case .text(let name, let value): + body.append(textPart(name: name, value: value)) + case .file(let name, let fileURL, let mimeType): + body.append(try filePart(name: name, fileURL: fileURL, mimeType: mimeType)) + } + } + + body.append(closingBoundary()) + return body + } + + /// Build the multipart payload to a temporary file + /// - Note: Streams file contents to avoid memory pressure on large files + /// - Returns: URL of temporary file containing multipart body + func buildToTempFile() throws -> URL { + let tempDir = FileManager.default.temporaryDirectory + let tempFile = tempDir.appendingPathComponent(UUID().uuidString) + + // Create temp file + FileManager.default.createFile(atPath: tempFile.path, contents: nil) + + guard let handle = FileHandle(forWritingAtPath: tempFile.path) else { + throw NSError( + domain: "MultipartBuilder", + code: 1, + userInfo: [NSLocalizedDescriptionKey: "Failed to create temp file"] + ) + } + + defer { try? handle.close() } + + // Write parts + for part in parts { + switch part { + case .text(let name, let value): + let data = textPart(name: name, value: value) + try handle.write(contentsOf: data) + + case .file(let name, let fileURL, let mimeType): + // Write file part header + let header = filePartHeader( + name: name, + fileName: fileURL.lastPathComponent, + mimeType: mimeType + ) + try handle.write(contentsOf: header) + + // Stream file contents in chunks + try streamFile(from: fileURL, to: handle) + + // Write trailing newline + try handle.write(contentsOf: Data("\r\n".utf8)) + } + } + + // Write closing boundary + try handle.write(contentsOf: closingBoundary()) + + return tempFile + } + + // MARK: - Private Helpers + + private func textPart(name: String, value: String) -> Data { + var data = Data() + data.append(Data("--\(boundary)\r\n".utf8)) + data.append(Data("Content-Disposition: form-data; name=\"\(name)\"\r\n".utf8)) + data.append(Data("\r\n".utf8)) + data.append(Data("\(value)\r\n".utf8)) + return data + } + + private func filePart(name: String, fileURL: URL, mimeType: String?) throws -> Data { + var data = Data() + + // Header + data.append(filePartHeader(name: name, fileName: fileURL.lastPathComponent, mimeType: mimeType)) + + // File content + let fileData = try Data(contentsOf: fileURL) + data.append(fileData) + + // Trailing newline + data.append(Data("\r\n".utf8)) + + return data + } + + private func filePartHeader(name: String, fileName: String, mimeType: String?) -> Data { + var header = Data() + header.append(Data("--\(boundary)\r\n".utf8)) + header.append( + Data( + "Content-Disposition: form-data; name=\"\(name)\"; filename=\"\(fileName)\"\r\n" + .utf8 + ) + ) + + if let mimeType = mimeType { + header.append(Data("Content-Type: \(mimeType)\r\n".utf8)) + } + + header.append(Data("\r\n".utf8)) + return header + } + + private func closingBoundary() -> Data { + return Data("--\(boundary)--\r\n".utf8) + } + + private func streamFile(from url: URL, to handle: FileHandle) throws { + guard let input = InputStream(url: url) else { + throw NSError( + domain: "MultipartBuilder", + code: 2, + userInfo: [NSLocalizedDescriptionKey: "Failed to open file for reading"] + ) + } + + input.open() + defer { input.close() } + + let bufferSize = 64 * 1024 // 64KB chunks + var buffer = [UInt8](repeating: 0, count: bufferSize) + + while input.hasBytesAvailable { + let bytesRead = input.read(&buffer, maxLength: bufferSize) + if bytesRead > 0 { + let data = Data(bytes: buffer, count: bytesRead) + try handle.write(contentsOf: data) + } else if bytesRead < 0 { + throw input.streamError + ?? NSError( + domain: "MultipartBuilder", + code: 3, + userInfo: [NSLocalizedDescriptionKey: "Error reading file"] + ) + } + } + } +} diff --git a/Tests/HuggingFaceTests/HubTests/FileOperationsTests.swift b/Tests/HuggingFaceTests/HubTests/FileOperationsTests.swift new file mode 100644 index 0000000..d5d4150 --- /dev/null +++ b/Tests/HuggingFaceTests/HubTests/FileOperationsTests.swift @@ -0,0 +1,406 @@ +import Foundation +import Testing + +@testable import HuggingFace + +#if swift(>=6.1) + @Suite("File Operations Tests", .serialized) + struct FileOperationsTests { + func createMockClient(bearerToken: String? = "test_token") -> HubClient { + let configuration = URLSessionConfiguration.ephemeral + configuration.protocolClasses = [MockURLProtocol.self] + let session = URLSession(configuration: configuration) + return HubClient( + session: session, + host: URL(string: "https://huggingface.co")!, + userAgent: "TestClient/1.0", + bearerToken: bearerToken + ) + } + + // MARK: - List Files Tests + + @Test("List files in repository", .mockURLSession) + func testListFiles() async throws { + let mockResponse = """ + [ + { + "path": "README.md", + "type": "file", + "oid": "abc123", + "size": 1234 + }, + { + "path": "config.json", + "type": "file", + "oid": "def456", + "size": 567 + }, + { + "path": "model", + "type": "directory" + } + ] + """ + + await MockURLProtocol.setHandler { request in + #expect(request.url?.path == "/api/models/facebook/bart-large/tree/main") + #expect(request.url?.query?.contains("recursive=true") == true) + #expect(request.httpMethod == "GET") + + let response = HTTPURLResponse( + url: request.url!, + statusCode: 200, + httpVersion: "HTTP/1.1", + headerFields: ["Content-Type": "application/json"] + )! + + return (response, Data(mockResponse.utf8)) + } + + let client = createMockClient() + let repoID: Repo.ID = "facebook/bart-large" + let files = try await client.listFiles(in: repoID, kind: .model, revision: "main") + + #expect(files.count == 3) + #expect(files[0].path == "README.md") + #expect(files[0].type == .file) + #expect(files[1].path == "config.json") + #expect(files[2].path == "model") + #expect(files[2].type == .directory) + } + + @Test("List files without recursive", .mockURLSession) + func testListFilesNonRecursive() async throws { + let mockResponse = """ + [ + { + "path": "README.md", + "type": "file", + "oid": "abc123", + "size": 1234 + } + ] + """ + + await MockURLProtocol.setHandler { request in + // Verify recursive is NOT in query + #expect( + request.url?.query?.contains("recursive") == false + || request.url?.query == nil + ) + + let response = HTTPURLResponse( + url: request.url!, + statusCode: 200, + httpVersion: "HTTP/1.1", + headerFields: ["Content-Type": "application/json"] + )! + + return (response, Data(mockResponse.utf8)) + } + + let client = createMockClient() + let repoID: Repo.ID = "user/repo" + let files = try await client.listFiles(in: repoID, kind: .model, recursive: false) + + #expect(files.count == 1) + } + + // MARK: - File Info Tests + + @Test("Get file info - file exists", .mockURLSession) + func testFileInfoExists() async throws { + await MockURLProtocol.setHandler { request in + #expect(request.url?.path == "/facebook/bart-large/resolve/main/README.md") + #expect(request.httpMethod == "HEAD") + #expect(request.value(forHTTPHeaderField: "Range") == "bytes=0-0") + + let response = HTTPURLResponse( + url: request.url!, + statusCode: 206, + httpVersion: "HTTP/1.1", + headerFields: [ + "Content-Length": "12345", + "ETag": "\"abc123def\"", + "X-Repo-Commit": "commit-sha-123", + ] + )! + + return (response, Data()) + } + + let client = createMockClient() + let repoID: Repo.ID = "facebook/bart-large" + let info = try await client.getFile( + at: "README.md", + in: repoID, + kind: .model, + revision: "main" + ) + + #expect(info.exists == true) + #expect(info.size == 12345) + #expect(info.etag == "\"abc123def\"") + #expect(info.revision == "commit-sha-123") + #expect(info.isLFS == false) + } + + @Test("Get file info - LFS file", .mockURLSession) + func testFileInfoLFS() async throws { + await MockURLProtocol.setHandler { request in + let response = HTTPURLResponse( + url: request.url!, + statusCode: 200, + httpVersion: "HTTP/1.1", + headerFields: [ + "Content-Length": "100000000", + "X-Linked-Size": "100000000", + ] + )! + + return (response, Data()) + } + + let client = createMockClient() + let repoID: Repo.ID = "user/model" + let info = try await client.getFile(at: "pytorch_model.bin", in: repoID) + + #expect(info.exists == true) + #expect(info.isLFS == true) + } + + @Test("Get file info - file does not exist", .mockURLSession) + func testFileInfoNotExists() async throws { + await MockURLProtocol.setHandler { request in + let response = HTTPURLResponse( + url: request.url!, + statusCode: 404, + httpVersion: "HTTP/1.1", + headerFields: [:] + )! + + return (response, Data()) + } + + let client = createMockClient() + let repoID: Repo.ID = "user/model" + let info = try await client.getFile(at: "nonexistent.txt", in: repoID) + + #expect(info.exists == false) + } + + // MARK: - File Exists Tests + + @Test("Check if file exists - true", .mockURLSession) + func testFileExists() async { + await MockURLProtocol.setHandler { request in + let response = HTTPURLResponse( + url: request.url!, + statusCode: 200, + httpVersion: "HTTP/1.1", + headerFields: [:] + )! + + return (response, Data()) + } + + let client = createMockClient() + let repoID: Repo.ID = "user/model" + let exists = await client.fileExists(at: "README.md", in: repoID) + + #expect(exists == true) + } + + @Test("Check if file exists - false", .mockURLSession) + func testFileNotExists() async { + await MockURLProtocol.setHandler { request in + let response = HTTPURLResponse( + url: request.url!, + statusCode: 404, + httpVersion: "HTTP/1.1", + headerFields: [:] + )! + + return (response, Data()) + } + + let client = createMockClient() + let repoID: Repo.ID = "user/model" + let exists = await client.fileExists(at: "nonexistent.txt", in: repoID) + + #expect(exists == false) + } + + // MARK: - Download Tests + + @Test("Download file data", .mockURLSession) + func testDownloadData() async throws { + let expectedData = "Hello, World!".data(using: .utf8)! + + await MockURLProtocol.setHandler { request in + #expect(request.url?.path == "/user/model/resolve/main/test.txt") + #expect(request.httpMethod == "GET") + + let response = HTTPURLResponse( + url: request.url!, + statusCode: 200, + httpVersion: "HTTP/1.1", + headerFields: ["Content-Type": "text/plain"] + )! + + return (response, expectedData) + } + + let client = createMockClient() + let repoID: Repo.ID = "user/model" + let data = try await client.downloadContentsOfFile(at: "test.txt", from: repoID) + + #expect(data == expectedData) + } + + @Test("Download with raw endpoint", .mockURLSession) + func testDownloadRaw() async throws { + await MockURLProtocol.setHandler { request in + #expect(request.url?.path == "/user/model/raw/main/test.txt") + + let response = HTTPURLResponse( + url: request.url!, + statusCode: 200, + httpVersion: "HTTP/1.1", + headerFields: [:] + )! + + return (response, Data()) + } + + let client = createMockClient() + let repoID: Repo.ID = "user/model" + _ = try await client.downloadContentsOfFile(at: "test.txt", from: repoID, useRaw: true) + } + + // MARK: - Delete Tests + + @Test("Delete single file", .mockURLSession) + func testDeleteFile() async throws { + await MockURLProtocol.setHandler { request in + #expect(request.url?.path == "/api/models/user/model/commit/main") + #expect(request.httpMethod == "POST") + #expect( + request.value(forHTTPHeaderField: "Content-Type") == "application/json" + ) + + let response = HTTPURLResponse( + url: request.url!, + statusCode: 200, + httpVersion: "HTTP/1.1", + headerFields: [:] + )! + + return (response, Data("true".utf8)) + } + + let client = createMockClient() + let repoID: Repo.ID = "user/model" + try await client.deleteFile(at: "test.txt", from: repoID, message: "Delete test file") + } + + @Test("Delete multiple files", .mockURLSession) + func testDeleteBatch() async throws { + await MockURLProtocol.setHandler { request in + #expect(request.url?.path == "/api/datasets/org/dataset/commit/dev") + + let response = HTTPURLResponse( + url: request.url!, + statusCode: 200, + httpVersion: "HTTP/1.1", + headerFields: [:] + )! + + return (response, Data("true".utf8)) + } + + let client = createMockClient() + let repoID: Repo.ID = "org/dataset" + try await client.deleteFiles( + at: ["file1.txt", "file2.txt", "dir/file3.txt"], + from: repoID, + kind: .dataset, + branch: "dev", + message: "Delete old files" + ) + } + + // MARK: - FileBatch Tests + + @Test("FileBatch dictionary literal initialization") + func testFileBatchDictionaryLiteral() { + let batch: FileBatch = [ + "README.md": .path("/tmp/readme.md"), + "config.json": .path("/tmp/config.json"), + ] + + let items = Array(batch) + #expect(items.count == 2) + #expect(items.contains { $0.key == "README.md" && $0.value.url.path == "/tmp/readme.md" }) + #expect(items.contains { $0.key == "config.json" && $0.value.url.path == "/tmp/config.json" }) + } + + @Test("FileBatch add and remove") + func testFileBatchMutations() { + var batch = FileBatch() + #expect(batch.count == 0) + + batch["file1.txt"] = .path("/tmp/file1.txt") + #expect(batch.count == 1) + + batch["file2.txt"] = .path("/tmp/file2.txt") + #expect(batch.count == 2) + + batch["file1.txt"] = nil + #expect(batch.count == 1) + #expect(batch["file2.txt"]?.url.path == "/tmp/file2.txt") + } + + // MARK: - Error Handling Tests + + @Test("Handle network error", .mockURLSession) + func testNetworkError() async throws { + await MockURLProtocol.setHandler { request in + throw NSError( + domain: NSURLErrorDomain, + code: NSURLErrorNotConnectedToInternet + ) + } + + let client = createMockClient() + let repoID: Repo.ID = "user/model" + + await #expect(throws: Error.self) { + _ = try await client.downloadContentsOfFile(at: "test.txt", from: repoID) + } + } + + @Test("Handle unauthorized access", .mockURLSession) + func testUnauthorized() async throws { + await MockURLProtocol.setHandler { request in + let response = HTTPURLResponse( + url: request.url!, + statusCode: 401, + httpVersion: "HTTP/1.1", + headerFields: [:] + )! + + return (response, Data("{\"error\": \"Unauthorized\"}".utf8)) + } + + let client = createMockClient(bearerToken: nil) + let repoID: Repo.ID = "user/private-model" + + await #expect(throws: HTTPClientError.self) { + _ = try await client.downloadContentsOfFile(at: "test.txt", from: repoID) + } + } + } + +#endif // swift(>=6.1)