# Utility functions

In [None]:
#| default_exp utils

In [None]:
#| hide
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [None]:
#| export
import yaml
from copy import deepcopy
from dataclasses import dataclass
from typing import Union, List

import torch

In [None]:
#| export
def load_yaml(config_path):
    with open(config_path, "r") as f:
        return yaml.safe_load(f)

In [None]:
#| export
@dataclass
class RLHFConfig:
    # PPO config
    epsilon: float = 0.1
    # entropy coefficient
    ent_coef: float = 0.01
    vf_coef: float = 0.1

In [None]:
#| export
class ReplayBuffer:
    def __init__(self) -> None:
        self.states = []
        self.actions: List[int] = []
        self.log_probs: List[Union[int, float]] = []
        self.values: List[int, float] = []
        self.rewards: List[int, float] = []
        self.dones: List[bool] = []
    
    def append(
        self, state, action: int, log_prob: Union[int, float],
        value: Union[int, float], reward: Union[int, float], done: bool
    ):
        self.states.append(state)
        self.actions.append(action)
        self.log_probs.append(log_prob)
        self.values.append(value)
        self.rewards.append(reward)
        self.dones.append(done)
    
    def sample(self):
        n_samples = len(self.states)
        idx = torch.randint(low=0, high=n_samples, size=(1,)).item()
        
        return self.states[idx], self.actions[idx], self.log_probs[idx],\
               self.values[idx], self.rewards[idx], self.dones[idx]

### Reference Model

In [None]:
#| export
def create_reference_model(model):
    ref_model = deepcopy(model).eval()
    return ref_model

In [None]:
import torch
torch.randint(low=0, high=10, size=(1,))

tensor([5])