# Persona Generation Training

This notebook fine-tunes models to generate persona descriptions from dialogues.

## Features
- Model-agnostic: Supports T5, RuT5, BART, Qwen, Llama, and more
- Auto-detects model architecture (seq2seq vs causal LM)
- Train/validation/test split functionality
- Evaluation with multiple metrics
- Target speaker selection (1 or 2)

In [None]:
# Install dependencies
!pip install -q torch transformers datasets accelerate sacrebleu rouge-score tqdm sentence-transformers

## Configuration

Set up model and training parameters:

In [None]:
# Model configuration - Change to any HuggingFace model
# Seq2Seq models: cointegrated/rut5-base, google/flan-t5-base, facebook/bart-base
# Causal LMs: Qwen/Qwen2.5-0.5B-Instruct, meta-llama/Llama-3.2-1B-Instruct
MODEL_NAME = "cointegrated/rut5-base"

# Paths
DATA_DIR = "./data"
OUTPUT_DIR = "./persona_model"

# Training configuration
MAX_LENGTH = 1024
MAX_TARGET_LENGTH = 200
BATCH_SIZE = 4
GRADIENT_ACCUMULATION_STEPS = 4
NUM_EPOCHS = 3
LEARNING_RATE = 2e-4
WARMUP_STEPS = 100

# Target speaker: 1 or 2
TARGET_SPEAKER = 1

# Data split: train, val, or test
TRAIN_SPLIT = "train"
VAL_SPLIT = "val"
TEST_SPLIT = "test"

print(f"Model: {MODEL_NAME}")
print(f"Target Speaker: {TARGET_SPEAKER}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Effective batch size: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS}")

## Data Loading Functions

In [None]:
import json
import os
from typing import List, Dict

def load_dialogues(data_dir: str, split: str = "train") -> List[Dict]:
    """Load dialogues dataset from JSON file."""
    filename = f"dialogues_{split}.json"
    filepath = os.path.join(data_dir, filename)

    if not os.path.exists(filepath):
        raise FileNotFoundError(f"Data file not found: {filepath}")

    with open(filepath, "r", encoding="utf-8") as f:
        data = json.load(f)

    print(f"Loaded {len(data)} examples from {filename}")
    return data

In [None]:
def build_dialogue_text(dialogue: List[Dict], target_speaker: int) -> str:
    """Build dialogue text from messages."""
    messages = []
    for msg in dialogue:
        speaker_label = "Пользователь 1" if msg['speaker'] == 1 else "Пользователь 2"
        messages.append(f"{speaker_label}: {msg['text']}")
    return "\n".join(messages)

## Data Preparation Functions

In [None]:
def prepare_training_data(data: List[Dict], target_speaker: int, is_seq2seq: bool = False) -> List[Dict]:
    """Prepare data for fine-tuning."""
    prepared = []

    for item in data:
        dialogue = item.get("dialogue", [])
        
        # Get target persona based on target_speaker
        if target_speaker == 1:
            reference_persona = item.get('persona_1', '')
        else:
            reference_persona = item.get('persona_2', '')

        if not dialogue or not reference_persona:
            continue

        dialogue_text = build_dialogue_text(dialogue, target_speaker)
        
        if is_seq2seq:
            prepared.append({
                "input": dialogue_text,
                "target": reference_persona
            })
        else:
            # For causal LM, format as dialogue + instruction
            formatted_text = f"Диалог:\n{dialogue_text}\n\nОпиши личность Пользователя {target_speaker}:"
            prepared.append({"text": formatted_text, "target": reference_persona})

    print(f"Prepared {len(prepared)} training examples")
    return prepared

In [None]:
from transformers import AutoTokenizer

def tokenize_function(examples, tokenizer, max_length, is_seq2seq: bool = False):
    """Tokenize text data."""
    if is_seq2seq:
        inputs = tokenizer(
            examples["input"],
            truncation=True,
            max_length=max_length,
            padding="max_length",
            return_tensors=None,
        )
        targets = tokenizer(
            examples["target"],
            truncation=True,
            max_length=max_length,
            padding="max_length",
            return_tensors=None,
        )
        # For seq2seq, labels are the target ids
        inputs["labels"] = targets["input_ids"]
        labels = [
            [(label if label != tokenizer.pad_token_id else -100) for label in labels_seq]
            for labels_seq in targets["input_ids"]
        ]
        inputs["labels"] = labels
        return inputs
    else:
        # For causal LM, tokenize the full text
        tokenized = tokenizer(
            examples["text"],
            truncation=True,
            max_length=max_length,
            padding="max_length",
            return_tensors=None,
        )
        return tokenized

## Load and Prepare Data

In [None]:
# Load train and validation data
print("Loading training data...")
train_data = load_dialogues(DATA_DIR, TRAIN_SPLIT)

print("Loading validation data...")
val_data = load_dialogues(DATA_DIR, VAL_SPLIT)

print(f"Train size: {len(train_data)}, Val size: {len(val_data)}")

# Prepare training data
train_prepared = prepare_training_data(train_data, TARGET_SPEAKER)
val_prepared = prepare_training_data(val_data, TARGET_SPEAKER)

print(f"Prepared train: {len(train_prepared)}, val: {len(val_prepared)}")

## Model Loading Functions

In [None]:
import torch

def load_model_and_tokenizer(model_name: str):
    """Load model and tokenizer. Auto-detects seq2seq vs causal LM."""
    print(f"Loading model: {model_name}")
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Try seq2seq first, fallback to causal LM
    try:
        from transformers import AutoModelForSeq2SeqLM
        model = AutoModelForSeq2SeqLM.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16,
            device_map="auto",
        )
        is_seq2seq = True
        print("Detected: Seq2Seq model (encoder-decoder)")
    except (OSError, ValueError, KeyError):
        try:
            from transformers import AutoModelForCausalLM
            model = AutoModelForCausalLM.from_pretrained(
                model_name,
                torch_dtype=torch.bfloat16,
                device_map="auto",
            )
            is_seq2seq = False
            print("Detected: Causal LM (decoder-only)")
        except Exception as e:
            raise RuntimeError(f"Failed to load model {model_name}: {e}")

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    return model, tokenizer, is_seq2seq

In [None]:
# Load model and auto-detect type
model, tokenizer, is_seq2seq = load_model_and_tokenizer(MODEL_NAME)

print(f"Model loaded on: {model.device}")
print(f"Model parameters: {model.num_parameters():,}")

## Dataset Preparation

In [None]:
from datasets import Dataset

# Create datasets
train_dataset = Dataset.from_list(train_prepared)
val_dataset = Dataset.from_list(val_prepared)

# Determine columns to remove
cols_to_remove = ["input", "target"] if is_seq2seq else ["text", "target"]

# Tokenize
train_dataset = train_dataset.map(
    lambda x: tokenize_function(x, tokenizer, MAX_LENGTH, is_seq2seq),
    batched=True,
    remove_columns=cols_to_remove,
)
val_dataset = val_dataset.map(
    lambda x: tokenize_function(x, tokenizer, MAX_LENGTH, is_seq2seq),
    batched=True,
    remove_columns=cols_to_remove,
)

print(f"Train dataset: {len(train_dataset)} samples")
print(f"Validation dataset: {len(val_dataset)} samples")
print("Tokenization complete!")

## Setup Training

In [None]:
from transformers import (
    Trainer, Seq2SeqTrainer,
    TrainingArguments, Seq2SeqTrainingArguments,
    DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
)

# Data collator and trainer class
if is_seq2seq:
    data_collator = DataCollatorForSeq2Seq(
        tokenizer=tokenizer,
        model=model,
        padding=True,
    )
    TrainerClass = Seq2SeqTrainer
    TrainingArgsClass = Seq2SeqTrainingArguments
else:
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False,
    )
    TrainerClass = Trainer
    TrainingArgsClass = TrainingArguments

print(f"Using {TrainerClass.__name__}")

In [None]:
# Training arguments
training_args = TrainingArgsClass(
    output_dir=OUTPUT_DIR,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    learning_rate=LEARNING_RATE,
    warmup_steps=WARMUP_STEPS,
    logging_steps=10,
    save_steps=100,
    eval_steps=100,
    save_total_limit=3,
    fp16=False,
    bf16=torch.cuda.is_bf16_supported(),
    eval_strategy="steps",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    report_to="none",
    remove_unused_columns=False,
    predict_with_generate=True if is_seq2seq else False,
)

# Create trainer
trainer = TrainerClass(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
    processing_class=tokenizer,
)

print("Trainer ready!")

## Train the Model

This will take some time depending on your GPU.

In [None]:
print("Starting training...")
trainer.train()
print("\nTraining completed!")

## Save the Model

In [None]:
print(f"Saving model to {OUTPUT_DIR}")
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print("Model saved successfully!")

## Generation Functions

Functions to generate persona descriptions from dialogues:

In [None]:
def generate_persona(model, tokenizer, dialogue: List[Dict], target_speaker: int,
                      is_seq2seq: bool, max_new_tokens: int = 150, temperature: float = 0.7):
    """Generate a persona description from dialogue."""
    
    dialogue_text = build_dialogue_text(dialogue, target_speaker)
    
    if is_seq2seq:
        prompt = dialogue_text
    else:
        prompt = f"Диалог:\n{dialogue_text}\n\nОпиши личность Пользователя {target_speaker}:"

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    outputs = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        do_sample=True,
        pad_token_id=tokenizer.pad_token_id,
    )

    if is_seq2seq:
        result = tokenizer.decode(outputs[0], skip_special_tokens=True)
    else:
        full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
        if f"Опиши личность Пользователя {target_speaker}:" in full_output:
            result = full_output.split(f"Опиши личность Пользователя {target_speaker}:")[-1].strip()
        else:
            result = full_output
    
    # Clean up
    result = result.split("\n\n")[0].strip()
    result = result.split("Диалог:")[0].strip()
    
    return result

## Quick Evaluation

In [None]:
# Reload model for evaluation
eval_model, eval_tokenizer, eval_is_seq2seq = load_model_and_tokenizer(OUTPUT_DIR)

print("\nSample predictions:")
print("=" * 80)

for i, item in enumerate(val_data[:5]):
    # Get reference persona
    reference = item.get(f"persona_{TARGET_SPEAKER}", "")
    dialogue = item.get("dialogue", [])
    
    # Generate prediction
    predicted = generate_persona(
        eval_model, eval_tokenizer, dialogue, TARGET_SPEAKER,
        eval_is_seq2seq, max_new_tokens=150, temperature=0.7
    )
    
    print(f"\n--- Sample {i+1} ---")
    print(f"Reference Persona {TARGET_SPEAKER}: {reference}")
    print(f"Predicted: {predicted}")
    print("-" * 80)