In [1]:
%load_ext autoreload
%autoreload 2
import sys
import os
sys.path.append("../../")
from src.utils.trainer_utils import test_model
from src.model.components.control_token_wrappers import PauseClassifierWrapper
from src.utils.instantiators import instantiate_generation_params
from typing import List
from src.utils.trainer_utils import inference_formatting_function, reward_conditioning_inference_formatting_function,save_json
from functools import partial
from src.utils.constants import CORRECT_ANSWER_FEEDBACK
import pandas as pd
from copy import deepcopy
import editdistance
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
from torch.cuda import empty_cache
import json
from notebooks.model_playgrounds.utils import (
    load_model_and_tokenizer,
    load_generation_config,
    load_test_metrics,
    preprocess_data_fn,
    load_dataset_from_config,
    make_df,
    rollout_models
)
from functools import partial
from pytorch_lightning import seed_everything
from src.reward import GSM8KFinalAnswerLogLikelihoodReward
from tqdm import tqdm 
seed_everything(42, workers=True)

MODEL_NAMES= [
    "sft_peft",
    "baseline-(model-w-out-pause-peft)-2-epoch",
    "offline_star_exp-no_pause_peft_temp_1.0_part2",
    "offline_star_exp-pause_temp_1.0_part2"
]

MODEL_PATHS = [
    "/dlabscratch1/baldwin/pause2/PauseToken/logs/sft/runs/2024-10-18_17-38-00/final",
    "/dlabscratch1/baldwin/pause2/PauseToken/logs/train/runs/2024-10-28_11-26-28/final",
    "/dlabscratch1/baldwin/pause2/PauseToken/logs/train/runs/2024-11-04_16-16-47/last_ckpt",
    "/dlabscratch1/baldwin/pause2/PauseToken/logs/train/runs/2024-11-04_16-38-59/last_ckpt"
]

NAME_TO_PATH = {name: path for name, path in zip(MODEL_NAMES, MODEL_PATHS)}


load_model_and_tokenizer = partial(load_model_and_tokenizer, name_to_path_dict = NAME_TO_PATH)
rollout_models = partial(rollout_models, name_to_path_dict = NAME_TO_PATH)


def find_last_common_ids(ids1, ids2):
    for i in range(len(ids1)):
        if ids1[i] != ids2[i]:
            return i-1
    return len(ids1) - 1

  from .autonotebook import tqdm as notebook_tqdm
Seed set to 42


In [2]:
# ~~~~~~~ Load Generation Config & dataset~~~~~~~
cfg = load_generation_config(
    pad_token_id = 0,
    eos_token_id=2,
    bos_token_id=1,
    max_length=600,
    overrides=[f"generation_config.temperature={1.0}", "generation_config.do_sample=False"]
)
gen_cfg = instantiate_generation_params(cfg)

dataset = load_dataset_from_config("gsm8k")
dataset["train"][0]



{'output': 'A kilogram of chicken costs $6 - $2 = $<<6-2=4>>4.\nThree kilograms of chicken cost $4 x 3 = $<<4*3=12>>12.\nSo, a 3-kilogram of chicken and a kilogram of pork cost $12 + $6 = $18.\n#### 18',
 'input': 'A kilogram of pork costs $6 while a kilogram of chicken costs $2 less. How much will a 3-kilogram of chicken and a kilogram of pork cost?'}

# Let's do a couple rollouts

In [3]:
# ~~~~~ Running inference with each model on n questions of the GSM8k's train set and for each questions sample k generations on gsm8k for each temperature~~~~~
n_samples = 20
k_generations = 1
force_overwrite = True

select_idx = []
for samp_idx in range(n_samples):
    for gen_idx in range(k_generations):
        select_idx.append(samp_idx)
        
train_samples = dataset["train"].select(select_idx)
test_samples = dataset["test"].select(select_idx)
gsm8k_metrics =load_test_metrics("gsm8k")

res_per_temp = {}

train_exp_name = f"gsm8k_train_{n_samples}_samples_{k_generations}_temperature_{gen_cfg['generation_config'].temperature}"

results_train = rollout_models(
    data_samples = train_samples,
    model_names = MODEL_NAMES,
    generation_config= gen_cfg,
    exp_name = train_exp_name,
    force_overwrite=force_overwrite,
    batch_size=25,
    test_metrics=gsm8k_metrics
)

test_exp_name = f"gsm8k_test_{n_samples}_samples_{k_generations}_temperature_{gen_cfg['generation_config'].temperature}"

results_test = rollout_models(
    data_samples = test_samples,
    model_names = MODEL_NAMES,
    generation_config= gen_cfg,
    exp_name = test_exp_name,
    force_overwrite=force_overwrite,
    batch_size=25,
    test_metrics=gsm8k_metrics
)


Loading checkpoint shards: 100%|██████████| 3/3 [00:09<00:00,  3.23s/it]


Setting control token temperature to 1.0


Loading checkpoint shards: 100%|██████████| 3/3 [00:08<00:00,  2.94s/it]


device cuda:0


Map: 100%|██████████| 20/20 [00:00<00:00, 1101.79 examples/s]
Rollout Step:   0%|          | 0/1 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Rollout Step: 100%|██████████| 1/1 [00:08<00:00,  8.56s/it]
Loading checkpoint shards: 100%|██████████| 3/3 [00:15<00:00,  5.17s/it]


device cuda:0


Rollout Step:   0%|          | 0/1 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Rollout Step: 100%|██████████| 1/1 [00:33<00:00, 33.22s/it]
Loading checkpoint shards: 100%|██████████| 3/3 [00:08<00:00,  3.00s/it]


device cuda:0


Rollout Step:   0%|          | 0/1 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Rollout Step: 100%|██████████| 1/1 [00:15<00:00, 15.35s/it]
Loading checkpoint shards: 100%|██████████| 3/3 [00:02<00:00,  1.48it/s]


Setting control token temperature to 1.0


Loading checkpoint shards: 100%|██████████| 3/3 [00:01<00:00,  1.91it/s]


device cuda:0


Rollout Step:   0%|          | 0/1 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Rollout Step: 100%|██████████| 1/1 [00:07<00:00,  7.02s/it]
Performing rollouts for models: 100%|██████████| 4/4 [02:10<00:00, 32.72s/it, current_model=offline_star_exp-pause_temp_1.0_part2]
Loading results from files: 100%|██████████| 4/4 [00:00<00:00, 144.28it/s]
Loading checkpoint shards: 100%|██████████| 3/3 [00:01<00:00,  1.57it/s]


Setting control token temperature to 1.0


Loading checkpoint shards: 100%|██████████| 3/3 [00:01<00:00,  2.08it/s]


device cuda:0


Map: 100%|██████████| 20/20 [00:00<00:00, 2120.85 examples/s]
Rollout Step:   0%|          | 0/1 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Rollout Step: 100%|██████████| 1/1 [00:16<00:00, 16.80s/it]
Loading checkpoint shards: 100%|██████████| 3/3 [00:08<00:00,  2.95s/it]


device cuda:0


Rollout Step:   0%|          | 0/1 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Rollout Step: 100%|██████████| 1/1 [00:32<00:00, 32.80s/it]
Loading checkpoint shards: 100%|██████████| 3/3 [00:09<00:00,  3.12s/it]


device cuda:0


Rollout Step:   0%|          | 0/1 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Rollout Step: 100%|██████████| 1/1 [00:32<00:00, 32.79s/it]
Loading checkpoint shards: 100%|██████████| 3/3 [00:03<00:00,  1.11s/it]


Setting control token temperature to 1.0


Loading checkpoint shards: 100%|██████████| 3/3 [00:03<00:00,  1.16s/it]


device cuda:0


Rollout Step:   0%|          | 0/1 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Rollout Step: 100%|██████████| 1/1 [00:12<00:00, 12.80s/it]
Performing rollouts for models: 100%|██████████| 4/4 [02:23<00:00, 35.78s/it, current_model=offline_star_exp-pause_temp_1.0_part2]
Loading results from files: 100%|██████████| 4/4 [00:00<00:00, 517.67it/s]


# Computing NLL of ground truth final answer evolves throughout the CoT

In [4]:

def get_full_text(examples):
    text = []
    for i in range(len(examples["input"])):
        prompt = examples["input"][i]
        output = examples["output"][i]
        text.append(f'{prompt}{output}')
    return {"text": text}

def compute_reward_after_each_token(data_samples, gts ,model_name, exp_name, force_overwrite = False):
    output_dir = output_dir = os.path.join(".", "data", exp_name, "rewards")
    
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    path_to_file = os.path.join(output_dir, model_name)
    
    if os.path.exists(path_to_file) and not force_overwrite:
        print(f"File {path_to_file} already exists, skipping generation. Will load from file.")
        with open(path_to_file, "r") as f:
            results = json.load(f)
            return results
        
        
    # ~~~~ Load Model ~~~~
    model, tokenizer = load_model_and_tokenizer(model_name)
    # ~~~~ Instantiate Rewards ~~~~
    reward_fn = GSM8KFinalAnswerLogLikelihoodReward(
        tokenizer=tokenizer,
        model=model,
        delimiter="####"
    )

    process_fn = preprocess_data_fn(model_name, tokenizer.eos_token)
    gts = gts.map(process_fn, batched = True)
    gts = gts.map(get_full_text, batched=True)
    
    results = []
    
    for gt,sample in tqdm(zip(gts, data_samples), total = len(gts)):
        sample_res = {}
        # ~~tokenize questions
        question = tokenizer(gt["input"].lstrip("<s> "))["input_ids"]
        # ~~tokenize prediction
        pred_ids = tokenizer(sample["generated_text"].lstrip("<s> "))["input_ids"]
        # ~~tokenize gt
        gt_ids = tokenizer(gt["text"].lstrip("<s> "))["input_ids"]
        last_common_idx = find_last_common_ids(question, pred_ids)
        rewards = []
        
        sample_res = {"question": sample["input"],"pred_output": sample["predicted_output"], "ground_truth": gt["output"], "is_correct": sample['test/accuracy'] }
        
        token_positions = []
        rewards = []
        decoded_tokens = []
        for i in range(last_common_idx+1, len(pred_ids)):
            sub_pred = pred_ids[:i]
            rewards.append(reward_fn(sub_pred,gt_ids ))
            token_positions.append(i)
            decoded_tokens.append(tokenizer.convert_ids_to_tokens(pred_ids[i-1]))

            
        sample_res = {**sample_res,**{"rewards": rewards, "token_positions": token_positions, "added_tokens": decoded_tokens}}
        results.append(sample_res)
    del model
    del tokenizer
    empty_cache()
    save_json(results, output_dir, model_name)
    return results
            

In [5]:
force_overwrite = True
debug_n = None
all_rewards_train = {}
all_rewards_test = {}

for model_name in MODEL_NAMES:

    print(f"Computing Rewards For Train samples of {model_name}")
    r_train = results_train[model_name][:debug_n] if debug_n is not None else results_train[model_name]
    gts_train = train_samples.select(range(debug_n)) if debug_n is not None else train_samples

    res_train = compute_reward_after_each_token(
        data_samples=r_train,
        gts = gts_train,
        exp_name=train_exp_name,
        model_name=model_name,
        force_overwrite=force_overwrite
    )

    all_rewards_train[model_name] = res_train
    
    print(f"Computing Rewards For test samples of {model_name}")
    r_test = results_test[model_name][:debug_n] if debug_n is not None else results_test[model_name]
    gts_test = test_samples.select(range(debug_n)) if debug_n is not None else test_samples

    res_test = compute_reward_after_each_token(
        data_samples=r_test,
        gts = gts_test,
        exp_name=test_exp_name,
        model_name=model_name,
        force_overwrite=force_overwrite
    )
    
    all_rewards_test[model_name] = res_test

Computing Rewards For Train samples of sft_peft


Loading checkpoint shards: 100%|██████████| 3/3 [00:02<00:00,  1.38it/s]


Setting control token temperature to 1.0


Loading checkpoint shards: 100%|██████████| 3/3 [00:01<00:00,  2.02it/s]
Map: 100%|██████████| 20/20 [00:00<00:00, 794.41 examples/s]
100%|██████████| 20/20 [01:31<00:00,  4.57s/it]


Computing Rewards For test samples of sft_peft


Loading checkpoint shards: 100%|██████████| 3/3 [00:02<00:00,  1.14it/s]


Setting control token temperature to 1.0


Loading checkpoint shards: 100%|██████████| 3/3 [00:01<00:00,  1.66it/s]
Map: 100%|██████████| 20/20 [00:00<00:00, 1975.23 examples/s]
100%|██████████| 20/20 [01:51<00:00,  5.56s/it]


Computing Rewards For Train samples of baseline-(model-w-out-pause-peft)-2-epoch


Loading checkpoint shards: 100%|██████████| 3/3 [00:08<00:00,  2.97s/it]
100%|██████████| 20/20 [08:31<00:00, 25.57s/it]


Computing Rewards For test samples of baseline-(model-w-out-pause-peft)-2-epoch


Loading checkpoint shards: 100%|██████████| 3/3 [00:09<00:00,  3.19s/it]
100%|██████████| 20/20 [09:49<00:00, 29.49s/it]


Computing Rewards For Train samples of offline_star_exp-no_pause_peft_temp_1.0_part2


Loading checkpoint shards: 100%|██████████| 3/3 [00:08<00:00,  2.98s/it]
100%|██████████| 20/20 [06:26<00:00, 19.34s/it]


Computing Rewards For test samples of offline_star_exp-no_pause_peft_temp_1.0_part2


Loading checkpoint shards: 100%|██████████| 3/3 [00:13<00:00,  4.45s/it]
100%|██████████| 20/20 [09:22<00:00, 28.12s/it]


Computing Rewards For Train samples of offline_star_exp-pause_temp_1.0_part2


Loading checkpoint shards: 100%|██████████| 3/3 [00:04<00:00,  1.36s/it]


Setting control token temperature to 1.0


Loading checkpoint shards: 100%|██████████| 3/3 [00:04<00:00,  1.53s/it]
100%|██████████| 20/20 [01:29<00:00,  4.50s/it]


Computing Rewards For test samples of offline_star_exp-pause_temp_1.0_part2


Loading checkpoint shards: 100%|██████████| 3/3 [00:02<00:00,  1.41it/s]


Setting control token temperature to 1.0


Loading checkpoint shards: 100%|██████████| 3/3 [00:01<00:00,  2.10it/s]
100%|██████████| 20/20 [02:00<00:00,  6.02s/it]


In [6]:
all_rewards_test["baseline-(model-w-out-pause-peft)-2-epoch"][1].keys()

dict_keys(['question', 'pred_output', 'ground_truth', 'is_correct', 'rewards', 'token_positions', 'added_tokens'])

In [7]:
def filter_rewards_on_token(data, token= "<0x0A>"):
    
    new_data = {}
    columns_to_filter = ["rewards","token_positions","added_tokens"]
    
    for model_name, reses in data.items():
        new_data[model_name] = []
        for res in reses:
            relevant_indx = list(
                map(lambda x: x[0], 
                    filter(lambda x: token == x[1], 
                        enumerate(res["added_tokens"])
                        )
                )
            )
            tmp_dict = {}
            for col in res.keys():                
                if col in columns_to_filter:
                    tmp_dict[col] = [res[col][i] for i in relevant_indx]
                else:
                    tmp_dict[col] = res[col]

            new_data[model_name].append(tmp_dict)
            
    return new_data

In [8]:
only_correct_rewards_train = {}
only_incorrect_rewards_train = {}
only_correct_rewards_test = {}
only_incorrect_rewards_test = {}
for name in all_rewards_train.keys():
    only_correct_rewards_train[name] = list(
        filter(lambda x: x["is_correct"], all_rewards_train[name])
    )
    only_incorrect_rewards_train[name] = list(
        filter(lambda x: not x["is_correct"], all_rewards_train[name])
    )
    
    only_correct_rewards_test[name] = list(
        filter(lambda x: x["is_correct"], all_rewards_test[name])
    )
    
    only_incorrect_rewards_test[name] = list(
        filter(lambda x: not x["is_correct"], all_rewards_test[name])
    )

In [9]:

for sample_type in ["correct", "incorrect"]:
    print(f"Number of {sample_type} Samples Per Model")
    for set in ["train", "test"]:
        print(f"    For {set} samples")
        var_name = "only_correct_rewards_" + set
        samples_dict = eval(var_name)
        for name in samples_dict.keys():
            print(f"        {name}: {len(samples_dict[name])} samples")

Number of correct Samples Per Model
    For train samples
        sft_peft: 14 samples
        baseline-(model-w-out-pause-peft)-2-epoch: 12 samples
        offline_star_exp-no_pause_peft_temp_1.0_part2: 16 samples
        offline_star_exp-pause_temp_1.0_part2: 14 samples
    For test samples
        sft_peft: 9 samples
        baseline-(model-w-out-pause-peft)-2-epoch: 11 samples
        offline_star_exp-no_pause_peft_temp_1.0_part2: 10 samples
        offline_star_exp-pause_temp_1.0_part2: 11 samples
Number of incorrect Samples Per Model
    For train samples
        sft_peft: 14 samples
        baseline-(model-w-out-pause-peft)-2-epoch: 12 samples
        offline_star_exp-no_pause_peft_temp_1.0_part2: 16 samples
        offline_star_exp-pause_temp_1.0_part2: 14 samples
    For test samples
        sft_peft: 9 samples
        baseline-(model-w-out-pause-peft)-2-epoch: 11 samples
        offline_star_exp-no_pause_peft_temp_1.0_part2: 10 samples
        offline_star_exp-pause_temp_1.0_

# Plotting Results

In [10]:
def plot_rewards(res, title, save_name ,n_samples = None, width = 800, height = 600):
    # Initialize the figure
    fig = go.Figure()

    # Function to insert line breaks into long text strings
    def add_line_breaks(text, max_line_length=50):
        words = text.split()
        lines = []
        current_line = ""
        
        for word in words:
            # Add word to the line if it doesn't exceed the max line length
            if len(current_line) + len(word) + 1 <= max_line_length:
                current_line += " " + word if current_line else word
            else:
                lines.append(current_line)
                current_line = word
        if current_line:
            lines.append(current_line)
        
        return "<br>".join(lines)

    # Loop over each sample in the dataset
    for i, sample in enumerate(res):
        # Prepare hover information text for each point with line breaks
        hover_text = [
            f"<b>Added Token:</b> {added_token}<br>"
            f"<b>Question:</b> {add_line_breaks(sample['question'])}<br>"
            f"<b>Pred Output:</b> {add_line_breaks(sample['pred_output'])}<br>"
            f"<b>Ground Truth:</b> {add_line_breaks(sample['ground_truth'])}<br>"
            f"<b>Reward:</b> {reward}"
            for reward, added_token in zip(sample['rewards'], sample['added_tokens'])
        ]

        # Add a line for each sample
        fig.add_trace(go.Scatter(
            x=sample['token_positions'],
            y=sample['rewards'],
            mode='lines+markers',
            name=f"sample_{i}",
            hovertext=hover_text,
            hoverinfo="text"
        ))

        if n_samples is not None and i >= n_samples:
            break

    # Update layout with figure size and hover label styling
    fig.update_layout(
        title=title,
        xaxis_title="Token Position",
        yaxis_title="Reward",
        hovermode="x unified",
        font=dict(size=10),
        margin=dict(l=0, r=0, t=50, b=0),
        width=width,
        height=height,
        hoverlabel=dict(
            bgcolor="white",
            font_size=9,
            font_family="Arial",
            align="left"
        )
    )

    output_dir = os.path.join(".", "plots")
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        
    path = os.path.join(output_dir, save_name)
    fig.write_html(path, full_html=True)
    

# Making the Plots

In [11]:
name_to_title_name = {
    "sft_peft": "Warmed Up Pause Model",
    "baseline-(model-w-out-pause-peft)-2-epoch": "Warmed Up Baseline Model (No Pause)",
    "offline_star_exp-no_pause_peft_temp_1.0_part2": "Offline STaR on Pause Model",
    "offline_star_exp-pause_temp_1.0_part2": "Offline STaR on Baseline Model (No Pause)"
}

In [12]:
n_samples = 10
for name in only_correct_rewards_train.keys():
    title_name = name_to_title_name[name]
    plot_rewards(only_correct_rewards_train[name], title = f"Token Position vs Reward On Train samples of {title_name} for Correct Samples", save_name=f"train_correct_samples_{title_name}_{n_samples}.html" ,n_samples=n_samples)

In [13]:
n_samples = 10
for name in only_correct_rewards_test.keys():
    title_name = name_to_title_name[name]
    plot_rewards(only_correct_rewards_test[name], title = f"Token Position vs Reward On Test samples of {title_name} for Correct Samples", save_name=f"test_correct_samples_{title_name}_{n_samples}.html" ,n_samples=n_samples)

In [14]:
n_samples = 10
for name in only_incorrect_rewards_train.keys():
    plot_rewards(only_incorrect_rewards_train[name], title = f"Token Position vs Reward On Train samples of {name} for Incorrect Samples", save_name=f"train_inccorrect_samples_{title_name}_{n_samples}.html" ,n_samples=n_samples)

In [15]:
n_samples = 10
for name in only_incorrect_rewards_test.keys():
    title_name = name_to_title_name[name]
    plot_rewards(only_incorrect_rewards_test[name], title = f"Token Position vs Reward On Test samples of {title_name} for Incorrect Samples", save_name=f"test_inccorrect_samples_{title_name}_{n_samples}.html" ,n_samples=n_samples)

In [16]:
only_incorrect_rewards_test["sft_peft"][0].keys()

dict_keys(['question', 'pred_output', 'ground_truth', 'is_correct', 'rewards', 'token_positions', 'added_tokens'])

In [17]:
index = 4
print("Question")
print(only_correct_rewards_test["baseline-(model-w-out-pause-peft)-2-epoch"][index]["question"])
print("pred_output")
print(only_correct_rewards_test["baseline-(model-w-out-pause-peft)-2-epoch"][index]["pred_output"])
print("ground_truth")
print(only_correct_rewards_test["baseline-(model-w-out-pause-peft)-2-epoch"][index]["ground_truth"])


Question
Eliza's rate per hour for the first 40 hours she works each week is $10. She also receives an overtime pay of 1.2 times her regular hourly rate. If Eliza worked for 45 hours this week, how much are her earnings for this week?
pred_output
Eliza's regular hourly rate is $10/hour.
Her overtime hourly rate is $10/hour x 1.2 = $<<10*1.2=12>>12/hour.
Eliza worked for 45 hours this week.
Her regular earnings for this week is $10/hour x 40 hours = $<<10*40=400>>400.
Her overtime earnings for this week is $12/hour x 5 hours = $<<12*5=60>>60.
Her total earnings for this week is $400 + $60 = $<<400+60=460>>460.
#### 460</s>
ground_truth
Eliza is entitled to 45 -40 = <<45-40=5>>5 hours overtime pay.
Her hourly rate for the overtime pay is $10 x 1.2 = $<<10*1.2=12>>12.
So, Eliza will receive $12 x 5 =$<<12*5=60>>60 for overtime pay.
Her regular weekly earning is $10 x 40 = $<<10*40=400>>400.
Thus, Eliza will receive a total of $400 + $60 = $<<400+60=460>>460 for this week's work.
#### 460<

# Plots With only rewards after "\n" in the CoT


In [18]:
only_correct_rewards_train = filter_rewards_on_token(only_correct_rewards_train)
only_correct_rewards_test = filter_rewards_on_token(only_correct_rewards_test)
only_incorrect_rewards_train = filter_rewards_on_token(only_incorrect_rewards_train)
only_incorrect_rewards_test = filter_rewards_on_token(only_incorrect_rewards_test)

n_samples = 10
for name in only_correct_rewards_train.keys():
    title_name = name_to_title_name[name]
    plot_rewards(
        only_correct_rewards_train[name],
        title = f"Token Position vs Reward On Train samples of {title_name} for Correct Samples",
        save_name=f"only_line_return_train_correct_samples_{title_name}_{n_samples}.html",
        n_samples=n_samples
    )
    
    n_samples = 10
for name in only_correct_rewards_test.keys():
    title_name = name_to_title_name[name]
    plot_rewards(
        only_correct_rewards_test[name],
        title = f"Token Position vs Reward On Test samples of {title_name} for Correct Samples",
        save_name=f"only_line_return_test_correct_samples_{title_name}_{n_samples}.html",
        n_samples=n_samples
    )
    
n_samples = 10
for name in only_incorrect_rewards_train.keys():
    plot_rewards(
        only_incorrect_rewards_train[name],
        title = f"Token Position vs Reward On Train samples of {name} for Incorrect Samples",
        save_name=f"only_line_return_train_inccorrect_samples_{title_name}_{n_samples}.html",
        n_samples=n_samples
    )

n_samples = 10
for name in only_incorrect_rewards_test.keys():
    title_name = name_to_title_name[name]
    plot_rewards(
        only_incorrect_rewards_test[name],
        title = f"Token Position vs Reward On Test samples of {title_name} for Incorrect Samples",
        save_name=f"only_line_return_test_inccorrect_samples_{title_name}_{n_samples}.html",
        n_samples=n_samples
    )