-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[chatgpt] add supervised learning fine-tune code (#3183)
* [chatgpt] add supervised fine-tune code * [chatgpt] delete unused code and modified comment code * [chatgpt] use pytorch distributed sampler instead --------- Co-authored-by: zhangpengpeng <zhangpengpeng@joyy.com>
- Loading branch information
Showing
14 changed files
with
428 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from .reward_dataset import RmStaticDataset, HhRlhfDataset | ||
from .utils import is_rank_0 | ||
from .sft_dataset import SFTDataset | ||
|
||
__all__ = ['RmStaticDataset', 'HhRlhfDataset','is_rank_0'] | ||
__all__ = ['RmStaticDataset', 'HhRlhfDataset','is_rank_0', 'SFTDataset'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from typing import Callable | ||
import random | ||
from torch.utils.data import Dataset | ||
import torch.distributed as dist | ||
from tqdm import tqdm | ||
import torch | ||
|
||
from .utils import is_rank_0 | ||
|
||
|
||
class SFTDataset(Dataset): | ||
""" | ||
Dataset for sft model | ||
Args: | ||
dataset: dataset for supervised model | ||
tokenizer: tokenizer for supervised model | ||
max_length: max length of input | ||
""" | ||
|
||
def __init__(self, dataset, tokenizer: Callable, max_length: int=512) -> None: | ||
super().__init__() | ||
self.prompts = [] | ||
|
||
for data in tqdm(dataset, disable=not is_rank_0()): | ||
prompt = data['prompt'] + data['completion'] + "<|endoftext|>" | ||
prompt_token = tokenizer(prompt, | ||
max_length=max_length, | ||
padding="max_length", | ||
truncation=True, | ||
return_tensors="pt") | ||
|
||
self.prompts.append(prompt_token) | ||
|
||
def __len__(self): | ||
length = len(self.prompts) | ||
return length | ||
|
||
def __getitem__(self, idx): | ||
return self.prompts[idx] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from .actor import Actor | ||
from .critic import Critic | ||
from .reward_model import RewardModel | ||
from .lm import LM | ||
|
||
__all__ = ['Actor', 'Critic', 'RewardModel'] | ||
__all__ = ['Actor', 'Critic', 'RewardModel', 'LM'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from typing import Optional, Tuple, Union | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from ..generation import generate | ||
from .actor import Actor | ||
|
||
|
||
class LM(Actor): | ||
""" | ||
Language model base class. | ||
Args: | ||
model (nn.Module): Language Model. | ||
lora_rank (int): LoRA rank. | ||
lora_train_bias (str): LoRA bias training mode. | ||
""" | ||
|
||
def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = 'none') -> None: | ||
super().__init__(model=model, lora_rank=lora_rank, lora_train_bias=lora_train_bias) | ||
|
||
def forward(self, | ||
sequences: torch.LongTensor, | ||
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: | ||
"""Returns output log probs | ||
""" | ||
output = self.model(sequences, attention_mask=attention_mask) | ||
logits = output['logits'] | ||
log_probs = F.log_softmax(logits, dim=-1) | ||
return log_probs | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from .bloom_actor import BLOOMActor | ||
from .bloom_critic import BLOOMCritic | ||
from .bloom_rm import BLOOMRM | ||
from .bloom_lm import BLOOMLM | ||
|
||
__all__ = ['BLOOMActor', 'BLOOMCritic', 'BLOOMRM'] | ||
__all__ = ['BLOOMActor', 'BLOOMCritic', 'BLOOMRM', 'BLOOMLM'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from typing import Optional | ||
|
||
import torch | ||
from transformers import BloomConfig, BloomForCausalLM, BloomModel | ||
|
||
from ..base import LM | ||
|
||
|
||
class BLOOMLM(LM): | ||
""" | ||
BLOOM language model. | ||
Args: | ||
pretrained (str): Pretrained model name or path. | ||
config (BloomConfig): Model config. | ||
checkpoint (bool): Enable gradient checkpointing. | ||
lora_rank (int): LoRA rank. | ||
lora_train_bias (str): LoRA bias training mode. | ||
""" | ||
|
||
def __init__(self, | ||
pretrained: str = None, | ||
config: Optional[BloomConfig] = None, | ||
checkpoint: bool = False, | ||
lora_rank: int = 0, | ||
lora_train_bias: str = 'none') -> None: | ||
if pretrained is not None: | ||
model = BloomForCausalLM.from_pretrained(pretrained) | ||
elif config is not None: | ||
model = BloomForCausalLM(config) | ||
else: | ||
model = BloomForCausalLM(BloomConfig()) | ||
if checkpoint: | ||
model.gradient_checkpointing_enable() | ||
super().__init__(model, lora_rank, lora_train_bias) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from .gpt_actor import GPTActor | ||
from .gpt_critic import GPTCritic | ||
from .gpt_rm import GPTRM | ||
from .gpt_lm import GPTLM | ||
|
||
__all__ = ['GPTActor', 'GPTCritic', 'GPTRM'] | ||
__all__ = ['GPTActor', 'GPTCritic', 'GPTRM', 'GPTLM'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from typing import Optional | ||
|
||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config | ||
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel | ||
|
||
from ..base import LM | ||
|
||
|
||
class GPTLM(LM): | ||
""" | ||
GPT language model. | ||
Args: | ||
pretrained (str): Pretrained model name or path. | ||
config (GPT2Config): Model config. | ||
checkpoint (bool): Enable gradient checkpointing. | ||
lora_rank (int): Rank of the LoRa layer. | ||
lora_train_bias (str): Bias training strategy for the LoRa layer. | ||
""" | ||
|
||
def __init__(self, | ||
pretrained: Optional[str] = None, | ||
config: Optional[GPT2Config] = None, | ||
checkpoint: bool = False, | ||
lora_rank: int = 0, | ||
lora_train_bias: str = 'none') -> None: | ||
if pretrained is not None: | ||
model = GPT2LMHeadModel.from_pretrained(pretrained) | ||
elif config is not None: | ||
model = GPT2LMHeadModel(config) | ||
else: | ||
model = GPT2LMHeadModel(GPT2Config()) | ||
if checkpoint: | ||
model.gradient_checkpointing_enable() | ||
super().__init__(model, lora_rank, lora_train_bias) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from .opt_actor import OPTActor | ||
from .opt_critic import OPTCritic | ||
from .opt_rm import OPTRM | ||
from .opt_lm import OPTLM | ||
|
||
__all__ = ['OPTActor', 'OPTCritic', 'OPTRM'] | ||
__all__ = ['OPTActor', 'OPTCritic', 'OPTRM', 'OPTLM'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from typing import Optional | ||
|
||
from transformers.models.opt.configuration_opt import OPTConfig | ||
from transformers.models.opt.modeling_opt import OPTForCausalLM | ||
|
||
from ..base import LM | ||
|
||
|
||
class OPTLM(LM): | ||
""" | ||
OPT language model. | ||
Args: | ||
pretrained (str): Pretrained model name or path. | ||
config (OPTConfig): Model config. | ||
checkpoint (bool): Enable gradient checkpointing. | ||
lora_rank (int): Rank of the low-rank approximation. | ||
lora_train_bias (str): LoRA bias training mode. | ||
""" | ||
|
||
def __init__(self, | ||
pretrained: Optional[str] = None, | ||
config: Optional[OPTConfig] = None, | ||
checkpoint: bool = False, | ||
lora_rank: int = 0, | ||
lora_train_bias: str = 'none') -> None: | ||
if pretrained is not None: | ||
model = OPTForCausalLM.from_pretrained(pretrained) | ||
elif config is not None: | ||
model = OPTForCausalLM(config) | ||
else: | ||
model = OPTForCausalLM(OPTConfig()) | ||
if checkpoint: | ||
model.gradient_checkpointing_enable() | ||
super().__init__(model, lora_rank, lora_train_bias) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from .base import Trainer | ||
from .ppo import PPOTrainer | ||
from .rm import RewardModelTrainer | ||
from .sft import SFTTrainer | ||
|
||
__all__ = ['Trainer', 'PPOTrainer', 'RewardModelTrainer'] | ||
__all__ = ['Trainer', 'PPOTrainer', 'RewardModelTrainer', 'SFTTrainer'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
from abc import ABC | ||
from typing import Optional | ||
import loralib as lora | ||
import torch | ||
from chatgpt.dataset import SFTDataset | ||
from chatgpt.models.loss import GPTLMLoss | ||
from torch.optim import Adam, Optimizer | ||
from torch.utils.data import DataLoader | ||
from torch.utils.data.distributed import DistributedSampler | ||
from tqdm import tqdm | ||
import torch.distributed as dist | ||
from .strategies import Strategy | ||
from .utils import is_rank_0 | ||
from colossalai.logging import get_dist_logger | ||
|
||
|
||
class SFTTrainer(ABC): | ||
""" | ||
Trainer to use while training reward model. | ||
Args: | ||
model (torch.nn.Module): the model to train | ||
strategy (Strategy): the strategy to use for training | ||
optim(Optimizer): the optimizer to use for training | ||
train_dataset (SFTDataset or SFTDistributedDataset): the dataset to use for training | ||
eval_dataset (SFTDataset or SFTDistributedDataset): the dataset to use for evaluation | ||
batch_size (int, defaults to 1): the batch size while training | ||
max_epochs (int, defaults to 2): the number of epochs to train | ||
optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model, | ||
strategy: Strategy, | ||
optim: Optimizer, | ||
train_dataset: SFTDataset, | ||
eval_dataset: SFTDataset, | ||
sampler: Optional[DistributedSampler] = None, | ||
batch_size: int = 1, | ||
max_epochs: int = 2, | ||
) -> None: | ||
super().__init__() | ||
self.strategy = strategy | ||
self.epochs = max_epochs | ||
self.train_dataset = train_dataset | ||
self.eval_dataset = eval_dataset | ||
self.sampler = sampler | ||
|
||
self.train_dataloader = DataLoader(self.train_dataset, shuffle=(sampler is None), | ||
sampler=sampler, batch_size=batch_size) | ||
self.eval_dataloader = DataLoader(self.eval_dataset, batch_size=batch_size) | ||
|
||
self.model = strategy.setup_model(model) | ||
if "DDP" in str(self.strategy): | ||
self.model = self.model.module | ||
self.loss_fn = GPTLMLoss() | ||
self.optimizer = strategy.setup_optimizer(optim, self.model) | ||
|
||
def fit(self, logger, use_lora, log_interval=10): | ||
epoch_bar = tqdm(range(self.epochs), desc='Train epoch', disable=not is_rank_0()) | ||
for epoch in range(self.epochs): | ||
if isinstance(self.sampler, DistributedSampler): | ||
self.sampler.set_epoch(epoch) | ||
# train | ||
self.model.train() | ||
for batch_id, batch in enumerate(self.train_dataloader): | ||
prompt_ids = batch["input_ids"] | ||
p_mask = batch["attention_mask"] | ||
prompt_ids = prompt_ids.squeeze(1).cuda() | ||
p_mask = p_mask.squeeze(1).cuda() | ||
prompt_logits = self.model(prompt_ids, attention_mask=p_mask) | ||
|
||
loss = self.loss_fn(prompt_logits, prompt_ids) | ||
self.strategy.backward(loss, self.model, self.optimizer) | ||
self.strategy.optimizer_step(self.optimizer) | ||
self.optimizer.zero_grad() | ||
if batch_id % log_interval == 0: | ||
logger.info(f'Train Epoch {epoch}/{self.epochs} Batch {batch_id} Rank {dist.get_rank()} loss {loss.item()}') | ||
|
||
# eval | ||
self.model.eval() | ||
with torch.no_grad(): | ||
loss_sum = 0 | ||
num_seen = 0 | ||
for batch in self.eval_dataloader: | ||
prompt_ids = batch["input_ids"] | ||
p_mask = batch["attention_mask"] | ||
prompt_ids = prompt_ids.squeeze(1).cuda() | ||
p_mask = p_mask.squeeze(1).cuda() | ||
|
||
prompt_logits = self.model(prompt_ids, attention_mask=p_mask) | ||
loss = self.loss_fn(prompt_logits, prompt_ids) | ||
loss_sum += loss.item() | ||
num_seen += prompt_ids.size(0) | ||
|
||
loss_mean = loss_sum / num_seen | ||
if dist.get_rank() == 0: | ||
logger.info(f'Eval Epoch {epoch}/{self.epochs} loss {loss_mean}') | ||
epoch_bar.update() | ||
|
Oops, something went wrong.