<a href="https://colab.research.google.com/github/fireindark707/BERT_code_example/blob/main/Question_Answering_with_Pretrained_LM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 環境準備

In [1]:
!nvidia-smi #check GPU，Colab會自動分配GPU，顯存15G以上比較好，不然建議終止工作階段重連（重啟不會更換GPU）

Tue Jul 20 16:53:41 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.42.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   54C    P8    10W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
# 連結google drive，設定路徑

run_path = "/content/gdrive/MyDrive/NLP/Example/QA_LM" 

from google.colab import drive
import os
drive.mount('/content/gdrive')
os.chdir(run_path) #更改路徑
os.getcwd()

import sys
sys.path.append(run_path) #添加當前路徑為sys path中，不然無法import

Mounted at /content/gdrive


In [3]:
!pip install datasets transformers accelerate

Collecting datasets
  Downloading datasets-1.9.0-py3-none-any.whl (262 kB)
[?25l[K     |█▎                              | 10 kB 26.4 MB/s eta 0:00:01[K     |██▌                             | 20 kB 33.0 MB/s eta 0:00:01[K     |███▊                            | 30 kB 26.2 MB/s eta 0:00:01[K     |█████                           | 40 kB 20.0 MB/s eta 0:00:01[K     |██████▎                         | 51 kB 9.3 MB/s eta 0:00:01[K     |███████▌                        | 61 kB 9.6 MB/s eta 0:00:01[K     |████████▊                       | 71 kB 9.4 MB/s eta 0:00:01[K     |██████████                      | 81 kB 9.2 MB/s eta 0:00:01[K     |███████████▎                    | 92 kB 9.4 MB/s eta 0:00:01[K     |████████████▌                   | 102 kB 8.6 MB/s eta 0:00:01[K     |█████████████▊                  | 112 kB 8.6 MB/s eta 0:00:01[K     |███████████████                 | 122 kB 8.6 MB/s eta 0:00:01[K     |████████████████▎               | 133 kB 8.6 MB/s eta 0:00:01

In [4]:
import transformers

print(transformers.__version__) # version check, at least 4.8.1

4.8.2


# 在QA資料集上微調 預訓練Language Model

這個腳本會講解如何微調BERT等預訓練模型來進行QA任務，要注意的是，這個範例中的回答方式不是透過文本生成來回答問題，而是擷取給定Context中的文本片段來進行回答。目前，最常用的QA資料集為[SQuAD](https://rajpurkar.github.io/SQuAD-explorer/)，以下為SQuAD的範例：

<img src="https://i.imgur.com/sOKTl1Z.jpg" width="500"/>

## 超參數設定

In [5]:
squad_v2 = False # Switch between SQUAD v1 or 2
model_checkpoint = "bert-base-cased" # from huggingface library
batch_size = 12

In [6]:
class arguments:
  def __init__(self,batch_size,model_checkpoint):
    self.dataset_name=None
    self.dataset_config_name=None
    self.train_file=None
    self.preprocessing_num_workers=4
    self.do_predict=False
    self.validation_file=None
    self.test_file=None
    self.max_seq_length=384
    self.pad_to_max_length=True
    self.model_name_or_path=model_checkpoint
    self.config_name=None
    self.tokenizer_name=None
    self.use_slow_tokenizer=False
    self.per_device_train_batch_size=batch_size
    self.per_device_eval_batch_size=batch_size
    self.learning_rate=3e-5
    self.weight_decay=0.01
    self.num_train_epochs=2 
    self.max_train_steps=None
    self.gradient_accumulation_steps=1 #顯存不夠大時設定，建議gradient_accumulation_steps*batch_size >= 16
    self.lr_scheduler_type="linear"
    self.num_warmup_steps=0
    self.output_dir="./output/"
    self.seed=None
    self.doc_stride=128
    self.n_best_size=20
    self.null_score_diff_threshold=0.0
    self.version_2_with_negative=False
    self.max_answer_length=30
    self.max_train_samples=3000 #測試時請設定，不然會跑很久；正式訓練改為None
    self.max_eval_samples=None
    self.overwrite_cache=True
    self.max_predict_samples=None
    self.model_type=None

args = arguments(batch_size,model_checkpoint)

## 下載資料集

這部分將使用[Datasets](https://github.com/huggingface/datasets) 提供的 `load_dataset` 來完成資料集準備。當然，直接從其他QA資料集的官方網站下載也可以。`load_metric` 是寫好的評估方法。如果要用自己的json、csv格式的dataset，load_dataset也可以完成。請看官方文檔[Datasets documentation](https://huggingface.co/docs/datasets/loading_datasets.html#from-local-files)。


In [7]:
from datasets import load_dataset, load_metric

datasets = load_dataset("squad_v2" if squad_v2 else "squad")

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1947.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1021.0, style=ProgressStyle(description…


Downloading and preparing dataset squad/plain_text (download: 33.51 MiB, generated: 85.63 MiB, post-processed: Unknown size, total: 119.14 MiB) to /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/6b6c4172d0119c74515f44ea0b8262efe4897f2ddb6613e5e915840fdc309c16...


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=8116577.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1054280.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Dataset squad downloaded and prepared to /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/6b6c4172d0119c74515f44ea0b8262efe4897f2ddb6613e5e915840fdc309c16. Subsequent calls will reuse this data.


In [8]:
import pprint
pp = pprint.PrettyPrinter(indent=4) #只是為了漂亮印出Dict

print("datasets結構：\n")
pp.pprint(datasets)
print("\n\ntrain資料：\n")
pp.pprint(datasets['train'][0])
print("\n\nvalidation資料：\n")
pp.pprint(datasets['validation'][0])

datasets結構：

{   'train': Dataset({
    features: ['id', 'title', 'context', 'question', 'answers'],
    num_rows: 87599
}),
    'validation': Dataset({
    features: ['id', 'title', 'context', 'question', 'answers'],
    num_rows: 10570
})}


train資料：

{   'answers': {'answer_start': [515], 'text': ['Saint Bernadette Soubirous']},
    'context': 'Architecturally, the school has a Catholic character. Atop the '
               "Main Building's gold dome is a golden statue of the Virgin "
               'Mary. Immediately in front of the Main Building and facing it, '
               'is a copper statue of Christ with arms upraised with the '
               'legend "Venite Ad Me Omnes". Next to the Main Building is the '
               'Basilica of the Sacred Heart. Immediately behind the basilica '
               'is the Grotto, a Marian place of prayer and reflection. It is '
               'a replica of the grotto at Lourdes, France where the Virgin '
               'Mary reputedly app

## 訓練資料預處理

通常Transformers提供的 Tokenizer 可以自動將文本Tokenize並轉換為model可以讀取的形式。BERT的輸入格式參考下方圖片：

<img src="https://scontent.ftpe8-2.fna.fbcdn.net/v/t39.30808-6/217391420_361550902039881_726614731465651999_n.jpg?_nc_cat=100&ccb=1-3&_nc_sid=730e14&_nc_ohc=nOmpzbPWfvUAX8VepZy&tn=vEbaqoRjIt21kGFt&_nc_ht=scontent.ftpe8-2.fna&oh=1b467bb9f79300db49c36760fccc2933&oe=60F8ABAF" width="800" />

In [9]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=29.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=570.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=213450.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=435797.0, style=ProgressStyle(descripti…




In [10]:
#預處理

from dataset_preprocess import QAdataset

SQuAD_dataset = QAdataset(datasets,tokenizer,args)

train_dataset = SQuAD_dataset.generate_train_dataset()
eval_dataset = SQuAD_dataset.generate_eval_dataset()

print("train_dataset：\n")
pp.pprint(train_dataset)

print("\n\neval_dataset：\n")
pp.pprint(eval_dataset)

  

HBox(children=(FloatProgress(value=0.0, description='Running tokenizer on train dataset #0', max=1.0, style=Pr…

  

HBox(children=(FloatProgress(value=0.0, description='Running tokenizer on train dataset #1', max=1.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Running tokenizer on train dataset #2', max=1.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Running tokenizer on train dataset #3', max=1.0, style=Pr…





 

HBox(children=(FloatProgress(value=0.0, description='Running tokenizer on validation dataset #0', max=3.0, sty…

   

HBox(children=(FloatProgress(value=0.0, description='Running tokenizer on validation dataset #1', max=3.0, sty…

HBox(children=(FloatProgress(value=0.0, description='Running tokenizer on validation dataset #2', max=3.0, sty…

HBox(children=(FloatProgress(value=0.0, description='Running tokenizer on validation dataset #3', max=3.0, sty…





train_dataset：

Dataset({
    features: ['attention_mask', 'end_positions', 'input_ids', 'start_positions', 'token_type_ids'],
    num_rows: 3000
})


eval_dataset：

Dataset({
    features: ['attention_mask', 'example_id', 'input_ids', 'offset_mapping', 'token_type_ids'],
    num_rows: 10822
})


## 模型建構

In [27]:
from transformers import BertPreTrainedModel,BertModel
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_outputs import QuestionAnsweringModelOutput

class BertForQuestionAnswering(BertPreTrainedModel):

    _keys_to_ignore_on_load_unexpected = [r"pooler"]

    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels # default 2

        self.bert = BertModel(config, add_pooling_layer=False) # BertModel為BERT的Encoder，非分類任務不需要pooling_layer
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) # Linear Layer, 輸出start, end logits 可視為預測得分

        self.init_weights() 

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        start_positions=None,
        end_positions=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        ) # outputs為encoder結果: last_hidden_state,pooler_output,past_key_values,hidden_states,attentions,cross_attentions

        sequence_output = outputs[0] # last_hidden_state: batch size * 768

        logits = self.qa_outputs(sequence_output) # batch size * 2
        start_logits, end_logits = logits.split(1, dim=-1) # split to start and end logits, each shape: batch size * 1
        start_logits = start_logits.squeeze(-1).contiguous() # batch size
        end_logits = end_logits.squeeze(-1).contiguous()

        total_loss = None
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # 若start/end positions超出模型inputs範圍，直接忽略
            ignored_index = start_logits.size(1) # batch size
            start_positions = start_positions.clamp(0, ignored_index) # clamp = into the range [ min, max ]
            end_positions = end_positions.clamp(0, ignored_index)

            loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2

        return QuestionAnsweringModelOutput(
            loss=total_loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


## Fine-tune模型

現在資料就緒，可以使用Transformers預設好的`AutoModelForQuestionAnswering`進行預訓練，也可以自己定義。

In [20]:
from transformers import default_data_collator
from transformers import AdamW
from transformers import get_scheduler
from accelerate import Accelerator
from torch.utils.data.dataloader import DataLoader
import logging
import math
import os
import numpy as np
import torch
from tqdm.auto import tqdm

logger = logging.getLogger(__name__)

In [21]:
model = BertForQuestionAnswering.from_pretrained(args.model_name_or_path) #載入模型

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForQuestionAnswering: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-cased and a

In [22]:
# dataloader

data_collator = default_data_collator 

train_dataloader = DataLoader(
        train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size
    )

eval_dataset_for_model = eval_dataset.remove_columns(["example_id", "offset_mapping"])
eval_dataloader = DataLoader(
        eval_dataset_for_model, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
    )

In [23]:
# Optimizer
# 權重分兩組，一個有weight decay，另一組不設。Weight dacay是在loss計算中加入懲罰，用來避免overfit
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": args.weight_decay,
    },
    {
        "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)

In [24]:
# 用`accelerate`自動切換設備/多線程，跟寫很多.cuda()或.to(device)效果相同，可參考https://pypi.org/project/accelerate/

accelerator = Accelerator()
print(accelerator.state)

model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader
) #backward時要用accelerator.backward(loss)取代loss.backward()

Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda
Use FP16 precision: False



In [25]:
# 根據步數設定學習率調整策略，可以調整args.lr_scheduler_type為其他策略，例如cosine
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) #計算總訓練步數

if args.max_train_steps is None:
    args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
else:
    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

lr_scheduler = get_scheduler(
    name=args.lr_scheduler_type, #學習率調整策略
    optimizer=optimizer, 
    num_warmup_steps=args.num_warmup_steps, #warm up步數通常取總布數的1/10，這邊默認是0
    num_training_steps=args.max_train_steps,
)

In [26]:
# 訓練
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

print("***** Running training *****")
print(f"  Num examples = {len(train_dataset)}")
print(f"  Num Epochs = {args.num_train_epochs}")
print(f"  Instantaneous batch size per device = {args.per_device_train_batch_size}")
print(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
print(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
print(f"  Total optimization steps = {args.max_train_steps}")

# 進度條設定，disable部分是為了規定多個GPU時只顯示一個進度條，在Colab中有沒有都不影響
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)

completed_steps = 0
for epoch in range(args.num_train_epochs):
    model.train()
    for step, batch in enumerate(train_dataloader):
        outputs = model(**batch)
        loss = outputs.loss
        loss = loss / args.gradient_accumulation_steps
        accelerator.backward(loss)
        if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            progress_bar.update(1)
            completed_steps += 1

        if completed_steps >= args.max_train_steps:
            break

***** Running training *****
  Num examples = 3000
  Num Epochs = 2
  Instantaneous batch size per device = 12
  Total train batch size (w. parallel, distributed & accumulation) = 12
  Gradient Accumulation steps = 1
  Total optimization steps = 500


HBox(children=(FloatProgress(value=0.0, max=500.0), HTML(value='')))

# 在eval集上評估

### 評估方法說明

SQuAD使用exact match (EM)和F1 score。這些分數是根據單個「問題+答案」對計算的。當一個問題可能有多個正確答案時，計算所有可能的正確答案的最大分數。Model的EM和F1所有例子分數的平均分。

#### Exact Match

對於每個問題+答案對，如果模型預測的詞句與（其中一個）真實答案的詞句完全匹配，EM=1，否則EM=0。嚴格的全有或全無指標；偏離一個詞句就會得到0分。

#### F1

通過預測中的token與真實答案中的token進行計算的。預測和真實答案之間的共享token的數量來計算F1分數的基礎：Precision是共享token的數量與預測中的總token數的比率，Recall是共享token的數量與真實答案中的總token數的比率。

In [143]:
# 評估
from utils_qa import * #post_processing_function, create_and_fill_np_array

squad_ver = "squad_v2" if args.version_2_with_negative else "squad"

print("***** Running Evaluation *****")
print(f"  Num examples = {len(eval_dataset)}")
print(f"  Batch size = {args.per_device_eval_batch_size}")
print(f"  squad_version = {squad_ver}")

metric = load_metric(squad_ver)

progress_bar = tqdm(range(len(eval_dataloader)), disable=not accelerator.is_local_main_process)

all_start_logits = []
all_end_logits = []
for step, batch in enumerate(eval_dataloader):
    with torch.no_grad():
        outputs = model(**batch)
        start_logits = outputs.start_logits
        end_logits = outputs.end_logits

        if not args.pad_to_max_length:  # 必須為預測做padding才能使用gather（accelerator的要求）
            start_logits = accelerator.pad_across_processes(start_logits, dim=1, pad_index=-100) # pad tensor across processes to max length
            end_logits = accelerator.pad_across_processes(end_logits, dim=1, pad_index=-100)

        all_start_logits.append(accelerator.gather(start_logits).cpu().numpy()) # 分佈式計算時用於集合預測結果
        all_end_logits.append(accelerator.gather(end_logits).cpu().numpy())
    progress_bar.update(1)

max_len = max([x.shape[1] for x in all_start_logits])  # 獲得最大長度

# concatenate array
start_logits_concat = create_and_fill_np_array(all_start_logits, eval_dataset, max_len)
end_logits_concat = create_and_fill_np_array(all_end_logits, eval_dataset, max_len)

# 使用完畢，刪除
del all_start_logits
del all_end_logits

outputs_numpy = (start_logits_concat, end_logits_concat) #預測結果
prediction = post_processing_function(datasets['validation'], eval_dataset, outputs_numpy) #後處理原理可以參考：https://qa.fastforwardlabs.com/no%20answer/null%20threshold/bert/distilbert/exact%20match/f1/robust%20predictions/2020/06/09/Evaluating_BERT_on_SQuAD.html#Metrics-for-QA
eval_metric = metric.compute(predictions=prediction.predictions, references=prediction.label_ids)
print(f"Evaluation metrics: {eval_metric}")

***** Running Evaluation *****
  Num examples = 10822
  Batch size = 12
  squad_version = squad


HBox(children=(FloatProgress(value=0.0, max=902.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10570.0), HTML(value='')))


Evaluation metrics: {'exact_match': 53.7275307473983, 'f1': 65.71941994620163}


In [144]:
# 測試看看簡單選取start_logits的最大值和end_logits的最大值輸出的prediction（outputs_numpy）
# 與經過post_processing_function處理過後的prediction差別

for qs_id in range(10):

  start_idx = outputs_numpy[0][qs_id].argmax()
  end_idx = outputs_numpy[1][qs_id].argmax()
  print(str(start_idx)+" "+str(end_idx))
  print(' '.join(tokenizer.convert_ids_to_tokens(eval_dataset['input_ids'][qs_id][start_idx:end_idx+1])))
  print(prediction.predictions[qs_id]['prediction_text'])
  print('\n')

46 47
Denver Broncos
Denver Broncos


57 58
Carolina Panthers
Carolina Panthers


89 81

Santa Clara


43 44
Denver Broncos
Denver Broncos


118 118
gold
gold


11 13
Super Bowl 50
Super Bowl 50


72 75
February 7 , 2016
February 7, 2016


35 37
American Football Conference
American Football Conference


163 101

an American football game


8 36
Super Bowl 50 was an American football game to determine the champion of the National Football League ( NFL ) for the 2015 season . The American Football Conference
Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference




In [145]:
# 人工查看，比較預測結果與和答案

for i in range(10):
  print("====Question====")
  pp.pprint(datasets['validation'][i])
  print("====prediction====")
  pp.pprint(prediction.predictions[i])
  print("\n")

====Question====
{   'answers': {   'answer_start': [177, 177, 177],
                   'text': [   'Denver Broncos',
                               'Denver Broncos',
                               'Denver Broncos']},
    'context': 'Super Bowl 50 was an American football game to determine the '
               'champion of the National Football League (NFL) for the 2015 '
               'season. The American Football Conference (AFC) champion Denver '
               'Broncos defeated the National Football Conference (NFC) '
               'champion Carolina Panthers 24–10 to earn their third Super '
               "Bowl title. The game was played on February 7, 2016, at Levi's "
               'Stadium in the San Francisco Bay Area at Santa Clara, '
               'California. As this was the 50th Super Bowl, the league '
               'emphasized the "golden anniversary" with various gold-themed '
               'initiatives, as well as temporarily suspending the tradition '
        

In [146]:
# 儲存微調好的模型

if args.output_dir is not None:
    accelerator.wait_for_everyone()
    unwrapped_model = accelerator.unwrap_model(model)
    unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save)