Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.
12 changes: 12 additions & 0 deletions workflows/chatbot/fine_tuning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,9 @@ python finetune_clm.py \
--no_cuda \
```

**note:** set `--do_lm_eval` to evaluate model with `truthfulqa_mc` metric, and you can set `--lm_eval_tasks` to evaluate more tasks supported in [EleutherAI/lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness)


- use the below command line for finetuning chatbot on the [HuggingFaceH4/oasst1_en](https://huggingface.co/datasets/HuggingFaceH4/oasst1_en).

```bash
Expand Down Expand Up @@ -244,6 +247,8 @@ python finetune_clm.py \
# the script also support other models, like mpt.
```

**note:** Use `rouge` metric to evaluate model on summarization task.

- use the below command line for code tuning with `meta-llama/Llama-2-7b-hf` on [theblackcat102/evol-codealpaca-v1](https://huggingface.co/datasets/theblackcat102/evol-codealpaca-v1).

```bash
Expand Down Expand Up @@ -577,3 +582,10 @@ You could also indicate `--peft` to switch peft method in P-tuning, Prefix tunin
see https://github.com/huggingface/peft. Note for MPT, only LoRA is supported.

Add option **"--use_fast_tokenizer False"** when using latest transformers if you met failure in llama fast tokenizer for llama, The `tokenizer_class` in `tokenizer_config.json` should be changed from `LLaMATokenizer` to `LlamaTokenizer`

# Evaluation

- For task=completion/chat, set `--do_lm_eval` to evaluate model with `truthfulqa_mc` metric, and you can set `--lm_eval_tasks` to evaluate more tasks supported in [EleutherAI/lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness)
- For task=summarization, we use `rouge` metric.
- For custom evaluation function, you can refer to `instruction_tuning_pipeline/eval_utils.py`, and call it at end of the training

Original file line number Diff line number Diff line change
Expand Up @@ -159,15 +159,15 @@ def preprocess_function(examples):

input_ids = prompt_ids + resp_ids + [tokenizer.eos_token_id]
if not finetune_args.train_on_inputs:
labels = [-100] * len(prompt_ids) + resp_ids + [tokenizer.eos_token_id]
labels = [IGNORE_INDEX] * len(prompt_ids) + resp_ids + [tokenizer.eos_token_id]
else:
labels = prompt_ids + resp_ids + [tokenizer.eos_token_id]

# padding
input_len = len(input_ids)
pad_len = data_args.max_seq_length - input_len
input_ids = input_ids + [tokenizer.eos_token_id] * pad_len
labels = labels + [-100] * pad_len
labels = labels + [IGNORE_INDEX] * pad_len
attention_mask = [1] * input_len + [0] * pad_len

assert len(input_ids) == data_args.max_seq_length
Expand Down Expand Up @@ -197,27 +197,39 @@ def preprocess_function(examples):
examples["input_ids"] = []
examples["labels"] = []
examples["attention_mask"] = []
examples["decoder_input_ids"] = []
examples["decoder_attention_mask"] = []
examples["decoder_labels"] = []

for article, highlight in zip(articles, highlights):
max_input = data_args.max_source_length - len(template_ids)

article_tokens = tokenizer.tokenize(article)[:max_input]
prompt_ids = tokenizer.convert_tokens_to_ids(article_tokens) + template_ids

# for inference
decoder_input_ids = copy.deepcopy(prompt_ids)

max_resp = data_args.max_seq_length - len(prompt_ids) - 1
resp_ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(highlight))[:max_resp]
resp_ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(highlight))[:max_resp] + \
[tokenizer.eos_token_id]

input_ids = prompt_ids + resp_ids + [tokenizer.eos_token_id]
# for inference
max_decoder_labels_len = data_args.max_seq_length - data_args.max_source_length - 1
decoder_labels = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(highlight)
)[:max_decoder_labels_len] + [tokenizer.eos_token_id]

input_ids = prompt_ids + resp_ids
if not finetune_args.train_on_inputs:
labels = [-100] * len(prompt_ids) + resp_ids + [tokenizer.eos_token_id]
labels = [IGNORE_INDEX] * len(prompt_ids) + resp_ids
else:
labels = prompt_ids + resp_ids + [tokenizer.eos_token_id]
labels = prompt_ids + resp_ids

# padding
input_len = len(input_ids)
pad_len = data_args.max_seq_length - input_len
input_ids = input_ids + [tokenizer.eos_token_id] * pad_len
labels = labels + [-100] * pad_len
labels = labels + [IGNORE_INDEX] * pad_len
attention_mask = [1] * input_len + [0] * pad_len

assert len(input_ids) == data_args.max_seq_length
Expand All @@ -228,6 +240,20 @@ def preprocess_function(examples):
examples["labels"].append(labels)
examples["attention_mask"].append(attention_mask)

# left padding for inference
input_len = len(decoder_input_ids)
pad_len = data_args.max_source_length - input_len
decoder_input_ids = [tokenizer.eos_token_id] * pad_len + decoder_input_ids
decoder_attention_mask = [0] * pad_len + [1] * input_len

input_len = len(decoder_labels)
pad_len = data_args.max_seq_length - data_args.max_source_length - input_len
decoder_labels = decoder_labels + [IGNORE_INDEX] * pad_len
examples["decoder_input_ids"].append(decoder_input_ids)
examples["decoder_labels"].append(decoder_labels)
examples["decoder_attention_mask"].append(decoder_attention_mask)


return examples

return preprocess_function
Expand All @@ -238,8 +264,18 @@ def preprocess_dataset(raw_datasets, tokenizer, data_args, finetune_args):
if finetune_args.task == "chat":
preprocess = ChatDataPreprocess(tokenizer.eos_token)
new_datasets = datasets.DatasetDict()
prompts = preprocess.create_data(raw_datasets["train_ift"])
new_datasets["train"] = datasets.Dataset.from_dict(prompts)
for key in raw_datasets:
prompts = preprocess.create_data(raw_datasets[key])

# deal irregular column name
if "train" in key:
new_key = "train"
if "val" in key:
new_key = "validation"
if "test" in key:
new_key = "test"

new_datasets[new_key] = datasets.Dataset.from_dict(prompts)

preprocess_fn = preprocess.tokenize_func(tokenizer, data_args, finetune_args)

Expand All @@ -248,7 +284,7 @@ def preprocess_dataset(raw_datasets, tokenizer, data_args, finetune_args):
elif finetune_args.task == "summarization":
preprocess = SummarizationDataPreprocess()
preprocess_fn = preprocess.tokenize_func(tokenizer, data_args, finetune_args)
return raw_datasets, preprocess_fn

elif finetune_args.task == "completion":
# default use alpaca template
preprocess = CompletionDataPreprocess()
Expand All @@ -265,6 +301,7 @@ def preprocess_dataset(raw_datasets, tokenizer, data_args, finetune_args):

preprocess_fn = preprocess.tokenize_func(tokenizer, data_args, finetune_args)

return raw_datasets, preprocess_fn
else:
raise NotImplementedError(f'finetune task data preprocessing is not support currently.')

return raw_datasets, preprocess_fn
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import evaluate
import nltk
import numpy as np
import torch
from torch.utils.data import DataLoader

@torch.no_grad()
def compute_rouge_metric(model, tokenizer, eval_dataset, training_args, gen_kwargs):
model.eval()
model.config.bos_token_id = tokenizer.bos_token_id
model.config.eos_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
# Metric
metric = evaluate.load("rouge")

def collate_fn(batch):
input_ids = [torch.tensor(ins["decoder_input_ids"]) for ins in batch]
labels = [torch.tensor(ins["decoder_labels"]) for ins in batch]
attention_mask = [torch.tensor(ins["decoder_attention_mask"]) for ins in batch]
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=tokenizer.eos_token_id)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
attention_mask = torch.nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=0)
return dict(
input_ids=input_ids,
labels=labels,
attention_mask=attention_mask,
)

# TODO: support batch_size >1
eval_dataloader = DataLoader(eval_dataset, collate_fn=collate_fn,
batch_size=1)


def postprocess_text(preds, labels):
preds = [pred.strip() for pred in preds]
labels = [label.strip() for label in labels]

# rougeLSum expects newline after each sentence
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

return preds, labels

for step, batch in enumerate(eval_dataloader):
preds = model.generate(
input_ids=batch["input_ids"].to(model.device),
attention_mask=batch["attention_mask"].to(model.device),
**gen_kwargs,
)
labels = batch["labels"]
labels = labels.cpu().numpy()

preds = preds.cpu().numpy()

# Replace -100s used for padding as we can't decode them
preds = np.where(preds != -100, preds, tokenizer.pad_token_id).tolist()
# only pred
preds = [pred[batch["input_ids"].shape[1]:] for pred in preds]

decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

labels = np.where(labels != -100, labels, tokenizer.pad_token_id).tolist()
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

# Some simple post-processing
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

metric.add_batch(
predictions=decoded_preds,
references=decoded_labels,
)


result = metric.compute(use_stemmer=True)
result = {k: round(v * 100, 4) for k, v in result.items()}
return result


Loading