Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error: DriverError(CUDA_ERROR_OUT_OF_MEMORY, "out of memory") with multiple GPU #2046

Open
evgenyigumnov opened this issue Apr 12, 2024 · 2 comments

Comments

@evgenyigumnov
Copy link
Contributor

I have: 4x RTX 3080 = 40GB total memory (each GPU by 10 GB memory)

I try to load model Mistral 7 about 15Gb file.

But I take error:

root@C.10529376:~/ai-server$ cargo run
    Finished dev [unoptimized + debuginfo] target(s) in 0.20s
     Running `target/debug/ai-server`
retrieved the files in 27.070873ms
Error: DriverError(CUDA_ERROR_OUT_OF_MEMORY, "out of memory")

Is it possible to run on multiple GPU mode?

@tomsanbear
Copy link
Collaborator

Definitely possible but won't just work with standard examples. Please take a look at this example for how to run across multiple GPUs:
https://github.com/huggingface/candle/tree/main/candle-examples/examples/llama_multiprocess

@evgenyigumnov
Copy link
Contributor Author

Unfortunately it didn't help

PS C:\Users\igumn\ai-server> cargo run
   Compiling ai-server v0.1.0 (C:\Users\igumn\ai-server)
error[E0616]: field `hidden_size` of struct `candle_transformers::models::mistral::Config` is private
  --> src\model.rs:19:29
   |
19 |         let n_elem = config.hidden_size / config.num_attention_heads;
   |                             ^^^^^^^^^^^ private field

error[E0616]: field `num_attention_heads` of struct `candle_transformers::models::mistral::Config` is private
  --> src\model.rs:19:50
   |
19 |         let n_elem = config.hidden_size / config.num_attention_heads;
   |                                                  ^^^^^^^^^^^^^^^^^^^ private field

error[E0616]: field `rope_theta` of struct `candle_transformers::models::mistral::Config` is private
  --> src\model.rs:22:36
   |
22 |             .map(|i| 1f32 / config.rope_theta.powf(i as f32 / n_elem as f32))
   |                                    ^^^^^^^^^^ private field

error[E0616]: field `num_hidden_layers` of struct `candle_transformers::models::mistral::Config` is private
  --> src\model.rs:34:56
   |
34 |             kvs: Arc::new(Mutex::new(vec![None; config.num_hidden_layers])),
   |                                                        ^^^^^^^^^^^^^^^^^ private field

error[E0061]: this function takes 2 arguments but 4 arguments were supplied
   --> src\main.rs:379:25
    |
379 |             let model = Mistral::new(vb, &cache, &config, comm)?;
    |                         ^^^^^^^^^^^^ --  ------ unexpected argument of type `&model::Cache`
    |                                      |
    |                                      unexpected argument of type `VarBuilderArgs<'_, ShardedSafeTensors>`
    |
note: expected `VarBuilderArgs<'_, Box<dyn SimpleBackend>>`, found `Rc<Comm>`
   --> src\main.rs:379:59
    |
379 |             let model = Mistral::new(vb, &cache, &config, comm)?;
    |                                                           ^^^^
    = note: expected struct `VarBuilderArgs<'_, Box<dyn SimpleBackend>>`
               found struct `Rc<Comm>`
note: associated function defined here
   --> C:\Users\igumn\.cargo\registry\src\index.crates.io-6f17d22bba15001f\candle-transformers-0.4.1\src\models\mistral.rs:383:12
    |
383 |     pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
    |            ^^^
help: remove the extra arguments
    |
379 -             let model = Mistral::new(vb, &cache, &config, comm)?;
379 +             let model = Mistral::new(, &config, /* VarBuilderArgs<'_, Box<dyn SimpleBackend>> */)?;
    |

warning: unused import: `candle_core::backend::BackendDevice`
  --> src\main.rs:18:5
   |
18 | use candle_core::backend::BackendDevice;
   |     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
   |
   = note: `#[warn(unused_imports)]` on by default

warning: unused import: `candle_core::backend::BackendStorage`
 --> src\model.rs:2:5
  |
2 | use candle_core::backend::BackendStorage;
  |     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Some errors have detailed explanations: E0061, E0616.
For more information about an error, try `rustc --explain E0061`.
warning: `ai-server` (bin "ai-server") generated 2 warnings
error: could not compile `ai-server` (bin "ai-server") due to 5 previous errors; 2 warnings emitted
[package]
name = "ai-server"
version = "0.1.0"
edition = "2021"

[dependencies]

candle-nn = "0.4.1"
candle-core = "0.4.1"
candle-datasets = "0.4.1"
candle-transformers = "0.4.1"
candle-examples = "0.4.1"
hf-hub = "0.3.2"
tokenizers = "0.15.2"
anyhow = "1.0.81"
clap = { version = "4.5.3", features = ["derive"] }
tracing = "0.1.37"
tracing-chrome = "0.7.1"
tracing-subscriber = "0.3.7"
axum = "0.7.5"
serde = { version = "1.0.197", features = ["derive"] }
tokio = "1.36.0"
once_cell = "1.19.0"
futures = "0.3.30"
cudarc = "0.10.0"

[build-dependencies]
bindgen_cuda = { version = "0.1.1", optional = true }


[features]
default = ["cuda"]
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:bindgen_cuda", "cudarc/nccl"]


model.rs

use candle_core::backend::BackendStorage;
use candle_core::{ DType, Device, Result, Tensor};
use std::sync::{Arc, Mutex};
const MAX_SEQ_LEN: usize = 4096;
pub type Config = candle_transformers::models::mistral::Config;

#[derive(Clone)]
pub struct Cache {
    #[allow(clippy::type_complexity)]
    kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
    cos: Tensor,
    sin: Tensor,
}

impl Cache {
    pub fn new(dtype: DType, config: &Config, device: &Device) -> Result<Self> {
        // precompute freqs_cis
        let n_elem = config.hidden_size / config.num_attention_heads;
        let theta: Vec<_> = (0..n_elem)
            .step_by(2)
            .map(|i| 1f32 / config.rope_theta.powf(i as f32 / n_elem as f32))
            .collect();
        let theta = Tensor::new(theta.as_slice(), device)?;
        let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
            .to_dtype(DType::F32)?
            .reshape((MAX_SEQ_LEN, 1))?
            .matmul(&theta.reshape((1, theta.elem_count()))?)?;
        // This is different from the paper, see:
        // https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
        let cos = idx_theta.cos()?.to_dtype(dtype)?;
        let sin = idx_theta.sin()?.to_dtype(dtype)?;
        Ok(Self {
            kvs: Arc::new(Mutex::new(vec![None; config.num_hidden_layers])),
            cos,
            sin,
        })
    }
}

main.rs

use std::io::Write;
use std::rc::Rc;
use axum::{
    routing::post,
    routing::get,
    Json, Router,
};
use serde::{Deserialize, Serialize};
use axum::{response::Html};


use anyhow::{Error as E, Result};
use clap::Parser;

use candle_transformers::models::mistral::{Config, Model as Mistral};

use candle_core::{DType, Device, Tensor};
use candle_core::backend::BackendDevice;
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
use once_cell::sync::Lazy;
use futures::lock::Mutex;
use cudarc::nccl::safe::{Comm, Id};

static AI_SERVER: Lazy<Mutex<Option<TextGeneration>>> = Lazy::new(|| Mutex::new(None));

mod model;

#[tokio::main]
async fn main() -> anyhow::Result<()> {

    let args = Args::parse();
    let ai = init(&args)?;
    AI_SERVER.lock().await.replace(ai);

    let app = Router::new().route("/", post(handle_request)).route("/", get(handler));

    let listener = tokio::net::TcpListener::bind("127.0.0.1:8181")
        .await
        .unwrap();
    println!("listening on {}", listener.local_addr().unwrap());
    axum::serve(listener, app).await.unwrap();

    Ok(())

}

#[derive(Deserialize)]
struct Input {
    doc1: String,
    doc2: String,
}

#[derive(Serialize)]
struct Response {
    result: String,
}

async fn handle_request(Json(payload): Json<Input>) -> Json<Response> {


    let prompt = format!(r#"
{}

Which category does the text above belong to? Here is the list of categories:

{}

Give me a short answer. Just a category number. Category number:"#, payload.doc1, payload.doc2);

    let mut ai_server_mut = AI_SERVER.lock().await;
    let ai_server_opt = ai_server_mut.as_mut();
    match ai_server_opt {
        Some(ai_server) => {
            let result_opt = ai_server.run(&prompt, 3);
            match result_opt {
                Ok(result) => {
                    Json(Response { result: result.replace("%","").replace("\n","") })
                }
                Err(e) => {
                    Json(Response { result: e.to_string() })
                }
            }
        }
        None => {
            Json(Response { result: "AI server not initialized".to_string() })
        }
    }
}

async fn handler() -> Html<&'static str> {
    Html("<h1>Server status: online</h1>")
}


struct TextGeneration {
    model: Mistral,
    device: Device,
    tokenizer: TokenOutputStream,
    logits_processor: LogitsProcessor,
    repeat_penalty: f32,
    repeat_last_n: usize,
}

impl TextGeneration {
    #[allow(clippy::too_many_arguments)]
    fn new(
        model: Mistral,
        tokenizer: Tokenizer,
        seed: u64,
        temp: Option<f64>,
        top_p: Option<f64>,
        repeat_penalty: f32,
        repeat_last_n: usize,
        device: &Device,
    ) -> Self {
        let logits_processor = LogitsProcessor::new(seed, temp, top_p);
        Self {
            model,
            tokenizer: TokenOutputStream::new(tokenizer),
            logits_processor,
            repeat_penalty,
            repeat_last_n,
            device: device.clone(),
        }
    }

    fn run(&mut self, prompt: &str, sample_len: usize) -> Result<String> {
        let mut result_string = String::new();
        self.tokenizer.clear();
        self.model.clear_kv_cache();
        let mut tokens = self
            .tokenizer
            .tokenizer()
            .encode(prompt, true)
            .map_err(E::msg)?
            .get_ids()
            .to_vec();
        println!("{}", prompt);
        std::io::stdout().flush()?;

        let mut generated_tokens = 0usize;
        let eos_token = match self.tokenizer.get_token("</s>") {
            Some(token) => token,
            None => anyhow::bail!("cannot find the </s> token"),
        };
        let start_gen = std::time::Instant::now();
        for index in 0..sample_len {
            let context_size = if index > 0 { 1 } else { tokens.len() };
            let start_pos = tokens.len().saturating_sub(context_size);
            let ctxt = &tokens[start_pos..];
            let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
            let logits = self.model.forward(&input, start_pos)?;
            let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
            let logits = if self.repeat_penalty == 1. {
                logits
            } else {
                let start_at = tokens.len().saturating_sub(self.repeat_last_n);
                candle_transformers::utils::apply_repeat_penalty(
                    &logits,
                    self.repeat_penalty,
                    &tokens[start_at..],
                )?
            };

            let next_token = self.logits_processor.sample(&logits)?;
            tokens.push(next_token);
            generated_tokens += 1;
            if next_token == eos_token {
                break;
            }
            if let Some(t) = self.tokenizer.next_token(next_token)? {
                result_string.push_str(&t);
                print!("{t}");
                std::io::stdout().flush()?;
            }
        }
        let dt = start_gen.elapsed();
        if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
            result_string.push_str(&rest);
            print!("{rest}");
        }
        std::io::stdout().flush()?;
        println!(
            "\n{generated_tokens} tokens generated ({:.2} token/s)",
            generated_tokens as f64 / dt.as_secs_f64(),
        );


        Ok(result_string)
    }
}

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
    #[arg(long)]
    num_shards: usize,

    #[arg(long)]
    rank: Option<usize>,

    #[arg(long)]
    start_port: Option<usize>,

    #[arg(long)]
    cpu: bool,
    #[arg(long)]
    tracing: bool,

    #[arg(long)]
    use_flash_attn: bool,


    #[arg(long)]
    temperature: Option<f64>,

    #[arg(long)]
    top_p: Option<f64>,

    #[arg(long, default_value_t = 299792458)]
    seed: u64,

    #[arg(long)]
    model_id: Option<String>,

    #[arg(long, default_value = "main")]
    revision: String,

    #[arg(long)]
    tokenizer_file: Option<String>,

    #[arg(long)]
    weight_files: Option<String>,


    #[arg(long, default_value_t = 1.1)]
    repeat_penalty: f32,

    #[arg(long, default_value_t = 64)]
    repeat_last_n: usize,
}

fn init(args: &Args) -> Result<TextGeneration> {
    use tracing_chrome::ChromeLayerBuilder;
    use tracing_subscriber::prelude::*;

    // let args = Args {
    //     cpu: false,
    //     tracing: false,
    //     use_flash_attn: false,
    //     temperature: None,
    //     top_p: None,
    //     seed: 299792458,
    //     model_id: None,
    //     revision: "main".to_string(),
    //     tokenizer_file: None,
    //     weight_files: None,
    //     repeat_penalty: 1.1,
    //     repeat_last_n: 3,
    // };



    let _guard = if args.tracing {
        let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
        tracing_subscriber::registry().with(chrome_layer).init();
        Some(guard)
    } else {
        None
    };


    let start = std::time::Instant::now();
    let api = Api::new()?;
    let model_id = match &args.model_id {
        Some(model_id) => model_id.to_string(),
        None => {
            "mistralai/Mistral-7B-v0.1".to_string()
        }
    };
    let repo = api.repo(Repo::with_revision(
        model_id,
        RepoType::Model,
        args.revision.to_string(),
    ));
    let tokenizer_filename = match &args.tokenizer_file {
        Some(file) => std::path::PathBuf::from(file),
        None => repo.get("tokenizer.json")?,
    };
    let filenames = match &args.weight_files {
        Some(files) => files
            .split(',')
            .map(std::path::PathBuf::from)
            .collect::<Vec<_>>(),
        None => {
            candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
        }
    };
    println!("retrieved the files in {:?}", start.elapsed());

    if args.rank.is_none() && args.start_port.is_some() {
        let children: Vec<_> = (0..args.num_shards)
            .map(|rank| {
                let mut args: std::collections::VecDeque<_> = std::env::args().collect();
                args.push_back("--rank".to_string());
                args.push_back(format!("{rank}"));
                let name = args.pop_front().unwrap();
                std::process::Command::new(name).args(args).spawn().unwrap()
            })
            .collect();
        for mut child in children {
            child.wait().unwrap();
        }
        return Err(E::msg("all children have exited"));
    }


    let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;

    let start = std::time::Instant::now();
    let config = Config::config_7b_v0_1(args.use_flash_attn);

    if args.start_port.is_some() {

        let comm_file: String = "nccl_id.txt".to_string();

        let comm_file = std::path::PathBuf::from(comm_file);
        if comm_file.exists() {
            return Err(E::msg("comm file already exists, please remove it first"));
        }

        let (model, device) = {
            let num_shards = args.num_shards;
            // Primitive IPC
            let id = if args.rank.unwrap() == 0 {
                let id = Id::new().unwrap();
                let tmp_file = comm_file.with_extension(".comm.tgz");
                std::fs::File::create(&tmp_file)?
                    .write_all(&id.internal().iter().map(|&i| i as u8).collect::<Vec<_>>())?;
                std::fs::rename(&tmp_file, &comm_file)?;
                id
            } else {
                while !comm_file.exists() {
                    std::thread::sleep(std::time::Duration::from_secs(1));
                }
                let data = std::fs::read(&comm_file)?;
                let internal: [i8; 128] = data
                    .into_iter()
                    .map(|i| i as i8)
                    .collect::<Vec<_>>()
                    .try_into()
                    .unwrap();
                let id: Id = Id::uninit(internal);
                id
            };

            let device_arc = cudarc::driver::CudaDevice::new(args.rank.unwrap())?;
            let device = Device::new_cuda(args.rank.unwrap())?;
            let dtype = if device.is_cuda() {
                DType::BF16
            } else {
                DType::F32
            };

            let cache = model::Cache::new(dtype, &config, &device)?;

            let vb = unsafe {
                candle_nn::var_builder::ShardedSafeTensors::var_builder(&filenames, dtype, &device)?
            };

            let comm = match Comm::from_rank(device_arc, args.rank.unwrap(), args.num_shards, id) {
                Ok(comm) => Rc::new(comm),
                Err(err) => anyhow::bail!("nccl error {:?}", err.0),
            };
            let model = Mistral::new(vb, &cache, &config, comm)?;

            (model, device)
        };

        println!("loaded the model in {:?}", start.elapsed());
        let pipeline = TextGeneration::new(
            model,
            tokenizer,
            299792458,
            Some(0.0),
            Some(00.0),
            1.1,
            64,
            &device,
        );

        Ok(pipeline)

    } else {
        let device = candle_examples::device(false)?;

        let (model, device) = {
            let dtype = if device.is_cuda() {
                DType::BF16
            } else {
                DType::F32
            };
            let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
            let model = Mistral::new(&config, vb)?;
            (model, device)
        };

        println!("loaded the model in {:?}", start.elapsed());
        let pipeline = TextGeneration::new(
            model,
            tokenizer,
            299792458,
            Some(0.0),
            Some(00.0),
            1.1,
            64,
            &device,
        );

        Ok(pipeline)

    }

}

As I understand it, these errors can disappear if changes are made to the following files:

  1. candle_transformers::models::mistral::Config
  2. candle-transformers-0.4.1\src\models\mistral.rs

right?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants