-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
🐛 Describe the bug
When I tried to load a pre-trained Pegasus model through transformers, it seemed that an error was encountered when initializing one of the Embedding layers.
Here is the Traceback:
Building prefix dict from /home/liuzhaofeng/anaconda3/lib/python3.9/site-packages/jieba/dict.txt ...
[08/03/22 10:38:47] DEBUG colossalai - jieba - DEBUG: Building prefix dict
from /home/liuzhaofeng/anaconda3/lib/python3.9/site
-packages/jieba/dict.txt ...
Dumping model to file cache /tmp/jieba.cache
[08/03/22 10:38:49] DEBUG colossalai - jieba - DEBUG: Dumping model to file
cache /tmp/jieba.cache
Dump cache file failed.
Traceback (most recent call last):
File "/home/liuzhaofeng/anaconda3/lib/python3.9/site-packages/jieba/__init__.py", line 100, in initialize
replace_file(fpath, cache_file)
PermissionError: [Errno 1] Operation not permitted: '/tmp/tmpp3dplm6z' -> '/tmp/jieba.cache'
[08/03/22 10:38:50] ERROR colossalai - jieba - ERROR: Dump cache file failed.
╭─────── Traceback (most recent call last) ───────╮
│ /home/liuzhaofeng/anaconda3/lib/python3.9/site- │
│ packages/jieba/__init__.py:100 in initialize │
│ │
│ 97 │ │ │ │ │ from shutil import mo │
│ 98 │ │ │ │ else: │
│ 99 │ │ │ │ │ replace_file = os.ren │
│ ❱ 100 │ │ │ │ replace_file(fpath, cache │
│ 101 │ │ │ except: │
│ 102 │ │ │ │ logger.exception("Dump ca │
│ 103 │
╰─────────────────────────────────────────────────╯
PermissionError: [Errno 1] Operation not permitted:
'/tmp/tmpp3dplm6z' -> '/tmp/jieba.cache'
Loading model cost 3.289815902709961 seconds.
DEBUG colossalai - jieba - DEBUG: Loading model cost
3.289815902709961 seconds.
Prefix dict has been built succesfully.
DEBUG colossalai - jieba - DEBUG: Prefix dict has been
built succesfully.
[08/03/22 10:38:51] INFO colossalai - torch.distributed.distributed_c10d -
INFO: Added key: store_based_barrier_key:1 to store
for rank: 0
INFO colossalai - torch.distributed.distributed_c10d -
INFO: Rank 0: Completed store-based barrier for
key:store_based_barrier_key:1 with 1 nodes.
INFO colossalai - torch.distributed.distributed_c10d -
INFO: Added key: store_based_barrier_key:2 to store
for rank: 0
INFO colossalai - torch.distributed.distributed_c10d -
INFO: Rank 0: Completed store-based barrier for
key:store_based_barrier_key:2 with 1 nodes.
INFO colossalai - torch.distributed.distributed_c10d -
INFO: Added key: store_based_barrier_key:3 to store
for rank: 0
INFO colossalai - torch.distributed.distributed_c10d -
INFO: Rank 0: Completed store-based barrier for
key:store_based_barrier_key:3 with 1 nodes.
INFO colossalai - torch.distributed.distributed_c10d -
INFO: Added key: store_based_barrier_key:4 to store
for rank: 0
INFO colossalai - torch.distributed.distributed_c10d -
INFO: Rank 0: Completed store-based barrier for
key:store_based_barrier_key:4 with 1 nodes.
INFO colossalai - torch.distributed.distributed_c10d -
INFO: Added key: store_based_barrier_key:5 to store
for rank: 0
INFO colossalai - torch.distributed.distributed_c10d -
INFO: Rank 0: Completed store-based barrier for
key:store_based_barrier_key:5 with 1 nodes.
INFO colossalai - torch.distributed.distributed_c10d -
INFO: Added key: store_based_barrier_key:6 to store
for rank: 0
INFO colossalai - torch.distributed.distributed_c10d -
INFO: Rank 0: Completed store-based barrier for
key:store_based_barrier_key:6 with 1 nodes.
INFO colossalai - torch.distributed.distributed_c10d -
INFO: Added key: store_based_barrier_key:7 to store
for rank: 0
INFO colossalai - torch.distributed.distributed_c10d -
INFO: Rank 0: Completed store-based barrier for
key:store_based_barrier_key:7 with 1 nodes.
INFO colossalai - torch.distributed.distributed_c10d -
INFO: Added key: store_based_barrier_key:8 to store
for rank: 0
INFO colossalai - torch.distributed.distributed_c10d -
INFO: Rank 0: Completed store-based barrier for
key:store_based_barrier_key:8 with 1 nodes.
INFO colossalai - colossalai - INFO: /home/liuzhaofeng/a
naconda3/lib/python3.9/site-packages/colossalai/con
text/parallel_context.py:521 set_device
INFO colossalai - colossalai - INFO: process rank 0 is
bound to device 0
[08/03/22 10:39:00] INFO colossalai - colossalai - INFO: /home/liuzhaofeng/a
naconda3/lib/python3.9/site-packages/colossalai/con
text/parallel_context.py:557 set_seed
INFO colossalai - colossalai - INFO: initialized seed on
rank 0, numpy: 1024, python random: 1024,
ParallelMode.DATA: 1024, ParallelMode.TENSOR:
1024,the default parallel seed is
ParallelMode.DATA.
INFO colossalai - colossalai - INFO: /home/liuzhaofeng/a
naconda3/lib/python3.9/site-packages/colossalai/ini
tialize.py:117 launch
INFO colossalai - colossalai - INFO: Distributed
environment is initialized, data parallel size: 1,
pipeline parallel size: 1, tensor parallel size: 1
Traceback (most recent call last):
File "/tmp/pycharm_project_471/pegasus_train.py", line 104, in <module>
model = PegasusForConditionalGeneration.from_pretrained(hf_model_name)
File "/home/liuzhaofeng/anaconda3/lib/python3.9/site-packages/transformers/modeling_utils.py", line 1843, in from_pretrained
model = cls(config, *model_args, **model_kwargs)
File "/home/liuzhaofeng/anaconda3/lib/python3.9/site-packages/colossalai/utils/model/utils.py", line 52, in wrapper
f(module, *args, **kwargs)
File "/home/liuzhaofeng/anaconda3/lib/python3.9/site-packages/transformers/models/pegasus/modeling_pegasus.py", line 1294, in __init__
self.model = PegasusModel(config)
File "/home/liuzhaofeng/anaconda3/lib/python3.9/site-packages/colossalai/utils/model/utils.py", line 52, in wrapper
f(module, *args, **kwargs)
File "/home/liuzhaofeng/anaconda3/lib/python3.9/site-packages/transformers/models/pegasus/modeling_pegasus.py", line 1140, in __init__
self.encoder = PegasusEncoder(config, self.shared)
File "/home/liuzhaofeng/anaconda3/lib/python3.9/site-packages/colossalai/utils/model/utils.py", line 52, in wrapper
f(module, *args, **kwargs)
File "/home/liuzhaofeng/anaconda3/lib/python3.9/site-packages/transformers/models/pegasus/modeling_pegasus.py", line 653, in __init__
self.embed_positions = PegasusSinusoidalPositionalEmbedding(
File "/home/liuzhaofeng/anaconda3/lib/python3.9/site-packages/colossalai/utils/model/utils.py", line 52, in wrapper
f(module, *args, **kwargs)
File "/home/liuzhaofeng/anaconda3/lib/python3.9/site-packages/transformers/models/pegasus/modeling_pegasus.py", line 114, in __init__
self.weight = self._init_weight(self.weight)
File "/home/liuzhaofeng/anaconda3/lib/python3.9/site-packages/transformers/models/pegasus/modeling_pegasus.py", line 122, in _init_weight
n_pos, dim = out.shape
ValueError: not enough values to unpack (expected 2, got 1)
Although there is an error when loading the jieba package at the top, I don't think this is the reason for the error below, because this problem will also occur when ColossalAI is not used, and the file permission problem does not affect the fine-tuning of my model.
The model I want to fine-tune is Pegasus open sourced by IDEA on HuggingFace: https://huggingface.co/IDEA-CCNL/Randeng-Pegasus-238M-Summary-Chinese
It should be noted that they customized a tokenizer: https://github.com/IDEA-CCNL/Fengshenbang-LM/tree/main/fengshen/examples/pegasus
Below is my script:
import colossalai
import torch as th
from torch.utils.data import Dataset, DataLoader
from colossalai.utils import get_current_device
from colossalai.core import global_context as gpc
from colossalai.zero.init_ctx import ZeroInitContext
from transformers import AdamW, get_scheduler, PegasusForConditionalGeneration
from tokenizers_pegasus import PegasusTokenizer
max_input_length = 512
max_target_length = 256
train_batch_size = 4
test_batch_size = 4
learning_rate = 2e-5
epoch_num = 8
beam_size = 4
no_repeat_ngram_size = 2
colossalai.launch_from_torch(config="./configs/colossalai_zero.py")
hf_model_name = "IDEA-CCNL/Randeng-Pegasus-523M-Summary-Chinese"
tokenizer = PegasusTokenizer.from_pretrained(hf_model_name)
class SummaryDataset(Dataset):
def __init__(self, data_file):
self.data = None
self.data_file = data_file
self.load_data()
def load_data(self):
data_list = []
with open(self.data_file, 'rt', encoding='utf-8') as f:
for idx, line in enumerate(f):
items = line.strip().split(',')
if len(items) == 2:
data_list.append({
'title': items[0],
'content': items[1]
})
self.data = data_list
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def collote_fn(batch_samples):
batch_inputs, batch_targets = [], []
for sample in batch_samples:
batch_inputs.append(sample['content'])
batch_targets.append(sample['title'])
batch_data = tokenizer(
batch_inputs,
padding=True,
max_length=max_input_length,
truncation=True,
return_tensors="pt"
)
with tokenizer.as_target_tokenizer():
labels = tokenizer(
batch_targets,
padding=True,
max_length=max_target_length,
truncation=True,
return_tensors="pt"
)["input_ids"]
batch_data['decoder_input_ids'] = model.prepare_decoder_input_ids_from_labels(labels)
end_token_index = th.where(labels == tokenizer.eos_token_id)[1]
for idx, end_idx in enumerate(end_token_index):
labels[idx][end_idx + 1:] = -100
batch_data['labels'] = labels
return batch_data
if __name__ == '__main__':
train_data = SummaryDataset("dataset/summary/train.csv")
valid_data = SummaryDataset("dataset/summary/test.csv")
train_dataloader = DataLoader(train_data, batch_size=train_batch_size, shuffle=True, collate_fn=collote_fn)
valid_dataloader = DataLoader(valid_data, batch_size=test_batch_size, shuffle=False, collate_fn=collote_fn)
with ZeroInitContext(target_device=get_current_device(), shard_strategy=gpc.config.zero.model_config.shard_strategy,
shard_param=True):
model = PegasusForConditionalGeneration.from_pretrained(hf_model_name)
optimizer = AdamW(model.parameters(), lr=learning_rate)
lr_scheduler = get_scheduler(
"linear",
optimizer=optimizer,
num_warmup_steps=0,
num_training_steps=gpc.config.NUM_EPOCHS * len(train_dataloader),
)
engine, train_dataloader, eval_dataloader, lr_scheduler = colossalai.initialize(model=model,
optimizer=optimizer,
train_dataloader=train_dataloader,
test_dataloader=valid_dataloader,
lr_scheduler=lr_scheduler)
for epoch in range(gpc.config.NUM_EPOCHS):
engine.train()
for batch, batch_data in enumerate(train_dataloader, start=1):
engine.zero_grad()
batch_data = batch_data.to(get_current_device())
outputs = model(**batch_data)
loss = outputs.loss
engine.backward(loss)
engine.step()
lr_scheduler.step()
total_loss = 0
engine.eval()
for batch, batch_data in enumerate(eval_dataloader, start=1):
engine.zero_grad()
batch_data = batch_data.to(get_current_device())
with th.no_grad():
outputs = model(**batch_data)
loss = outputs.loss
total_loss += loss.item()
print(f"epoch {epoch} loss {total_loss / len(eval_dataloader)}")Environment
$ python --version
Python 3.9.12
$ pip list
colossalai 0.1.8+torch1.10cu11.3
tokenizers 0.12.1
torch 1.10.0+cu113
transformers 4.18.0
$ nvidia-smi
Wed Aug 3 13:09:42 2022
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.129.06 Driver Version: 470.129.06 CUDA Version: 11.4 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA GeForce ... Off | 00000000:04:00.0 Off | N/A |
| 23% 28C P8 8W / 250W | 2695MiB / 11178MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 1 NVIDIA GeForce ... Off | 00000000:05:00.0 Off | N/A |
| 23% 26C P8 9W / 250W | 3589MiB / 11178MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 2 NVIDIA GeForce ... Off | 00000000:08:00.0 Off | N/A |
| 23% 26C P8 8W / 250W | 3371MiB / 11178MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 3 NVIDIA GeForce ... Off | 00000000:09:00.0 Off | N/A |
| 23% 26C P8 9W / 250W | 1227MiB / 11178MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 4 NVIDIA GeForce ... Off | 00000000:84:00.0 Off | N/A |
| 23% 27C P8 8W / 250W | 8MiB / 11178MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 5 NVIDIA GeForce ... Off | 00000000:85:00.0 Off | N/A |
| 23% 25C P8 8W / 250W | 8MiB / 11178MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 6 NVIDIA GeForce ... Off | 00000000:88:00.0 Off | N/A |
| 23% 27C P8 9W / 250W | 8MiB / 11178MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 7 NVIDIA GeForce ... Off | 00000000:89:00.0 Off | N/A |
| 23% 24C P8 8W / 250W | 8MiB / 11178MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 0 N/A N/A 1326351 C /home/venv/bin/python 2687MiB |
| 0 N/A N/A 2146382 G /usr/lib/xorg/Xorg 4MiB |
| 1 N/A N/A 2146382 G /usr/lib/xorg/Xorg 4MiB |
| 1 N/A N/A 3715225 C /home/venv/bin/python 3581MiB |
| 2 N/A N/A 2146382 G /usr/lib/xorg/Xorg 4MiB |
| 2 N/A N/A 3715763 C /home/venv/bin/python 3363MiB |
| 3 N/A N/A 1174978 C /home/venv/bin/python 1219MiB |
| 3 N/A N/A 2146382 G /usr/lib/xorg/Xorg 4MiB |
| 4 N/A N/A 2146382 G /usr/lib/xorg/Xorg 4MiB |
| 5 N/A N/A 2146382 G /usr/lib/xorg/Xorg 4MiB |
| 6 N/A N/A 2146382 G /usr/lib/xorg/Xorg 4MiB |
| 7 N/A N/A 2146382 G /usr/lib/xorg/Xorg 4MiB |
+-----------------------------------------------------------------------------+
binmakeswell and 1SAA
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working