In [1]:
import torch

In [2]:
torch.long

torch.int64

## basics

- references:
    - https://pytorch.org/docs/stable/tensor_attributes.html#torch-dtype


| **dtype**    | **等价形式**   | **Data type**           | **comment** |
|--------------|----------------|-------------------------|-------------|
| torch.half   | torch.float16  | 16-bit floating point 1 |             |
|              | torch.bfloat16 | 16-bit floating point 2 |             |
| torch.float  | torch.float32  |                         |             |
| torch.double | torch.float64  |                         |             |
| torch.short  | torch.int16    |                         |             |
| torch.int    | torch.int32    |                         |             |
| torch.int    | torch.int8     |                         |             |
| torch.long   | torch.int64    | 64-bit integer (signed) |             |

## torch_dtype 与 load_in_8bit

In [2]:
from transformers import AutoModelForCausalLM

In [12]:
model = AutoModelForCausalLM.from_pretrained('gpt2')

print(model.get_memory_footprint() / (1024**2))
for name, para in model.named_parameters():
    print(para.dtype, name, para.shape)

486.7002410888672
torch.float32 transformer.wte.weight torch.Size([50257, 768])
torch.float32 transformer.wpe.weight torch.Size([1024, 768])
torch.float32 transformer.h.0.ln_1.weight torch.Size([768])
torch.float32 transformer.h.0.ln_1.bias torch.Size([768])
torch.float32 transformer.h.0.attn.c_attn.weight torch.Size([768, 2304])
torch.float32 transformer.h.0.attn.c_attn.bias torch.Size([2304])
torch.float32 transformer.h.0.attn.c_proj.weight torch.Size([768, 768])
torch.float32 transformer.h.0.attn.c_proj.bias torch.Size([768])
torch.float32 transformer.h.0.ln_2.weight torch.Size([768])
torch.float32 transformer.h.0.ln_2.bias torch.Size([768])
torch.float32 transformer.h.0.mlp.c_fc.weight torch.Size([768, 3072])
torch.float32 transformer.h.0.mlp.c_fc.bias torch.Size([3072])
torch.float32 transformer.h.0.mlp.c_proj.weight torch.Size([3072, 768])
torch.float32 transformer.h.0.mlp.c_proj.bias torch.Size([768])
torch.float32 transformer.h.1.ln_1.weight torch.Size([768])
torch.float32 tran

In [13]:
model = AutoModelForCausalLM.from_pretrained('gpt2', torch_dtype=torch.float16)
print(model.get_memory_footprint() / (1024**2))
for name, para in model.named_parameters():
    print(para.dtype, name, para.shape)

249.3501205444336
torch.float16 transformer.wte.weight torch.Size([50257, 768])
torch.float16 transformer.wpe.weight torch.Size([1024, 768])
torch.float16 transformer.h.0.ln_1.weight torch.Size([768])
torch.float16 transformer.h.0.ln_1.bias torch.Size([768])
torch.float16 transformer.h.0.attn.c_attn.weight torch.Size([768, 2304])
torch.float16 transformer.h.0.attn.c_attn.bias torch.Size([2304])
torch.float16 transformer.h.0.attn.c_proj.weight torch.Size([768, 768])
torch.float16 transformer.h.0.attn.c_proj.bias torch.Size([768])
torch.float16 transformer.h.0.ln_2.weight torch.Size([768])
torch.float16 transformer.h.0.ln_2.bias torch.Size([768])
torch.float16 transformer.h.0.mlp.c_fc.weight torch.Size([768, 3072])
torch.float16 transformer.h.0.mlp.c_fc.bias torch.Size([3072])
torch.float16 transformer.h.0.mlp.c_proj.weight torch.Size([3072, 768])
torch.float16 transformer.h.0.mlp.c_proj.bias torch.Size([768])
torch.float16 transformer.h.1.ln_1.weight torch.Size([768])
torch.float16 tran

In [14]:
model = AutoModelForCausalLM.from_pretrained('gpt2', torch_dtype=torch.float16, load_in_8bit=True)
print(model.get_memory_footprint() / (1024**2))
for name, para in model.named_parameters():
    print(para.dtype, name, para.shape)

168.3501205444336
torch.float16 transformer.wte.weight torch.Size([50257, 768])
torch.float16 transformer.wpe.weight torch.Size([1024, 768])
torch.float16 transformer.h.0.ln_1.weight torch.Size([768])
torch.float16 transformer.h.0.ln_1.bias torch.Size([768])
torch.int8 transformer.h.0.attn.c_attn.weight torch.Size([2304, 768])
torch.float16 transformer.h.0.attn.c_attn.bias torch.Size([2304])
torch.int8 transformer.h.0.attn.c_proj.weight torch.Size([768, 768])
torch.float16 transformer.h.0.attn.c_proj.bias torch.Size([768])
torch.float16 transformer.h.0.ln_2.weight torch.Size([768])
torch.float16 transformer.h.0.ln_2.bias torch.Size([768])
torch.int8 transformer.h.0.mlp.c_fc.weight torch.Size([3072, 768])
torch.float16 transformer.h.0.mlp.c_fc.bias torch.Size([3072])
torch.int8 transformer.h.0.mlp.c_proj.weight torch.Size([768, 3072])
torch.float16 transformer.h.0.mlp.c_proj.bias torch.Size([768])
torch.float16 transformer.h.1.ln_1.weight torch.Size([768])
torch.float16 transformer.h.1.

- transformer.wte.weight、transformer.wpe.weight： torch.float16
- h.0 - h.11
    - ln_1.weight, ln_1.bias, ln_2.weight, ln_2.bias: torch.float16
    - attn
        - c_attn.weight: torch.int8
            - bias: torch.float16
        - c_proj.weight: torch.int8
            - bias: torch.float16
    - mlp
        - c_fc.weight: torch.int8
                - bias: torch.float16
- ln_f.weight, ln_f.bias: torch.float16