In [2]:
import torch 
from torch import nn
class RewardModel:
    def __init__(self, model_instance, num_classes):
        # Dynamically set the class of this instance to be the same as the model_instance's class
        self.__class__ = type(model_instance)
        
        # Assign the model_instance to self
        self.model = model_instance
        
        # Assuming the model's configuration attribute is accessible via model.config.hidden_size
        hidden_size = self.model.config.hidden_size
        
        # Define a linear classification head
        self.classification_head = nn.Linear(hidden_size, num_classes)

    def forward(self, input_ids, attention_mask=None, token_type_ids=None, **kwargs):
        # Call the forward method of the base model
        outputs = self.model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, **kwargs)
        
        # Assuming the output we need is the last hidden state
        # outputs.last_hidden_state contains the last hidden state
        last_hidden_state = outputs.last_hidden_state
        
        # Apply the classification head on the [CLS] token (index 0 of last_hidden_state)
        cls_token = last_hidden_state[:, 0, :]
        logits = self.classification_head(cls_token)
        
        return logits