*Загрузим необходимые пакеты*

In [1]:
%pip install -U -q lightning transformers datasets trl wandb transformers

[0mNote: you may need to restart the kernel to use updated packages.


# *Level 1*

*Установим модель и датасет*

In [2]:
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from datasets import load_dataset

torch.set_float32_matmul_precision('medium')

In [3]:
checkpoint = "HuggingFaceTB/SmolLM2-135M-Instruct"

device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(
    checkpoint,
	num_labels=1,
).to(device)

Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at HuggingFaceTB/SmolLM2-135M-Instruct and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
dataset_name = "juyoungml/HelpSteer2-binarized"

ds = load_dataset(dataset_name)

train_ds = ds["train"]
val_ds = ds["validation"]
ds

DatasetDict({
    train: Dataset({
        features: ['prompt', 'chosen', 'rejected', 'chosen_score', 'rejected_score', 'chosen_rationale', 'rejected_rationale', 'score_diff', 'difficulty'],
        num_rows: 7224
    })
    validation: Dataset({
        features: ['prompt', 'chosen', 'rejected', 'chosen_score', 'rejected_score', 'chosen_rationale', 'rejected_rationale', 'score_diff', 'difficulty'],
        num_rows: 373
    })
})

In [5]:
def format_ds(example):
	user_msg = example['prompt']
	chosen = example['chosen']
	rejected = example['rejected']

	prompt = f"<|im_start|>user\n{user_msg}<|im_end|>\n\n<|im_start|>assistant\n"

	return {
		"chosen": f"{prompt}{chosen}<|im_end|>",
		"rejected": f"{prompt}{rejected}<|im_end|>",
	}

In [6]:
train_ds = train_ds.map(format_ds, remove_columns=train_ds.column_names)
val_ds = val_ds.map(format_ds, remove_columns=train_ds.column_names)

*Перейдем к обучению reward модели*

In [7]:
from trl import RewardTrainer, RewardConfig

In [8]:
reward_conf = RewardConfig(
    output_dir="reward_models/SmolLM2_135_classifier",
    num_train_epochs=1,
    learning_rate=5e-5,
    fp16=True,
    max_length=768,
    eval_strategy="steps",
    eval_steps=96,
)

trainer = RewardTrainer(
    model=model,
    args=reward_conf,
    processing_class=tokenizer,
    train_dataset=train_ds,
    eval_dataset=val_ds,
)

In [9]:
# trainer.train()
# trainer.save_model(reward_conf.output_dir)

*Перейдем к реализации reinforce \w baseline и дообучению sft с помощью него*

In [10]:
from transformers import AutoModelForCausalLM
from torch.optim import AdamW
from torch.utils.data import DataLoader
from lightning.pytorch.loggers import WandbLogger
from lightning import LightningModule, LightningDataModule, Trainer

*Воспользуемся lightning*

In [11]:
class ReinforceDataModule(LightningDataModule):
	def __init__(
		self,
		dataset,
		tokenizer,
		batch_size,
		max_prompt_length,
		num_workers,
	):
		super().__init__()
		self.ds = dataset
		self.tokenizer = tokenizer
		self.batch_size = batch_size
		self.max_prompt_length = max_prompt_length
		self.num_workers = num_workers

	def setup(self, stage=None):
		self.train_dataset = self.ds["train"]
		self.val_dataset = self.ds["validation"]

	def _collate_fn(self, batch):
		prompts = [example["prompt"] for example in batch]
		encoding = self.tokenizer(
			prompts,
			return_tensors="pt",
			padding=True,
			padding_side="left",
			truncation=True,
			max_length=self.max_prompt_length,
		)
		return {
			"input_ids": encoding["input_ids"],
			"attention_mask": encoding["attention_mask"],
		}

	def train_dataloader(self):
		return DataLoader(
			self.train_dataset,
			batch_size=self.batch_size,
			shuffle=True,
			num_workers=self.num_workers,
			persistent_workers=True,
			collate_fn=self._collate_fn,
		)

	def val_dataloader(self):
		return DataLoader(
			self.val_dataset,
			batch_size=self.batch_size,
			shuffle=False,
			num_workers=self.num_workers,
			persistent_workers=True,
			collate_fn=self._collate_fn,
		)

In [12]:
class ReinforceModel(LightningModule):
	def __init__(
		self, policy_model, reward_model, learning_rate, max_new_tokens, alpha, top_k, top_p, eos_token_id
	):
		super().__init__()
		self.save_hyperparameters(ignore=['policy_model', 'reward_model'])

		self.policy_model = policy_model
		self.reward_model = reward_model
		self.reward_model.eval()

		# Инициализация бейзлайна в виде скользящего среднего
		self.register_buffer("moving_avg_reward", torch.tensor(0.0))

	def configure_optimizers(self):
		optimizer = AdamW(self.policy_model.parameters(), lr=self.hparams.learning_rate)
		return optimizer

	@torch.no_grad()
	def generate(self, input_ids, attention_mask=None):
		# генерация ответов policy модели
		return self.policy_model.generate(
			input_ids=input_ids,
			attention_mask=attention_mask,
			max_new_tokens=self.hparams.max_new_tokens,
			do_sample=True,
			top_k=self.hparams.top_k,
			top_p=self.hparams.top_p,
			pad_token_id=self.hparams.eos_token_id,
		)

	def compute_reward(self, input_ids, attention_mask=None):
		outputs = self.reward_model(input_ids=input_ids, attention_mask=attention_mask)
		# возьмем за награду вероятность положительного класса
		rewards = torch.sigmoid(outputs.logits)
		return rewards

	def training_step(self, batch, batch_idx):
		input_ids = batch["input_ids"]
		attention_mask = batch["attention_mask"]

		full_input_ids = self.generate(input_ids, attention_mask)  # prompt + response
		responses = full_input_ids[:, input_ids.shape[-1]:]  # response only
		response_attention_mask = torch.where(
			responses == self.hparams.eos_token_id,
			torch.tensor(1), torch.tensor(0)
		)
		full_attention_mask = torch.cat([attention_mask, response_attention_mask], dim=1)

		# считаем награды
		rewards = self.compute_reward(full_input_ids, full_attention_mask).detach()
		self.moving_avg_reward = self.hparams.alpha * self.moving_avg_reward + (1 - self.hparams.alpha) * rewards.mean()

		advantages = rewards - self.moving_avg_reward
		
		# Получаем логиты (задача классификатора) от Policy Model
		logits = self.policy_model(full_input_ids, attention_mask=full_attention_mask).logits
		response_logits = logits[:, -self.hparams.max_new_tokens-1:-1, :]
		log_probs = response_logits.log_softmax(dim=-1)
		target_ids = responses.clone()
		selected_log_probs = torch.gather(log_probs, 2, target_ids.unsqueeze(-1)).squeeze(-1)

		loss = - (selected_log_probs * advantages).mean()

		self.log_dict({
            "train_loss": loss,
            "avg_reward": rewards.mean(),
            "moving_avg": self.moving_avg_reward,
        })

		return loss
	
	def validation_step(self, batch, batch_idx):
		pass


In [13]:
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M-Instruct")
policy_model = AutoModelForCausalLM.from_pretrained(
        "HuggingFaceTB/SmolLM2-135M-Instruct"
    )
reward_model = AutoModelForSequenceClassification.from_pretrained(
        "reward_models/SmolLM2_135_classifier/checkpoint-673"
	)

dm = ReinforceDataModule(
    dataset=ds,
    tokenizer=tokenizer,
    batch_size=12,
    max_prompt_length=192,
    num_workers=4,
)

reinforce_model = ReinforceModel(
	policy_model = policy_model,
	reward_model = reward_model,
	learning_rate=5e-5,
	max_new_tokens=512,
	alpha=0.99,
	top_k=30,
	top_p=0.95,
    eos_token_id=tokenizer.eos_token_id,
)

logger = WandbLogger(
    name="smollm2-reinforce",
    save_dir="reinforce_models/",
)

trainer = Trainer(
    max_epochs=3,
    log_every_n_steps=10,
    precision="16-mixed",
    logger=logger,
)

Using 16bit Automatic Mixed Precision (AMP)
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(reinforce_model, datamodule=dm)

[34m[1mwandb[0m: Currently logged in as: [33m671342ihxrxmx[0m ([33m671342ihxrxmx-iu[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type                           | Params | Mode
-----------------------------------------------------------------------
0 | policy_model | LlamaForCausalLM               | 134 M  | eval
1 | reward_model | LlamaForSequenceClassification | 134 M  | eval
-----------------------------------------------------------------------
269 M     Trainable params
0         Non-trainable params
269 M     Total params
1,076.122 Total estimated model params size (MB)
0         Modules in train mode
794       Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]