Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Community] OPT Inference in HF Transformers #88

Closed
patrickvonplaten opened this issue May 10, 2022 · 28 comments
Closed

[Community] OPT Inference in HF Transformers #88

patrickvonplaten opened this issue May 10, 2022 · 28 comments
Labels
question Further information is requested

Comments

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented May 10, 2022

You can now use the OPT models in Hugging Face Transformers

Go here for details: https://twitter.com/huggingface/status/1524783489593360385

(Edited by admin. Original post below)


We're working hard at Hugging Face on adding all the checkpoints to Transformers. Thanks to @stephenroller and co. , we've now managed to correctly convert the checkpoints. They are all uploaded here: https://huggingface.co/models?other=opt_metasq

If you go into a specific repo, you'll find a detailed explanation on how to run them.

@Mrs-Hudson
Copy link

Thanks @patrickvonplaten ! Do you plan to add code for running generation using those checkpoints? I have been trying to do this through #89 borrowing from your next token prediction scripts

@patrickvonplaten
Copy link
Contributor Author

We'll have the checkpoints added to Transformers by the end of the week, then it should be quite easy to run generation on them :-)

@stephenroller stephenroller pinned this issue May 10, 2022
@stephenroller
Copy link
Contributor

Thank you Patrick! This is huge for accessibility of the models. Metaseq is notoriously unfriendly.

@hunterlang
Copy link

@patrickvonplaten that's awesome! Will we just be able to call model.parallelize() to use the big checkpoints like with the big T5/GPT2 models? That will make it super easy...

@zhisbug
Copy link
Contributor

zhisbug commented May 11, 2022

We'll have the checkpoints added to Transformers by the end of the week, then it should be quite easy to run generation on them :-)

Mind sharing your conversion script?

@sanxchep
Copy link

sanxchep commented May 11, 2022

@patrickvonplaten that's awesome! Will we just be able to call model.parallelize() to use the big checkpoints like with the big T5/GPT2 models? That will make it super easy...

Hope this happens because it is another headache to distribute the layers to every GPU.
Also, @patrickvonplaten please try to get this model supported by any inference engine like Nvidia Triton over TensorRT. There is a huge issue going on with multi angled models like t5-11b, GPT-2 where there are not any resources on parallel or batch inference. Making the models usable but not practical.

@patrickvonplaten
Copy link
Contributor Author

@sanxchep We'll open-source something for this tomorrow (hopefully :-))

@patrickvonplaten
Copy link
Contributor Author

All other checkpoints are available now here: https://huggingface.co/models?other=opt_metasq

@patrickvonplaten
Copy link
Contributor Author

patrickvonplaten commented May 11, 2022

@zhisbug yes:

#!/usr/bin/env python3
"""
Script for backing out of the MP-resharded (reshard.pt) files and getting back
a non-flattened state dict.
Particularly useful for converting our models to other repositories.
Usage:
    $ ls 125m
    dict.txt
    gpt2-merges.txt
    gpt2-vocab.json
    reshard-model_part-0.pt
    reshard-model_part-1.pt
    $ python -m metaseq.scripts.convert_to_singleton 125m
    $ ls 125m
    dict.txt
    gpt2-merges.txt
    gpt2-vocab.json
    reshard-model_part-0.pt
    reshard-model_part-1.pt
    restored.pt
"""

import torch
import argparse
import glob

from metaseq import options, tasks, checkpoint_utils, utils
from metaseq.dataclass.configs import MetaseqConfig
from metaseq.dataclass.utils import convert_namespace_to_omegaconf
from metaseq.distributed import utils as dist_utils
from metaseq.distributed import fsdp_enable_wrap, fsdp_wrap
from metaseq.distributed.stitch_fsdp_ckpt import glue_megatron_parts


def worker_main(cfg: MetaseqConfig):
    """
    Load up the model on all workers for Model Parallelism, then
    unflatten, move to cpu, and save to "restored.pt".
    """
    task = tasks.setup_task(cfg.task)

    def _build_model(cfg, task):
        # hardcoded to cpu & fp16
        model = task.build_model(cfg.model).half().cuda()
        return fsdp_wrap(model)

    with fsdp_enable_wrap(
        cfg.distributed_training,
        use_sharded_state=cfg.distributed_training.use_sharded_state,
    ):
        models, _model_args, _task = checkpoint_utils.load_model_ensemble_and_task(
            utils.split_paths(cfg.common_eval.path),
            arg_overrides=None,
            task=task,
            suffix=cfg.checkpoint.checkpoint_suffix,
            strict=True,
            num_shards=cfg.checkpoint.checkpoint_shard_count,
            build_model_hook=_build_model,
        )
        model = models[0]

    # consolidate everything on rank0
    mp_size = dist_utils.get_model_parallel_world_size()
    model_parts = [{} for _ in range(mp_size)]

    with model.summon_full_params():
        for name, p in model.named_parameters():
            gathered = [torch.zeros_like(p) for _ in range(mp_size)]
            torch.distributed.all_gather(
                gathered, p, group=dist_utils.get_global_group()
            )
            for r, t in enumerate(gathered):
                model_parts[r][name] = t.cpu()

    glued = glue_megatron_parts(model_parts)

    if "decoder.output_projection.weight" in glued:
        del glued["decoder.output_projection.weight"]

    _model_args['model'] = vars(_model_args['model'])
    _model_args['model']['_name'] = 'transformer_lm'
    _model_args['model']['decoder.version'] = torch.tensor([3])
    _model_args['criterion'] = vars(_model_args['criterion'])
    glued = {'cfg': _model_args, 'model': glued}

    if dist_utils.get_global_rank() == 0:
        with open(cfg.task.data + "/restored.pt", "wb") as f:
            torch.save(glued, f)


def main():
    # parser to be used like docstring shows
    real_parser = argparse.ArgumentParser()
    real_parser.add_argument("location")
    args = real_parser.parse_args()
    files = glob.glob(f"{args.location}/reshard*.pt")

    MP = len(files)
    BPE_MERGES = args.location + "/gpt2-merges.txt"
    BPE_VOCAB = args.location + "/gpt2-vocab.json"

    # Skeleton out all the annoying command line args we can infer
    ARGS = [
        "--model-parallel-size",
        str(MP),
        "--distributed-world-size",
        str(MP),
        "--task",
        "language_modeling",
        "--bpe-merges",
        BPE_MERGES,
        "--bpe-vocab",
        BPE_VOCAB,
        "--bpe",
        "hf_byte_bpe",
        "--path",
        args.location + "/reshard.pt",
        "--checkpoint-shard-count",
        "1",
        "--use-sharded-state",
        args.location,
    ]

    # build up the config file
    parser = options.get_generation_parser()
    # dumb defaults overriding
    parser.set_defaults(lr_scheduler=None, criterion=None)
    args = options.parse_args_and_arch(parser, input_args=ARGS)
    cfg = convert_namespace_to_omegaconf(args)
    cfg.distributed_training.distributed_world_size = MP
    dist_utils.call_main(cfg, worker_main)


if __name__ == "__main__":
    main()

Note you should run this on this branch here: #60

@KastanDay
Copy link

All other checkpoints are available now here: https://huggingface.co/models?other=opt_metasq

Thanks for publishing these! Any chance you recognize this familiar error:

  File "/home/kastanday/githubs/Megatron-LM/megatron/mpu/initialize.py", line 222, in get_pipeline_model_parallel_group
    assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, \
AssertionError: pipeline_model parallel group is not initialized

I’m using the default arguments. Here's a larger stack trace:

setting number of micro-batches to constant 1
> initializing torch distributed ...
> initializing tensor model parallel with size 1
> initializing pipeline model parallel with size 1
> setting random seeds to 1234 ...
Traceback (most recent call last):
  File "run_model.py", line 13, in <module>
    initialize_megatron(args_defaults={
  File "/home/kastanday/githubs/Megatron-LM/megatron/initialize.py", line 82, in initialize_megatron
    finish_mpu_init()
  File "/home/kastanday/githubs/Megatron-LM/megatron/initialize.py", line 65, in finish_mpu_init
    _set_random_seed(args.seed)
  File "/home/kastanday/githubs/Megatron-LM/megatron/initialize.py", line 210, in _set_random_seed
    seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank())
  File "/home/kastanday/githubs/Megatron-LM/megatron/mpu/initialize.py", line 294, in get_pipeline_model_parallel_rank
    return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
  File "/home/kastanday/githubs/Megatron-LM/megatron/mpu/initialize.py", line 222, in get_pipeline_model_parallel_group
    assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, \
AssertionError: pipeline_model parallel group is not initialized

Thanks for any help.

@YvesZumbach
Copy link

I end up with the exact same stacktrace:

$ bash run.sh
using world size: 1, data-parallel-size: 1, tensor-model-parallel size: 1, pipeline-model-parallel size: 1
setting global batch size to 1
using torch.float32 for parameters ...
------------------------ arguments ------------------------
  accumulate_allreduce_grads_in_fp32 .............. False
  activations_checkpoint_method ................... None
  activations_checkpoint_num_layers ............... 1
  adam_beta1 ...................................... 0.9
  adam_beta2 ...................................... 0.999
  adam_eps ........................................ 1e-08
  adlr_autoresume ................................. False
  adlr_autoresume_interval ........................ 1000
  apply_query_key_layer_scaling ................... True
  apply_residual_connection_post_layernorm ........ False
  attention_dropout ............................... 0.1
  attention_softmax_in_fp32 ....................... False
  bert_binary_head ................................ True
  bert_load ....................................... None
  bf16 ............................................ False
  bias_dropout_fusion ............................. True
  bias_gelu_fusion ................................ True
  biencoder_projection_dim ........................ 0
  biencoder_shared_query_context_model ............ False
  block_data_path ................................. None
  clip_grad ....................................... 1.0
  consumed_train_samples .......................... 0
  consumed_valid_samples .......................... 0
  data_impl ....................................... infer
  data_parallel_size .............................. 1
  data_path ....................................... None
  dataloader_type ................................. single
  DDP_impl ........................................ local
  decoder_seq_length .............................. None
  distribute_checkpointed_activations ............. False
  distributed_backend ............................. nccl
  embedding_path .................................. None
  empty_unused_memory_level ....................... 0
  encoder_seq_length .............................. 2048
  eod_mask_loss ................................... False
  eval_interval ................................... 1000
  eval_iters ...................................... 100
  evidence_data_path .............................. None
  exit_duration_in_mins ........................... None
  exit_interval ................................... None
  ffn_hidden_size ................................. 3072
  finetune ........................................ False
  fp16 ............................................ False
  fp16_lm_cross_entropy ........................... False
  fp32_residual_connection ........................ False
  global_batch_size ............................... 1
  hidden_dropout .................................. 0.1
  hidden_size ..................................... 768
  hysteresis ...................................... 2
  ict_head_size ................................... None
  ict_load ........................................ None
  img_dim ......................................... 224
  indexer_batch_size .............................. 128
  indexer_log_interval ............................ 1000
  init_method_std ................................. 0.02
  init_method_xavier_uniform ...................... False
  initial_loss_scale .............................. 4294967296
  kv_channels ..................................... 64
  layernorm_epsilon ............................... 1e-05
  lazy_mpu_init ................................... None
  load ............................................ None
  local_rank ...................................... None
  log_batch_size_to_tensorboard ................... False
  log_interval .................................... 100
  log_learning_rate_to_tensorboard ................ True
  log_loss_scale_to_tensorboard ................... True
  log_memory_to_tensorboard ....................... False
  log_num_zeros_in_grad ........................... False
  log_params_norm ................................. False
  log_timers_to_tensorboard ....................... False
  log_validation_ppl_to_tensorboard ............... False
  loss_scale ...................................... None
  loss_scale_window ............................... 1000
  lr .............................................. None
  lr_decay_iters .................................. None
  lr_decay_samples ................................ None
  lr_decay_style .................................. linear
  lr_warmup_fraction .............................. None
  lr_warmup_iters ................................. 0
  lr_warmup_samples ............................... 0
  make_vocab_size_divisible_by .................... 128
  mask_prob ....................................... 0.15
  masked_softmax_fusion ........................... True
  max_position_embeddings ......................... 2048
  merge_file ...................................... None
  micro_batch_size ................................ 1
  min_loss_scale .................................. 1.0
  min_lr .......................................... 0.0
  mmap_warmup ..................................... False
  no_async_tensor_model_parallel_allreduce ........ False
  no_load_optim ................................... None
  no_load_rng ..................................... None
  no_save_optim ................................... None
  no_save_rng ..................................... None
  num_attention_heads ............................. 12
  num_channels .................................... 3
  num_classes ..................................... 1000
  num_layers ...................................... 12
  num_layers_per_virtual_pipeline_stage ........... None
  num_workers ..................................... 2
  onnx_safe ....................................... None
  openai_gelu ..................................... False
  optimizer ....................................... adam
  override_lr_scheduler ........................... False
  params_dtype .................................... torch.float32
  patch_dim ....................................... 16
  pipeline_model_parallel_size .................... 1
  pipeline_model_parallel_split_rank .............. None
  query_in_block_prob ............................. 0.1
  rampup_batch_size ............................... None
  rank ............................................ 0
  reset_attention_mask ............................ False
  reset_position_ids .............................. False
  retriever_report_topk_accuracies ................ []
  retriever_score_scaling ......................... False
  retriever_seq_length ............................ 256
  sample_rate ..................................... 1.0
  save ............................................ None
  save_interval ................................... None
  scatter_gather_tensors_in_pipeline .............. True
  seed ............................................ 1234
  seq_length ...................................... 2048
  sgd_momentum .................................... 0.9
  short_seq_prob .................................. 0.1
  split ........................................... 969, 30, 1
  tensor_model_parallel_size ...................... 1
  tensorboard_dir ................................. None
  tensorboard_log_interval ........................ 1
  tensorboard_queue_size .......................... 1000
  titles_data_path ................................ None
  tokenizer_type .................................. None
  train_iters ..................................... None
  train_samples ................................... None
  use_checkpoint_lr_scheduler ..................... False
  use_contiguous_buffers_in_local_ddp ............. True
  use_cpu_initialization .......................... None
  use_one_sent_docs ............................... False
  virtual_pipeline_model_parallel_size ............ None
  vocab_extra_ids ................................. 0
  vocab_file ...................................... None
  weight_decay .................................... 0.01
  world_size ...................................... 1
-------------------- end of arguments ---------------------
setting number of micro-batches to constant 1
> initializing torch distributed ...
> initializing tensor model parallel with size 1
> initializing pipeline model parallel with size 1
> setting random seeds to 1234 ...
Traceback (most recent call last):
  File "run_model.py", line 19, in <module>
    "encoder_seq_length": 2048
  File "/home/yves/Megatron-LM/megatron/initialize.py", line 82, in initialize_megatron
    finish_mpu_init()
  File "/home/yves/Megatron-LM/megatron/initialize.py", line 65, in finish_mpu_init
    _set_random_seed(args.seed)
  File "/home/yves/Megatron-LM/megatron/initialize.py", line 210, in _set_random_seed
    seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank())
  File "/home/yves/Megatron-LM/megatron/mpu/initialize.py", line 294, in get_pipeline_model_parallel_rank
    return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
  File "/home/yves/Megatron-LM/megatron/mpu/initialize.py", line 223, in get_pipeline_model_parallel_group
    'pipeline_model parallel group is not initialized'
AssertionError: pipeline_model parallel group is not initialized
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 31376) of binary: /home/yves/opt/bin/python3.7
Traceback (most recent call last):
  File "/home/yves/opt/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "/home/yves/opt/lib/python3.7/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 345, in wrapper
    return f(*args, **kwargs)
  File "/home/yves/opt/lib/python3.7/site-packages/torch/distributed/run.py", line 719, in main
    run(args)
  File "/home/yves/opt/lib/python3.7/site-packages/torch/distributed/run.py", line 713, in run
    )(*cmd_args)
  File "/home/yves/opt/lib/python3.7/site-packages/torch/distributed/launcher/api.py", line 131, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/home/yves/opt/lib/python3.7/site-packages/torch/distributed/launcher/api.py", line 261, in launch_agent
    failures=result.failures,
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
run_model.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2022-05-12_12:20:21
  host      : yves-zumbach-3-tcp.tenant-chairesearch-test.svc.cluster.local
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 31376)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================

@baiyuting
Copy link

baiyuting commented May 12, 2022

when trying to run opt_metaseq_1300m , I also got errors:

Traceback (most recent call last):
File "run_model.py", line 28, in
checkpoint = checkpoint_utils.load_model_ensemble_and_task(
File "/root/metaseq/metaseq/checkpoint_utils.py", line 473, in load_model_ensemble_and_task
state = load_checkpoint_to_cpu(filename, arg_overrides)
File "/root/metaseq/metaseq/checkpoint_utils.py", line 440, in load_checkpoint_to_cpu
state = _upgrade_state_dict(state)
File "/root/metaseq/metaseq/checkpoint_utils.py", line 579, in _upgrade_state_dict
{"criterion_name": "CrossEntropyCriterion", "best_loss": state["best_loss"]}
KeyError: 'best_loss'
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 54948) of binary: /opt/anaconda3/envs/metaseq/bin/python
Traceback (most recent call last):
File "/opt/anaconda3/envs/metaseq/bin/torchrun", line 8, in
sys.exit(main())
File "/opt/anaconda3/envs/metaseq/lib/python3.8/site-packages/torch/distributed/elastic/multiprocessing/errors/init.py", line 345, in wrapper
return f(*args, **kwargs)
File "/opt/anaconda3/envs/metaseq/lib/python3.8/site-packages/torch/distributed/run.py", line 719, in main
run(args)
File "/opt/anaconda3/envs/metaseq/lib/python3.8/site-packages/torch/distributed/run.py", line 710, in run
elastic_launch(
File "/opt/anaconda3/envs/metaseq/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 131, in call
return launch_agent(self._config, self._entrypoint, list(args))
File "/opt/anaconda3/envs/metaseq/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 259, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:

@KastanDay
Copy link

I end up with the exact same stacktrace:

I just found HF's published colab notebook https://twitter.com/huggingface/status/1524783493489774592?s=20&t=DZLKFh3FrVadi2zmMs62aA. You might want to try this method. Shoutout @suchenzang for the great community engagement.

@stephenroller stephenroller changed the title [Community] Running models in inference [Community] OPT Inference in HF Transformers May 12, 2022
@Mrs-Hudson
Copy link

Mrs-Hudson commented May 12, 2022

@patrickvonplaten Thanks for the great work!
I am trying to run generations using the huggingface checkpoint for 30B but I see a CUDA error:
My config: GPU models and configuration: Azure compute node with 8 gpus
Virtual machine size
Standard_ND40rs_v2 (40 cores, 672 GB RAM, 2900 GB disk)

Code:
`from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed

import torch

model = AutoModelForCausalLM.from_pretrained("facebook/opt-30b", torch_dtype=torch.float16).cuda()

tokenizer = AutoTokenizer.from_pretrained("facebook/opt-30b", use_fast=False)=

prompt = "Hello, I'm am conscious and"

input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()

set_seed(32)

generated_ids = model.generate(input_ids, do_sample=True)

print(tokenizer.batch_decode(generated_ids, skip_special_tokens=True))`

Stacktrace:
`Downloading: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9.12G/9.12G [01:30<00:00, 109MB/s]
Downloading: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9.19G/9.19G [02:56<00:00, 55.8MB/s]
Downloading: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9.19G/9.19G [02:20<00:00, 70.2MB/s]
Downloading: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9.19G/9.19G [02:32<00:00, 64.6MB/s]
Downloading: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9.19G/9.19G [03:33<00:00, 46.2MB/s]
Downloading: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9.19G/9.19G [03:49<00:00, 43.0MB/s]
Downloading: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 784M/784M [00:08<00:00, 94.4MB/s]

Traceback (most recent call last):
File "hfGenerateScript.py", line 5, in
model = AutoModelForCausalLM.from_pretrained("facebook/opt-30b", torch_dtype=torch.float16).cuda()
File "/anaconda/envs/azureml_py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 680, in cuda
return self._apply(lambda t: t.cuda(device))
File "/anaconda/envs/azureml_py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 570, in _apply
module._apply(fn)
File "/anaconda/envs/azureml_py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 570, in _apply
module._apply(fn)
File "/anaconda/envs/azureml_py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 570, in _apply
module._apply(fn)
[Previous line repeated 2 more times]
File "/anaconda/envs/azureml_py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 593, in _apply
param_applied = fn(param)
File "/anaconda/envs/azureml_py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 680, in
return self._apply(lambda t: t.cuda(device))
RuntimeError: CUDA out of memory. Tried to allocate 392.00 MiB (GPU 0; 31.75 GiB total capacity; 30.18 GiB already allocated; 9.75 MiB free; 30.18 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF`

Do you know how we can distribute the inference across all available GPUs?

@Mrs-Hudson
Copy link

When I use DataParallel, I see the error below
Traceback (most recent call last): File "hfGenerateScript.py", line 19, in <module> generated_ids = model.generate(input_ids, do_sample=True) File "/anaconda/envs/azureml_py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1177, in __getattr__ raise AttributeError("'{}' object has no attribute '{}'".format( AttributeError: 'DataParallel' object has no attribute 'generate'

@xhluca
Copy link
Contributor

xhluca commented May 19, 2022

@sanxchep We'll open-source something for this tomorrow (hopefully :-))

@patrickvonplaten I saw the announcement about using opt-30b on a colab notebook by loading the weights from disk. This is pretty cool, but i think some of us (and I believe @sanxchep as well) would be more interested in splitting the weights across multiple GPUs (e.g. 8x16GB GPUs) and run them using tensor parallelism or Zero-3, to achieve real-time inference speed.

@patrickvonplaten
Copy link
Contributor Author

@xhlulu, fully agree! We're working on it with @sgugger in transformers! Will announce the "Big model inference feature" by next week probably.

@todpole3 todpole3 unpinned this issue May 26, 2022
@getao
Copy link

getao commented May 27, 2022

Hi,

I followed the scripts in https://colab.research.google.com/drive/14wnxMvD9zsiBQo2FtTpxn6w2cpXCcb-7#scrollTo=y8Ne7jJdaF9F&uniqifier=1 in my local machine. However, I always confront a problem when my code executes at:

with init_empty_weights():
model = AutoModelForCausalLM.from_config(config)

The error message is:
File "/home/root/anaconda3/envs/metaseq/lib/python3.8/site-packages/accelerate/big_modeling.py", line 68, in register_empty_parameter
module._parameters[name] = nn.Parameter(module._parameters[name].to(torch.device("meta")))
RuntimeError: Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, msnpu, xla, vulkan device type at start of device string: meta

My environment is as follows, could you please kindly help me point the reasons for the problem? Thanks:
Package Version Location


absl-py 1.0.0
accelerate 0.9.0
antlr4-python3-runtime 4.9.3
apex 0.1
asttokens 2.0.5
attrs 21.4.0
azure-core 1.24.0
azure-storage-blob 12.12.0
backcall 0.2.0
black 22.1.0
boto3 1.23.6
botocore 1.26.6
cachetools 5.1.0
certifi 2022.5.18.1
cffi 1.15.0
cfgv 3.3.1
charset-normalizer 2.0.12
click 8.0.4
colorama 0.4.4
cryptography 37.0.2
Cython 0.29.30
decorator 5.1.1
distlib 0.3.4
editdistance 0.6.0
executing 0.8.3
fairscale 0.4.1 /home/root/fairscale
filelock 3.7.0
fire 0.4.0
Flask 2.1.1
google-auth 2.6.6
google-auth-oauthlib 0.4.6
grpcio 1.46.3
huggingface-hub 0.7.0
hydra-core 1.2.0
identify 2.5.1
idna 3.3
importlib-metadata 4.11.4
importlib-resources 5.7.1
iniconfig 1.1.1
iopath 0.1.9
ipdb 0.13.9
ipython 8.3.0
isodate 0.6.1
itsdangerous 2.1.2
jedi 0.18.1
Jinja2 3.1.1
jmespath 1.0.0
joblib 1.1.0
Markdown 3.3.7
MarkupSafe 2.1.1
matplotlib-inline 0.1.3
megatron-lm 1.1.5 /home/root/Megatron-LM
metaseq 0.0.1 /home/root/metaseq
mkl-fft 1.3.1
mkl-random 1.2.2
mkl-service 2.4.0
more-itertools 8.13.0
msrest 0.6.21
mypy 0.950
mypy-extensions 0.4.3
ninja 1.10.2.3
nodeenv 1.6.0
numpy 1.22.3
oauthlib 3.2.0
omegaconf 2.2.1
packaging 21.3
parso 0.8.3
pathspec 0.9.0
pexpect 4.8.0
pickleshare 0.7.5
Pillow 9.0.1
pip 21.2.4
platformdirs 2.5.2
pluggy 1.0.0
portalocker 2.4.0
pre-commit 2.19.0
prompt-toolkit 3.0.29
protobuf 3.20.1
ptyprocess 0.7.0
pure-eval 0.2.2
py 1.11.0
pyasn1 0.4.8
pyasn1-modules 0.2.8
pybind11 2.9.2
pycparser 2.21
Pygments 2.12.0
pyparsing 3.0.9
pytest 7.1.2
python-dateutil 2.8.2
PyYAML 6.0
regex 2022.4.24
requests 2.27.1
requests-oauthlib 1.3.1
rsa 4.8
s3transfer 0.5.2
sacrebleu 2.1.0
scikit-learn 1.1.1
scipy 1.8.1
setuptools 61.2.0
six 1.16.0
sklearn 0.0
stack-data 0.2.0
tabulate 0.8.9
tensorboard 2.9.0
tensorboard-data-server 0.6.1
tensorboard-plugin-wit 1.8.1
termcolor 1.1.0
threadpoolctl 3.1.0
timeout-decorator 0.5.0
tokenizers 0.12.1
toml 0.10.2
tomli 2.0.1
torch 1.8.1
torchaudio 0.8.0a0+e4e171a
torchvision 0.9.1
tqdm 4.64.0
traitlets 5.2.1.post0
transformers 4.19.2
typing_extensions 4.1.1
urllib3 1.26.9
virtualenv 20.14.1
wcwidth 0.2.5
Werkzeug 2.1.2
wheel 0.37.1
zipp 3.8.0

@stephenroller
Copy link
Contributor

Your PyTorch version isn't new enough for meta tensors.

@getao
Copy link

getao commented May 29, 2022

Your PyTorch version isn't new enough for meta tensors.

Thanks, it works when I upgraded into pytorch 1.11

@yeswanthkuruba
Copy link

@xhlulu, fully agree! We're working on it with @sgugger in transformers! Will announce the "Big model inference feature" by next week probably.

Hi @patrickvonplaten do you have any update on "Big model inference feature".

@patrickvonplaten
Copy link
Contributor Author

Also see: #164

@patrickvonplaten
Copy link
Contributor Author

After fixing the conversion script in #164, re-converted all singleton metaseq checkpoitns here: https://huggingface.co/models?other=opt_metasq

@suchenzang
Copy link
Contributor

Marking as done - any additional issues on this front to be tracked in a new issue. Thanks for all the hard work here!

@xhluca
Copy link
Contributor

xhluca commented Aug 22, 2022

Sorry if I missed an announcement that would justify this issue being closed, but where can I find more information about the Big model inference feature? I know it's a comment that's unrelated to the initial issue, so I'm happy to open a new one if it's more appropriate.

Right now, unless i'm mistaken, the huggingface implementation sequentially split the model (layer-wise) across the GPUs, which means only one GPU is used at any time. This means out of 8 GPUs, 7 is used strictly for RAM and 1 is used for actual compute.

So in theory, we can see a 5-7x improvement on the same hardware if pipeline parallelism is used, which i believe is what the big model inference feature would be beneficial in abstracting the hard part (running deepspeed PP and megatron-lm MP) while benefiting from a higher hardware utilization.

Happy to hear your thoughts on that!

cc: @suchenzang @patrickvonplaten

@sanxchep
Copy link

Sorry if I missed an announcement that would justify this issue being closed, but where can I find more information about the Big model inference feature? I know it's a comment that's unrelated to the initial issue, so I'm happy to open a new one if it's more appropriate.

Right now, unless i'm mistaken, the huggingface implementation sequentially split the model (layer-wise) across the GPUs, which means only one GPU is used at any time. This means out of 8 GPUs, 7 is used strictly for RAM and 1 is used for actual compute.

So in theory, we can see a 5-7x improvement on the same hardware if pipeline parallelism is used, which i believe is what the big model inference feature would be beneficial in abstracting the hard part (running deepspeed PP and megatron-lm MP) while benefiting from a higher hardware utilization.

Happy to hear your thoughts on that!

cc: @suchenzang @patrickvonplaten

I agree, was going to write on the same, OPT cannot be used efficiently if we don't have parallelism.

@suchenzang
Copy link
Contributor

@xhluca @sanxchep for issues with HF Transformers, I think it's best to follow up with the HF team on how they implemented HF inference for OPT-175B (unless there's specific code changes here that you think are necessary for that to happen). There are also other community integrations, like alpa, which may help address some of the concerns here: https://opt.alpa.ai/

@patrickvonplaten
Copy link
Contributor Author

We'll share a post on Twitter and add a model to the Hub once we've transfered the HF weights to Meta :-)
I'll try to remember to post it here too!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests