Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type Error in GPTLMHeadModel #3

Open
axelmagn opened this issue Mar 4, 2024 · 7 comments
Open

Type Error in GPTLMHeadModel #3

axelmagn opened this issue Mar 4, 2024 · 7 comments

Comments

@axelmagn
Copy link
Contributor

axelmagn commented Mar 4, 2024

I am having a go at running inference and evaluation for this model, and running into a TypeError in GPTLMHeadModel:

In [1]: import torch
   ...: from transformers import AutoTokenizer
   ...: from based.models.gpt import GPTLMHeadModel
   ...: 
   ...: tokenizer = AutoTokenizer.from_pretrained("gpt2")
   ...: model = GPTLMHeadModel.from_pretrained_hf("hazyresearch/based-360m").to("cuda", dtype=torch.float
   ...: 16)
tokenizer_config.json: 100%|███████████████████████████████████████████| 26.0/26.0 [00:00<00:00, 260kB/s]
config.json: 100%|██████████████████████████████████████████████████████| 665/665 [00:00<00:00, 8.64MB/s]
vocab.json: 100%|███████████████████████████████████████████████████| 1.04M/1.04M [00:00<00:00, 12.1MB/s]
merges.txt: 100%|█████████████████████████████████████████████████████| 456k/456k [00:00<00:00, 8.99MB/s]
tokenizer.json: 100%|███████████████████████████████████████████████| 1.36M/1.36M [00:00<00:00, 17.8MB/s]
config.json: 100%|██████████████████████████████████████████████████| 2.86k/2.86k [00:00<00:00, 36.7MB/s]
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[1], line 6
      3 from based.models.gpt import GPTLMHeadModel
      5 tokenizer = AutoTokenizer.from_pretrained("gpt2")
----> 6 model = GPTLMHeadModel.from_pretrained_hf("hazyresearch/based-360m").to("cuda", dtype=torch.float16)

File /based/models/gpt.py:468, in GPTPreTrainedModel.from_pretrained_hf(cls, pretrained_model_name, device, **kwargs)
    466 config_data = load_config_hf(pretrained_model_name)
    467 config = GPT2Config(**config_data)
--> 468 model = cls(config, device=device, **kwargs)
    469 state_dict = load_state_dict_hf(pretrained_model_name, device=device)
    471 # remove the 'model.' prefix from the keys

File /based/models/gpt.py:741, in GPTLMHeadModel.__init__(self, config, process_group, device, dtype)
    739 super().__init__(config)
    740 self.process_group = process_group
--> 741 self.transformer = GPTModel(config, process_group=process_group, **factory_kwargs)
    742 self.tie_word_embeddings = getattr(config, "tie_word_embeddings", True)
    743 lm_head_bias = getattr(config, "lm_head_bias", False)

File /based/models/gpt.py:585, in GPTModel.__init__(self, config, process_group, device, dtype)
    569     self.embeddings = ParallelGPT2Embeddings(
    570         config.hidden_size,
    571         vocab_size,
   (...)
    575         **factory_kwargs,
    576     )
    578 # We change the order of dropout, residual and layer norm:
    579 # Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
    580 # Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and
    581 # the main branch (output of MLP). The model definition is unchanged, but the mapping of the
    582 # nn.Dropout probabilities are changed.
    583 # This is for performance reason: we can fuse dropout + add + layer_norm.
    584 self.layers = nn.ModuleList(
--> 585     [
    586         create_block(config, layer_idx=i, process_group=process_group, **factory_kwargs)
    587         for i in range(config.num_hidden_layers)
    588     ]
    589 )
    590 self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
    591 if self.fused_dropout_add_ln:

File /based/models/gpt.py:586, in <listcomp>(.0)
    569     self.embeddings = ParallelGPT2Embeddings(
    570         config.hidden_size,
    571         vocab_size,
   (...)
    575         **factory_kwargs,
    576     )
    578 # We change the order of dropout, residual and layer norm:
    579 # Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
    580 # Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and
    581 # the main branch (output of MLP). The model definition is unchanged, but the mapping of the
    582 # nn.Dropout probabilities are changed.
    583 # This is for performance reason: we can fuse dropout + add + layer_norm.
    584 self.layers = nn.ModuleList(
    585     [
--> 586         create_block(config, layer_idx=i, process_group=process_group, **factory_kwargs)
    587         for i in range(config.num_hidden_layers)
    588     ]
    589 )
    590 self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
    591 if self.fused_dropout_add_ln:

File /based/models/gpt.py:371, in create_block(config, layer_idx, process_group, device, dtype, **kwargs)
    369 mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
    370 use_rms_norm = getattr(config, "rms_norm", False)
--> 371 norm_cls = partial(
    372     nn.LayerNorm if not use_rms_norm else RMSNorm,
    373     eps=config.layer_norm_epsilon,
    374     **factory_kwargs,
    375 )
    376 # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
    377 residual_in_fp32 = getattr(config, "residual_in_fp32", False)

TypeError: the first argument must be callable

For reproducibility, I have been running this in a docker container:

FROM nvidia/cuda:11.8.0-devel-ubuntu22.04

RUN apt-get update && apt-get install -y \
    apt-utils \
    python3.10 \
    python3-pip \
    git \
    && rm -rf /var/lib/apt/lists/*

RUN pip install --upgrade pip
RUN pip install \
    torch==2.1.2 \
    torchvision==0.16.2 \
    torchaudio==2.1.2 \
    --index-url https://download.pytorch.org/whl/cu118 # due to observed causal-conv1d dependency

RUN pip install \
    jupyter==1.0.0 \
    hydra-core==1.3.2

RUN pip install jupyter
COPY . .
RUN pip install .

Any idea what could be going wrong here?

@simran-arora
Copy link
Collaborator

Hi,
I think it's because this RMSNorm is being set to None

norm_cls = partial(

Due to the import structure here:

The options are to

Sorry for the difficulty -- we will fix the install / instructions for this

@axelmagn
Copy link
Contributor Author

axelmagn commented Mar 4, 2024

No worries, and thanks for the speedy reply. Your guidance helped me get past the above error by installing the norm from flash-attn, but there seem to be more undocumented dependency issues:

root@d75213223120:/app# python3 test_script.py 
tokenizer_config.json: 100%|██████████████████████████████████████████████| 26.0/26.0 [00:00<00:00, 149kB/s]
config.json: 100%|█████████████████████████████████████████████████████████| 665/665 [00:00<00:00, 8.40MB/s]
vocab.json: 100%|██████████████████████████████████████████████████████| 1.04M/1.04M [00:00<00:00, 9.97MB/s]
merges.txt: 100%|████████████████████████████████████████████████████████| 456k/456k [00:00<00:00, 28.5MB/s]
tokenizer.json: 100%|██████████████████████████████████████████████████| 1.36M/1.36M [00:00<00:00, 9.81MB/s]
config.json: 100%|█████████████████████████████████████████████████████| 2.86k/2.86k [00:00<00:00, 35.0MB/s]
No module named 'causal_attention_cuda'
Traceback (most recent call last):
  File "/app/test_script.py", line 6, in <module>
    model = GPTLMHeadModel.from_pretrained_hf("hazyresearch/based-360m").to("cuda", dtype=torch.float16)
  File "/app/based/models/gpt.py", line 468, in from_pretrained_hf
    model = cls(config, device=device, **kwargs)
  File "/app/based/models/gpt.py", line 741, in __init__
    self.transformer = GPTModel(config, process_group=process_group, **factory_kwargs)
  File "/app/based/models/gpt.py", line 585, in __init__
    [
  File "/app/based/models/gpt.py", line 586, in <listcomp>
    create_block(config, layer_idx=i, process_group=process_group, **factory_kwargs)
  File "/app/based/models/gpt.py", line 382, in create_block
    block = Block(
  File "/app/based/models/block.py", line 86, in __init__
    self.mixer = mixer_cls(dim)
  File "/app/based/models/mixers/slide_attention.py", line 357, in __init__
    if fused_bias_fc and FusedDense is None: raise ImportError("fused_dense is not installed")
ImportError: fused_dense is not installed

I'm a little baffled, since it seems like FusedDense is being imported from flash_attn here:

from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear

Are there additional subpackages within flash-attn that need to be installed?

For reference, here is my updated dockerfile:

FROM nvidia/cuda:11.8.0-devel-ubuntu22.04

RUN apt-get update && apt-get install -y \
    apt-utils \
    python3.10 \
    python3-pip \
    git \
    && rm -rf /var/lib/apt/lists/*

RUN pip install --upgrade pip
RUN pip install \
    torch==2.1.2 \
    torchvision==0.16.2 \
    torchaudio==2.1.2 \
    --index-url https://download.pytorch.org/whl/cu118 # due to observed causal-conv1d dependency

RUN pip install \
    jupyter==1.0.0 \
    hydra-core==1.3.2 \
    packaging==23.2 \
    ninja==1.11.1.1 

# RUN pip install 'git+https://github.com/Dao-AILab/flash-attention.git@6c9e60d' 
RUN pip install 'git+https://github.com/Dao-AILab/flash-attention.git@6c9e60d#subdirectory=csrc/layer_norm'

# install apex
RUN pip install -v \
    --disable-pip-version-check \
    --no-cache-dir \
    --no-build-isolation \
    --config-settings "--build-option=--cpp_ext" \
    --config-settings "--build-option=--cuda_ext" \
    'git+https://github.com/NVIDIA/apex@b496d85'

# install based
RUN mkdir -p /app
WORKDIR /app
COPY . .
RUN pip install .

CMD python3 test_script.py

@simran-arora
Copy link
Collaborator

That line you pointed out requires this to be installed: https://github.com/Dao-AILab/flash-attention/tree/main/csrc/fused_dense_lib

Would recommend cloning flash-attention and python setup.py install within this directory
An alternative workaround, without the install, is to, in the config, set fused_bias_fc = False

@axelmagn
Copy link
Contributor Author

axelmagn commented Mar 5, 2024

After some tweaking, I think I've got it working. I ended up using the HazyReasearch/flash-attention fork. For others trying via docker, this is the dockerfile I used:

FROM nvidia/cuda:11.8.0-devel-ubuntu22.04

ARG TORCH_CUDA_ARCH_LIST="8.0+PTX"

RUN apt-get update && apt-get install -y \
    build-essential \
    apt-utils \
    python3.10 \
    python3-pip \
    git \
    && rm -rf /var/lib/apt/lists/*

RUN pip install --upgrade pip
RUN pip install \
    torch==2.1.2 \
    torchvision==0.16.2 \
    torchaudio==2.1.2 \
    --index-url https://download.pytorch.org/whl/cu118 # due to observed causal-conv1d dependency

RUN pip install \
    jupyter==1.0.0 \
    hydra-core==1.3.2 \
    packaging==23.2 \
    ninja==1.11.1.1 


# install apex
RUN pip install -v \
    --disable-pip-version-check \
    --no-cache-dir \
    --no-build-isolation \
    --config-settings "--build-option=--cpp_ext" \
    --config-settings "--build-option=--cuda_ext" \
    'git+https://github.com/NVIDIA/apex@b496d85'

RUN pip install 'git+https://github.com/HazyResearch/flash-attention@v2.5.2' --no-build-isolation
RUN pip install 'git+https://github.com/HazyResearch/flash-attention@v2.5.2#subdirectory=csrc/fused_dense_lib'  --no-build-isolation
RUN pip install 'git+https://github.com/HazyResearch/flash-attention@v2.5.2#subdirectory=csrc/layer_norm' --no-build-isolation

# install based
RUN mkdir -p /app
WORKDIR /app
COPY . .
RUN pip install .

CMD python3 test_script.py

It requires NVIDIA docker tookit to run, with the command:

docker run --rm --runtime=nvidia --gpus all based

@melisa-writer
Copy link

Hi! I got a similar problem while running the sample code:

import torch
from transformers import AutoTokenizer
from based.models.gpt import GPTLMHeadModel

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = GPTLMHeadModel.from_pretrained_hf("hazyresearch/based-360m").to("cuda", dtype=torch.float16)

input = tokenizer.encode("If I take one more step, it will be", return_tensors="pt").to("cuda")
output = model.generate(input, max_length=20)
print(tokenizer.decode(output[0]))

Error:

Traceback (most recent call last):
  File "/home/melisarussak/based/inference_test.py", line 6, in <module>
    model = GPTLMHeadModel.from_pretrained_hf("hazyresearch/based-360m").to("cuda", dtype=torch.float16)
  File "/home/melisarussak/based/based/models/gpt.py", line 470, in from_pretrained_hf
    model = cls(config, device=device, **kwargs)
  File "/home/melisarussak/based/based/models/gpt.py", line 743, in __init__
    self.transformer = GPTModel(config, process_group=process_group, **factory_kwargs)
  File "/home/melisarussak/based/based/models/gpt.py", line 587, in __init__
    [
  File "/home/melisarussak/based/based/models/gpt.py", line 588, in <listcomp>
    create_block(config, layer_idx=i, process_group=process_group, **factory_kwargs)
  File "/home/melisarussak/based/based/models/gpt.py", line 373, in create_block
    norm_cls = partial(
TypeError: the first argument must be callable

so I used the Dockerfile given by @axelmagn and now I get:

No module named 'causal_attention_cuda'
Successfully imported the causal dot product kernel!
Could not import the FLA triton kernels...
Traceback (most recent call last):
  File "/app/inference_test.py", line 9, in <module>
    output = model.generate(input, max_length=20)
  File "/app/based/generation.py", line 573, in generate
    output = decode(
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/app/based/generation.py", line 194, in decode
    scores.append(get_logits(sequences[-1], inference_params))
  File "/app/based/generation.py", line 155, in get_logits
    logits = model(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/app/based/models/gpt.py", line 806, in forward
    hidden_states = self.transformer(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/app/based/models/gpt.py", line 674, in forward
    hidden_states, residual = layer(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/app/based/models/block.py", line 189, in forward
    hidden_states = self.mixer(hidden_states, position_ids=position_ids, decay=decay, **mixer_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/app/based/models/mixers/linear_attention.py", line 127, in forward
    return self.recurrent_forward(hidden_states, kv_state, k_state, q, k, v)
  File "/app/based/models/mixers/linear_attention.py", line 195, in recurrent_forward
    kv_state += k[:, :, -1:] * v[:, :, -1:]
RuntimeError: The size of tensor a (16) must match the size of tensor b (273) at non-singleton dimension 4

Is this due to code changes 2 days ago or I am missing some steps?

@simran-arora
Copy link
Collaborator

yes that was due to the changes, please try again and let me know if you run into issues

@melisa-writer
Copy link

it works now! 🎉 thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants