In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, PreTrainedTokenizer
from collections import defaultdict, deque
import math
import logging
import json
from typing import Dict, List, Tuple, Optional, Union
from dataclasses import dataclass
import numpy as np
from torch.utils.data import Dataset, DataLoader
import concurrent.futures
from tqdm import tqdm, trange
import wandb
import os
from pathlib import Path
from collections import deque
from torch.utils.data import Dataset
from typing import List, Tuple, Dict, Any, Optional
from collections import Counter
import random
import string
import re

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
os.environ["CUDA_VISIBLE_DEVICES"]= "2"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


  from .autonotebook import tqdm as notebook_tqdm


# Config

In [2]:
class PRMConfig:
    """Configuration class for PRM hyperparameters and settings"""
    # MC config
    # "meta-llama/Meta-Llama-3-8B" "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B" "Qwen/Qwen2.5-Math-7B-Instruct" "Qwen/Qwen2.5-Math-7B"
    max_new_tokens: int = 512
    num_rollouts: int = 5
    reward_threshold: float = 0.2
    samples_per_question: int = 1
    # PRM config
    batch_size: int = 32
    learning_rate: float = 5e-4
    hidden_size: int = 256
    num_workers: int = 4
    # Misc config
    use_wandb: bool = False
    checkpoint_dir: str = "checkpoints"

# PRMDataset

In [3]:
################################################################################
#                        UTILITY: ANSWER NORMALISATION                         #
################################################################################
import sympy as sp

def _strip_markup(ans: str) -> str:
    """Remove common LaTeX/markup & variable tags."""
    # Remove LaTeX inline math wrappers \( … \) or \[ … \]
    ans = re.sub(r"\\[\[(](.*?)[\\\])]", r"\1", ans)
    # Remove \boxed{…}
    ans = re.sub(r"\\boxed\{([^}]*)\}", r"\1", ans)
    # Remove variable assignments like "y =" or "x=" at start
    ans = re.sub(r"^[a-zA-Z]\s*=\s*", "", ans)
    # Trim outer $ … $ if present
    ans = ans.strip()
    if ans.startswith("$") and ans.endswith("$"):
        ans = ans[1:-1]
    return ans.strip()

def _sanitize(text: str) -> str:
    """Normalise a candidate answer string for comparison."""
    text = _strip_markup(text)
    text = text.strip()
    text = re.sub(r"[\s\.;:,]+$", "", text)     # trailing punctuation
    text = re.sub(r"\s+", " ", text)              # collapse spaces
    return text

def _to_float(expr: str) -> Optional[float]:
    try:
        return float(eval(expr.replace("^", "**")))
    except Exception:
        return None

def _numeric_equiv(a: str, b: str) -> bool:
    """Return True if `a` and `b` are numerically equivalent or exact match."""
    a_clean, b_clean = map(_sanitize, (a, b))
    if a_clean == b_clean:
        return True

    # Attempt simple numeric evaluation
    a_val, b_val = _to_float(a_clean), _to_float(b_clean)
    if a_val is not None and b_val is not None:
        return math.isclose(a_val, b_val, rel_tol=1e-6)

    if sp is not None:
        try:
            a_expr = sp.sympify(a_clean.replace("^", "**"))
            b_expr = sp.sympify(b_clean.replace("^", "**"))
            return sp.simplify(a_expr - b_expr) == 0
        except Exception:
            pass
    return False

def system_prompt(type):
    prompt = ""
    if type == "sample":
        prompt = """You are a math-problem expert. Your task is to complete the step-by-step solution for the problem provided. Write each reasoning step on its own line in the exact form \"Step k: [your reasoning step]\n\", numbering start from Step 1. When the final answer is obtained, write exactly one final line, \"Answer: [Final answer]\". Do NOT add explanations, extra steps, or any text after the "Answer:" line.
            
Format Guide with Examples:
## Example 1 ##
Problem: Find the sum of the first 8 positive even integers.
Step 1: The first 8 even integers are 2, 4, 6, 8, 10, 12, 14, 16.
Step 2: Use the formula for an arithmetic series: S = n·(first + last)/2.
Step 3: Substitute n=8, first=2, last=16 to get S = 8·(2+16)/2 = 8·9 = 72.
Answer: 72

## Example 2 ##
Problem: Determine the next number in the sequence 2, 4, 8, 16.
Step 1: Notice each term is obtained by multiplying the previous term by 2.
Step 2: Multiply 16 by 2, 16 * 2 = 32.
Answer: 32

Follow this structure exactly. Once you have written the "Answer: " line, stop generating."""
    if type == "rollout":
        prompt = """You are a math problem-solving expert. Continue solving the given problem step by step, strictly following the required format. Each new step must begin with \"Step k+1: ...\", \"Step k+2:...\", and so on, continuing from the last given step number. When the final answer is reached, write only one final line starting with: \"Answer: [Final Answer]\". Do not add any explanations, extra commentary, or additional text after the "Answer:" line. Your output must follow this exact step-by-step format with no deviations.

Format Guide with Examples:
## Example 1 ##
Current solution steps:
Problem: Find the sum of the first 8 positive even integers.
Step 1: The first 8 even integers are 2, 4, 6, 8, 10, 12, 14, 16.
Step 2: Use the formula for an arithmetic series: S = n·(first + last)/2.

Continue and finish the solution:
Step 3: Substitute n=8, first=2, last=16 to get S = 8·(2+16)/2 = 8·9 = 72.
Answer: 72

## Example 2 ##
Current solution steps:
Problem: Determine the next number in the sequence 2, 4, 8, 16.
Step 1: Notice each term is obtained by multiplying the previous term by 2.

Continue and finish the solution:
Step 2: Multiply 16 by 2, 16 * 2 = 32.
Answer: 32

Keep the reasoning steps precise and factual. Your job is to complete the solution cleanly using this format structure."""
    return prompt

In [5]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

class MCReward:
    STEP_PATTERN = re.compile(r"Step\s+\d+:")
    ANSWER_PATTERN = re.compile(r"Answer\s*:\s*(.+?)\s*(?:$|\n)")

    def __init__(self, config: "PRMConfig", model, tokenizer):
        self.config = config
        self.model = model
        self.tokenizer = tokenizer
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Function to generate one or more step-by-step solutions for a given question.
    def generate_solutions(self, question: str, sys_prompt: str, num_solutions: int):
        prompt = f"{sys_prompt}\n\n{question}\n"  # Prompt the model to start the step-by-step solution
        input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.model.device)
        # Generate multiple solutions via sampling
        outputs = self.model.generate(
            input_ids,
            max_new_tokens=self.config.max_new_tokens,
            do_sample=True,
            num_return_sequences=num_solutions,
            temperature=0.7,         # sampling temperature for diversity (adjust as needed)
            top_p=0.8,               # top-p sampling for diversity
            pad_token_id=self.tokenizer.eos_token_id  # pad token ID to avoid warning for some models
        )
        solutions = []
        prompt_len = input_ids.shape[-1]
        for i in range(num_solutions):
            # Each output is the concatenation of the prompt and the generated completion.
            generated_ids = outputs[i]
            # Extract only the newly generated tokens (skip the prompt tokens).
            gen_ids = generated_ids[prompt_len:]
            text = self.tokenizer.decode(gen_ids, skip_special_tokens=True)
            solutions.append(text)
        return solutions
    
    # Function to parse a solution text into steps and final answer.
    def _extract_answer(self, text: str) -> Optional[str]:
        """Try multiple heuristics / regexes to pull out an answer string."""
        # Primary regex (robust to Answer:, Answer ‑, etc.)
        match = self.ANSWER_PATTERN.search(text)
        if match:
            return _sanitize(match.group(1))
        
        # Fallback 1: last non‑empty line if it looks simple / numeric
        lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
        if lines:
            candidate = lines[-1]
            if re.search(r"\d", candidate):  # contains digit
                return _sanitize(candidate)

        # Fallback 2: look for last line that starts with 'Answer'
        for line in reversed(text.splitlines()):
            if line.strip().lower().startswith("answer"):
                return _sanitize(line.split("Answer", 1)[-1])
        
        return None

    def parse_solution(self, solution_text: str):
        """
        Given the model's generated solution text, split it into a list of steps and the final answer string.
        Expects each step to start with 'Step X:' and the answer to start with 'Answer:'.
        """
        steps = []
        # answer = None
        # Split by lines to identify steps and answer
        for line in solution_text.splitlines():
            line = line.strip()
            if not line:
                continue
            if line.startswith("Step"):
                steps.append(line)
            # if line.startswith("Answer:"):
            #     # Extract everything after "Answer:" as the answer
            #     answer = line.split("Answer:", 1)[1].strip()
            #     # Stop if answer is found (anything after answer line is not needed)
            #     break
            answer = self._extract_answer(solution_text)
        return steps, answer
    
    # Function to estimate intermediate rewards for each step via rollouts.
    def compute_step_rewards(self, question, sys_prompt, steps, gold_answer):
        """
        For each prefix ending at a given step in 'steps', generate rollouts and compute the reward 
        (fraction of rollouts ending in the correct answer). Returns a list of reward values corresponding to each step.
        """
        rewards = []
        total_steps = len(steps)

        # Pre‑encode static prefix (sys_prompt + question) once for efficiency
        base_prompt = f"{sys_prompt}\n\n{question}\n"
        base_ids = self.tokenizer.encode(base_prompt, return_tensors="pt").to(self.device)

        for i in range(total_steps):
            # prefix_steps = steps[:i+1]  
            # prefix_text = f"{sys_prompt}\n\n{question}\n" + "\n".join(prefix_steps) + "\n"
            prefix_tokens = self.tokenizer.encode("\n".join(steps[: i + 1]) + "\n", return_tensors="pt").to(self.device) # steps up to current step i (0-indexed)
            # Decide how to prompt the next part:
            # if i < total_steps - 1:
            #     next_step_num = i + 2  # (because i is 0-indexed and Step numbering is 1-indexed)
            #     prefix_text += f"Step {next_step_num}:"
            # else:
            #     # If this is the last step of the solution, prompt the final answer.
            #     prefix_text += "Answer:"
            if i < total_steps - 1:
                next_label = f"Step {i + 2}:"
            else:
                next_label = "Answer:"
            cont_ids = self.tokenizer.encode(next_label, return_tensors="pt").to(self.device)

            # Build full prefix ids (avoid Python concat inefficiency by cat)
            prefix_ids = torch.cat([base_ids, prefix_tokens, cont_ids], dim=-1)

            # prefix_ids = self.tokenizer.encode(prefix_text, return_tensors='pt').to(self.model.device)
            rollout_outputs = self.model.generate(
                prefix_ids,
                max_new_tokens=self.config.max_new_tokens,
                do_sample=True,
                num_return_sequences=self.config.num_rollouts,
                temperature=0.7,
                top_p=0.8,
                pad_token_id=self.tokenizer.eos_token_id
            )
            new_token_start = prefix_ids.shape[-1] 
            # Check each rollout's final answer against the gold answer
            correct_count = 0
            for idx, seq in enumerate(rollout_outputs):
                completion = self.tokenizer.decode(seq[new_token_start:], skip_special_tokens=True)
                pred_answer = self._extract_answer(completion)
                if pred_answer is not None and _numeric_equiv(pred_answer, gold_answer):
                    correct_count += 1
                reward = correct_count / float(self.config.num_rollouts)
                print(f"[Rollout {idx}]:", pred_answer, "vs", gold_answer, "Reward:", reward)
            rewards.append(reward)
            # for j in range(self.config.num_rollouts):
            #     gen_ids = rollout_outputs[j][prefix_ids.shape[-1]:]  # only new tokens
            #     completion_text = self.tokenizer.decode(gen_ids, skip_special_tokens=True)
            #     print(f"[Rollout {j}]:", completion_text, "\n")
            #     # We find the last occurrence of "Answer:" in the completion (in case it's used earlier)
            #     answer_index = completion_text.rfind("Answer:")
            #     if answer_index != -1:
            #         predicted_answer = completion_text[answer_index + len("Answer:"):].strip()
            #     else:
            #         # If the model didn't explicitly output "Answer:", treat the whole completion as answer
            #         predicted_answer = completion_text.strip()
            #     # Exact string match (after stripping) to the gold answer
            #     if predicted_answer == str(gold_answer).strip():
            #         correct_count += 1
            #     print(f"[Rollout {j}]:",predicted_answer, gold_answer)
            # reward = correct_count / float(self.config.num_rollouts)
            # rewards.append(reward)
        return rewards
    
    # Build datasets based on input datas
    def build_datasets(self, problems: List):
        dataset = []  # will hold the output list of dicts
        for problem in problems:
            question = problem["question"]
            # gold_answer = problem["gold_answer"]
            gold_answer = _sanitize(problem["gold_answer"])
            # Generate one or more solutions for this question
            sample_prompt = system_prompt("sample")
            rollout_prompt = system_prompt("rollout")
            solutions = self.generate_solutions(question, sys_prompt=sample_prompt, num_solutions=self.config.samples_per_question)
            for sol_text in solutions:
                steps, answer = self.parse_solution(sol_text)
                print("Parsed solution:", steps, answer)
                if answer is None: # If no answer was found in the solution (edge case), skip this solution
                    continue
                # Compute intermediate rewards for each step in this solution
                rewards = self.compute_step_rewards(question, sys_prompt=rollout_prompt, steps=steps, gold_answer=gold_answer)
                # Prepare the output entry
                entry = {
                    "question": question,
                    "completion": steps,      # list of "Step i: ..." strings
                    "rewards": rewards,       # list of reward values for each step
                    "answer": answer,         # model's final answer from this solution
                    "gold_answer": gold_answer
                }
                dataset.append(entry)
        return dataset

In [6]:
model_name =  "Qwen/Qwen2.5-Math-7B" # "Qwen/Qwen2.5-Math-7B-Instruct"  #"Qwen/Qwen2.5-Math-7B"
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

problems = [
    {"question": "Solve for y: 2y - 7 = 3(y - 4).", 
     "gold_answer": "5"},
    # Add more problems as needed...
]

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  2.16it/s]


In [7]:
mcr = MCReward(config=PRMConfig , model=model, tokenizer=tokenizer)
dataset = mcr.build_datasets(problems)

# Print or inspect the dataset
for entry in dataset:
    print(entry)
    print("-" * 80)

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Parsed solution: ['Step 1: Distribute the 3 on the right side: 2y - 7 = 3y - 12.', 'Step 2: Move all terms involving y to one side and constant terms to the other: 2y - 3y = -12 + 7.', 'Step 3: Simplify both sides: -y = -5.', 'Step 4: Divide both sides by -1 to solve for y: y = 5.'] 5
[Rollout 0]: 5 vs 5 Reward: 0.2
[Rollout 1]: 5 vs 5 Reward: 0.4
[Rollout 2]: 5 vs 5 Reward: 0.6
[Rollout 3]: 5 vs 5 Reward: 0.8
[Rollout 4]: 5 vs 5 Reward: 1.0
[Rollout 0]: 5 vs 5 Reward: 0.2
[Rollout 1]: 5 vs 5 Reward: 0.4
[Rollout 2]: 5 vs 5 Reward: 0.6
[Rollout 3]: 5 vs 5 Reward: 0.8
[Rollout 4]: 5 vs 5 Reward: 1.0
[Rollout 0]: 5 vs 5 Reward: 0.2
[Rollout 1]: 5 vs 5 Reward: 0.4
[Rollout 2]: 5 vs 5 Reward: 0.6
[Rollout 3]: 5 vs 5 Reward: 0.8
[Rollout 4]: 5 vs 5 Reward: 1.0
[Rollout 0]: boxed{5}\) vs 5 Reward: 0.0
[Rollout 1]: 5 vs 5 Reward: 0.2
[Rollout 2]: 5 vs 5 Reward: 0.4
[Rollout 3]: 5 vs 5 Reward: 0.6
[Rollout 4]: 5 vs 5 Reward: 0.8
{'question': 'Solve for y: 2y - 7 = 3(y - 4).', 'completion': ['S

In [None]:
if __name__ == "__main__":
    model_name = "Qwen/Qwen2.5-Math-7B-Instruct"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")

    cfg = PRMConfig(num_rollouts=3, samples_per_question=1)
    mcr = MCReward(cfg, model, tokenizer)

    problems = [
        {"question": "Solve for y: 2y - 7 = 3(y - 4).", "gold_answer": "5"},
    ]

    ds = mcr.build_dataset(problems)
    for row in ds:
        print(row)
        print("-" * 80)

# PRM Model

In [3]:
import torch
import torch.nn as nn
from typing import Optional

class ProcessRewardModel(nn.Module):
    """Enhanced Process Reward Model with dropout and layer normalization"""
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        output_size: int,
        dropout: float = 0.1,
        num_layers: Optional[int] = None
    ):
        """ 
        Args:
            input_size (int): Size of input features
            hidden_size (int): Size of hidden layers
            output_size (int): Size of output
            dropout (float): Dropout rate
            num_layers (Optional[int]): Number of hidden layers
        """
        super(ProcessRewardModel, self).__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.dropout_rate = dropout
        self.num_layers = num_layers or 2
        
        # Input layer
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.ln1 = nn.LayerNorm(hidden_size)
        
        # Hidden layers
        hidden_layers = []
        for i in range(self.num_layers - 1):
            in_features = hidden_size if i == 0 else hidden_size // (2 ** i)
            out_features = hidden_size // (2 ** (i + 1))
            hidden_layers.extend([
                nn.Linear(in_features, out_features),
                nn.LayerNorm(out_features),
                nn.ReLU(),
                nn.Dropout(dropout)
            ])
        self.hidden_layers = nn.Sequential(*hidden_layers)
        
        # Output layer
        last_hidden_size = hidden_size // (2 ** (self.num_layers - 1))
        self.fc_out = nn.Linear(last_hidden_size, output_size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x (torch.Tensor): Input tensor
            
        Returns:
            torch.Tensor: Output predictions
        """
        # Input layer
        x = self.dropout(torch.relu(self.ln1(self.fc1(x))))
        
        # Hidden layers
        x = self.hidden_layers(x)
        
        # Output layer
        x = torch.sigmoid(self.fc_out(x))
        return x
    
    def get_complexity(self) -> int:
        """
        Returns:
            int: Total number of parameters
        """
        return sum(p.numel() for p in self.parameters())


# PRMTrainer

In [None]:
# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('omega_prm.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

class PRMTrainer:
    def __init__(self, mc: MCReward, config: PRMConfig):
        self.mc   = mc
        self.config    = config
        self.device = mc.device
        self.tok    = mc.tokenizer

        # PRM 자체 초기화
        feat_dim = mc.model.config.hidden_size        # LLM hidden size
        self.prm = ProcessRewardModel(feat_dim, self.config.hidden_size, output_size=1).to(self.device)
        self.opt = optim.AdamW(self.prm.parameters(), lr=self.config.learning_rate, weight_decay=0.01)
        # self.crit = nn.BCELoss()
        self.crit = nn.MSELoss()

        # Initialize wandb if enabled
        if self.config.use_wandb:
            wandb.init(project="mc-prm", name="prm-train", config=vars(self.config))
        # Create checkpoint directory
        Path(self.config.checkpoint_dir).mkdir(exist_ok=True)

    # ------------------------------------------------------------------ utils
    @torch.no_grad()
    def _encode_features(self, input_ids: torch.Tensor) -> torch.Tensor:
        """
        LLM 의 hidden state → [CLS-pooling] 식 임베딩.
        input_ids : [B, T]
        return    : [B, feat_dim]
        """
        # emb_layer = self.model.get_input_embeddings()          # shared embedding matrix
        # emb = (emb_layer(token_ids)).mean(dim=1) 
        out = self.mcts.model(input_ids=input_ids,
                              output_hidden_states=True,
                              return_dict=True)
        # 마지막 hidden-state의 0-번 토큰(CLS) 임베딩 사용
        return out.hidden_states[-1][:, 0, :]

    # ------------------------------------------------------ data preparation

        ds = PRMDataset(texts, lbls,
                        tokenizer=self.tok,
                        max_length=self.cfg.max_length)
        
        print("PRMDataset(size={}, avg_len={:.1f})".format(len(ds), sum(len(t.split()) for t in texts)/len(texts)))
        return ds, structured

    # ---------------------------------------------------------- train / valid
    def _run_epoch(self, loader: DataLoader, train: bool) -> float:
        self.prm.train() if train else self.prm.eval()
        tot = 0.0
        for step, (ids, r) in enumerate(loader):
            ids, r = ids.to(self.device), r.to(self.device)
            with torch.set_grad_enabled(train):
                feats = self._encode_features(ids)
                out   = self.prm(feats)
                loss  = self.crit(out, r)
                if train:
                    self.opt.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.prm.parameters(), 1.0)
                    self.opt.step()
            tot += loss.item()

            if self.cfg.use_wandb and train:
                wandb.log({
                    "batch_loss": loss.item(),
                    # "epoch"     : self.cur_epoch,   # train_prm 에서 설정
                    "step"      : step
                })
        return tot / len(loader)

    def train_prm(
        self,
        train_questions: List[str],
        val_questions  : List[str],
        num_epochs: int = 5,
    ) -> Dict[str, List[float]]:
        # 1) 데이터 수집
        train_ds, _ = self.build_dataset(train_questions)
        val_ds,   _ = self.build_dataset(val_questions)
        print("train ds:", train_ds)

        train_loader = DataLoader(train_ds, batch_size=self.cfg.batch_size, shuffle=True, num_workers=self.cfg.num_workers)
        val_loader = DataLoader(val_ds, batch_size=self.cfg.batch_size,shuffle=False, num_workers=self.cfg.num_workers)

        # 2) 학습 loop
        hist = {"train": [], "val": []}
        best = float("inf")
        for ep in range(num_epochs):
            tr = self._run_epoch(train_loader, train=True)
            vl = self._run_epoch(val_loader,   train=False)
            hist["train"].append(tr)
            hist["val"].append(vl)
            print(f"[EP {ep}] train {tr:.4f} | val {vl:.4f}")

            if self.cfg.use_wandb:
                wandb.log({
                    "epoch"     : ep,
                    "train_loss": tr,
                    "val_loss"  : vl,
                })
            if vl < best:
                best = vl
                self.save_checkpoint(ep, vl)
                # torch.save(self.prm.state_dict(), Path(self.cfg.checkpoint_dir) / "best_prm.pt")
        return hist

    # -------------------------------------------------------------- metrics
    def get_metrics(self) -> Dict[str, float]:
        return {
            "params": sum(p.numel() for p in self.prm.parameters()),
        }
    
    # -------------------------------------------------------------- save checkpoints
    def save_checkpoint(self, epoch: int, validation_loss: float):
        """Save model checkpoint"""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.prm.state_dict(),
            'optimizer_state_dict': self.opt.state_dict(),
            'validation_loss': validation_loss,
            'config': self.cfg.__dict__
        }
        path = Path(self.cfg.checkpoint_dir) / f"checkpoint_epoch_{epoch}.pt"
        torch.save(checkpoint, path)
        logger.info(f"Saved checkpoint to {path}")

    def load_checkpoint(self, path: str):
        """Load model checkpoint"""
        try:
            checkpoint = torch.load(path, map_location=self.device)
            self.prm.load_state_dict(checkpoint['model_state_dict'])
            self.opt.load_state_dict(checkpoint['optimizer_state_dict'])
            logger.info(f"Loaded checkpoint from {path}")
            return checkpoint['epoch'], checkpoint['validation_loss']
        except Exception as e:
            logger.error(f"Failed to load checkpoint: {str(e)}")
            raise
