Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[chatgpt] add supervised learning fine-tune code #3183

Merged
merged 7 commits into from
Mar 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion applications/ChatGPT/chatgpt/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .reward_dataset import RmStaticDataset, HhRlhfDataset
from .utils import is_rank_0
from .sft_dataset import SFTDataset

__all__ = ['RmStaticDataset', 'HhRlhfDataset','is_rank_0']
__all__ = ['RmStaticDataset', 'HhRlhfDataset','is_rank_0', 'SFTDataset']
40 changes: 40 additions & 0 deletions applications/ChatGPT/chatgpt/dataset/sft_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Callable
import random
from torch.utils.data import Dataset
import torch.distributed as dist
from tqdm import tqdm
import torch

from .utils import is_rank_0


class SFTDataset(Dataset):
"""
Dataset for sft model

Args:
dataset: dataset for supervised model
tokenizer: tokenizer for supervised model
max_length: max length of input
"""

def __init__(self, dataset, tokenizer: Callable, max_length: int=512) -> None:
super().__init__()
self.prompts = []

for data in tqdm(dataset, disable=not is_rank_0()):
prompt = data['prompt'] + data['completion'] + "<|endoftext|>"
prompt_token = tokenizer(prompt,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")

self.prompts.append(prompt_token)

def __len__(self):
length = len(self.prompts)
return length

def __getitem__(self, idx):
return self.prompts[idx]
3 changes: 2 additions & 1 deletion applications/ChatGPT/chatgpt/models/base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .actor import Actor
from .critic import Critic
from .reward_model import RewardModel
from .lm import LM

__all__ = ['Actor', 'Critic', 'RewardModel']
__all__ = ['Actor', 'Critic', 'RewardModel', 'LM']
33 changes: 33 additions & 0 deletions applications/ChatGPT/chatgpt/models/base/lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from ..generation import generate
from .actor import Actor


class LM(Actor):
"""
Language model base class.

Args:
model (nn.Module): Language Model.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""

def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = 'none') -> None:
super().__init__(model=model, lora_rank=lora_rank, lora_train_bias=lora_train_bias)

def forward(self,
sequences: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Returns output log probs
"""
output = self.model(sequences, attention_mask=attention_mask)
logits = output['logits']
log_probs = F.log_softmax(logits, dim=-1)
return log_probs

3 changes: 2 additions & 1 deletion applications/ChatGPT/chatgpt/models/bloom/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .bloom_actor import BLOOMActor
from .bloom_critic import BLOOMCritic
from .bloom_rm import BLOOMRM
from .bloom_lm import BLOOMLM

__all__ = ['BLOOMActor', 'BLOOMCritic', 'BLOOMRM']
__all__ = ['BLOOMActor', 'BLOOMCritic', 'BLOOMRM', 'BLOOMLM']
36 changes: 36 additions & 0 deletions applications/ChatGPT/chatgpt/models/bloom/bloom_lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Optional

import torch
from transformers import BloomConfig, BloomForCausalLM, BloomModel

from ..base import LM


class BLOOMLM(LM):
"""
BLOOM language model.

Args:
pretrained (str): Pretrained model name or path.
config (BloomConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""

def __init__(self,
pretrained: str = None,
config: Optional[BloomConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = BloomForCausalLM.from_pretrained(pretrained)
elif config is not None:
model = BloomForCausalLM(config)
else:
model = BloomForCausalLM(BloomConfig())
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model, lora_rank, lora_train_bias)

3 changes: 2 additions & 1 deletion applications/ChatGPT/chatgpt/models/gpt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .gpt_actor import GPTActor
from .gpt_critic import GPTCritic
from .gpt_rm import GPTRM
from .gpt_lm import GPTLM

__all__ = ['GPTActor', 'GPTCritic', 'GPTRM']
__all__ = ['GPTActor', 'GPTCritic', 'GPTRM', 'GPTLM']
36 changes: 36 additions & 0 deletions applications/ChatGPT/chatgpt/models/gpt/gpt_lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Optional

from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel

from ..base import LM


class GPTLM(LM):
"""
GPT language model.

Args:
pretrained (str): Pretrained model name or path.
config (GPT2Config): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the LoRa layer.
lora_train_bias (str): Bias training strategy for the LoRa layer.
"""

def __init__(self,
pretrained: Optional[str] = None,
config: Optional[GPT2Config] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = GPT2LMHeadModel.from_pretrained(pretrained)
elif config is not None:
model = GPT2LMHeadModel(config)
else:
model = GPT2LMHeadModel(GPT2Config())
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model, lora_rank, lora_train_bias)

3 changes: 2 additions & 1 deletion applications/ChatGPT/chatgpt/models/opt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .opt_actor import OPTActor
from .opt_critic import OPTCritic
from .opt_rm import OPTRM
from .opt_lm import OPTLM

__all__ = ['OPTActor', 'OPTCritic', 'OPTRM']
__all__ = ['OPTActor', 'OPTCritic', 'OPTRM', 'OPTLM']
36 changes: 36 additions & 0 deletions applications/ChatGPT/chatgpt/models/opt/opt_lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Optional

from transformers.models.opt.configuration_opt import OPTConfig
from transformers.models.opt.modeling_opt import OPTForCausalLM

from ..base import LM


class OPTLM(LM):
"""
OPT language model.

Args:
pretrained (str): Pretrained model name or path.
config (OPTConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""

def __init__(self,
pretrained: Optional[str] = None,
config: Optional[OPTConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = OPTForCausalLM.from_pretrained(pretrained)
elif config is not None:
model = OPTForCausalLM(config)
else:
model = OPTForCausalLM(OPTConfig())
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model, lora_rank, lora_train_bias)

3 changes: 2 additions & 1 deletion applications/ChatGPT/chatgpt/trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .base import Trainer
from .ppo import PPOTrainer
from .rm import RewardModelTrainer
from .sft import SFTTrainer

__all__ = ['Trainer', 'PPOTrainer', 'RewardModelTrainer']
__all__ = ['Trainer', 'PPOTrainer', 'RewardModelTrainer', 'SFTTrainer']
101 changes: 101 additions & 0 deletions applications/ChatGPT/chatgpt/trainer/sft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from abc import ABC
from typing import Optional
import loralib as lora
import torch
from chatgpt.dataset import SFTDataset
from chatgpt.models.loss import GPTLMLoss
from torch.optim import Adam, Optimizer
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
import torch.distributed as dist
from .strategies import Strategy
from .utils import is_rank_0
from colossalai.logging import get_dist_logger


class SFTTrainer(ABC):
"""
Trainer to use while training reward model.

Args:
model (torch.nn.Module): the model to train
strategy (Strategy): the strategy to use for training
optim(Optimizer): the optimizer to use for training
train_dataset (SFTDataset or SFTDistributedDataset): the dataset to use for training
eval_dataset (SFTDataset or SFTDistributedDataset): the dataset to use for evaluation
batch_size (int, defaults to 1): the batch size while training
max_epochs (int, defaults to 2): the number of epochs to train
optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer
"""

def __init__(
self,
model,
strategy: Strategy,
optim: Optimizer,
train_dataset: SFTDataset,
eval_dataset: SFTDataset,
sampler: Optional[DistributedSampler] = None,
batch_size: int = 1,
max_epochs: int = 2,
) -> None:
super().__init__()
self.strategy = strategy
self.epochs = max_epochs
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.sampler = sampler

self.train_dataloader = DataLoader(self.train_dataset, shuffle=(sampler is None),
sampler=sampler, batch_size=batch_size)
self.eval_dataloader = DataLoader(self.eval_dataset, batch_size=batch_size)

self.model = strategy.setup_model(model)
if "DDP" in str(self.strategy):
self.model = self.model.module
self.loss_fn = GPTLMLoss()
self.optimizer = strategy.setup_optimizer(optim, self.model)

def fit(self, logger, use_lora, log_interval=10):
epoch_bar = tqdm(range(self.epochs), desc='Train epoch', disable=not is_rank_0())
for epoch in range(self.epochs):
if isinstance(self.sampler, DistributedSampler):
self.sampler.set_epoch(epoch)
# train
self.model.train()
for batch_id, batch in enumerate(self.train_dataloader):
prompt_ids = batch["input_ids"]
p_mask = batch["attention_mask"]
prompt_ids = prompt_ids.squeeze(1).cuda()
p_mask = p_mask.squeeze(1).cuda()
prompt_logits = self.model(prompt_ids, attention_mask=p_mask)

loss = self.loss_fn(prompt_logits, prompt_ids)
self.strategy.backward(loss, self.model, self.optimizer)
self.strategy.optimizer_step(self.optimizer)
self.optimizer.zero_grad()
if batch_id % log_interval == 0:
logger.info(f'Train Epoch {epoch}/{self.epochs} Batch {batch_id} Rank {dist.get_rank()} loss {loss.item()}')

# eval
self.model.eval()
with torch.no_grad():
loss_sum = 0
num_seen = 0
for batch in self.eval_dataloader:
prompt_ids = batch["input_ids"]
p_mask = batch["attention_mask"]
prompt_ids = prompt_ids.squeeze(1).cuda()
p_mask = p_mask.squeeze(1).cuda()

prompt_logits = self.model(prompt_ids, attention_mask=p_mask)
loss = self.loss_fn(prompt_logits, prompt_ids)
loss_sum += loss.item()
num_seen += prompt_ids.size(0)

loss_mean = loss_sum / num_seen
if dist.get_rank() == 0:
logger.info(f'Eval Epoch {epoch}/{self.epochs} loss {loss_mean}')
epoch_bar.update()

Loading