Skip to content

Commit aa22467

Browse files
committed
remote: refactor HuggingFaceApiService; implement download feature in HuggingFaceRemoteDataSource
1 parent 5138cb6 commit aa22467

File tree

4 files changed

+187
-145
lines changed

4 files changed

+187
-145
lines changed

examples/llama.android/app/src/main/java/com/example/llama/data/remote/HuggingFaceApiService.kt

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
package com.example.llama.data.remote
22

33
import okhttp3.ResponseBody
4+
import retrofit2.Response
45
import retrofit2.http.GET
6+
import retrofit2.http.HEAD
57
import retrofit2.http.Path
68
import retrofit2.http.Query
79
import retrofit2.http.Streaming
@@ -21,10 +23,17 @@ interface HuggingFaceApiService {
2123
@GET("api/models/{modelId}")
2224
suspend fun getModelDetails(@Path("modelId") modelId: String): HuggingFaceModelDetails
2325

24-
@GET("{modelId}/resolve/main/{filePath}")
26+
@HEAD("{modelId}/resolve/main/{filename}")
27+
suspend fun getModelFileHeader(
28+
@Path("modelId", encoded = true) modelId: String,
29+
@Path("filename", encoded = true) filename: String
30+
): Response<Void>
31+
32+
@Deprecated("Use DownloadManager instead!")
33+
@GET("{modelId}/resolve/main/{filename}")
2534
@Streaming
2635
suspend fun downloadModelFile(
2736
@Path("modelId") modelId: String,
28-
@Path("filePath") filePath: String
37+
@Path("filename") filename: String
2938
): ResponseBody
3039
}
Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,50 @@
11
package com.example.llama.data.remote
22

3+
import android.net.Uri
4+
import androidx.core.net.toUri
5+
import com.example.llama.di.HUGGINGFACE_HOST
36
import java.util.Date
47

8+
private const val FILE_EXTENSION_GGUF = ".gguf"
9+
510
data class HuggingFaceModel(
611
val _id: String,
712
val id: String,
813
val modelId: String,
914

1015
val author: String,
11-
val createdAt: Date?,
12-
val lastModified: Date?,
16+
val createdAt: Date,
17+
val lastModified: Date,
1318

14-
val library_name: String?,
15-
val pipeline_tag: String?,
16-
val tags: List<String>?,
19+
val pipeline_tag: String,
20+
val tags: List<String>,
1721

18-
val private: Boolean?,
19-
val gated: Boolean?,
22+
val private: Boolean,
23+
val gated: Boolean,
2024

21-
val likes: Int?,
22-
val downloads: Int?,
25+
val likes: Int,
26+
val downloads: Int,
2327

24-
val sha: String?,
28+
val sha: String,
29+
val siblings: List<Sibling>,
2530

26-
val siblings: List<Sibling>?,
31+
val library_name: String?,
2732
) {
2833
data class Sibling(
2934
val rfilename: String,
3035
)
36+
37+
fun getGgufFilename(): String? =
38+
siblings.map { it.rfilename }.first { it.endsWith(FILE_EXTENSION_GGUF) }
39+
40+
fun toDownloadInfo() = getGgufFilename()?.let { HuggingFaceDownloadInfo(_id, modelId, it) }
41+
}
42+
43+
data class HuggingFaceDownloadInfo(
44+
val _id: String,
45+
val modelId: String,
46+
val filename: String,
47+
) {
48+
val uri: Uri
49+
get() = "$HUGGINGFACE_HOST${modelId}/resolve/main/$filename".toUri()
3150
}
Lines changed: 109 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,47 @@
11
package com.example.llama.data.remote
22

3+
import android.app.DownloadManager
4+
import android.content.Context
5+
import android.os.Environment
36
import android.util.Log
47
import kotlinx.coroutines.Dispatchers
8+
import kotlinx.coroutines.delay
59
import kotlinx.coroutines.withContext
6-
import java.io.File
710
import javax.inject.Inject
811
import javax.inject.Singleton
912

13+
private const val QUERY_Q4_0_GGUF = "gguf q4_0"
14+
private const val FILTER_TEXT_GENERATION = "text-generation"
15+
private const val SORT_BY_DOWNLOADS = "downloads"
16+
private const val SEARCH_RESULT_LIMIT = 20
17+
1018
interface HuggingFaceRemoteDataSource {
19+
/**
20+
* Query openly available Q4_0 GGUF models on HuggingFace
21+
*/
1122
suspend fun searchModels(
12-
query: String? = "gguf q4_0",
13-
filter: String? = "text-generation", // Only generative models,
14-
sort: String? = "downloads",
23+
query: String? = QUERY_Q4_0_GGUF,
24+
filter: String? = FILTER_TEXT_GENERATION,
25+
sort: String? = SORT_BY_DOWNLOADS,
1526
direction: String? = "-1",
16-
limit: Int? = 20,
27+
limit: Int? = SEARCH_RESULT_LIMIT,
1728
full: Boolean = true,
1829
): List<HuggingFaceModel>
1930

2031
suspend fun getModelDetails(modelId: String): HuggingFaceModelDetails
2132

22-
suspend fun downloadModelFile(modelId: String, filePath: String, outputFile: File): Result<File>
33+
/**
34+
* Obtain selected HuggingFace model's GGUF file size from HTTP header
35+
*/
36+
suspend fun getFileSize(modelId: String, filePath: String): Long?
37+
38+
/**
39+
* Download selected HuggingFace model's GGUF file via DownloadManager
40+
*/
41+
suspend fun downloadModelFile(
42+
context: Context,
43+
downloadInfo: HuggingFaceDownloadInfo,
44+
): Result<Unit>
2345
}
2446

2547
@Singleton
@@ -42,7 +64,7 @@ class HuggingFaceRemoteDataSourceImpl @Inject constructor(
4264
direction = direction,
4365
limit = limit,
4466
full = full,
45-
)
67+
).filter { it.gated != true && it.private != true }
4668
}
4769

4870
override suspend fun getModelDetails(
@@ -51,32 +73,98 @@ class HuggingFaceRemoteDataSourceImpl @Inject constructor(
5173
apiService.getModelDetails(modelId)
5274
}
5375

54-
override suspend fun downloadModelFile(
76+
override suspend fun getFileSize(
5577
modelId: String,
56-
filePath: String,
57-
outputFile: File
58-
): Result<File> = withContext(Dispatchers.IO) {
78+
filePath: String
79+
): Long? = withContext(Dispatchers.IO) {
5980
try {
60-
val response = apiService.downloadModelFile(modelId, filePath)
81+
apiService.getModelFileHeader(modelId, filePath).let {
82+
if (it.isSuccessful) {
83+
it.headers()[HTTP_HEADER_CONTENT_LENGTH]?.toLongOrNull()
84+
} else {
85+
null
86+
}
87+
}
88+
} catch (e: Exception) {
89+
Log.e(TAG, "Error getting file size for $modelId: ${e.message}")
90+
null
91+
}
92+
}
6193

62-
// Create parent directories if needed
63-
outputFile.parentFile?.mkdirs()
94+
override suspend fun downloadModelFile(
95+
context: Context,
96+
downloadInfo: HuggingFaceDownloadInfo,
97+
): Result<Unit> = withContext(Dispatchers.IO) {
98+
try {
99+
val downloadManager =
100+
context.getSystemService(Context.DOWNLOAD_SERVICE) as DownloadManager
101+
val request = DownloadManager.Request(downloadInfo.uri).apply {
102+
setTitle("HuggingFace model download")
103+
setDescription("Downloading ${downloadInfo.filename}")
104+
setNotificationVisibility(DownloadManager.Request.VISIBILITY_VISIBLE_NOTIFY_COMPLETED)
105+
setDestinationInExternalPublicDir(
106+
Environment.DIRECTORY_DOWNLOADS,
107+
downloadInfo.filename
108+
)
109+
setAllowedNetworkTypes(
110+
DownloadManager.Request.NETWORK_WIFI or DownloadManager.Request.NETWORK_MOBILE
111+
)
112+
setAllowedOverMetered(true)
113+
setAllowedOverRoaming(false)
114+
}
115+
Log.d(TAG, "Enqueuing download request for: ${downloadInfo.modelId}")
116+
val downloadId = downloadManager.enqueue(request)
117+
118+
delay(DOWNLOAD_MANAGER_DOUBLE_CHECK_DELAY)
64119

65-
// Save the file
66-
response.byteStream().use { input ->
67-
outputFile.outputStream().use { output ->
68-
input.copyTo(output)
120+
val cursor = downloadManager.query(DownloadManager.Query().setFilterById(downloadId))
121+
if (cursor != null && cursor.moveToFirst()) {
122+
val statusIndex = cursor.getColumnIndex(DownloadManager.COLUMN_STATUS)
123+
if (statusIndex >= 0) {
124+
val status = cursor.getInt(statusIndex)
125+
cursor.close()
126+
127+
when (status) {
128+
DownloadManager.STATUS_FAILED -> {
129+
// Get failure reason if available
130+
val reasonIndex = cursor.getColumnIndex(DownloadManager.COLUMN_REASON)
131+
val reason = if (reasonIndex >= 0) cursor.getInt(reasonIndex) else -1
132+
val errorMessage = when (reason) {
133+
DownloadManager.ERROR_HTTP_DATA_ERROR -> "HTTP error"
134+
DownloadManager.ERROR_INSUFFICIENT_SPACE -> "Insufficient storage"
135+
DownloadManager.ERROR_TOO_MANY_REDIRECTS -> "Too many redirects"
136+
DownloadManager.ERROR_UNHANDLED_HTTP_CODE -> "Unhandled HTTP code"
137+
DownloadManager.ERROR_CANNOT_RESUME -> "Cannot resume download"
138+
DownloadManager.ERROR_FILE_ERROR -> "File error"
139+
else -> "Unknown error"
140+
}
141+
Result.failure(Exception(errorMessage))
142+
}
143+
else -> {
144+
// Download is pending, paused, or running
145+
Result.success(Unit)
146+
}
147+
}
148+
} else {
149+
// Assume success if we can't check status
150+
cursor.close()
151+
Result.success(Unit)
69152
}
153+
} else {
154+
// Assume success if cursor is empty
155+
cursor?.close()
156+
Result.success(Unit)
70157
}
71-
72-
Result.success(outputFile)
73158
} catch (e: Exception) {
74-
Log.e(TAG, "Error downloading file $filePath: ${e.message}")
159+
Log.e(TAG, "Failed to enqueue download: ${e.message}")
75160
Result.failure(e)
76161
}
77162
}
78163

79164
companion object {
80165
private val TAG = HuggingFaceRemoteDataSourceImpl::class.java.simpleName
166+
167+
private const val HTTP_HEADER_CONTENT_LENGTH = "content-length"
168+
private const val DOWNLOAD_MANAGER_DOUBLE_CHECK_DELAY = 500L
81169
}
82170
}

0 commit comments

Comments
 (0)