Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[chatgpt] add supervised learning fine-tune code #3183

Merged
merged 7 commits into from
Mar 22, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion applications/ChatGPT/chatgpt/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .reward_dataset import RmStaticDataset, HhRlhfDataset
from .utils import is_rank_0
from .sft_dataset import SFTDataset, SFTDistributedDataset

__all__ = ['RmStaticDataset', 'HhRlhfDataset','is_rank_0']
__all__ = ['RmStaticDataset', 'HhRlhfDataset','is_rank_0', 'SFTDataset', 'SFTDistributedDataset']
123 changes: 123 additions & 0 deletions applications/ChatGPT/chatgpt/dataset/sft_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from typing import Callable
import random
from torch.utils.data import Dataset
import torch.distributed as dist
from torch.utils.data import IterableDataset
from tqdm import tqdm
import torch

from .utils import is_rank_0


class SFTDataset(Dataset):
"""
Dataset for sft model

Args:
dataset: dataset for reward model
tokenizer: tokenizer for reward model
max_length: max length of input
"""

def __init__(self, dataset, tokenizer: Callable, max_length: int=512) -> None:
super().__init__()
self.prompts = []

for data in tqdm(dataset, disable=not is_rank_0()):
prompt = data['prompt'] + data['completion'] + "<|endoftext|>"
prompt_token = tokenizer(prompt,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.prompts.append({
"input_ids": prompt_token['input_ids'],
"attention_mask": prompt_token['attention_mask']
})

def __len__(self):
length = len(self.prompts)
return length

def __getitem__(self, idx):
return self.prompts[idx]["input_ids"], self.prompts[idx]["attention_mask"]


class DistributedSampler:
def __init__(self, shuffle=True, partition=True):
self.epoch = -1
self.update()
self.shuffle = shuffle
self.partition = partition

def update(self):
assert dist.is_available()
if dist.is_initialized():
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
else:
self.rank = 0
self.world_size = 1

worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
self.worker_id = 0
self.num_workers = 1
else:
self.worker_id = worker_info.id
self.num_workers = worker_info.num_workers

return dict(rank=self.rank,
worker_size=self.world_size,
worker_id=self.worker_id,
num_workers=self.num_workers)

def set_epoch(self, epoch: int):
self.epoch = epoch

def sample(self, data):
data = list(range(len(data)))
if self.partition:
if self.shuffle:
random.Random(self.epoch).shuffle(data)
data = data[self.rank::self.world_size]
data = data[self.worker_id::self.num_workers]
return data


class SFTDistributedDataset(IterableDataset):
def __init__(self, dataset, tokenizer: Callable,max_length=512, batch_size=16, shuffle=True, partition=True):
self.prompts = dataset
self.max_length = max_length
self.tokenizer = tokenizer
self.sampler = DistributedSampler(shuffle, partition)
self.batch_size = batch_size

def set_epoch(self, epoch):
self.sampler.set_epoch(epoch)

def batch(self):
buf = []
sampler_info = self.sampler.update()
indexes = self.sampler.sample(self.prompts)
for index in indexes:
data = self.prompts[index]
prompt = data['prompt'] + data['completion'] + "<|endoftext|>"
buf.append(prompt)
if len(buf) >= self.batch_size:
yield buf
buf = []
if len(buf) > 0:
yield buf

def __iter__(self):
for data in self.batch():
assert isinstance(data, list)
prompt_token = self.tokenizer(data,
max_length=self.max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
input_ids = prompt_token['input_ids']
attention_mask = prompt_token['attention_mask']
yield input_ids, attention_mask
3 changes: 2 additions & 1 deletion applications/ChatGPT/chatgpt/models/base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .actor import Actor
from .critic import Critic
from .reward_model import RewardModel
from .lm import LM

__all__ = ['Actor', 'Critic', 'RewardModel']
__all__ = ['Actor', 'Critic', 'RewardModel', 'LM']
33 changes: 33 additions & 0 deletions applications/ChatGPT/chatgpt/models/base/lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from ..generation import generate
from .actor import Actor


class LM(Actor):
"""
Actor model base class.

Args:
model (nn.Module): Actor Model.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""

def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = 'none') -> None:
super().__init__(model=model, lora_rank=lora_rank, lora_train_bias=lora_train_bias)

def forward(self,
sequences: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Returns action log probs
"""
output = self.model(sequences, attention_mask=attention_mask)
logits = output['logits']
log_probs = F.log_softmax(logits, dim=-1)
return log_probs

3 changes: 2 additions & 1 deletion applications/ChatGPT/chatgpt/models/bloom/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .bloom_actor import BLOOMActor
from .bloom_critic import BLOOMCritic
from .bloom_rm import BLOOMRM
from .bloom_lm import BLOOMLM

__all__ = ['BLOOMActor', 'BLOOMCritic', 'BLOOMRM']
__all__ = ['BLOOMActor', 'BLOOMCritic', 'BLOOMRM', 'BLOOMLM']
36 changes: 36 additions & 0 deletions applications/ChatGPT/chatgpt/models/bloom/bloom_lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Optional

import torch
from transformers import BloomConfig, BloomForCausalLM, BloomModel

from ..base import LM


class BLOOMLM(LM):
"""
BLOOM Actor model.

Args:
pretrained (str): Pretrained model name or path.
config (BloomConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""

def __init__(self,
pretrained: str = None,
config: Optional[BloomConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = BloomForCausalLM.from_pretrained(pretrained)
elif config is not None:
model = BloomForCausalLM(config)
else:
model = BloomForCausalLM(BloomConfig())
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model, lora_rank, lora_train_bias)

3 changes: 2 additions & 1 deletion applications/ChatGPT/chatgpt/models/gpt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .gpt_actor import GPTActor
from .gpt_critic import GPTCritic
from .gpt_rm import GPTRM
from .gpt_lm import GPTLM

__all__ = ['GPTActor', 'GPTCritic', 'GPTRM']
__all__ = ['GPTActor', 'GPTCritic', 'GPTRM', 'GPTLM']
36 changes: 36 additions & 0 deletions applications/ChatGPT/chatgpt/models/gpt/gpt_lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Optional

from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel

from ..base import LM


class GPTLM(LM):
"""
GPT Actor model.

Args:
pretrained (str): Pretrained model name or path.
config (GPT2Config): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the LoRa layer.
lora_train_bias (str): Bias training strategy for the LoRa layer.
"""

def __init__(self,
pretrained: Optional[str] = None,
config: Optional[GPT2Config] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = GPT2LMHeadModel.from_pretrained(pretrained)
elif config is not None:
model = GPT2LMHeadModel(config)
else:
model = GPT2LMHeadModel(GPT2Config())
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model, lora_rank, lora_train_bias)

3 changes: 2 additions & 1 deletion applications/ChatGPT/chatgpt/models/opt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .opt_actor import OPTActor
from .opt_critic import OPTCritic
from .opt_rm import OPTRM
from .opt_lm import OPTLM

__all__ = ['OPTActor', 'OPTCritic', 'OPTRM']
__all__ = ['OPTActor', 'OPTCritic', 'OPTRM', 'OPTLM']
36 changes: 36 additions & 0 deletions applications/ChatGPT/chatgpt/models/opt/opt_lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Optional

from transformers.models.opt.configuration_opt import OPTConfig
from transformers.models.opt.modeling_opt import OPTForCausalLM

from ..base import LM


class OPTLM(LM):
"""
OPT Actor model.

Args:
pretrained (str): Pretrained model name or path.
config (OPTConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""

def __init__(self,
pretrained: Optional[str] = None,
config: Optional[OPTConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = OPTForCausalLM.from_pretrained(pretrained)
elif config is not None:
model = OPTForCausalLM(config)
else:
model = OPTForCausalLM(OPTConfig())
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model, lora_rank, lora_train_bias)

3 changes: 2 additions & 1 deletion applications/ChatGPT/chatgpt/trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .base import Trainer
from .ppo import PPOTrainer
from .rm import RewardModelTrainer
from .sft import SFTTrainer

__all__ = ['Trainer', 'PPOTrainer', 'RewardModelTrainer']
__all__ = ['Trainer', 'PPOTrainer', 'RewardModelTrainer', 'SFTTrainer']
Loading