In [None]:
from transformers import AutoModel
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, set_seed

## init reward model 

In [None]:
from reward_model import RewardModel
from transformers import AutoTokenizer
from peft import PeftModel
from torch.nn.utils import skip_init
import torch


tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)


reward_model = RewardModel.from_pretrained("THUDM/chatglm-6b", load_in_8bit=True, device_map='auto')

## load score weight

reward_model = PeftModel.from_pretrained(reward_model, './output/reward_model/', load_in_8bit=True)

In [3]:
class CastOutputToHalf(torch.nn.Sequential):
    def forward(self, x):
        return super().forward(x).half()


reward_model.gradient_checkpointing_disable()

reward_model.base_model.model.score.load_state_dict(torch.load("output/reward_model/score.bin"))

for k, p in reward_model.base_model.model.score.named_parameters():
        print(k, p.shape, p[0, :5], p[0, -5:], p[0][:20].mean(), p[0][-20:].mean())

# reward_model.score = CastOutputToHalf(reward_model.score)

weight torch.Size([1, 4096]) tensor([0.0771, 0.0723, 0.1037, 0.1068, 0.0667], device='cuda:0',
       dtype=torch.float16) tensor([0.0782, 0.0766, 0.0637, 0.0989, 0.1059], device='cuda:0',
       dtype=torch.float16) tensor(0.0939, device='cuda:0', dtype=torch.float16) tensor(0.0825, device='cuda:0', dtype=torch.float16)


In [4]:
for k, p in reward_model.named_parameters():
    p.requires_grad = False

for k, p in reward_model.score.named_parameters():
        print(k, p.shape, p[0, :5], p[0, -5:], p[0][:20].mean(), p[0][-20:].mean())


weight torch.Size([1, 4096]) tensor([0.0771, 0.0723, 0.1037, 0.1068, 0.0667], device='cuda:0',
       dtype=torch.float16) tensor([0.0782, 0.0766, 0.0637, 0.0989, 0.1059], device='cuda:0',
       dtype=torch.float16) tensor(0.0939, device='cuda:0', dtype=torch.float16) tensor(0.0825, device='cuda:0', dtype=torch.float16)


## init actor

In [5]:
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_int8_training

model_name = "THUDM/chatglm-6b"

pretrained_model = AutoModel.from_pretrained(model_name, load_in_8bit=True, trust_remote_code=True, device_map='auto')


## SFT pretrained_model with LoRA

pretrained_model.gradient_checkpointing_enable()
pretrained_model.enable_input_require_grads()
pretrained_model.is_parallelizable = True
pretrained_model.model_parallel = True
# model.lm_head = CastOutputToFloat(model.lm_head)
pretrained_model.config.use_cache = (
    False  # silence the warnings. Please re-enable for inference!
)
pretrained_model = prepare_model_for_int8_training(pretrained_model)

# setup peft
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
)
pretrained_model = get_peft_model(pretrained_model, peft_config)

Explicitly passing a `revision` is encouraged when loading a configuration with custom code to ensure no malicious code has been contributed in a newer revision.
Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.


Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

In [6]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

In [7]:
model = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model)

model.gradient_checkpointing_disable = model.pretrained_model.gradient_checkpointing_disable
model.gradient_checkpointing_enable = model.pretrained_model.gradient_checkpointing_enable

print_trainable_parameters(model)

trainable params: 3674113 || all params: 6176960513 || trainable%: 0.059480920952424424


## init dataset

In [8]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)

Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.


In [None]:
from datasets import load_dataset, Dataset


dataset = load_dataset("BelleGroup/train_0.5M_CN", split='train')
dataset = Dataset.from_dict(dataset[:1000])
dataset = dataset.rename_columns({'instruction': 'query'})

In [10]:
def encode_data(sample):
    sample['input_ids'] = tokenizer.encode(sample["query"], max_length=512, truncation=True)
    return sample

dataset = dataset.map(encode_data)

dataset.set_format(type="torch")

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

## init ppo trainer

In [11]:
config = PPOConfig(
    model_name=model_name,
    learning_rate=1e-5,
    log_with="all",  # wandb and tensorboard
    accelerator_kwargs={"logging_dir":"output/ppo/"},
    mini_batch_size=2,
    batch_size=2,
    gradient_accumulation_steps=1,
)

In [12]:
def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])


In [None]:
import torch
from utils.trainer import ChatGLMPPOTrainer

optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config.learning_rate)

ppo_trainer = ChatGLMPPOTrainer(
    config, model, ref_model=None, tokenizer=tokenizer, dataset=dataset, data_collator=collator, optimizer=optimizer
)

In [14]:
from tqdm import tqdm 
from trl.core import LengthSampler


generation_kwargs = {
    "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
    "eos_token_id": -1,
}
output_min_length = 4
output_max_length = 16
output_length_sampler = LengthSampler(output_min_length, output_max_length)


for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    query_tensors = batch["input_ids"]

    model.gradient_checkpointing_disable()
    model.pretrained_model.config.use_cache = True
    # Get response from Causal LM
    response_tensors = []
    for query in query_tensors:
        gen_len = output_length_sampler()
        generation_kwargs["max_new_tokens"] = gen_len
        response = ppo_trainer.generate(
            query_tensors[0], **generation_kwargs
        )
        response_tensors.append(response.squeeze()[-gen_len:])

    # Compute sentiment score
    batch["response"] = [tokenizer.decode(ids) for ids in response_tensors]
    texts = [q + r for q, r in zip(batch["query"], batch["response"])]
    # pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
    # rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]

    contents = [torch.cat([qids, rids]) for qids, rids in zip(batch['input_ids'], response_tensors)]

    rewards = []
    for c in contents:
        rewards.append(reward_model(c.reshape(1, -1))[0].sum())

    # Run PPO step
    model.gradient_checkpointing_enable()
    model.pretrained_model.config.use_cache = False

    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
    ppo_trainer.log_stats(stats, batch, rewards)

0it [00:00, ?it/s]The dtype of attention mask (torch.int64) is not bool
  tensor = as_tensor(value)
1it [00:31, 31.72s/it]

## save model after ppo

In [None]:
model.save_pretrained("output/ppo/")