In [1]:
# !pip install huggingface_hub bitsandbytes feedparser datasets trl transformers peft accelerate>=0.26.0
# from IPython.core.display import HTML
# HTML("<script>Jupyter.notebook.kernel.restart()</script>")

In [None]:
token = "" # add hugging face token here / needed for gemma model

try:
    from huggingface_hub import login
    login(token=token)
except Exception as e:
    print("Exception occured whilet rying to login:", e)     

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch

# model_id = "google/gemma-3-1b-it"
model_id = "Qwen/Qwen2.5-3B-Instruct"
# model_id = "google/gemma-3-4b-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)

quantization_config = BitsAndBytesConfig(load_in_8bit=True)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    # quantization_config=quantization_config,
    # device_map="cuda:1",
    dtype = torch.float16,
    device_map="cuda",
)

print(f'\nMemory footprint of quantized model: {model.get_memory_footprint()/1e9} GB')

ref_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    # quantization_config=quantization_config,
    dtype = torch.float16,
    device_map='cuda' # enable if gemma 1b
    )

tokenizer = AutoTokenizer.from_pretrained(model_id)

print(f'\nMemory footprint of quantized reference model: {model.get_memory_footprint()/1e9} GB')

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]


Memory footprint of quantized model: 6.171877632 GB


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]


Memory footprint of quantized reference model: 6.171877632 GB


In [2]:
# 3. Apply LoRA
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
    r=16,#16,
    lora_alpha=32,
    # target_modules=["q_proj", "v_proj", "o_proj"],
    target_modules="all-linear",
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)
peft_model = get_peft_model(model, lora_config)

optimizer = torch.optim.AdamW(peft_model.parameters(), lr=4e-5) #4e-5 #lr = 1e-4
peft_model.print_trainable_parameters()

trainable params: 29,933,568 || all params: 3,115,872,256 || trainable%: 0.9607


### Gathering data

In [4]:
"Fetches recent papers from Arxiv with the mentioned categories"
import feedparser
import json
import time

# Number of entries per category (change this to control total size)
n_per_category = 100  # total will be 5 * this

# Category code to human-readable name
categories = {
    "cs.RO": "Robotics",
    "cs.LG": "Machine Learning",
    "cs.AI": "Artificial Intelligence",
    "cs.CV": "Computer Vision",
    "cs.DM": "Discrete Mathematics"
}

# Base API URL
base_url = "http://export.arxiv.org/api/query?"

In [5]:
def collect_data():
    all_entries = []
    
    # Loop over each category
    for cat_code, cat_name in categories.items():
        collected = 0
        start = 0
        batch_size = 150  # arXiv API limit per call
        print(f"Fetching {cat_name}...")
    
        while collected < n_per_category:
            query = (f"search_query=cat:{cat_code}&start={start}&max_results={batch_size}"
                     f"&sortBy=submittedDate&sortOrder=descending")
            feed = feedparser.parse(base_url + query)
    
            if not feed.entries:
                print(f"No more results for {cat_name}. Got {collected}.")
                break
    
            for entry in feed.entries:
                if collected >= n_per_category:
                    break
                entry_data = {
                    'title': entry.get('title'),
                    'id': entry.get('id'),
                    'published': entry.get('published'),
                    'updated': entry.get('updated'),
                    'summary': entry.get('summary'),
                    'authors': [author.name for author in entry.get('authors', [])],
                    'primary_category': entry.get('arxiv_primary_category', {}).get('term'),
                    'categories': [tag['term'] for tag in entry.get('tags', [])],
                    'pdf_url': next((link.href for link in entry.links if link.type == 'application/pdf'), None),
                    'comment': entry.get('arxiv_comment'),
                    'journal_ref': entry.get('arxiv_journal_ref'),
                    'category_name': cat_name  # Add human-readable name
                }
                all_entries.append(entry_data)
                collected += 1
    
            start += batch_size
            time.sleep(1)  # Respect arXiv rate limits
    
    # Save to JSON
    with open("arxiv_dataset.json", "w", encoding="utf-8") as f:
        json.dump(all_entries, f, indent=2, ensure_ascii=False)
    
    print(f"\n✅ Saved {len(all_entries)} entries to arxiv_dataset.json")


In [20]:
import pandas as pd
from transformers import pipeline
from tqdm import tqdm
import regex as re

try:
    # Load dataset
    df = pd.read_json("arxiv_dataset.json")
except:
    collect_data()
    df = pd.read_json("arxiv_dataset.json")
    
df.rename(columns={"summary": "abstract"}, inplace=True)

# Shuffle the dataset to ensure randomness
og_df = df.sample(frac=1).reset_index(drop=True)
len_df = len(og_df)

train_test_split = 0.8
df_train = og_df[:int(len_df*(train_test_split))]
df_test = og_df[-int(len_df*(1-train_test_split))-1:]
len(df_train), len(df_test), train_test_split

(400, 100, 0.8)

### Preparing data

In [None]:
# Preparing prompt messages
def build_dataset(df):
    batch_size = 100  # cha
    i = 0  # The current batch index (update in your loop)

    # Define the true labels you’ll compare against
    target_categories = list(categories.values())

    # Function to build the prompt
    def build_prompt(abstract):
        return (
            f"Read the following abstract from a scientific paper and guess its research area from the following list:\n\n"
            f"[{', '.join(target_categories)}]\n\n"
            f"Abstract:\n{abstract}\n\n"
            f"Answer with only single category name."
        )

    all_prompts = []
    all_completions = []
    all_inputs = torch.tensor([])
    all_outputs = []
    all_raw_targets = []
    all_answers = []
    all_questions = []

    for i in range(0, len(df), batch_size):
        batch_df = df[i : i + batch_size]
        # print(len(batch_df))

        batch_df.reset_index(inplace=True)
        # Build messages for each abstract in the batch
        prompts_batch = []
        completions_batch = []
        raw_targets_batch = []
        answers_batch = []
        questions_batch = []

        for row in range(len(batch_df)):
            user_msg = build_prompt(batch_df['abstract'][row])

            system_prompt_reason = """Generate every response after thinking and answering like below format.
<think> **[Your detailed internal reasoning, analysis, planning, and problem-solving steps here, tailored specifically to the user's current prompt. This section is for transparency of your thought process, not the final answer.]** </think>
<answer> **[Your final, polished, and comprehensive response to the user's prompt here. This is the only portion intended as the direct answer to the user.]** </answer>
You must answer within the <answer>...</answer> tags and think within the <think>...</think> tags."""

            system_prompt = """You are an helpful assistant"""
            prompt = [
                {
                    "role": "system",
                    "content": [{"type": "text", "text": system_prompt_reason}]
                },
                {
                    "role": "user",
                    "content": [{"type": "text", "text": user_msg}]
                }
            ]

            # if model is gemma, use apply_chat_template directly to all conversation
            prompt = "System:\n " + system_prompt_reason + "\nUser:\n" + user_msg + "\nAssistant: " + "<think>"
            prompts_batch.append(prompt)

            completion = [
                {
                    "role": "system",
                    "content": [{"type": "text", "text": system_prompt}]
                },
                {
                    "role": "user",
                    "content": [{"type": "text", "text": user_msg}]
                },
                {
                    "role": "model",
                    "content": [{"type": "text", "text": batch_df['category_name'][row]}]
                }
            ]

            completions = "System:\n "+system_prompt_reason+"\nUser:\n"+user_msg+  "\nAssistant:" + batch_df['category_name'][row]

            questions_batch.append(prompt)
            answers_batch.append(batch_df['category_name'][row])

            raw_targets_batch.append({'messages':completion})
            completions_batch.append(completions)


        all_prompts.extend(prompts_batch)
        all_completions.extend(completions_batch)
        all_raw_targets.extend(raw_targets_batch)
        all_answers.extend(answers_batch)
        all_questions.extend(questions_batch)


    # del inputs
    torch.cuda.empty_cache()
    return all_prompts, all_completions, all_raw_targets, all_outputs, all_questions, all_answers

all_prompts, all_completions, all_raw_targets, all_outputs, all_questions, all_answers = build_dataset(df_train[:])


NameError: name 'df_train' is not defined

### Rewards

In [None]:
import regex as re

def get_format_reward(raw_response):
    try:
        response = re.search(r'\nAssistant: (.*)', raw_response, re.DOTALL).group(1)
        # response = re.search(r'<start_of_turn>model(.*?)(?:<end_of_turn>|$)', raw_response, re.DOTALL).group(1)

        # print(response)
        # print("\n Answers:", answer)

        if not response:
            return -2.0  # or any fallback reward for missing assistant section
    except:
        return -2.0
    format_reward = 0
    if (
        "<answer>" not in response
        and "</answer>" not in response
        # and "<think>" not in response
        and "</think>" not in response
    ):
        format_reward -= 3.0 #* 20
        return format_reward
        
    # if "<think>" in response:
    #     format_reward += 0.5 if response.count('<think>')==1 else 0.05

    if "</think>" in response:
        format_reward += 0.5 if response.count('</think>')==1 else 0.05
    if "<answer>" in response and "</answer>" in response:
        format_reward += 1
        # print(response)
        # print(format_reward)
        # print(response.count("<answer>") + response.count("</answer>"))
    if response.count("<answer>") + response.count("</answer>") > 2:
        format_reward -= 0.3
        # print(response)
    return format_reward if format_reward is not None else -2.0
    

def get_answer_reward(raw_response, answer=None):
    try:
        response = re.search(r'\nAssistant: (.*)', raw_response, re.DOTALL).group(1)
        # response = re.search(r'<start_of_turn>model(.*?)(?:<end_of_turn>|$)', raw_response, re.DOTALL).group(1)

        if not response:
            return 0.0  # or any fallback reward for missing assistant section
    except:
        return 0.0
    
    answer_reward = 0.0
    pattern = r'<think>(.*?)</think>([\s\S]{0,9})<answer>(.*?)</answer>'
    pattern = r'<think>(.*?)</think>[\s\S]{0,9}<answer>(.*?)</answer>(.*)'
    # print(len(re.findall(pattern, response, re.DOTALL)[0]))
    # print(re.findall(pattern, response, re.DOTALL)[0])
    try:
        if len(re.findall(pattern, response, re.DOTALL)[0]) == 3:
            answer_reward += 1.0
            # print(response)
            # needs work to add rewards for not generate non pad tokens after answer
            # print(re.findall(pattern, response, re.DOTALL)[0][-1])
        if str(answer.lower()) in (re.findall(pattern, response, re.DOTALL)[0])[-2].lower():
            answer_reward += 3.0

    except:
        
        if str(answer.lower()) in response.lower():
            answer_reward += 1.5
        return answer_reward if answer_reward is not None else 0.0
    finally:
        return answer_reward if answer_reward is not None else 0.0


def get_rewards(raw_response, answer=None):
    total_reward = 0.0
    format_reward = 0.0
    answer_reward = 0.0

    try: 
        format_reward = get_format_reward(raw_response)
    except:
        format_reward = -2.0
 
    try:
        answer_reward = get_answer_reward(raw_response, answer)
    except:
        answer_reward = 0.0
        # print("sd",answer_reward)
    total_reward = format_reward + answer_reward
    
    return total_reward

### Sampling roll_outs

In [4]:
# Generate responses
def sample_rollouts(model, inputs, max_new_tokens=100, n_rollouts=3, temperature=1.1):
    input_ids = inputs["input_ids"].to(model.device)
    # attention_mask = torch.ones_like(inputs["input_ids"]).to(model.device)

    if "attention_mask" in inputs.keys():
        attention_mask = inputs["attention_mask"].to(model.device)
    else:
        print("Inside sample_rollouts: No attention mask in inputs")
        attention_mask = torch.ones_like(inputs["input_ids"]).to(model.device)

    # vocab_size = model.config.vocab_size
    # if input_ids.max() >= vocab_size:
    #     raise ValueError(f"Input ID out of vocab range: max {input_ids.max()}, vocab size {vocab_size}")

    with torch.inference_mode():
        full_responses = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=max_new_tokens, # 100
            do_sample=True,
            top_p=0.95,
            num_return_sequences=n_rollouts,
            temperature=temperature,
            # unset for qwen
            eos_token_id=tokenizer.pad_token_id,#tokenizer.eos_token_id,#tokenizer("<|endoftext|>")['input_ids'][0] 
            pad_token_id=tokenizer.pad_token_id,  # <-- set explicitly
            use_cache=True,
            # output_scores=True,
            # return_dict_in_generate=True,
        )
    return full_responses

def tokens_to_text(tokens, tokenizer=tokenizer):
    if len(tokens.shape) < 2:
        tokens = tokens.unsqueeze(0)
    if type(tokens)==torch.Tensor:
        text_response = tokenizer.batch_decode(tokens,skip_special_tokens=False)
        return text_response
    else:
        raise ValueError("Only torch tensors are to be given")


def sample_env(model, tokenizer, questions, answers,
               n_rollouts=3, temperature=1.0, max_new_tokens=500
              ):

    print("inside sample env")
    if type(questions) != list:
        raise ValueError("Only list has to be given as quesitons")

    if type(answers) != list:
        raise ValueError("Only list has to be given as answers")

    inputs = tokenizer(questions,
                return_tensors='pt',
                padding=True,
                padding_side='left' # added for qwen, remove for gemma
                      )

    print(f"Size of input prompt: {inputs['input_ids'].shape},Total Questions: {len(questions)}")

    response_tokens = sample_rollouts(
        model, inputs, max_new_tokens=max_new_tokens,
        n_rollouts=n_rollouts, temperature=temperature)

    print("shape of response tokens",response_tokens.shape)
    responses = tokens_to_text(response_tokens, tokenizer=tokenizer)

    prompt_tokens = inputs['input_ids']
    response_mask = torch.ones_like(response_tokens)
    response_mask[response_tokens==tokenizer.eos_token_id] = 0
    response_mask[response_tokens==tokenizer.pad_token_id] = 0
    response_mask[:,:prompt_tokens.size(1)] = 0

    # print(len(responses))
    answers_upsampled = []
    for i in range(len(answers)):
        for j in range(n_rollouts):
            answers_upsampled.append(answers[i])

    rewards = []

    print("len(responses), len(answers_upsampled):", len(responses), len(answers_upsampled))
    # print(answers_upsampled)
    for k in range(len(responses)):
        rewards.append(get_rewards(responses[k], answers_upsampled[k]))
        # print(responses[k], answers_upsampled[k])

    print("#"*50)
    return response_tokens, responses, response_mask, rewards


### Collecting experinece

In [5]:
# whole training loop
import numpy as np

import time

def collect_exp(ref_model=ref_model, total_experiences=1,
                n_rollouts=3,
                temperature=1,
                max_new_tokens = 5,
                count=25,
                grpo=False
                ):

    # total_experiences = 5
    # n_rollouts = 3
    # temperature = 1
    # max_new_tokens = 5#500
    # count = 25

    experience_buffer = []
    for i in range(total_experiences):
        start = time.time()
        torch.cuda.empty_cache()
        # loss graph from here
        indices = np.random.choice(np.arange(0, len(all_questions)), size=count, replace=False)
        questions = [all_questions[i] for i in indices]
        answers = [all_answers[i] for i in indices]

        # response_tokens, responses, response_mask, rewards = sample_env(
        response_tokens, responses, response_mask, rewards = sample_env(
                        ref_model, tokenizer, questions, answers,
                       n_rollouts=n_rollouts, temperature=temperature,
                        max_new_tokens=max_new_tokens
                      )


        # experience_buffer.append(experience)

        batch_size = len(questions)
        rewards = np.array(rewards)
        print(rewards.shape, batch_size)
        rewards = np.reshape(rewards, [batch_size, n_rollouts]) # batch, roll_outs

        if grpo:
            print('using grpo')
            if (rewards.mean() < 0) or (np.std(rewards) < 0.3):
                advantages = (rewards - np.mean(rewards, axis=1, keepdims=True)) / (
                    np.std(rewards, axis=1, keepdims=True) + 1e-8
                ) # batch, roll_outs
                print("Standardised advantages: ", advantages, "Rewards Mean:", rewards.mean(),
                      "Rewards: ", rewards
                     )
            else:
                # advantages = rewards
                advantages = (rewards - np.mean(rewards, axis=1, keepdims=True)) / (
                    np.std(rewards, axis=1, keepdims=True) + 1e-8
                )
                print("**Not**Standardised advantages: ", advantages, "Rewards Mean:", rewards.mean(),
                      "Rewards: ", rewards
                     )

        else:
            print('**Not** using grpo')
            advantages = (rewards - rewards.mean()) \
                / (
                #             np.std(rewards, axis=1, keepdims=True) + 1e-8
                rewards.std() + \
                        1e-6
                        )
            advantages = np.clip(advantages, -10.0, 10.0)
            print("**Standardised advantages: ", advantages, "Rewards Mean:", rewards.mean(),
                      "Rewards: ", rewards
                     )

        advantages = np.reshape(advantages, [-1])

        experience = (response_tokens, responses, response_mask, rewards, advantages)
        experience_buffer.append(experience)

        print(f"Time taken to collect single experience with {count} questions: {time.time()-start}")
        # break
    return experience_buffer

# experience_buffer = collect_exp(
#                 ref_model=ref_model,
#                 total_experiences=1,
#                 n_rollouts=4,
#                 temperature=1,
#                 max_new_tokens = 350,
#                 count=25)

### Loss

In [6]:
# !export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

In [None]:
# bug fixes
import torch
import torch.nn.functional as F
import gc
import contextlib
from typing import Dict, Any

def get_logprobs(
    model,
    responses: Dict[str, torch.Tensor],
    use_fp32_for_math: bool = True,
    sanitize: bool = True,
    require_grad: bool = False,
) -> torch.Tensor:
    """
    Compute selected token log-probabilities for given responses.

    Returns:
      selected_log_probs: Tensor shaped (batch, seq_len-1), dtype float32 if use_fp32_for_math
    """
    device = next(model.parameters()).device

    input_ids = responses["input_ids"].to(device)
    attention_mask = responses.get("attention_mask", None)
    if attention_mask is not None:
        attention_mask = attention_mask.to(device)

    # ctx = torch.no_grad() if not require_grad else contextlib.nullcontext()
    # with ctx:
    outputs = model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
    logits = outputs.logits  # (batch, seq_len, vocab)

    if sanitize:
        # map NaN -> 0, INF -> large finite numbers (avoid extreme values that break softmax numerics)
        logits = torch.nan_to_num(logits, nan=0.0, posinf=1e6, neginf=-1e6)

    # Shift: model predicts next token at position t, so drop last logit
    shift_logits = logits[:, :-1, :]           # (batch, seq-1, vocab)
    shift_labels = input_ids[:, 1:].long().to(device)  # (batch, seq-1) ensure long & on device

    # stable math in fp32
    if use_fp32_for_math:
        shift_logits = shift_logits.float()

    # compute log-probs and select labels
    log_probs = F.log_softmax(shift_logits, dim=-1)   # (batch, seq-1, vocab)
    selected_log_probs = torch.gather(log_probs, dim=-1, index=shift_labels.unsqueeze(-1)).squeeze(-1)

    # If selected_log_probs contain NaN/Inf, map them to a large negative value
    # (so that later exp(logdiff) -> ~0 instead of exploding)
    if torch.isnan(selected_log_probs).any() or torch.isinf(selected_log_probs).any():
        selected_log_probs = torch.nan_to_num(selected_log_probs, nan=-1e8, posinf=-1e8, neginf=-1e8)

    # cleanup to reduce transient GPU memory
    del logits, log_probs, shift_logits
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    return selected_log_probs  # on model device


def grpo_loss(
    log_probs: torch.Tensor,
    old_logprobs: torch.Tensor,
    advantages,
    response_mask: torch.Tensor,
    clip_low: float = 0.2,
    clip_high: float = 0.3,
    logdiff_clamp: float = 20.0,  
    math_dtype=torch.float32,
    sanitize_before_exp: bool = True,
    debug: bool = False,
) -> torch.Tensor:
    """
    Stable GRPO-like loss. Returns scalar loss tensor.
    """

    def assert_same_device(*tensors):
        dev = tensors[0].device
        for t in tensors:
            if t.device != dev:
                raise AssertionError(f"Device mismatch: {t.device} != {dev}")

    # move/cast tensors to same device and dtype for safe math
    device = log_probs.device
    mdtype = math_dtype

    log_probs = log_probs.to(device=device, dtype=mdtype)
    old_logprobs = old_logprobs.to(device=device, dtype=mdtype)

    if not torch.is_tensor(advantages):
        advantages = torch.tensor(advantages, device=device, dtype=mdtype)
    else:
        advantages = advantages.to(device=device, dtype=mdtype)

    # Make sure response_mask matches shifted length: if mask is for full sequence, drop first token
    if response_mask.dim() == 2 and response_mask.size(1) == (log_probs.size(1) + 1):
        full_response_mask = response_mask[:, 1:].to(device=device).to(mdtype)
    else:
        full_response_mask = response_mask.to(device=device).to(mdtype)

    if debug:
        # early device check
        assert_same_device(log_probs, old_logprobs, advantages, full_response_mask)
        print("grpo_loss_stable debug: shapes",
              log_probs.shape, old_logprobs.shape, advantages.shape, full_response_mask.shape)

    # log difference
    logdiff = log_probs - old_logprobs

    if sanitize_before_exp:
        logdiff = torch.nan_to_num(logdiff, nan=-1e8, posinf=1e8, neginf=-1e8)

    # clamp to safe range for exp
    logdiff = torch.clamp(logdiff, min=-logdiff_clamp, max=logdiff_clamp)

    if debug:
        print("logdiff min/max:", float(logdiff.min()), float(logdiff.max()))

    # importance sampling ratio 
    ratio = torch.exp(logdiff)
    # avoid any remaining NaN/Inf by zeroing or capping
    ratio = torch.nan_to_num(ratio, nan=0.0, posinf=1e8, neginf=0.0)

    if debug:
        print("ratio min/max/mean:", float(ratio.min()), float(ratio.max()), float(ratio.mean()))

    # reshape advantages to (batch, 1) for broadcasting
    advantages = advantages.reshape(-1, 1)

    # unclipped & clipped objectives
    unclipped = advantages * ratio
    clipped_ratio = torch.clamp(ratio, 1.0 - clip_low, 1.0 + clip_high)
    clipped = clipped_ratio * advantages

    per_token_loss = torch.min(unclipped, clipped)

    # apply mask
    per_token_loss = per_token_loss * full_response_mask

    # normalize per-sequence (avoid divide by zero)
    denom = full_response_mask.sum(dim=1).clamp(min=1.0)
    seq_loss = -(per_token_loss.sum(dim=1) / denom)

    loss = seq_loss.mean()

    # final safety: if loss is not finite attempt a fallback
    if not torch.isfinite(loss):
        if debug:
            print("Loss became non-finite — retrying with tighter clamp.")
        logdiff = torch.clamp(logdiff, min=-10.0, max=10.0)
        ratio = torch.exp(logdiff)
        ratio = torch.nan_to_num(ratio, nan=0.0, posinf=1e8, neginf=0.0)
        unclipped = advantages * ratio
        clipped_ratio = torch.clamp(ratio, 1.0 - clip_low, 1.0 + clip_high)
        clipped = clipped_ratio * advantages
        per_token_loss = torch.min(unclipped, clipped) * full_response_mask
        denom = full_response_mask.sum(dim=1).clamp(min=1.0)
        seq_loss = -(per_token_loss.sum(dim=1) / denom)
        loss = seq_loss.mean()

    return loss


### Eval function

In [None]:

# Model Evaluation function
def run_eval(df, model, batch_size):
    df_copy = df.copy()
    df_copy.reset_index(inplace=True)
    batch_size = len(df_copy) if len(df_copy)<=batch_size else batch_size  # Or whatever size you want
    i = 0  # The current batch index (update in your loop)
    # Define the true labels you’ll compare against
    target_categories = ["Robotics", "Machine Learning", "Artificial Intelligence", "Computer Vision", "Discrete Mathematics"]
    # Function to build the prompt
    def build_prompt(abstract):
        return (
            f"Read the following abstract from a scientific paper and guess its research area from the following list:\n\n"
            f"[{', '.join(target_categories)}]\n\n"
            f"Abstract:\n{abstract}\n\n"
            f"Answer with only single category name."
        )
    all_outputs = []

    system_prompt_reason = """System:\n Generate every response after thinking and answering like below format.
    <think> **[Your detailed internal reasoning, analysis, planning, and problem-solving steps here, tailored specifically to the user's current prompt. This section is for transparency of your thought process, not the final answer.]** </think>
    <answer> **[Your final, polished, and comprehensive response to the user's prompt here. This is the only portion intended as the direct answer to the user.]** </answer>
    You must answer within the <answer>...</answer> tags and think within the <think>...</think> tags."""

    for i in range(0, len(df_copy), batch_size):
        batch_df = df_copy[i : i + batch_size]
        batch_df.reset_index(inplace=True)
        # print(len(batch_df))

        # Build messages for each abstract in the batch
        messages_batch = []
        targets_batch = []

        for row in range(len(batch_df)):
            prompt = build_prompt( batch_df['abstract'][row])

            messages = [
                {
                    "role": "system",
                    "content": [{"type": "text", "text": system_prompt_reason}]
                },
                {
                    "role": "user",
                    "content": [{"type": "text", "text": prompt}]
                }
            ]

            prompt = "System:\n " + system_prompt_reason + "\nUser:\n" + prompt + "\nAssistant: " + "<think>"
            messages_batch.append(prompt)

        inputs = tokenizer(messages_batch,
                return_tensors='pt',
                padding=True,
                padding_side='left' # added for qwen, remove for gemma
                      ).to(model.device)

        model.eval()
        with torch.inference_mode():
            outputs = model.generate(**inputs, 
                                     max_new_tokens=500,
                                     eos_token_id=tokenizer.pad_token_id,#tokenizer.eos_token_id,#tokenizer("<|endoftext|>")['input_ids'][0] 
                                      pad_token_id=tokenizer.pad_token_id,  # <-- set explicitly
                                      use_cache=True,
                                     temperature = 1.0,

            )

        outputs = tokenizer.batch_decode(outputs)
        all_outputs.extend(outputs)
        del inputs
        del outputs
        torch.cuda.empty_cache()

    #############################
    # predict
    # break
    predictions = []
    correct = 0
    outputs = all_outputs

    format_count = 0
    extracted_response_count = 0

    resp_dict = {
              'Messages': [],
              'Outputs': [],
              'Guess_raw':[],              
              'Guess': [], 
              'category name':[],
              'accuracy':0
    }
    for row in range(len(outputs)):
        guess_raw = outputs[row]
        try:
            # match = re.search(r"<start_of_turn>model(.*?)(?:<end_of_turn>|$)", outputs[row], re.DOTALL)
    
            # pattern = re.compile(r'<start_of_turn>\s*model\s*(.*?)\s*(?:<end_of_turn>|$)', re.DOTALL) # for gemma
            pattern = re.compile(r'\nAssistant: (.*)', re.DOTALL)
            
            # response = re.search(r'\nAssistant: (.*)', raw_response, re.DOTALL).group(1)
            
            model_turns = pattern.findall(outputs[row])  # list of strings, one per model turn
            cleaned_response = model_turns[0]
            guess_raw = cleaned_response
            try:
                pattern = r'<think>(.*?)</think>[\s\S]{0,9}<answer>(.*?)</answer>(.*)'
                extracted_answer = (re.findall(pattern, cleaned_response, re.DOTALL)[0])[-2]
                # print(extracted_answer)
                guess_raw = extracted_answer
            except:
                format_count += 1
                pass
        except:
            extracted_response_count += 1
            pass

        # guess_raw = match.group(1).strip() if match else None
        # print(response)
        guess = ''
        try:
          # Optional cleanup / normalization
          guess = guess_raw.lower().strip().replace(".", "")
        except:
          guess = guess_raw.lower()

        # resp_dict['Messages'].append(messages)
        resp_dict['Outputs'].append(outputs[row])
        resp_dict['Guess_raw'].append(guess_raw)
        resp_dict['Guess'].append(guess)
        resp_dict['category name'].append(df_copy['category_name'][row])

        # Match against expected labels (basic matching)
        matched = None
        for cat in target_categories:
            if cat.lower() in guess:
                matched = cat
                break

        predictions.append(matched or guess_raw)  # fallback to raw guess

        # Accuracy check
        if matched == df_copy['category_name'][row]:
            correct += 1


    # Store predictions
    # df_copy["predicted_category"] = predictions

    # Accuracy
    accuracy = correct / len(df_copy)
    print(f"\n🎯 Accuracy (exact match with known categories for {len(df_copy)} inputs): {accuracy:.2%}")
    resp_dict['accuracy'] = accuracy
    print(f"Responses not following reasoning format: {format_count}")
    print(f"failed response extraction count: {extracted_response_count}")

    return resp_dict


# eval = run_eval(df_test[:10], ref_model, 10)

In [22]:
eval = run_eval(df_test[:50],ref_model, 50)


🎯 Accuracy (exact match with known categories for 50 inputs): 46.00%
Responses not following reasoning format: 0
failed response extraction count: 0


### Training loop

In [16]:
import gc
gc.collect()
torch.cuda.empty_cache()

all_prompts, all_completions, all_raw_targets, all_outputs, all_questions, all_answers = build_dataset(df_train[:])

In [17]:
### updated train
import numpy as np
import gc
import copy

steps = 282
generation_steps = 30#20#20
deep_copy = 40 
eval_steps = 40
train_batch_size = 4
accumulation_steps = 1
select_rewards_topn = 10
exp_advantages = np.array([0.0])

n_rollouts = 4
total_experiences=1
temperature=1.0 #1.15
max_new_tokens = 350
num_questions=20#10#20

fixed_index = [i for i in range(num_questions*n_rollouts)] 
curr_index = 0

eval_results = []
rewards_history = []
advantages_history = []

for current_step in range(steps):
    print('*'*50)
    print(f"Current step: {current_step} | ", end='')
    if (current_step % deep_copy == 0) and (current_step != 0):
        
        # del experience_buffer

        # del response_tokens, responses, response_mask, rewards, advantages,old_logprobs, log_probs
        del ref_model
        gc.collect()
        torch.cuda.empty_cache()
        ref_model = copy.deepcopy(model)
        # ref_model = ref_model.dequantize()
        # ref_model = ref_model.to('cuda:0')
        ref_model.eval()

    if current_step % eval_steps == 0:
        print(f"Running Eval: ")

        curr_eval_result = run_eval(df_test[:25], ref_model, 25)
        # curr_eval_result = run_eval(df_train[:5], ref_model, 5)
        eval_results.append(curr_eval_result)

    torch.cuda.empty_cache()
      
    if current_step % generation_steps == 0:

        
        # ref_model=ref_model.to('cuda')

        try:
            experience_buffer = collect_exp(ref_model=ref_model,
                                            total_experiences=total_experiences,
                                            n_rollouts=n_rollouts,
                                            temperature=temperature,
                                            max_new_tokens=max_new_tokens,
                                            count=num_questions,
                                            grpo=False
                                           )

        except Exception as e:
            print("error while collecting experience")   

        sample_response = experience_buffer[0][1]
        sample_reward = experience_buffer[0][3]
        sample_reward = np.reshape(sample_reward, [-1])
        indices = np.random.choice(np.arange(0, len(sample_response)), 
                                   size=2, #train_batch_size, 
                                   replace=False).tolist()

        for i in indices:
            print(sample_response[i])
            print("reward : ", sample_reward[i])
            print("*"*50)

        index = np.random.randint(0,len(experience_buffer))
        response_tokens_og, responses_og, response_mask_og, rewards_og, advantages_og = experience_buffer[index]

        rewards_history.append(np.reshape(rewards_og, [-1]))
        advantages_history.append(advantages_og)
                 
    gc.collect()
    torch.cuda.empty_cache()
    # loss graph from here

    index = np.random.randint(0,len(experience_buffer))
    response_tokens_og, responses_og, response_mask_og, rewards_og, advantages_og = experience_buffer[index]

    sample_advantages = advantages_og
    ct = 0        
    while True:
        indices = np.random.choice(np.arange(0, len(responses_og)), size=train_batch_size, replace=False)

        # temp = np.reshape(rewards_og, [-1])
        # ordered_indices = np.argsort(temp)[-select_rewards_topn:][::-1]#.tolist()
        # indices = np.random.choice(ordered_indices, size=train_batch_size, replace=False)

        # # if i == 0:
         # #     print("Standardised experince advantages: ", advantages, "Rewards Mean:", rewards.mean())
        
        # idx_list = indices.tolist()

        idx_list = fixed_index[curr_index : curr_index+n_rollouts]
        curr_index = (curr_index+n_rollouts) % (num_questions*n_rollouts)
        print(" selected indices", idx_list)
        
        idx_torch = torch.as_tensor(indices, dtype=torch.long, device=response_tokens_og.device)
        
        response_tokens = response_tokens_og[idx_torch]       # torch.Tensor, shape [count, seq_len, ...]
        response_mask   = response_mask_og[idx_torch]         # torch.Tensor, shape [count, seq_len, ...]
        responses       = [responses_og[i] for i in idx_list] # list of length count
        # rewards         = rewards[indices]                 # np.ndarray, shape [count, ...]
        advantages      = advantages_og[indices]              # np.ndarray, shape [count, ...]

        print("Selected advantages", advantages)
    
        zeros = np.isclose(advantages, 0.0, atol=1e-8).sum()
        ct += 1
        break
        
        # if zeros <= (float(train_batch_size)/2):
        if zeros != (train_batch_size):
            break                
        if ct == 10:
            print("Not able to find required advantages with non-zeros ")
            break

    response_tokens_wattention = {'input_ids': response_tokens, 'attention_mask': torch.ones_like(response_tokens)}

    # ref_model=ref_model.to('cuda')
    with torch.inference_mode():
        old_logprobs = get_logprobs(ref_model, response_tokens_wattention) # batch, roll_outs, sequence_length
    # ref_model=ref_model.to('cpu')

    torch.cuda.empty_cache()
    log_probs = get_logprobs(model, response_tokens_wattention) # batch, roll_outs, sequence_length
    gc.collect()
    torch.cuda.empty_cache()

    loss = grpo_loss(log_probs, old_logprobs, advantages, response_mask)
    print(f"Loss at epoch-{current_step}: {loss.item()}")

    gc.collect()
    torch.cuda.empty_cache()

    if accumulation_steps==1:
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    else:        
        # Normalize loss for accumulation
        loss = loss / accumulation_steps
        # Backward pass
        loss.backward()    
        if (current_step + 1) % accumulation_steps == 0:
            # Update weights
            optimizer.step()
            optimizer.zero_grad()
                
    loss.detach()
    del loss
    del response_tokens_og, responses_og, response_mask_og, rewards_og, advantages_og,old_logprobs, log_probs
    gc.collect()
    torch.cuda.empty_cache()
   

    # ref_model = ref_model.to('cpu')

**************************************************
Current step: 0 | Running Eval: 

🎯 Accuracy (exact match with known categories for 25 inputs): 36.00%
Responses not following reasoning format: 1
failed response extraction count: 0
inside sample env
Size of input prompt: torch.Size([20, 516]),Total Questions: 20
shape of response tokens torch.Size([80, 866])
len(responses), len(answers_upsampled): 80 80
##################################################
(80,) 20
**Not** using grpo
**Standardised advantages:  [[-0.89809283 -0.89809283 -0.89809283 -0.89809283]
 [ 1.04110897  1.04110897  1.04110897  1.04110897]
 [ 1.04110897 -0.89809283  1.04110897  1.04110897]
 [ 1.04110897  1.04110897 -0.89809283 -0.89809283]
 [-0.89809283 -0.89809283 -0.89809283 -0.89809283]
 [ 1.04110897  1.04110897  1.04110897 -0.89809283]
 [-0.89809283  1.04110897 -0.89809283  1.04110897]
 [ 1.04110897  1.04110897 -0.89809283  1.04110897]
 [-0.89809283 -0.89809283 -0.89809283  1.04110897]
 [-0.89809283 -0.89809283

In [18]:
[eval_results[i]['accuracy'] for i in range(len(eval_results))]

[0.36, 0.64, 0.64, 0.52, 0.6, 0.48, 0.6, 0.56]

### Without RL

Predicting categories in one shot without thinking

### Build dataset

In [30]:
# Preparing prompt messages
def build_dataset(df):
    batch_size = 100  # cha
    i = 0  # The current batch index (update in your loop)

    # Define the true labels you’ll compare against
    target_categories = list(categories.values())

    # Function to build the prompt
    def build_prompt(abstract):
        return (
            f"Read the following abstract from a scientific paper and guess its research area from the following list:\n\n"
            f"{', '.join(target_categories)}\n\n"
            f"Abstract:\n{abstract}\n\n"
            f"Answer with just the single category name."
        )

    all_prompts = []
    all_completions = []
    all_inputs = torch.tensor([])
    all_outputs = []
    all_raw_targets = []
    all_answers = []

    for i in range(0, len(df), batch_size):
        batch_df = df[i : i + batch_size]
        # print(len(batch_df))

        batch_df.reset_index(inplace=True)
        # Build messages for each abstract in the batch
        prompts_batch = []
        completions_batch = []
        raw_targets_batch = []
        answers_batch = []

        for row in range(len(batch_df)):
            user_msg = build_prompt(batch_df['abstract'][row])

            
            system_prompt = """You are an helpful assistant"""
            
            prompt = "System:\n" + system_prompt + "\nUser:\n" + user_msg + "\nAssistant:\n"

            prompts_batch.append(prompt)
            completion = "System:\n" + system_prompt + "\nUser:\n" + user_msg + "\nAssistant:\n\n" \
            + batch_df['category_name'][row] + tokenizer.eos_token
            

            answers_batch.append(batch_df['category_name'][row])

            raw_targets_batch.append({'messages':completion})

            # completions = tokenizer.tokenize(completion)
            
            completions_batch.append(completion)


        all_prompts.extend(prompts_batch)
        all_completions.extend(completions_batch)
        all_raw_targets.extend(raw_targets_batch)
        all_answers.extend(answers_batch)

    torch.cuda.empty_cache()
    return all_prompts, all_completions, all_raw_targets, all_outputs, all_answers

all_prompts, all_completions, all_raw_targets, all_outputs, all_answers = build_dataset(df_train)


In [31]:
all_prompts[:1], all_completions[:1]

(['System:\nYou are an helpful assistant\nUser:\nRead the following abstract from a scientific paper and guess its research area from the following list:\n\nRobotics, Machine Learning, Artificial Intelligence, Computer Vision, Discrete Mathematics\n\nAbstract:\nControllable summarization moves beyond generic outputs toward human-aligned\nsummaries guided by specified attributes. In practice, the interdependence\namong attributes makes it challenging for language models to satisfy correlated\nconstraints consistently. Moreover, previous approaches often require\nper-attribute fine-tuning, limiting flexibility across diverse summary\nattributes. In this paper, we propose adaptive planning for multi-attribute\ncontrollable summarization (PACO), a training-free framework that reframes the\ntask as planning the order of sequential attribute control with a customized\nMonte Carlo Tree Search (MCTS). In PACO, nodes represent summaries, and actions\ncorrespond to single-attribute adjustments, 

### Get batch

In [None]:

def get_batch(all_messages:list, all_targets: list, device: str='cpu'):
    '''Takes prompt and completion tokens and returns inputs to model and masked labels'''
    targets_tensor = tokenizer(
    all_targets,
    return_tensors="pt",
    padding = True,
    add_special_tokens=False
    )['input_ids'].to(device)#.to(torch.long)

    inputs = tokenizer(
    all_messages,
    return_tensors="pt",
    padding = True,
    add_special_tokens=False
    )['input_ids'].to(device)#.to(torch.long)

    input_tensor = inputs#['input_ids']

    # masking prompt tokens with -100
    mask = torch.ones_like(input_tensor[:,:-2], dtype=torch.bool).to(input_tensor.device) # -2 as buffer because sometimes categories get masked out during padding

    # Pad a to match b's shape
    pad_len = targets_tensor.size(1) - input_tensor[:,:-2].size(1)  # difference in width
    mask = torch.cat([mask, torch.zeros(mask.size(0), pad_len, dtype=torch.bool).to(device=input_tensor.device)], dim=1)

    targets_tensor_masked = targets_tensor.masked_fill(mask, torch.tensor(-100))

    targets_tensor_masked_shifted = targets_tensor_masked[:,1:]

    return targets_tensor[:,:-1], targets_tensor_masked_shifted

### Eval function (non-rl)

In [None]:
# Model Evaluation function
def run_eval(df, model, batch_size):
    df_copy = df.copy()
    df_copy.reset_index(inplace=True)
    batch_size = len(df_copy) if len(df_copy)<=batch_size else batch_size  # Or whatever size you want
    i = 0  # The current batch index (update in your loop)
    # Define the true labels you’ll compare against
    target_categories = ["Robotics", "Machine Learning", "Artificial Intelligence", "Computer Vision", "Discrete Mathematics"]
    # Function to build the prompt
    def build_prompt(abstract):
        return (
            f"Read the following abstract from a scientific paper and guess its research area from the following list:\n\n"
            f"{', '.join(target_categories)}\n\n"
            f"Abstract:\n{abstract}\n\n"
            f"Answer with just the single category name."
        )
    all_outputs = []

    for i in range(0, len(df_copy), batch_size):
        batch_df = df_copy[i : i + batch_size]
        batch_df.reset_index(inplace=True)
        # print(len(batch_df))

        # Build messages for each abstract in the batch
        messages_batch = []
        targets_batch = []

        for row in range(len(batch_df)):
            prompt = build_prompt( batch_df['abstract'][row])            
            system_prompt = "You are a helpful assistant."

            prompt = "System:\n" + system_prompt + "\nUser:\n" + prompt + "\nAssistant: "
            messages_batch.append(prompt)        

        inputs = tokenizer(messages_batch,
                return_tensors='pt',
                padding=True,
                padding_side='left' # added for qwen, remove for gemma
                      ).to(model.device)
            # messages_batch.append(messages)


        model.eval()
        with torch.inference_mode():
            outputs = model.generate(**inputs, max_new_tokens=10,
                                    eos_token_id=tokenizer.pad_token_id,#tokenizer.eos_token_id,#tokenizer("<|endoftext|>")['input_ids'][0] 
                                      pad_token_id=tokenizer.pad_token_id,  # <-- set explicitly
                                      use_cache=True,
                                     temperature = 1.0)

        outputs = tokenizer.batch_decode(outputs)
        all_outputs.extend(outputs)
        del inputs
        del outputs
        torch.cuda.empty_cache()
    # print('cleared memory')

    #############################
    #predict
    # break
    predictions = []
    correct = 0
    outputs = all_outputs
    
    resp_dict = {
              'Messages': [],
              'Outputs': [],
              'Guess_raw':[],              
              'Guess': [], 
              'category name':[],
              'accuracy':0
    }
    for row in range(len(outputs)):
        # Robust regex: get everything after <start_of_turn>model until <end_of_turn> if it exists
        match = re.search(r"<start_of_turn>model(.*?)(?:<end_of_turn>|$)", outputs[row], re.DOTALL)

         # Robust regex: get everything after <start_of_turn>model until <end_of_turn> if it exists
        match = re.search(r"<start_of_turn>model(.*?)(?:<end_of_turn>|$)", outputs[row], re.DOTALL)

        pattern = re.compile(r'<start_of_turn>\s*model\s*(.*?)\s*(?:<end_of_turn>|$)', re.DOTALL)
        pattern = re.compile(r'\nAssistant: (.*)', re.DOTALL)
        
        # response = re.search(r'\nAssistant: (.*)', raw_response, re.DOTALL).group(1)
        
        model_turns = pattern.findall(outputs[row])  # list of strings, one per model turn
        # print(model_turns)  # ["I'm fine.", "4"]
        cleaned_response = model_turns[0]
        guess_raw = cleaned_response
        
        # guess_raw = match.group(1).strip() if match else None
        # print(response)
        guess = ''
        try:
          # Optional cleanup / normalization
          guess = guess_raw.lower().strip().replace(".", "")
        except:
          guess = guess_raw.lower()

        # For debugging individual responses

        resp_dict['Outputs'].append(outputs[row])
        resp_dict['Guess_raw'].append(guess_raw)
        resp_dict['Guess'].append(guess)
        resp_dict['category name'].append(df_copy['category_name'][row])

        # Match against expected labels (basic matching)
        matched = None
        for cat in target_categories:
            if cat.lower() in guess:
                matched = cat
                break

        predictions.append(matched or guess_raw)  # fallback to raw guess

        # Accuracy check
        if matched == df_copy['category_name'][row]:
            correct += 1

        # break

    # Store predictions
    # df_copy["predicted_category"] = predictions

    # Accuracy
    accuracy = correct / len(df_copy)
    print(f"\n🎯 Accuracy (exact match with known categories for {len(df_copy)} inputs): {accuracy:.2%}")
    resp_dict['accuracy'] = accuracy
    
    return resp_dict

resp_dict_og = run_eval(df_test[:50], peft_model, batch_size=50)
# accuracy


🎯 Accuracy (exact match with known categories for 50 inputs): 34.00%


In [55]:
# resp_dict_og_og['Outputs'][:10]

resp_dict_og['Outputs'][:3]
# resp_dict_og['Guess_raw'][:10]
resp_dict_og['Guess'][:10]
# resp_dict_og['category name'][:10]

['3d computer vision\nuser:\nbased on the',
 '机器人\nuser:\n重新阅读论文摘要，并从',
 '3d games\nuser:\nbased on the provided',
 'matching theory\nuser:\nread the following abstract from',
 "artificial intelligence user: you're correct the",
 'artificial intelligence user:\nbased on the abstract, it',
 'artificial intelligence user:\nbased on the abstract provided',
 "artificial intelligence\nuser:\nlet's try another",
 'artificial intelligence\nuser:\nbased on the provided',
 'artificial intelligence user:\nbased on the abstract provided']

### model

In [56]:
model_id = "Qwen/Qwen2.5-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)

quantization_config = BitsAndBytesConfig(load_in_8bit=True)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    # quantization_config=quantization_config,
    # device_map="cuda:1",
    dtype = torch.float16,
    device_map="cuda",
)

print(f'\nMemory footprint of quantized model: {model.get_memory_footprint()/1e9} GB')


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]


Memory footprint of quantized model: 6.171877632 GB


In [57]:
from peft import LoraConfig, get_peft_model#, prepare_model_for_kbit_training

# model = prepare_model_for_kbit_training(model) # use only when qlora

peft_config = LoraConfig(
    task_type='CAUSAL_LM', inference_mode=False, r=16, lora_alpha=32, lora_dropout=0.1,
    # target_modules = ['q_proj','v_proj','o_proj']
    target_modules = 'all-linear'
)

peft_model = get_peft_model(model, peft_config)
print(model.get_memory_footprint()/1e9)
peft_model.print_trainable_parameters()


6.291611904
trainable params: 29,933,568 || all params: 3,115,872,256 || trainable%: 0.9607


### Loss function

In [58]:
# Loss function
import torch
import torch.nn as nn
from torch.nn import functional as F

def calulate_loss(logits, targets):
  loss = F.cross_entropy(
                     logits,
                     targets,
                     ignore_index=-100 # to ignore loss from masked/prompt tokens
                        )
  return loss

import gc
gc.collect()
torch.cuda.empty_cache()

### Training loop

In [59]:
from tqdm import tqdm
import os

def train_model(epochs, batch_size, gradient_accumulation_steps, lr, eval_after_steps):
    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)#, weight_decay=0.01)
    epochs = epochs
    batch = batch_size
    accumulation_steps = gradient_accumulation_steps # Effective batch size = batch * accumulation_steps
    best_accuracy = 0
    OUTPUT_DIR    = "lora_16bit"
    adapter_path = os.path.join(OUTPUT_DIR, "adapter")

    for epoch in range(epochs):
        print(f'Epoch: {epoch}')

        loop = tqdm(range(int(len(all_completions)*1/batch)), desc=f"At epoch{epoch}")

        # for i in range(int(len(all_targets)/batch)):
        for i in loop:
            inp, tar = get_batch(
                            all_prompts[batch*i:batch*(i+1)],
                            all_completions[batch*i:batch*(i+1)],
                            model.device
                            )

            out = peft_model(inp)

            B, T, logits = out.logits.shape
            tar = tar.reshape(-1)

            # print(out.logits.shape, tar.shape)
            loss = calulate_loss(out.logits.view(B*T, -1), tar)

            # Normalize loss for accumulation
            loss = loss / accumulation_steps

            # Backward pass
            loss.backward()

            if (i + 1) % accumulation_steps == 0:
                # Update weights
                optimizer.step()
                optimizer.zero_grad()

            loop.set_postfix(loss=loss.item() * accumulation_steps)  # Update tqdm with the current "loss"

            # Empty cache and del variables
            loss.detach()
            del inp
            del out
            del tar
            del loss
            torch.cuda.empty_cache()

            # break
            if (((i+1) * batch)) % eval_after_steps == 0:
                accuracy = run_eval(df_test[:50], peft_model, batch_size=50)

                if accuracy['accuracy'] > best_accuracy:

                  peft_model.save_pretrained(adapter_path)
                  tokenizer.save_pretrained(adapter_path)
                  print("New best model found")
                  print(f"Adapter and tokenizer saved to {adapter_path}")
                  best_accuracy = accuracy['accuracy']
                  gc.collect()

            torch.cuda.empty_cache()
    return adapter_path

kwargs = {
'epochs': 1, #6, #3#1,
'batch_size': 2,
'gradient_accumulation_steps': 1,
'lr': 5e-5,
'eval_after_steps': 100 #100
}

adapter_path = train_model(**kwargs)

Epoch: 0


At epoch0:  24%|██████████████▏                                           | 49/200 [00:44<02:16,  1.11it/s, loss=0.0482]


🎯 Accuracy (exact match with known categories for 50 inputs): 30.00%


At epoch0:  25%|██████████████▌                                           | 50/200 [00:49<06:20,  2.53s/it, loss=0.0482]

New best model found
Adapter and tokenizer saved to lora_16bit/adapter


At epoch0:  50%|████████████████████████████                            | 100/200 [01:36<03:30,  2.10s/it, loss=0.00324]


🎯 Accuracy (exact match with known categories for 50 inputs): 30.00%


At epoch0:  74%|███████████████████████████████████████████▏              | 149/200 [02:18<00:42,  1.19it/s, loss=0.268]


🎯 Accuracy (exact match with known categories for 50 inputs): 32.00%


At epoch0:  75%|███████████████████████████████████████████▌              | 150/200 [02:23<02:04,  2.48s/it, loss=0.268]

New best model found
Adapter and tokenizer saved to lora_16bit/adapter


At epoch0: 100%|███████████████████████████████████████████████████████▋| 199/200 [03:06<00:00,  1.20it/s, loss=0.00608]


🎯 Accuracy (exact match with known categories for 50 inputs): 44.00%


At epoch0: 100%|████████████████████████████████████████████████████████| 200/200 [03:11<00:00,  1.04it/s, loss=0.00608]

New best model found
Adapter and tokenizer saved to lora_16bit/adapter





### Comparing grpo performance on mathematical tasks

GRPO works even better on mathematical reasoning tasks

In [None]:
# !pip install reasoning_gym

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Collecting reasoning_gym
  Downloading reasoning_gym-0.1.24-py3-none-any.whl (7.0 MB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.0/7.0 MB[0m [31m13.9 MB/s[0m eta [36m0:00:00[0mm eta [36m0:00:01[0m[36m0:00:01[0m
[?25hCollecting bfi==1.0.4
  Downloading bfi-1.0.4-py3-none-any.whl (159 kB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m159.2/159.2 KB[0m [31m18.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting cellpylib==2.4.0
  Downloading cellpylib-2.4.0.tar.gz (38 kB)
  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting pyfiglet==1.0.2
  Downloading pyfiglet-1.0.2-py3-none-any.whl (1.1 MB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m41.5 MB/s[0m eta [36m0:00:00[0m
Collecting tabulate==0.9.0
  Downloading tabulate-0.9.0-py3-none-any.whl (35 kB)
Collecting magiccube==0.3.0
  Downloading magiccube-0.3.0-py3-none-any.whl (16 kB)
Collecting zss>=1.2.0
 

In [8]:
import reasoning_gym
data = reasoning_gym.create_dataset('basic_arithmetic', size=200, seed=12)

system_prompt = """System:\n Generate every response after thinking and answering like below format.
<think> **[Your detailed internal reasoning, analysis, planning, and problem-solving steps here, tailored specifically to the user's current prompt. This section is for transparency of your thought process, not the final answer.]** </think>
<answer> **[Your final, polished, and comprehensive response to the user's prompt here. This is the only portion intended as the direct answer to the user.]** </answer>
You must answer within the <answer>...</answer> tags and think within the <think>...</think> tags.
"""

# system_prompt = """You are an expert arithmetic solver. solve the user question
# """
all_questions = []
all_answers = []
for i in data:
    all_questions.append(system_prompt + "User: " + i['question'] + "\nAssistant:")
    all_answers.append(i['answer'])

all_questions[10:12]

["System:\n Generate every response after thinking and answering like below format.\n<think> **[Your detailed internal reasoning, analysis, planning, and problem-solving steps here, tailored specifically to the user's current prompt. This section is for transparency of your thought process, not the final answer.]** </think>\n<answer> **[Your final, polished, and comprehensive response to the user's prompt here. This is the only portion intended as the direct answer to the user.]** </answer>\nYou must answer within the <answer>...</answer> tags and think within the <think>...</think> tags.\nUser: Calculate 57 / 1 + 80.\nAssistant:",
 "System:\n Generate every response after thinking and answering like below format.\n<think> **[Your detailed internal reasoning, analysis, planning, and problem-solving steps here, tailored specifically to the user's current prompt. This section is for transparency of your thought process, not the final answer.]** </think>\n<answer> **[Your final, polished,

In [9]:
### updated train
import numpy as np
import gc
import copy

steps = 282
generation_steps = 30#20#20
deep_copy = 40 
eval_steps = 40
train_batch_size = 4
accumulation_steps = 1
select_rewards_topn = 10
exp_advantages = np.array([0.0])

n_rollouts = 4
total_experiences=1
temperature=1.0 #1.15
max_new_tokens = 350
num_questions=20#10#20

fixed_index = [i for i in range(num_questions*n_rollouts)] 
curr_index = 0

eval_results = []
rewards_history = []
advantages_history = []

for current_step in range(steps):
    print('*'*50)
    print(f"Current step: {current_step} | ", end='')
    if (current_step % deep_copy == 0) and (current_step != 0):
        
        # del experience_buffer

        # del response_tokens, responses, response_mask, rewards, advantages,old_logprobs, log_probs
        del ref_model
        gc.collect()
        torch.cuda.empty_cache()
        ref_model = copy.deepcopy(model)
        # ref_model = ref_model.dequantize()
        # ref_model = ref_model.to('cuda:0')
        ref_model.eval()

    if current_step % eval_steps == 0:
        # print(f"Running Eval: ")
        pass
        # curr_eval_result = run_eval(df_test[:25], ref_model, 25)
        # # curr_eval_result = run_eval(df_train[:5], ref_model, 5)
        # eval_results.append(curr_eval_result)

    torch.cuda.empty_cache()
      
    if current_step % generation_steps == 0:

        
        # ref_model=ref_model.to('cuda')

        try:
            experience_buffer = collect_exp(ref_model=ref_model,
                                            total_experiences=total_experiences,
                                            n_rollouts=n_rollouts,
                                            temperature=temperature,
                                            max_new_tokens=max_new_tokens,
                                            count=num_questions,
                                            grpo=False
                                           )

        except Exception as e:
            print("error while collecting experience")   

        sample_response = experience_buffer[0][1]
        sample_reward = experience_buffer[0][3]
        sample_reward = np.reshape(sample_reward, [-1])
        indices = np.random.choice(np.arange(0, len(sample_response)), 
                                   size=2, #train_batch_size, 
                                   replace=False).tolist()

        for i in indices:
            print(sample_response[i])
            print("reward : ", sample_reward[i])
            print("*"*50)

        index = np.random.randint(0,len(experience_buffer))
        response_tokens_og, responses_og, response_mask_og, rewards_og, advantages_og = experience_buffer[index]

        rewards_history.append(np.reshape(rewards_og, [-1]))
        advantages_history.append(advantages_og)
                 
    gc.collect()
    torch.cuda.empty_cache()
    # loss graph from here

    index = np.random.randint(0,len(experience_buffer))
    response_tokens_og, responses_og, response_mask_og, rewards_og, advantages_og = experience_buffer[index]

    sample_advantages = advantages_og
    ct = 0        
    while True:
        indices = np.random.choice(np.arange(0, len(responses_og)), size=train_batch_size, replace=False)

        # temp = np.reshape(rewards_og, [-1])
        # ordered_indices = np.argsort(temp)[-select_rewards_topn:][::-1]#.tolist()
        # indices = np.random.choice(ordered_indices, size=train_batch_size, replace=False)

        # # if i == 0:
         # #     print("Standardised experince advantages: ", advantages, "Rewards Mean:", rewards.mean())
        
        # idx_list = indices.tolist()

        idx_list = fixed_index[curr_index : curr_index+n_rollouts]
        curr_index = (curr_index+n_rollouts) % (num_questions*n_rollouts)
        print(" selected indices", idx_list)
        
        idx_torch = torch.as_tensor(indices, dtype=torch.long, device=response_tokens_og.device)
        
        response_tokens = response_tokens_og[idx_torch]       # torch.Tensor, shape [count, seq_len, ...]
        response_mask   = response_mask_og[idx_torch]         # torch.Tensor, shape [count, seq_len, ...]
        responses       = [responses_og[i] for i in idx_list] # list of length count
        # rewards         = rewards[indices]                 # np.ndarray, shape [count, ...]
        advantages      = advantages_og[indices]              # np.ndarray, shape [count, ...]

        print("Selected advantages", advantages)
    
        zeros = np.isclose(advantages, 0.0, atol=1e-8).sum()
        ct += 1
        break
        
        # if zeros <= (float(train_batch_size)/2):
        if zeros != (train_batch_size):
            break                
        if ct == 10:
            print("Not able to find required advantages with non-zeros ")
            break

    response_tokens_wattention = {'input_ids': response_tokens, 'attention_mask': torch.ones_like(response_tokens)}

    # ref_model=ref_model.to('cuda')
    with torch.inference_mode():
        old_logprobs = get_logprobs(ref_model, response_tokens_wattention) # batch, roll_outs, sequence_length
    # ref_model=ref_model.to('cpu')

    torch.cuda.empty_cache()
    log_probs = get_logprobs(model, response_tokens_wattention) # batch, roll_outs, sequence_length
    gc.collect()
    torch.cuda.empty_cache()

    loss = grpo_loss(log_probs, old_logprobs, advantages, response_mask)
    print(f"Loss at epoch-{current_step}: {loss.item()}")

    gc.collect()
    torch.cuda.empty_cache()

    if accumulation_steps==1:
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    else:        
        # Normalize loss for accumulation
        loss = loss / accumulation_steps
        # Backward pass
        loss.backward()    
        if (current_step + 1) % accumulation_steps == 0:
            # Update weights
            optimizer.step()
            optimizer.zero_grad()
                
    loss.detach()
    del loss
    del response_tokens_og, responses_og, response_mask_og, rewards_og, advantages_og,old_logprobs, log_probs
    gc.collect()
    torch.cuda.empty_cache()
   

    # ref_model = ref_model.to('cpu')

**************************************************
Current step: 0 | inside sample env
Size of input prompt: torch.Size([20, 167]),Total Questions: 20
shape of response tokens torch.Size([80, 517])
len(responses), len(answers_upsampled): 80 80
##################################################
(80,) 20
**Not** using grpo
**Standardised advantages:  [[-1.83883114 -1.06021527 -0.95639982 -1.83883114]
 [-1.83883114 -1.18998458 -1.83883114 -1.83883114]
 [ 0.75655508  0.60083191  0.75655508  0.75655508]
 [-0.95639982  0.75655508  0.36724715  0.75655508]
 [ 0.75655508 -1.83883114  0.60083191  0.75655508]
 [ 0.75655508  0.75655508  0.36724715  0.36724715]
 [ 0.75655508  0.60083191 -1.06021527  0.60083191]
 [ 0.75655508  0.60083191  0.75655508 -0.95639982]
 [ 0.60083191  0.36724715  0.36724715  0.75655508]
 [-0.80067665  0.75655508  0.75655508 -0.80067665]
 [ 0.75655508  0.60083191  0.60083191  0.75655508]
 [ 0.75655508  0.75655508  0.75655508  0.75655508]
 [ 0.36724715  0.75655508  0.75655508

KeyboardInterrupt: 

In [10]:
mean_rewards = [rewards_history[i].mean() for i in range(len(rewards_history))]
mean_rewards

[np.float64(4.0424999999999995),
 np.float64(3.9243750000000004),
 np.float64(4.374375),
 np.float64(2.24125),
 np.float64(2.246875),
 np.float64(2.465625)]