In [1]:
from typing import Dict, List, Tuple, Union, Callable
from pathlib import Path
import logging

import torch
import transformers
from datasets import load_dataset, Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    HfArgumentParser,
    GenerationConfig
)
from tensordict.tensordict import TensorDict

# Configure logging using logging library
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from abc import ABC, abstractmethod


class ReasoningEnv(ABC):
    """It has the responsibility to simulate the environment that the agent is interacting with.
    Environment processes one query at a time.
    Core functionalities:
    - step():
    - rollout():
    - reset():
    - close():
    """

    def __init__(self, 
                 cfg: dict, 
                 tokenizer: AutoTokenizer,
                 dataset_path: Union[str, Path] = None,
                 dataset: Dataset = None,
                 data_generator: Callable = None):
        """

        Args:
            cfg (dict): _description_
            tokenizer (AutoTokenizer): tokenizer includes both the tokenizer and vocab size
            dataset_path (Union[str, Path]): path to the dataset
            dataset (Dataset): dataset
        """
        self._cfg = cfg
        self._tokenizer = tokenizer
        self._dataset = self._initialize_dataset(dataset=dataset, dataset_path=dataset_path, data_generator=data_generator)
        self._current_state = None
        self._current_query_idx = 0

    def _initialize_dataset(self, dataset=None, dataset_path=None, data_generator=None):
        if (dataset and dataset_path) or (dataset and data_generator) or (dataset_path and data_generator):
            raise ValueError("Provide only one of dataset, dataset_path, or data_generator")
        
        if dataset_path is not None:
            return load_dataset(dataset_path)
        elif dataset is not None:
            return dataset
        elif data_generator is not None:
            raise NotImplementedError("data_generator is not implemented")
        else:
            raise ValueError("Provide either dataset_path or dataset")

    @abstractmethod
    def step(self, action: Union[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, bool, bool, dict]:
        """
        Args:
            action (Union[str, torch.Tensor]): _description_

        Returns:
            Observation: _description_
            Reward: _description_
            Terminated: _description_
            Truncated: _description_
            Info: _description_
        """
        pass

    @abstractmethod
    def rollout(self, max_steps: int, 
                policy: AutoModelForCausalLM = None, 
                generation_config: GenerationConfig = None
                ) -> Tuple[torch.Tensor, torch.Tensor]:
        pass
        
    def reset(self) -> Tuple[torch.Tensor, Union[TensorDict, dict]]:
        self._current_state = self._dataset[self._current_query_idx]['input_ids']
        self._current_query_idx += 1

        info_ = {}
        return self._current_state, info_
    
    @abstractmethod
    def close(self):
        pass

class TokenLevelReasoningEnv(ReasoningEnv):
    
    def step(self, action: Union[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, bool, bool, dict]:
        pass

    def rollout(self, max_steps: int, 
                policy: AutoModelForCausalLM = None, 
                generation_config: GenerationConfig = None
                ) -> Tuple[torch.Tensor, torch.Tensor]:
        
        
        
        if policy is None:
            # use a uniform policy to randomly sample actions from the tokenizer
            class UniformPolicy(AutoModelForCausalLM):
                def __init__(self, tokenizer: AutoTokenizer):
                    self._tokenizer = tokenizer

                def __call__(self, x: torch.Tensor) -> torch.Tensor:
                    return torch.randint(0, self._tokenizer.vocab_size)
                
                def generate(self, input_ids: torch.Tensor, generation_config: GenerationConfig) -> torch.Tensor:
                    return torch.randint(0, self._tokenizer.vocab_size, size=input_ids.shape)
                
            policy = UniformPolicy(self._tokenizer)

        if generation_config is None:
            generation_config = GenerationConfig(
                max_new_tokens=max_steps,
                temperature=1.0,
                top_k=0.0,
                top_p=1.0,
                do_sample=True,
            )

        response = policy.generate(input_ids=self._current_state, generation_config=generation_config)

        self._current_state = response.input_ids

        return self._current_state, self._current_state

# implement this vectorized version of the environment using vmap in python. Make sure the implementation is optimal.
# You need to make sure that the dataset is loaded beforehand and each distinct part goes to a seperate environment.
class VecReasoningEnv:

    def __init__(self, cfg: dict, tokenizer: AutoTokenizer, dataset: Dataset):
        self._cfg = cfg
        self._tokenizer = tokenizer
        self._dataset = dataset
        self._envs = [TokenLevelReasoningEnv(cfg, tokenizer, dataset) for _ in range(len(dataset))]

    def step(self, action: Union[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, bool, bool, dict]:
        pass

In [3]:
SIMPLE_CHAT_TEMPLATE = "{% for message in messages %}{{message['role'].capitalize() + ': ' + message['content'] + '\n\n'}}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant: ' }}{% endif %}"
USER_PREFIX_PROMPT = "Your role as an assistant involves thoroughly exploring questions through a systematic long thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution. In the Thought section, detail your reasoning process using the specified format: <|begin_of_thought|> {thought with steps separated with ' '} <|end_of_thought|> Each step should include detailed considerations such as analisying questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The solution should remain a logical, accurate, concise expression style and detail necessary step needed to reach the conclusion, formatted as follows: <|begin_of_solution|> {final formatted, precise, and clear solution} <|end_of_solution|> Now, try to solve the following question through the above guidelines:"

In [6]:
# Model paths
base_model_path = "deepseek-ai/deepseek-math-7b-instruct"

# Initialize tokenizer with padding token
# TODO: understanding right padding and left padding and how it influences the performance
tokenizer = AutoTokenizer.from_pretrained(
    base_model_path, padding_side="left", trust_remote_code=True, return_tensors="pt"
)
if tokenizer.pad_token is None:
    # TODO: learn what it means EXACTLY TO SET EOS == PAD TOKEN
    logger.info(f"Padding token is None, setting it to eos_token: {tokenizer.eos_token}")
    tokenizer.pad_token = tokenizer.eos_token

if tokenizer.chat_template is None:
    logger.info(f"Chat template is None, setting it to SIMPLE_CHAT_TEMPLATE: {SIMPLE_CHAT_TEMPLATE}")
    tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE

# Load all required models
# value_model = AutoModelForSequenceClassification.from_pretrained(
#     base_model_path,
#     trust_remote_code=True,
#     num_labels=1,
#     device_map="auto",
#     torch_dtype=torch.float16,
# )
# logger.debug(f"Value Model Architecture: {value_model}")

# reward_model = AutoModelForSequenceClassification.from_pretrained(
#     base_model_path,
#     trust_remote_code=True,
#     num_labels=1,
#     device_map="auto",
#     torch_dtype=torch.float16,
# )
# logger.debug(f"Reward Model Architecture: {reward_model}")
policy = AutoModelForCausalLM.from_pretrained(
    base_model_path,
    trust_remote_code=True,
    device_map="auto",
    torch_dtype=torch.float16,
    offload_folder="offload/yo_v1",
)
logger.debug(f"Policy Model Architecture: {policy}")

# ref_policy = AutoModelForCausalLM.from_pretrained(
#     base_model_path,
#     trust_remote_code=True,
#     device_map="auto",
#     torch_dtype=torch.float16,
# )
# logger.debug(f"Ref Policy Model Architecture: {ref_policy}")

# Load the dataset
dataset = load_dataset(
    "openai/gsm8k",
    'main',
    split="train",                                          
)
logger.info(f"Loaded dataset: {dataset}")
logger.info(f"Dataset Characteristics: {dataset.features}")
logger.info(f"Dataset Length: {len(dataset)}")
logger.info(f"Dataset Example: {dataset.shape}")

INFO:__main__:Padding token is None, setting it to eos_token: <｜end▁of▁sentence｜>
Loading checkpoint shards: 100%|██████████| 2/2 [00:28<00:00, 14.19s/it]
INFO:__main__:Loaded dataset: Dataset({
    features: ['question', 'answer'],
    num_rows: 7473
})
INFO:__main__:Dataset Characteristics: {'question': Value(dtype='string', id=None), 'answer': Value(dtype='string', id=None)}
INFO:__main__:Dataset Length: 7473
INFO:__main__:Dataset Example: (7473, 2)


In [20]:
# TODO: this CoT is specialized for deepseek-math-7b-instruct/rl 
# https://huggingface.co/deepseek-ai/deepseek-math-7b-instruct
# make a proper function that is model agnostic in later refactoring
def apply_cot_prompt(question: str) -> str:
    return question + "\nPlease reason step by step, and put your final answer within \\boxed{}."

def preprocess_answer(answer: str) -> str:
    # Extract the number after '#### '
    return answer.split('#### ')[-1]

def tokenize(row: dict) -> dict:
    # Create messages format for the chat template
    row['question'] = apply_cot_prompt(row['question'])
    messages = [
        {'role': 'user', 'content': row['question']},
        # {'role': 'assistant', 'content': preprocess_answer(row['answer'])}
    ]
    
    input_ids = tokenizer.apply_chat_template(
        messages,
        tokenize=True, 
        add_generation_prompt=True, 
        padding=False # padding is set because of memory optimization
        ) 
    return {'input_ids': input_ids, 'answer': preprocess_answer(row['answer'])}

# Update the dataset mapping
tokenized_dataset = dataset.select(range(10)).map(tokenize, remove_columns=dataset.column_names)

logging.info(tokenized_dataset[0])

Map: 100%|██████████| 10/10 [00:00<00:00, 1807.34 examples/s]
INFO:root:{'answer': '72', 'input_ids': [100000, 5726, 25, 39203, 480, 5151, 34406, 276, 207, 19, 23, 280, 711, 3997, 279, 6511, 11, 285, 937, 838, 5151, 3222, 372, 1313, 34406, 279, 3638, 13, 1724, 1313, 34406, 1216, 39203, 480, 6926, 16369, 279, 6511, 285, 3638, 30, 185, 7900, 2806, 3458, 457, 3458, 11, 285, 1957, 520, 2328, 3510, 2383, 357, 63962, 90, 1424, 185, 185, 77398, 25]}


In [43]:
# Implementation of GRPO algorithm
iter_num = 1
step_num = 1
epsilon = 0.1
beta = 0.1
mu = 0.1
batch_size = 3
generation_config = GenerationConfig(
    max_new_tokens=1000,
    temperature=1.0,
    top_k=0.0,
    top_p=1.0,
    do_sample=True,
)
# reward_func = lambda x: x['answer'] == x['answer']
for i in range(iter_num):
    # ref_model.load_state_dict(policy.state_dict())
    # ref_model.eval()
    # ref_model.to(policy.device)
    for step in range(step_num):
        batch_data = tokenized_dataset.select(range(batch_size))
        batch_data = tokenizer.pad({'input_ids': batch_data['input_ids']}, padding=True, return_tensors='pt')
        print(batch_data['input_ids'].shape)
        # old_policy.load_state_dict(policy.state_dict())
        policy_output = policy.generate(input_ids=batch_data['input_ids'], generation_config=generation_config)
        policy_output = tokenizer.decode(policy_output[0], skip_special_tokens=True)
        logging.info(policy_output)
        break


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:100001 for open-end generation.


torch.Size([3, 84])


INFO:root:User: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?
Please reason step by step, and put your final answer within \boxed{}.

Assistant: Natalia sold 48 clips in April.
In May, she sold half as many clips as she did in April, so she sold 48/2 = 24 clips in May.
To find the total number of clips she sold in April and May, we need to add the number of clips sold in April to the number of clips sold in May, which is 48 + 24 = 72 clips.
So the answer is $\boxed{72}$.


In [41]:
tokenizer.decode(policy_output[2], skip_special_tokens=False)

"<｜begin▁of▁sentence｜>User: Betty is saving money for a new wallet which costs $100. Betty has only half of the money she needs. Her parents decided to give her $15 for that purpose, and her grandparents twice as much as her parents. How much more money does Betty need to buy the wallet?\nPlease reason step by step, and put your final answer within \\boxed{}.\n\nAssistant: First calculate how much money Betty's parents gave"

In [37]:
policy_output

tensor([[100001, 100001, 100001, 100001, 100001, 100001, 100001, 100001, 100001,
         100001, 100001, 100001, 100001, 100001, 100001, 100001, 100001, 100001,
         100001, 100001, 100001, 100001, 100000,   5726,     25,  39203,    480,
           5151,  34406,    276,    207,     19,     23,    280,    711,   3997,
            279,   6511,     11,    285,    937,    838,   5151,   3222,    372,
           1313,  34406,    279,   3638,     13,   1724,   1313,  34406,   1216,
          39203,    480,   6926,  16369,    279,   6511,    285,   3638,     30,
            185,   7900,   2806,   3458,    457,   3458,     11,    285,   1957,
            520,   2328,   3510,   2383,    357,  63962,     90,   1424,    185,
            185,  77398,     25,  39203,    480,   5151,    207,     19,     23,
          34406,    279,   6511,     13],
        [100001, 100001, 100001, 100001, 100001, 100001, 100001, 100001, 100001,
         100001, 100001, 100001, 100001, 100001, 100001, 100001, 10

In [5]:
# def tokenize(row: dict) -> dict:
#     row['messages'] = row['messages'][:1]
#     input_ids = tokenizer.apply_chat_template(row['messages'][:1], tokenize=True, add_generation_prompt=True, padding=False)
#     return {'input_ids': input_ids}

# tokenized_dataset = dataset.map(tokenize, remove_columns=dataset.column_names) # call it dataset later for GC reasons

# logging.info(tokenized_dataset)

INFO:root:Dataset({
    features: ['input_ids'],
    num_rows: 6447
})


In [10]:



def generate_responses(batch: Dict[str, torch.Tensor], generation_config: GenerationConfig) -> Dict[str, List[str]]:
    # Generate from policy model
    policy_output = policy.generate(
        input_ids=batch['input_ids'].to(policy.device),
        generation_config=generation_config,
        output_scores=True,
        return_dict_in_generate=True  # Return a dictionary to access score
    )
    # decode the output
    policy_output = tokenizer.decode(policy_output.sequences[0], skip_special_tokens=True)
    return policy_output

responses = []
for item in tokenized_dataset.select(range(2)):
    item['input_ids'] = torch.tensor(item['input_ids']).reshape(1, -1)
    responses.append(generate_responses(item, generation_config))

print(responses)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


['User: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?\n\nAssistant: ~~Natalia sold 2,769 clips to her peers in April (101 clips sold in May), and ~~Natalia sold 935 clips to her friends in April (73 clips sold in May), ~~Natalia sold 100 clips to 21 of her friends in May, and ~~Natalia sold 49 of her 33 friends in May. How many clips did Natalia sell to her friends in April and May?\n\nAssistant: Natalia sold all of the girls’ prints, numbers, and legend prints for $70.25 in April. What was the value of these in April and May', "User: Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?\n\nAssistant: \nTrainer: \nRube: Now her appearance regular, hers can attract paying customers. So, Rube got $10 an hour to look over and teach some little trouble faces.\n\nHow much did she win? And how much compensation would he send

In [48]:
responses[0].scores[0].shape
responses[0].scores[1].shape

torch.Size([1, 50304])

In [None]:
# RL Objective = Ref Policy and Policy Logits for the last
