Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
467 changes: 221 additions & 246 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion kernel-builder/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ clap-markdown = "0.1.5"
clap_complete = "4"
eyre = "0.6.12"
git2 = "0.20"
huggingface-hub = { git = "https://github.com/huggingface/huggingface_hub_rust.git", rev = "6084c0f911026b4fec2961742c611520d7eb3d27", package = "huggingface-hub", features = ["blocking", "xet"] }
hf-hub = { version = "1.0.0-rc.0", features = ["blocking"] }
itertools = "0.13"
minijinja = "2.5"
minijinja-embed = "2.5"
Expand Down
28 changes: 16 additions & 12 deletions kernel-builder/src/hf.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,32 @@
use eyre::{Context, Result};
use huggingface_hub::{HFClientSync, HFRepositorySync, RepoType};
use hf_hub::{HFClientSync, HFRepositorySync, RepoType};

/// Build a sync HF API client.
pub fn api() -> Result<huggingface_hub::HFClientSync> {
huggingface_hub::HFClientSync::new().context("Cannot create Hugging Face API client")
pub fn api() -> Result<hf_hub::HFClientSync> {
hf_hub::HFClientSync::new().context("Cannot create Hugging Face API client")
}

/// Get a repo handle.
pub fn repo_handle(api: &HFClientSync, repo_type: RepoType, repo_id: &str) -> HFRepositorySync {
pub fn repo_handle<T: RepoType>(api: &HFClientSync, repo_id: &str) -> HFRepositorySync<T> {
let parts: Vec<&str> = repo_id.splitn(2, '/').collect();
if parts.len() == 2 {
api.repo(repo_type, parts[0], parts[1])
api.repository::<T>(parts[0], parts[1])
} else {
api.repo(repo_type, "", repo_id)
api.repository::<T>("", repo_id)
}
}

/// Resolve the HF username of the currently logged-in user via `whoami`.
/// Requires a valid HF token to be configured.
pub fn whoami_username() -> Result<String> {
api()?.whoami().map(|user| user.username).map_err(|_| {
eyre::eyre!(
"Not logged in to Hugging Face. Run `hf auth login` first, \
or use --name <owner/repo> to skip auto-detection."
)
})
api()?
.whoami()
.send()
.map(|user| user.username)
.map_err(|_| {
eyre::eyre!(
"Not logged in to Hugging Face. Run `hf auth login` first, \
or use --name <owner/repo> to skip auto-detection."
)
})
}
78 changes: 38 additions & 40 deletions kernel-builder/src/upload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ use std::{

use clap::Args;
use eyre::{bail, Context, Result};
use huggingface_hub::{
AddSource, CommitOperation, CreateRepoParams, RepoCreateBranchParams, RepoCreateCommitParams,
RepoListFilesParams, RepoListRefsParams, RepoType,
use hf_hub::{
repository::{AddSource, CommitOperation},
RepoType, RepoTypeKernel, RepoTypeModel,
};
use kernels_data::metadata::Metadata;
use walkdir::WalkDir;
Expand All @@ -29,15 +29,6 @@ pub enum RepoTypeArg {
Kernel,
}

impl From<RepoTypeArg> for RepoType {
fn from(arg: RepoTypeArg) -> Self {
match arg {
RepoTypeArg::Model => RepoType::Model,
RepoTypeArg::Kernel => RepoType::Kernel,
}
}
}

#[derive(Debug, Args)]
pub struct UploadArgs {
/// Directory of the kernel build (defaults to current directory).
Expand Down Expand Up @@ -96,8 +87,14 @@ fn get_repo_and_branch(
}

pub fn run_upload(args: UploadArgs) -> Result<()> {
match args.repo_type {
RepoTypeArg::Model => run_upload_typed::<RepoTypeModel>(args),
RepoTypeArg::Kernel => run_upload_typed::<RepoTypeKernel>(args),
}
}

fn run_upload_typed<T: RepoType>(args: UploadArgs) -> Result<()> {
let api = hf::api()?;
let repo_type: RepoType = args.repo_type.into();
let kernel_dir = check_or_infer_kernel_dir(args.kernel_dir)?;
let kernel_dir = fs::canonicalize(&kernel_dir)
.wrap_err_with(|| format!("Cannot resolve kernel directory `{}`", kernel_dir.display()))?;
Expand All @@ -111,14 +108,13 @@ pub fn run_upload(args: UploadArgs) -> Result<()> {

let (repo_id, branch) = get_repo_and_branch(&kernel_dir, args.repo_id, args.branch, &variants)?;

let params = CreateRepoParams::builder()
let repo_url = api
.create_repository()
.repo_id(&repo_id)
.repo_type(repo_type)
.repo_type(T::default())
.private(args.private)
.exist_ok(true)
.build();
let repo_url = api
.create_repo(&params)
.send()
.wrap_err("Cannot create repository")?;
// Extract repo_id from URL, stripping "kernels/" prefix for kernel repos
let repo_id = repo_url
Expand All @@ -129,18 +125,19 @@ pub fn run_upload(args: UploadArgs) -> Result<()> {
.unwrap_or(&repo_id)
.to_owned();

let repo = repo_handle(&api, repo_type, &repo_id);
let repo = repo_handle::<T>(&api, &repo_id);

let is_new_version_branch = if let Some(ref branch) = branch {
let refs_params = RepoListRefsParams::builder().build();
let refs = repo
.list_refs(&refs_params)
.list_refs()
.send()
.wrap_err("Cannot list repository refs")?;
let exists = refs.branches.iter().any(|r| r.name == *branch);

if !exists {
let params = RepoCreateBranchParams::builder().branch(branch).build();
repo.create_branch(&params)
repo.create_branch()
.branch(branch)
.send()
.wrap_err_with(|| format!("Cannot create branch `{branch}`"))?;
}
eprintln!(
Expand All @@ -163,10 +160,18 @@ pub fn run_upload(args: UploadArgs) -> Result<()> {
);

if let Some(ref branch) = branch {
let params = RepoListFilesParams {
revision: Some(branch.clone()),
};
let version_existing_files: Vec<String> = repo.list_files(&params).unwrap_or_default();
let version_existing_files: Vec<String> = repo
.list_tree()
.revision(branch.clone())
.recursive(true)
.send()
.unwrap_or_default()
.into_iter()
.map(|entry| match entry {
hf_hub::repository::RepoTreeEntry::File { path, .. } => path,
hf_hub::repository::RepoTreeEntry::Directory { path, .. } => path,
})
.collect();

let version_ops = operations_by_branch.entry(branch.clone()).or_default();

Expand Down Expand Up @@ -215,15 +220,11 @@ pub fn run_upload(args: UploadArgs) -> Result<()> {
"Uploaded using `kernel-builder`.".to_owned()
};

let params = RepoCreateCommitParams {
operations: chunk.to_vec(),
commit_message,
commit_description: None,
revision: Some(branch.clone()),
create_pr: None,
parent_commit: None,
};
repo.create_commit(&params)
repo.create_commit()
.operations(chunk.to_vec())
.commit_message(&commit_message)
.revision(branch.clone())
.send()
.wrap_err_with(|| format!("Cannot create commit on branch `{branch}`"))?;

if batch_count > 1 {
Expand All @@ -236,10 +237,7 @@ pub fn run_upload(args: UploadArgs) -> Result<()> {
if total_ops == 0 {
eprintln!("No changes to upload.");
} else {
let type_prefix = match repo_type {
RepoType::Kernel => "kernels/",
_ => "",
};
let type_prefix = T::default().url_prefix();
let tree_path = branch
.as_ref()
.map_or(String::new(), |b| format!("/tree/{b}"));
Expand Down
3 changes: 0 additions & 3 deletions nix-builder/pkgs/kernel-abi-check/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@ rustPlatform.buildRustPackage {

cargoLock = {
lockFile = ../../../Cargo.lock;
outputHashes = {
"huggingface-hub-0.0.1" = "sha256-By8b1NUPWu+XF3Om1NcEO+o2qdZUco+FxvrJGNRqxWs=";
};
};

cargoBuildFlags = cargoFlags;
Expand Down
3 changes: 0 additions & 3 deletions nix-builder/pkgs/kernel-builder/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,6 @@ rustPlatform.buildRustPackage {

cargoLock = {
lockFile = ../../../Cargo.lock;
outputHashes = {
"huggingface-hub-0.0.1" = "sha256-By8b1NUPWu+XF3Om1NcEO+o2qdZUco+FxvrJGNRqxWs=";
};
};

cargoBuildFlags = cargoFlags;
Expand Down
3 changes: 0 additions & 3 deletions nix-builder/pkgs/python-modules/kernel-abi-check/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@ buildPythonPackage {

cargoDeps = rustPlatform.importCargoLock {
lockFile = ../../../../Cargo.lock;
outputHashes = {
"huggingface-hub-0.0.1" = "sha256-By8b1NUPWu+XF3Om1NcEO+o2qdZUco+FxvrJGNRqxWs=";
};
};

maturinBuildFlags = cargoFlags;
Expand Down
3 changes: 0 additions & 3 deletions nix-builder/pkgs/python-modules/kernels-data/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ buildPythonPackage {

cargoDeps = rustPlatform.importCargoLock {
lockFile = ../../../../Cargo.lock;
outputHashes = {
"huggingface-hub-0.0.1" = "sha256-By8b1NUPWu+XF3Om1NcEO+o2qdZUco+FxvrJGNRqxWs=";
};
};

maturinBuildFlags = cargoFlags;
Expand Down
Loading