In [None]:
import torch
from datasets import load_dataset
from peft import get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator, get_linear_schedule_with_warmup

# Config

In [None]:
import os
from huggingface_hub import login

login(token=os.environ.get("HF_TOKEN", ""), add_to_git_credential=True)

In [None]:
device: str = "mps"
model_name: str = "meta-llama/Llama-2-7b-chat-hf"

In [None]:
model_config = PromptTuningConfig(
	task_type=TaskType.CAUSAL_LM,
	prompt_tuning_init=PromptTuningInit.TEXT,
	num_virtual_tokens=8,
	prompt_tuning_init_text="Classify the emotion in the following sentence:",
	tokenizer_name_or_path=model_name
)

In [None]:
dataset_name: str = "daily_dialog"
max_sequence_length = 4096
learning_rate = 3e-2
num_epochs = 50
batch_size = 8

# Dataset

In [None]:
dataset = load_dataset(dataset_name)

In [None]:
emotions: list = list(dataset["train"].features["emotion"].feature.names)
emotions[0] = "neutral"

In [None]:
dataset = dataset.map(
	lambda samples: {
		"respond_emotions": [[emotions[label] for label in sample][1:] for sample in samples["emotion"]]
	},
	batched=True,
	num_proc=8
)

In [None]:
dataset = dataset.map(
	lambda samples: {
		"current_dialog": [sample[:-1] for sample in samples["dialog"]]
	},
	batched=True,
	num_proc=8
)

In [None]:
dataset = dataset.map(
	lambda samples: {
		"respond_dialog": [sample[1:] for sample in samples["dialog"]]
	},
	batched=True,
	num_proc=8
)

In [None]:
dataset["train"][0]

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
tokenizer.pad_token_id = tokenizer.eos_token_id if (tokenizer.pad_token_id is None) else tokenizer.pad_token_id

In [None]:
emotion_label_max_length: int = max([len(tokenizer(label)["input_ids"]) for label in emotions])
emotion_label_max_length

In [None]:
def preprocess(samples):
	model_inputs = tokenizer([
		f"dialog: {sample[0][i]}, respond_emotion: {sample[1][i]} => respond: "
		for sample in zip(samples["current_dialog"], samples["respond_emotions"]) for i
		in range(len(sample[0]))])
	labels = tokenizer(
		[str(correspond_dialog[i]) for correspond_dialog in samples["respond_dialog"] for i in
		 range(len(correspond_dialog))])

	sample_length = len(model_inputs)
	for i in range(sample_length):
		sample_input_ids = model_inputs["input_ids"][i]
		label_input_ids = labels["input_ids"][i] + [tokenizer.eos_token_id]

		model_inputs["input_ids"][i] = sample_input_ids + label_input_ids
		labels["input_ids"][i] = [-100] * len(sample_input_ids) + label_input_ids
		model_inputs["attention_mask"][i] = [1] * len(model_inputs["input_ids"][i])

	for i in range(sample_length):
		sample_input_ids = model_inputs["input_ids"][i]
		label_input_ids = labels["input_ids"][i]

		pad_length = max_sequence_length - len(sample_input_ids)
		model_inputs["input_ids"][i] = [tokenizer.pad_token_id] * pad_length + sample_input_ids
		model_inputs["attention_mask"][i] = [0] * pad_length + model_inputs["attention_mask"][i]
		labels["input_ids"][i] = [-100] * pad_length + label_input_ids

		model_inputs["input_ids"][i] = torch.tensor(model_inputs["input_ids"][i][:max_sequence_length])
		model_inputs["attention_mask"][i] = torch.tensor(model_inputs["attention_mask"][i][:max_sequence_length])
		labels["input_ids"][i] = torch.tensor(labels["input_ids"][i][:max_sequence_length])

	model_inputs["labels"] = labels["input_ids"]

	return model_inputs

In [None]:
processed_datasets = dataset.map(
	preprocess,
	batched=True,
	num_proc=8,
	remove_columns=dataset["train"].column_names,
	load_from_cache_file=False,
	desc="Running tokenizer on dataset"
)

In [None]:
train_dataset = processed_datasets["train"]
eval_dataset = processed_datasets["test"]

train_dataloader = DataLoader(train_dataset, collate_fn=default_data_collator, shuffle=True, pin_memory=True,
                              num_workers=2)
eval_dataloader = DataLoader(eval_dataset, collate_fn=default_data_collator, pin_memory=True, num_workers=2)

In [None]:
model = AutoModelForCausalLM.from_pretrained(model_name)
model = get_peft_model(model, model_config)
print(model.print_trainable_parameters())

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
lr_scheduler = get_linear_schedule_with_warmup(
	optimizer=optimizer,
	num_warmup_steps=0,
	num_training_steps=(len(train_dataloader) * num_epochs),
)

In [None]:
!export PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0

In [None]:
model = model.to(device)

for epoch in range(num_epochs):
	model.train()
	total_loss = 0
	for step, batch in enumerate(tqdm(train_dataloader)):
		batch = {k: v.to(device) for k, v in batch.items()}
		outputs = model(**batch)
		loss = outputs.loss
		total_loss += loss.detach().float()
		loss.backward()
		optimizer.step()
		lr_scheduler.step()
		optimizer.zero_grad()

	model.eval()
	eval_loss = 0
	eval_preds = []
	for step, batch in enumerate(tqdm(eval_dataloader)):
		batch = {k: v.to(device) for k, v in batch.items()}
		with torch.no_grad():
			outputs = model(**batch)
		loss = outputs.loss
		eval_loss += loss.detach().float()
		eval_preds.extend(
			tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(), skip_special_tokens=True)
		)

	eval_epoch_loss = eval_loss / len(eval_dataloader)
	eval_ppl = torch.exp(eval_epoch_loss)
	train_epoch_loss = total_loss / len(train_dataloader)
	train_ppl = torch.exp(train_epoch_loss)
	print(f"{epoch=}: {train_ppl=} {train_epoch_loss=} {eval_ppl=} {eval_epoch_loss=}")